MNIST Handwritten Digits Classification Using a Convolutional Neural Network
Asad Kazmi
AI Educator ? Simplifying AI ? I Help You Win with AI ? AI won’t steal your job, but someone who masters it might. Master AI. Stay Unstoppable.
The MNIST handwritten digits classification problem involves recognizing digits (0–9) from grayscale images. The MNIST dataset is a benchmark dataset for image classification tasks, particularly useful for testing deep learning algorithms. It contains 60,000 training images and 10,000 test images, each a 28x28 pixel grayscale image.
The MNIST dataset serves as a "Hello World" example for machine learning and deep learning practitioners. Solving it helps in understanding fundamental concepts of computer vision, CNN architectures, and model evaluation.
For creating impactful machine learning/deep learning models for image classification, one of the most foundational steps is data preprocessing—transforming raw data into a format that a machine can learn from.
When I first encountered a computer vision (image classification) task it felt overwhelming. How do you ensure your model is ready to learn from thousands of images? How do you avoid overfitting, and how do you make sure your model generalizes well?
But I realized this process involves several key stages: data loading, preprocessing, augmentation. After data processing we design and train a deep learning model and evaluate its performance on unseen test data.
In this article, we'll walk through a detailed example using the MNIST dataset to showcase how each of these steps contributes to creating a robust deep learning model.
Through trial, error, and learning, I followed the necessary steps to prepare the MNIST dataset:
1. Data Loading and Preprocessing
The first step is to load the dataset and prepare it for use in a machine learning model. We start by loading the dataset using the mnist.load_data() function from Keras. This function automatically splits the dataset into training and testing sets.
(X_train, y_train), (X_test, y_test) = mnist.load_data()
Normalization
Raw pixel values in the MNIST dataset range from 0 to 255. To optimize training and make the model converge faster, we normalize these pixel values by dividing them by 255, converting the values to a range of [0, 1].
X_train = X_train.astype('float32') / 255
X_test = X_test.astype('float32') / 255
Reshaping and Adding Channel Dimension
Since the images are grayscale, we need to reshape the data to include a channel dimension. The model expects the input to be in the shape (28, 28, 1), where 28x28 represents the pixel dimensions, and 1 denotes the grayscale channel.
X_train = X_train.reshape(-1, 28, 28, 1)
X_test = X_test.reshape(-1, 28, 28, 1)
One-Hot Encoding
Next, the target labels (0-9) are converted into one-hot encoded vectors. For instance, the label 3 becomes [0, 0, 0, 1, 0, 0, 0, 0, 0, 0]. This encoding is necessary for multi-class classification with the softmax activation function in the final layer.
from tensorflow.keras.utils import to_categorical
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)
2. Splitting Data
With separate train and test with the MNIST dataset, we created a validation set from the training data. This allows us to monitor performance during training and help prevent overfitting.
from sklearn.model_selection import train_test_split
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.2, random_state=42)
3. Data Augmentation
To further improve the model's ability to generalize, we apply data augmentation. This technique artificially increases the size of the training dataset by applying random transformations such as rotations, shifts, and flips. Keras's ImageDataGenerator class makes this process straightforward.
# Data augmentation
datagen = ImageDataGenerator(
rotation_range=10,
zoom_range=0.1,
width_shift_range=0.1,
height_shift_range=0.1
)
datagen.fit(X_train) # Fit data generator
The datagen.fit(X_train) ensures that transformations are applied only to the training data, not to the validation or test datasets.
领英推荐
4. Model Architecture
With our data preprocessed, it’s time to design the neural network. We’ll use a Convolutional Neural Network (CNN), which is particularly effective for image data.
Convolutional Layers
Convolutional layers apply filters to extract local features from the image. In our model, we use two convolutional layers with ReLU activation to capture spatial patterns.
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout, BatchNormalization, Input
# Define the model
model = Sequential([
# Use Input layer to specify the input shape
Input(shape=(28, 28, 1)), # Specify the input shape here
Conv2D(32, (3, 3), activation='relu'),
BatchNormalization(),
MaxPooling2D(),
Conv2D(64, (3, 3), activation='relu'),
BatchNormalization(),
MaxPooling2D(),
Flatten(),
Dense(128, activation='relu'),
Dropout(0.2),
Dense(10, activation='softmax')
])
Key Layers Explained:
5. Compilation and Callbacks
After defining the model architecture, we need to compile it by specifying the loss function, optimizer, and evaluation metrics.
Loss Function
We use categorical_crossentropy for multi-class classification, as it measures the difference between the predicted probabilities and the true class labels.
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
Callbacks
Callbacks help improve training efficiency and prevent overfitting:
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, TensorBoard
callbacks = [
EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True),
ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=2, min_lr=1e-6),
TensorBoard(log_dir='./logs')
]
6. Model Training
With our model compiled and callbacks set, we train the model using the augmented data generator. We specify a batch size of 32 and train the model for up to 10 epochs, although early stopping will stop training if performance plateaus.
history = model.fit(datagen.flow(X_train, y_train, batch_size=32),
epochs=10,
validation_data=(X_val, y_val),
callbacks=callbacks)
7. Evaluation
Once training is complete, we evaluate the model on the test data to assess its final performance.
test_loss, test_acc = model.evaluate(X_test, y_test)
print(f"Test Accuracy: {test_acc:.4f}, Test Loss: {test_loss:.4f}")
The code successfully builds a CNN model that learns to classify handwritten digits using image processing techniques, data augmentation, and callbacks to achieve high accuracy and robustness.
Stay tuned for an in-depth look at the math behind neural networks. Don’t miss the next newsletter for more insights!