Effectively tackling the multiclass problem: Siamese Models
In a traditional classification problem, we typically train a neural network with an Image/text/ any vector as input and the network outputs a probability (usually a softmax) for each class. If we want to know whether the input class is A, B, or C, then we train the model on good quantities of A, B, and C data and our network outputs 3 probabilities, one for each class. Then, if we input data for A, the network is supposed to output a high probability for the class A and low probability on the classes B and C.
This traditional approach works well in many problems but struggles in the following scenarios:
- Imagine a classification problem with 100s or 1000 different classes. Using softmax activation in the last layer will be computationally expensive and will increase the training time. For every training example, softmax will carry out the calculation shown in the image below and since there are 1000 classes, it will be repeated 1000 times.
If we are training on 100,000 data points there will be 10,000,000 probability calculations (I'm aware of broadcasting in vector calculations, just expanding the calculations to convey the point)
- Suppose we want to add another class, re-training the network will be time-consuming
- For all the 1000 classes, we might not have an equal distribution of data, there might be cases when minority classes are ignored. It can also happen that for the 1000 classes we have only 5 to 50 training examples per class, in this case learning distinctions will be a problem for the model(lack of data)
One-Shot Learning Intuition
The Facenet paper of 2015 proposed an interesting solution for these types of problems. Instead of the traditional approach, we try to learn a similarity function i.e. degree of difference between 2 inputs. If the degree of difference between the inputs is less than a threshold then the inputs are classified as similar else different.
The network takes 2 inputs, both the inputs pass through the same network and we get embeddings in the end. The loss function used in one-shot learning is based on distance (euclidian/manhattan), we compare the output embeddings of both the inputs and if they are close to each other i.e. within a threshold then we mark them as similar else different. This type of network is also called the Siamese Network as the two inputs pass through the same network
Loss Functions for Siamese Network
To implement the Siamese network, we need a distance-based loss function. There are 2 widely used loss functions:
- Contrastive Loss
- Triplet Loss
Contrastive Loss Intuition:
For contrastive loss, the dataset will have input1, input2, and y. The value of y will be 0 if both the inputs belong to the same class and will be 1 otherwise.
If the 2 inputs belong to the same class then the left side of the equations activates and loss will be the squared Distance between the inputs, which we want to minimize.
If two inputs belong to different classes the right side of the equations activates, here the inputs should be significantly distant from each other i.e. different classes should be at least m distance apart. If that’s the case return 0 else returns a positive value which we work to minimize i.e. increase Dw.
Triplet Loss Intuition:
For Triplet loss, the dataset structure will be as described below:
- Input 1 - Base or Anchor data point
- Input 2 - Datapoint belonging to the same class as Anchor data point
- Input 3 - Datapoint belonging to a different class wrt Anchor data point
Instead of 2 inputs, our network will have 3 inputs. Using the 3 output embeddings based on the equation above, we try to reduce the distance between the anchor and positive and increase the distance between the anchor and negative such that the squared difference of both crosses the threshold alpha. Below is the illustration of the same:
I have created a notebook to get you started with Siamese Network, the notebook will help you in defining the network architecture. Hope this notebook proves to be a good starting point.
Github Link - https://github.com/amit-raj-repo/Siamese-Network
The concept covered above are explained in detail in the notes below. Hope this helps you in your next DL journey.