Transfer Learning for CIFAR-10 Classification Using ResNet50

Transfer Learning for CIFAR-10 Classification Using ResNet50

Abstract:

In this article, we implement transfer learning to classify images in the CIFAR-10 dataset using a pre-trained ResNet50 model. The goal is to achieve a validation accuracy of at least 87%. I resize the CIFAR-10 images to the input size expected by ResNet50 and freeze its pre-trained layers. This method efficiently trains the model to adapt to the new dataset, ultimately achieving high accuracy. The process and results demonstrate the effectiveness of transfer learning in image classification tasks.

Introduction:

The CIFAR-10 dataset is a widely used benchmark in computer vision, containing 60,000 32x32 color images across 10 classes. Training a deep neural network from scratch on CIFAR-10 requires significant computational resources and time. Transfer learning, which utilizes pre-trained models on large datasets, offers a solution by reducing training time and improving performance. I used ResNet50, a deep convolutional neural network pre-trained on ImageNet, to classify CIFAR-10 images effectively.

Materials and Methods:

To achieve the goal, I followed these steps:

1. Data Preprocessing: Normalize CIFAR-10 images and convert labels to one-hot encoding.

2. Model Selection: Choose ResNet50 from Keras Applications for its robust performance in image classification.

3. Model Architecture: Add a lambda layer to resize CIFAR-10 images to 224x224, the input size expected by ResNet50. Freeze the pre-trained layers to retain their learned features.

4. Custom Layers: Add dense layers to tailor the model for CIFAR-10 classification.

5. Training: Train the model

6. Evaluation and Saving: Evaluate the model's performance on the validation set and save the trained model as cifar10.h5.

Results:

After training, our model achieved a validation accuracy of 89.94%, surpassing the target accuracy of 87%. Utilizing transfer learning significantly reduced the training time compared to training a model from scratch. The results confirm the potential of transfer learning for efficiently solving image classification tasks.

Discussion:

The results demonstrate that transfer learning with ResNet50 is highly effective for CIFAR-10 classification. By resizing input images and freezing pre-trained layers, I leveraged ResNet50's robust feature extraction capabilities. Data augmentation further enhanced the model's generalization ability. Future work could explore fine-tuning the later layers of ResNet50 for potentially higher accuracy or applying this approach to other datasets.

Acknowledgments:

Thanks to the developers of Keras for providing accessible tools for deep learning research. Their comprehensive library and documentation were instrumental in this article.

Literature Cited

- He, K., Zhang, X., Ren, S., & Sun, J. (2016). Deep Residual Learning for Image Recognition. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR).

- Krizhevsky, A., Hinton, G. (2009). Learning Multiple Layers of Features from Tiny Images. Technical Report, University of Toronto.

- Chollet, F., & others. (2015). Keras. Retrieved from https://keras.io

Appendices:

#!/usr/bin/env python3
"""implement transfer learning using cifar10"""
import tensorflow.keras as K
import tensorflow as tf


def preprocess_data(X, Y):
    """Preprocess Data."""
    Y = tf.one_hot(Y, 10)
    return (K.applications.resnet50.preprocess_input(X),
            tf.reshape(Y, [Y.shape[0], 10]))


def create_model():
    """Create model from resnet50."""
    base_model = K.applications.ResNet50(
            weights='imagenet',
            include_top=True,
            pooling='max')

    old_input = K.Input((32, 32, 3))
    pre = K.layers.Lambda(lambda x: tf.image.resize(x, [224, 224]))(old_input)
    inputs = base_model(pre)
    outputs = K.layers.Dense(10, activation='softmax')(inputs)
    model = K.Model(old_input, outputs)

    return model, base_model

def train_and_save():
    (X_train, Y_train), (X_valid, Y_valid) = K.datasets.cifar10.load_data()

    X_train, Y_train = preprocess_data(X_train, Y_train)
    X_valid, Y_valid = preprocess_data(X_valid, Y_valid)

    model, base_model = create_model()
    model.compile(optimizer=K.optimizers.Adam(1e-5),
                  loss=K.losses.CategoricalCrossentropy(),
                  metrics=['accuracy'])
    model.summary()
    model.fit(x=X_train, y=Y_train, validation_data=(X_valid, Y_valid),
              batch_size=64, epochs=2)
    base_model.trainable = False
    model.compile(optimizer=K.optimizers.Adam(),
                  loss=K.losses.CategoricalCrossentropy(),
                  metrics=['accuracy'])
    model.summary()
    model.fit(x=X_train, y=Y_train, validation_data=(X_valid, Y_valid),
              batch_size=64, epochs=4)
    model.save('cifar10.h5')


if __name__ == "__main__":
    train_and_save()        


要查看或添加评论,请登录

Mahdi Bani的更多文章

  • Large Language Models: The Wizards Behind Your Text Generation Magic

    Large Language Models: The Wizards Behind Your Text Generation Magic

    Once upon a time, in the mysterious realm of machine learning, Large Language Models (LLMs) were the secret sauce of AI…

  • Journey of My Malware Classification Project

    Journey of My Malware Classification Project

    Introduction: Embarking on a journey to classify malware using deep learning has been both a challenging and rewarding…

    1 条评论
  • My Journey in Developing a Malware Classifier

    My Journey in Developing a Malware Classifier

    Embarking on the journey of developing a malware classifier was both a challenge and an opportunity for growth. In this…

  • Unlocking the Future: A Deep Dive into BTC Price Forecasting

    Unlocking the Future: A Deep Dive into BTC Price Forecasting

    Cryptocurrencies are more popular with years, especially Bitcoin , have captured the attention of investors worldwide…

  • The art of optimization

    The art of optimization

    Optimization is critical in machine learning because it helps to find the best set of model parameters, minimize the…

  • Activation Functions in Neural Networks

    Activation Functions in Neural Networks

    When someone decide to read more about how artificial intelligence work , the sentence "activation functions" will be…

  • Is everything an object in python ?

    Is everything an object in python ?

    Unlike the other language, Python is an OOP(object oriented programming) language and that mean it can organizes…

  • What happens when you type `ls -l *.c` in the shell ?

    What happens when you type `ls -l *.c` in the shell ?

    To begin with i'am expecting that you have a basic knowledge about shell scripting and linux command. You have to…

  • C static libraries

    C static libraries

    what is static libraries? In the C programming language, a static library is a compiled object file containing all…

社区洞察

其他会员也浏览了