Multi-task Supervised and Unsupervised learning for generalization improvement (Code Included)
Ibrahim Sobh - PhD
?? Senior Expert of Artificial Intelligence, Valeo Group | LinkedIn Top Voice | Machine Learning | Deep Learning | Data Science | Computer Vision | NLP | Developer | Researcher | Lecturer
Intuitively, learning how to draw shapes (reconstruction) actually helps learning how to distinguish between them (classification).
Generalization is a central concept in machine learning: learning functions from a finite set of data, that can perform well on new data.
Understanding generalization performance is particularly critical for powerful function classes, such as neural networks. Neural networks have well-known overfitting issues. Common strategies to reduce overfitting including drop-out, early stopping and data augmentation.
Multi-task learning has been shown to improve generalization performance.
Supervised Auto-Encoder:
Supervised Auto-Encoder (SAE) is a neural network that predicts both inputs and outputs.
It was shown that, adding reconstruction loss never harms performance compared to the corresponding neural network model, and in some cases can significantly improve classification accuracy.
Representation learning:
To perform well in prediction, a common goal is representation learning, where the inputs are transformed into a new representation, for which it is easier to learn a simple predictor. Auto-encoders (AE) are usually used to extract a representation.
Conversely, solely training a representation according to the supervised tasks, like learning hidden layers in an neural network, is likely an under-constrained problem, and will find solutions that can well fit the data but that do not find underlying patterns in the data and do not generalize well.
The combination of the two losses has the promise to both balance extracting underlying structure, as well as providing accurate prediction performance.
In the work below, the Supervised auto-encoders (SAEs) are used as an approach to conduct unsupervised auxiliary tasks to improve generalization performance.
This article is inspired by the work in the paper: Supervised autoencoders: Improving generalization performance with unsupervised regularizers
Experiments
Two networks are trained and tested:
- NN: Neural Network for classification (one loss: classification) (left figure)
- SAE: for classification (two losses: classification and reconstruction) (right figure)
Both networks are train on the MNIST to classify hand written digits.
As shown in the figure below, NN overfits (There is a large gab between the train loss and the test loss. In other words, the network is not expected to perform well on new data).
On the other hand, the SAE as shown below, has a much better performance mainly because of the added reconstruction loss that enhanced the generalization.
Here is a one figure for both the SAE and the NN
As shown in the figure above, the SAE is not overfitting compared to the NN. Moreover, the SAE has lower loss over the test set.
SAE improves generalization performance over neural network (NN)
Code:
Imports
from keras.layers import Input, Dense from keras.models import Model from keras import utils from keras.datasets import mnist import numpy as np import matplotlib.pyplot as plt
Data
# Settings encoding_dim = 32 num_classes = 10 # Data (x_train, y_train), (x_test, y_test) = mnist.load_data() x_train = x_train[:2000] y_train = y_train[:2000] x_train = x_train.astype('float32') / 255. x_test = x_test.astype('float32') / 255. x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:]))) x_test = x_test.reshape((len(x_test), np.prod(x_test.shape[1:])))
SAE Network with two outputs: Classification and Reconstruction
# The SAE # Multi task model sae_input_img = Input(shape=(784,), name='input') # Encoder: input to Z encoded = Dense(256, activation='relu', name='encode_1')(sae_input_img) encoded = Dense(128, activation='relu', name='encode_2')(encoded) encoded = Dense(encoding_dim, activation='relu', name='z')(encoded) # Classification: Z to class predicted = Dense(num_classes, activation='softmax', name='class_output')(encoded) # Reconstruction Decoder: Z to input decoded = Dense(128, activation='relu', name='decode_1')(encoded) decoded = Dense(256, activation='relu', name='decode_2')(decoded) decoded = Dense(784, activation='sigmoid', name='reconst_output')(decoded) # Take input and give classification and reconstruction supervisedautoencoder = Model(inputs=[sae_input_img], outputs=[decoded, predicted]) supervisedautoencoder.compile(optimizer='SGD', loss={'class_output': 'categorical_crossentropy', 'reconst_output': 'binary_crossentropy'}, loss_weights={'class_output': 0.1, 'reconst_output': 1.0}, metrics=['acc']) supervisedautoencoder.summary(
NN Network with classification output only
# Single task classification model sc_input_img=Input(shape=(784,), name='input') # Encoder: input to Z encoded = Dense(256, activation='relu', name='nn_1')(sc_input_img) encoded = Dense(128, activation='relu', name='nn_2')(encoded) encoded = Dense(encoding_dim, activation='relu', name='z')(encoded) # Classification: Z (softmax) to class predicted = Dense(num_classes, activation='softmax', name='class_output')(encoded) # Take input and give classification and reconstruction supervisedclassifier = Model(sc_input_img, predicted) supervisedclassifier.compile(optimizer='SGD', loss='categorical_crossentropy', metrics=['acc']) supervisedclassifier.summary()
Train both networks: SAE and NN
# Multi-Task Train SAE_history = supervisedautoencoder.fit(x_train, {'reconst_output': x_train, 'class_output': y_train}, epochs=350, batch_size=32, shuffle=True, verbose=0, validation_data=(x_test, {'reconst_output': x_test, 'class_output': y_test})) # Single-Task Train SC_history = supervisedclassifier.fit(x_train, y_train, epochs=350, batch_size=32, shuffle=True, verbose=0, validation_data=(x_test, y_test))
Plot
plt.figure(figsize=(10, 7)) plt.plot(SAE_history.history['class_output_loss'], label='Train SAE') plt.plot(SAE_history.history['val_class_output_loss'], label='Test SAE') plt.plot(SC_history.history['loss'], label='Train NN') plt.plot(SC_history.history['val_loss'], label='Test NN') plt.title('Model loss') plt.ylabel('Loss') plt.xlabel('Epoch') plt.legend()
Auxiliary unsupervised reconstruction learning showed to improve the generalization and performance of other supervised tasks.
?Regards