What is Transfer Learning?
Julian Kaljuvee
Data Science / AI / ML Engineering @PredictiveLabs │ Ex-quant (Goldman, JPMorgan, LSEG, UBS)│ Alternative Data and Bio AI
Transfer learning is a machine learning technique where a model developed for one task is reused as the starting point for a model on a second task. This approach leverages the knowledge gained while solving one problem and applies it to a different but related problem. Transfer learning is especially useful when the target dataset is smaller and doesn't have enough data to train a model from scratch effectively.
Key Concepts
Example in Python using Keras
Let's use transfer learning with the popular pre-trained model VGG16 for an image classification task. The example assumes you have a dataset of images that you want to classify into different categories.
领英推荐
Code - Transfer Learning for Image Classification
# Step 1: Import Libraries
import tensorflow as tf
from tensorflow.keras.applications import VGG16
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.optimizers import Adam
# Step 2: Load Pre-trained Model
base_model = VGG16(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
# Step 3: Freeze the Layers
for layer in base_model.layers:
layer.trainable = False
# Step 4: Add Custom Layers
x = base_model.output
x = Flatten()(x)
x = Dense(512, activation='relu')(x)
predictions = Dense(10, activation='softmax')(x) # Assuming 10 classes in the new dataset
model = Model(inputs=base_model.input, outputs=predictions)
# Step 5: Compile the Model
model.compile(optimizer=Adam(lr=0.0001), loss='categorical_crossentropy', metrics=['accuracy'])
# Step 6: Prepare Data - use ImageDataGenerator to load and preprocess the training and
# validation datasets.
train_datagen = ImageDataGenerator(rescale=1./255, shear_range=0.2, zoom_range=0.2, horizontal_flip=True)
test_datagen = ImageDataGenerator(rescale=1./255)
train_generator = train_datagen.flow_from_directory('data/train', target_size=(224, 224), batch_size=32, class_mode='categorical')
validation_generator = test_datagen.flow_from_directory('data/validation', target_size=(224, 224), batch_size=32, class_mode='categorical')
# Step 7: Train the model using the data generators.
model.fit(train_generator, epochs=10, validation_data=validation_generator)
Step 8: Evaluate the performance of the model on the validation dataset.
loss, accuracy = model.evaluate(validation_generator)
print(f'Validation accuracy: {accuracy * 100:.2f}%')
Summary
In this example, we used transfer learning to adapt the pre-trained VGG16 model to a new image classification task. We froze the layers of VGG16 to retain the learned features and added new dense layers for our specific classification task. This approach helps in achieving good performance even with a smaller dataset by leveraging the knowledge from a large pre-trained model.
Feel free to use this approach for any classification task that has a relatively small data set.
PhD student in Optics- EUV mask inspection
4 个月Hello. Can you help me with some guidance on what strategies can I use to achieve good performance while using a trasfer learning with fine tuning a pretrained network. I am working on a research project and not getting good result since I have small dataset. I appreciate any help.