Slicing Pre-Trained models in Keras. Part (II.B)
Mohamed Ibrahim
Doctoral student @ The University of Bonn | Uncertainty Quantification and Explainable Machine Learning for Crop Monitoring
Hi again everyone!
Today we will see a practical example regarding slicing a pre-trained model. We will focus on DenseNet169.
from?tensorflow.keras?import?backend?as?
import?keras
2. Choose the DenseNet169 model
data_augmentation?=?tf.keras.Sequential
??
????[???
????????tf.keras.layers.experimental.preprocessing.Resizing(256,256),
????????tf.keras.layers.experimental.preprocessing.Rescaling(1./255),??
????????tf.keras.layers.experimental.preprocessing.RandomFlip("horizontal_and_vertical",seed=123),
????????tf.keras.layers.experimental.preprocessing.RandomRotation(0.5,?seed=123),
????????#?tf.keras.layers.experimental.preprocessing.RandomZoom(0.2,?0.3,?seed=123)
????]
)
inputs?=?tf.keras.Input(shape=(480,?480,?3))
x?=?data_augmentation(inputs)
base_model?=?tf.keras.applications.DenseNet169(
??
????weights=?'imagenet',??#?Load?weights?pre-trained?on?ImageNet.
????#?input_shape=(480,?480,?3)
????input_shape=(256,256,3)
????,?include_top?=?False,?input_tensor?=?x)
(
So first of all, the data_augmentation layer is just a Keras layer to augment the data on the fly we don't need it in our main model to test our idea actually, but I put it as it was a part of my MSc's code during my work. We create an input layer with the size of the image (480,480,3) and then pass the input to the augmentation layer to ensure resizing to (256,256,3). The line that follows will be to load the DenseNet169, and be initialized by ImageNet weights while removing the original classification head.
3. Check the model's details
base_model.summary()
We need also to remember the main architectural details:
and before we continue we just need to understand the architecture and how to read it in Keras:
领英推荐
This is the first part of the network before passing the data to the first DenseBlock (DB) as which consists of six convolutional blocks. So we need to create our first slice from the main model's input - base_model() - to the last convolutional block at the first dense block.
sliced_1?=?tf.keras.Model(inputs?=?base_model.input,?outputs=base_model.get_layer('pool2_pool').output)
Here we are slicing the model and we crop the model to include only the input and the first DenseBlock and TransitionBlock. pool2_pool is the name of the layer at the end of the first transition block. This part of the model will not be trained and will be kept frozen by ImageNet weights. Let's continue:
sliced_2?=?tf.keras.Model(inputs?=?base_model.get_layer('pool2_pool').output,?outputs?=?base_model.get_layer('pool3_pool').output)
sliced_2.trainable?=?False
sliced_3?=?tf.keras.Model(inputs?=?base_model.get_layer('pool3_pool').output,?outputs?=?base_model.get_layer('pool4_pool').output)
sliced_3.trainable?=?False
sliced_4?=?tf.keras.Model(inputs?=?base_model.get_layer('pool4_pool').output,?outputs?=?base_model.get_layer('bn').output)
sliced_4.trainable?=?False
We are passing the output from each slice to the next slice till we reconfigure the same model again but here, we have more flexibility to play with our model and test a lot of ideas.
x?=?sliced_1(inputs)
x =?sliced_2(x)
x?=?sliced_3(x)
x?=?sliced_4(x)
sliced_model = tf.keras.Model(inputs = inputs, outputs = x)
sliced_model.summary()
Here we will check if we were able to have the same original model:
We get the same original model!
To recap, the main idea is to understand your architecture design and to slice parts of the network to test your ideas, the crucial thing, because maybe you may get errors is to understand the main computational graph to create proper slices. Finally, I hope you enjoy it and I am waiting to know anyone who tried this with different architectures such as Inception and ResNet!
Thank you a lot and for any questions, please don't hesitate??
Ph.D. Student | Applied AI Scientist | Reinforcement learning | Sequential decision-making | Optimization | Mathematical modelling
2 年So as far as I understood, you basically use the idea of slicing to freeze some layers with their weights in the pre-trained DNN architectures and you tailor these architectures by adding new ideas/layers that are to be trained within the same model, right?