Fine-tuning AlexNet using unbiased Gaza & Ukraine war images
Abeer Albashiti
Fueling Empathy through AI | Founder at Larimar | Emotion AI prime mover | Royal Academy of Engineering Fellow | Mentor | TechWomen Fellow | Mental wellbeing | People Management
A two-label classifier using AlexNet that is fine-tuned on Matlab is available below. I transferred the learnings of AlexNet to distinguish two categories of images; Gaza and Ukraine war. The goal is to use that for future predictions of new war images and identify to which class refers.
The steps for creating the image recognizer are as below:
Here is the Matlab code followed by the embedded function:
clear
% To get started you need two things:
% 1- Training images of the different object classes we have 110 images each
% 2- A pre-trained deep neural network (AlexNet)
% You can substitute these categories for any of your own based on what
% image data you have avaliable.
%% Load Training Images
% In order for imageDataStore to parse the folder names as category labels,
% you would have to store image categories in corresponding sub-folders.
% data is saved as 'TrainingData' in the same directory as
%the TransferLearningDetail.m
allImages = imageDatastore('TrainingData', 'IncludeSubfolders', true,...
'LabelSource', 'foldernames');
%% Split data into training and test sets
% 80% of the images are training data and the remaining 20% is testing
[trainingImages, testImages] = splitEachLabel(allImages, 0.8, 'randomize');
%% Load Pre-trained Network (AlexNet)
% AlexNet is a pre-trained network trained on 1000 object categories.
% AlexNet is avaliable as a support package on FileExchange. It's one of
% an upgrade of the CNN
alex = alexnet;
%% Review Network Architecture
% Check alexnet architecture to change any specific feature as per the need
layers = alex.Layers;
%% Modify Pre-trained Network
% AlexNet was trained to recognize 1000 classes, we need to modify it to
% recognize just 2 classes; Gaza, Ukraine
layers(23) = fullyConnectedLayer(2); % change this based on # of classes
layers(25) = classificationLayer; %linear classifier
%% Perform Transfer Learning
% For transfer learning we want to change the weights of
% the network ever so slightly. How much a network is
% changed during training is controlled by the learning rates.
% epochs = 20, batch size = 64, learning rate = 0.001.
opts = trainingOptions('sgdm', 'InitialLearnRate', 0.001,...
'MaxEpochs', 20, 'MiniBatchSize', 64);
%% Set custom read function
% One of the great things about imageDataStore it lets you specify a
% "custom" read function, in this case it is simply resizing the input
% images to 227x227 pixels which is what AlexNet expects. You can do this by
% specifying a function handle of a function with code to read and
% pre-process the image.
trainingImages.ReadFcn = @readFunctionTrain;
%% Train the Network
% This process usually takes about 5-20 minutes on a desktop GPU.
myNet = trainNetwork(trainingImages, layers, opts);
%% Test Network Performance
% Now let's test the performance of our new
%"image recognizer" on the test set.
testImages.ReadFcn = @readFunctionTrain;
predictedLabels = classify(myNet, testImages);
accuracy = mean(predictedLabels == testImages.Labels);
% Replace predictedLabels from the test set with their predictedClass
% (1=Gaza), (2=Ukraine):
NumericalPredictedLabels = grp2idx(predictedLabels);
%% Compare the similarties between Gaza, Ukraine and a real set
% 1- Use 44 real catastrophic images
% Now let's test the performance of our new "image recognizer"
% on the real set.
folderPath = '/Users/abeeralbashiti/Desktop/'file name';
% Adjust the file extension if needed.
imageFiles = dir(fullfile(folderPath, '*.jpg'));
sz = myNet.Layers(1).InputSize;
for i = 1:numel(imageFiles)
% Load the image
imagePath = fullfile(folderPath, imageFiles(i).name);
% Read the image to classify
image = imread(imagePath);
%Show the image:
imshow(image)
% Adjust size of the image
% Create a new 3-channel image with the grayscale image
% copied into each channel
image = image(1:sz(1),1:sz(2));
image = imresize(image, [227 227]);
% Create a new 3-channel image with the grayscale
% image copied into each channel
image = cat(3, image, image, image);
% Use the model to make predictions
predictions = predict(myNet, image);
% Process the predictions as needed (e.g., interpret class labels)
[maxScore, predictedClass] = max(predictions);
% Print and save the results
fprintf('Image: %s, Predicted Class: %d\n', imageFiles(i).name, predictedClass);
globalVariable(i,:) = [i ,maxScore, predictedClass, predictions];
save('TransferLearningResults.mat','globalVariable');
end
You can use a custom input for a single prediction as well:
% 2- Use a custom input function for new images.
folderPath = '/Users/abeeralbashiti/Desktop/warImage.jpg';
% Adjust the file extension if needed.
NewImageFiles = dir(fullfile(folderPath, '*.jpg'));
sz = myNet.Layers(1).InputSize;
% Load the image
NewImagePath = fullfile(folderPath, NewImageFiles.name);
% Read the image to classify
NewImage = imread(NewImagePath);
%Show the image:
imshow(NewImage)
% Adjust size of the image; create a new 3-channel image with
%the grayscale image copied into each channel
NewImage = NewImage(1:sz(1),1:sz(2));
NewImage = imresize(NewImage, [227 227]);
% Create a new 3-channel image with the grayscale image
% copied into each channel
NewImage = cat(3, NewImage, NewImage, NewImage);
% Use the model to make predictions
NewPrediction = predict(myNet, NewImage);
% Process the predictions as needed (e.g., interpret class labels)
[NewMaxScore, NewPredictedClass] = max(NewPrediction);
% Save the results with your guess of the new class:
NewPrediction = [NewMaxScore, NewPredictedClass, NewPrediction];
save('NewPredictionResults.mat','NewPrediction');
Calculating the confusion matrix:
%% Calculate ConfusionMatrix between the image recognaizer &
% the real selected images, you can use the same analogy with your new images
realPredictions = globalVariable(:,3);
C = confusionmat(single(NumericalPredictedLabels),realPredictions);
confusionchart(C,'Title','Tested Data vs Predicted Data')
% To Draw a pie chart use the pie function:
pie(C)
The embedded function for reading and resizing the images:
%% The supporting function:
function I = readFunctionTrain(filename)
% Resize the images to the size required by the network; alexnet
I = imread(filename);
I = imresize(I, [227 227]);
Here is the confusion matrix i.e., Class 1 = Gaza’s war, Class 2 = Ukraine’s war.
What does this matrix mean in words and as a pie chart?
The classifier accuracy is 86% and can be improved by using more unbiased data in fine-tuning. Only using these results we are talking about a genocide on all fronts that’s happening to the people of Gaza compared to all other catastrophic wars including the most recent one in Ukraine.
I urge you all to regenerate the code with your unbiased data to see how AI classifiers empathize with Gaza more than current governments!
Tip: A good reference is available for fine-tuning AlexNet in Matlab.