What is Transfer Learning?
Image: MathWorks

What is Transfer Learning?

Transfer learning is a machine learning technique where a model developed for one task is reused as the starting point for a model on a second task. This approach leverages the knowledge gained while solving one problem and applies it to a different but related problem. Transfer learning is especially useful when the target dataset is smaller and doesn't have enough data to train a model from scratch effectively.

Key Concepts

  1. Pre-trained Models: These are models that have been previously trained on a large dataset, typically for a similar task. For example, models like VGG, ResNet, and Inception are pre-trained on ImageNet, a large dataset used for image classification.
  2. Fine-tuning: This involves taking a pre-trained model and adapting it to a new, specific task. Fine-tuning usually requires only minor adjustments to the original model, such as re-training some layers with the new data.
  3. Feature Extraction: Using the pre-trained model's layers as a fixed feature extractor. This involves freezing the layers of the pre-trained model and only training the final classification layer.

Example in Python using Keras

Let's use transfer learning with the popular pre-trained model VGG16 for an image classification task. The example assumes you have a dataset of images that you want to classify into different categories.

Code - Transfer Learning for Image Classification

# Step 1: Import Libraries
import tensorflow as tf
from tensorflow.keras.applications import VGG16
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.optimizers import Adam

# Step 2: Load Pre-trained Model
base_model = VGG16(weights='imagenet', include_top=False, input_shape=(224, 224, 3))

# Step 3: Freeze the Layers
for layer in base_model.layers:
    layer.trainable = False

# Step 4: Add Custom Layers
x = base_model.output
x = Flatten()(x)
x = Dense(512, activation='relu')(x)
predictions = Dense(10, activation='softmax')(x)  # Assuming 10 classes in the new dataset
model = Model(inputs=base_model.input, outputs=predictions)

# Step 5: Compile the Model
model.compile(optimizer=Adam(lr=0.0001), loss='categorical_crossentropy', metrics=['accuracy'])

# Step 6: Prepare Data - use ImageDataGenerator to load and preprocess the training and 
# validation datasets.
train_datagen = ImageDataGenerator(rescale=1./255, shear_range=0.2, zoom_range=0.2, horizontal_flip=True)
test_datagen = ImageDataGenerator(rescale=1./255)

train_generator = train_datagen.flow_from_directory('data/train', target_size=(224, 224), batch_size=32, class_mode='categorical')
validation_generator = test_datagen.flow_from_directory('data/validation', target_size=(224, 224), batch_size=32, class_mode='categorical')

# Step 7: Train the model using the data generators.
model.fit(train_generator, epochs=10, validation_data=validation_generator)

Step 8: Evaluate the performance of the model on the validation dataset.
loss, accuracy = model.evaluate(validation_generator)
print(f'Validation accuracy: {accuracy * 100:.2f}%')
        

Summary

In this example, we used transfer learning to adapt the pre-trained VGG16 model to a new image classification task. We froze the layers of VGG16 to retain the learned features and added new dense layers for our specific classification task. This approach helps in achieving good performance even with a smaller dataset by leveraging the knowledge from a large pre-trained model.

Feel free to use this approach for any classification task that has a relatively small data set.

Hala Mohammad

PhD student in Optics- EUV mask inspection

4 个月

Hello. Can you help me with some guidance on what strategies can I use to achieve good performance while using a trasfer learning with fine tuning a pretrained network. I am working on a research project and not getting good result since I have small dataset. I appreciate any help.

回复

要查看或添加评论,请登录

社区洞察

其他会员也浏览了