Flower classifier with Convolutional Neural Networks using PyTorch
Boris Kushnarev
Business Analyst at Optus | Data Analytics Consultant at The Data School | Neo4j Certified Professional | Tableau Data Analyst Certification | Alteryx Advanced Certified | [email protected]
Nowadays, Convolutional Neural Networks play an importent role in solving different questions in computer vision, including picture classifier. Picture recognition is used in such areas as health care, for example, skin cancern is detected with higher accuracy than doctores can do, selfdriving cars systems identify cars, pedastrians and other obsticals on the road, people rocognition became common among Facebook users. It was a significant improvment since AlexNet model was introduced in 2012. Then, other model were released, for instance VGG and Densenet which solve complex problems today. These and some other models are implemented in PyTorch which is an open source deep learning platform that provides a seemless path from research prototyping to production deployment and developed by Facebook (https://pytorch.org/).
Flower classifier project was a part of PyTorch challange in Udacity sponsored by Facebook and I am proud of being acccepted to this program. Two months of hard work and I was able to implement my gained knowledge and complete this project.
Flower samples from the database:
Introduction
Fully Connected Neural Networks are capable of solving picture classification problems; however, it has limitations related to picture representation in a vector form where we are loosing information about each pixel location as well as fully connected layers become lead to extremely big matrices. Hence, huge power resources are needed to solve such problems. The Fully Connected Neural Networks work well on recognizing 28 x 28 black and white hand written labeled digits from MNIST database or 28 x 28 lableled fashion images from the Fasion MNIST database. However, solving problems when the color image can be located anywhere on the picture are best performed with the convolutional neural networks and we will consider it in this article.
The developed image classifier can distinguish different species of flowers from 102 categories from this dataset. It uses trained Convolutional Neural Networks provided in PyTorch.
Model
To solve this proble three differen pretrained CNN architectures were used such as densent121, vgg19 and alexnet and then compared between each other. Fully connected neural network layers were attached to each trained CNN with 102 outputs. After that, only the fully connected part of our neural network architecure was trained with weights updating while the weights from the CNN parts were freezen.
Implementation
To implement this model a few steps are needed:
- Load the database. In addition, we can augment our dataset using different types of transformation such as random rotation, randon crop and random flips, for exanple. Initially, the dataset contains two sets training and testin for each flower species. However, to avoid bias the training dataset was additionally randomly splitted on to two parts: 20% of validation dataset and 80% remaining trainging dataset.
- Load one of the pretrained model mentioned above with freezen parameters to avoid backpropagation on the pretrained CNN. Then, the classifier, a fully connected neural network was added using class Classifier(nn.Module).
- Specify Loss function, for instance, CrossEntropyLoss and Optimizer - Adam.
- Train and validate the model
- Test the model
A fully implemented model can be found following this link: https://github.com/taglitis/flower_classifier
Overfitting
One of the problems which may occur with the Neural Network is overfitting and we try to avoid them using different technics such as early stopping when we try to idemtify a point when validation loss stop decline and starts rising and dropout when we randomly turn of some of the nodes.
Summary
Three models were used for flower classification: alexnet, densenet121 and vgg19 and compared between each other. The evolution of the train loss and the validation loss can be seen below on the graph:
It is easy to see that the train and validation losses are lower for densenet121 than for vgg19 and alexnet.
On the plot below we can see the accuracy evolution for our 3 models and how it increases as epoch increases. Again, we can densenet121 performs better than two other models .
The last saved best checkpoints for each model are:
- densenet121 - epoch 125 with validation accuracy 90.152%
- vgg19 - epoch 65 with validation accuracy 83.864%
- alexnet - epoch 89 with validation accuracy 74.091%
As we can see, the best validation accuracy was reached for densenet121 - 90.152%.
It is interesting to note that during testing the testing accuracy for each model usually are 2-3% higher compare to validation accuracy. It can be explan as the validation dataset is 20% of pictures from the training dataset which was randomly rotated, croped and resized. However, for testing the only one transformation was performed which is Center Crop.
Below is a screenshort for testing results for densenet121. You can see that the overall accuracy for densenet121 model for epoch 125 reaches 94.25% while during validation this number is only 90.152%
Model improvment
Even thoght, densenet121 model shows a good performance we can still try to adjust hyperparameters such as learning rate or introduce momentum to the model. One can try to use different normalization parameters such as mean and stdandard deviation. Moreover, the classifier architecture can be changed as well.
Thank you Udacity and Facebook for such opportunity!
#datascience #dataanlyst #machinelearning #deeplearning