Visualizing Neural Network Predictions on MNIST Dataset Using Keras in Google Colab
Padam Tripathi (Learner)
AI Architect | Generative AI, LLM | NLP | Image Processing | Cloud Architect | Data Engineering (Hands-On)
We provide Handwritten images to recognize the Number or Letter written in the image.
This notebook will train a simple neural network on the MNIST dataset and include data visualization at the end.
Step 1: Set Up Google Colab
Step 2: Copy and Paste the Notebook Code
Copy and paste the following code into your Google Colab notebook cells.
# Install necessary libraries
!pip install tensorflow matplotlib
# Step 3: Import Libraries
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras.utils import to_categorical
# Step 4: Load and Prepare the Data
# Load the MNIST dataset
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# Normalize the input data
x_train = x_train / 255.0
x_test = x_test / 255.0
# One-hot encode the labels
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)
# Step 5: Build the Model
model = Sequential([
Flatten(input_shape=(28, 28)),
Dense(128, activation='relu'),
Dense(64, activation='relu'),
Dense(10, activation='softmax')
])
# Step 6: Compile the Model
model.compile(optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy'])
# Step 7: Train the Model
history = model.fit(x_train, y_train, epochs=10, validation_split=0.2)
# Step 8: Evaluate the Model
test_loss, test_acc = model.evaluate(x_test, y_test)
print(f'Test accuracy: {test_acc}')
# Step 9: Data Visualization
# Plot the training history
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='Accuracy')
plt.plot(history.history['val_accuracy'], label = 'Val Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend(loc='lower right')
plt.title('Model Accuracy')
plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='Loss')
plt.plot(history.history['val_loss'], label = 'Val Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend(loc='upper right')
plt.title('Model Loss')
plt.show()
# Visualize some predictions
predictions = model.predict(x_test)
plt.figure(figsize=(10, 10))
for i in range(25):
plt.subplot(5, 5, i + 1)
plt.xticks([])
plt.yticks([])
plt.grid(False)
plt.imshow(x_test[i], cmap=plt.cm.binary)
plt.xlabel(f"True: {np.argmax(y_test[i])}, Pred: {np.argmax(predictions[i])}")
plt.show()
Explanation of the Notebook
领英推荐
Running the Notebook
OUTPUT of this Notebook:
#keras #data #pytorch
#data #dataengineering #programming #coding #developer #datascience #dataengineer #dataanalyst #python #java #scala #sql #database #bigdata #datapipe #machinelearning #cloudcomputing #etl #api #devops #analytics #aws #azure #gcp #cloud #ai #ml #machinelearning #artificialinteligence #bigdata #dataisbeautiful #codeday #learncoding #programminglife #dataengineeringlife #datascientist #developerlife