Effectively tackling the multiclass problem: Siamese Models

Effectively tackling the multiclass problem: Siamese Models

In a traditional classification problem, we typically train a neural network with an Image/text/ any vector as input and the network outputs a probability (usually a softmax) for each class. If we want to know whether the input class is A, B, or C, then we train the model on good quantities of A, B, and C data and our network outputs 3 probabilities, one for each class. Then, if we input data for A, the network is supposed to output a high probability for the class A and low probability on the classes B and C.

This traditional approach works well in many problems but struggles in the following scenarios:

  • Imagine a classification problem with 100s or 1000 different classes. Using softmax activation in the last layer will be computationally expensive and will increase the training time. For every training example, softmax will carry out the calculation shown in the image below and since there are 1000 classes, it will be repeated 1000 times.
No alt text provided for this image

If we are training on 100,000 data points there will be 10,000,000 probability calculations (I'm aware of broadcasting in vector calculations, just expanding the calculations to convey the point)

  • Suppose we want to add another class, re-training the network will be time-consuming
  • For all the 1000 classes, we might not have an equal distribution of data, there might be cases when minority classes are ignored. It can also happen that for the 1000 classes we have only 5 to 50 training examples per class, in this case learning distinctions will be a problem for the model(lack of data)

One-Shot Learning Intuition

The Facenet paper of 2015 proposed an interesting solution for these types of problems. Instead of the traditional approach, we try to learn a similarity function i.e. degree of difference between 2 inputs. If the degree of difference between the inputs is less than a threshold then the inputs are classified as similar else different.

The network takes 2 inputs, both the inputs pass through the same network and we get embeddings in the end. The loss function used in one-shot learning is based on distance (euclidian/manhattan), we compare the output embeddings of both the inputs and if they are close to each other i.e. within a threshold then we mark them as similar else different. This type of network is also called the Siamese Network as the two inputs pass through the same network

No alt text provided for this image

Loss Functions for Siamese Network

To implement the Siamese network, we need a distance-based loss function. There are 2 widely used loss functions:

  • Contrastive Loss
  • Triplet Loss

Contrastive Loss Intuition:

No alt text provided for this image

For contrastive loss, the dataset will have input1, input2, and y. The value of y will be 0 if both the inputs belong to the same class and will be 1 otherwise.

If the 2 inputs belong to the same class then the left side of the equations activates and loss will be the squared Distance between the inputs, which we want to minimize.

If two inputs belong to different classes the right side of the equations activates, here the inputs should be significantly distant from each other i.e. different classes should be at least m distance apart. If that’s the case return 0 else returns a positive value which we work to minimize i.e. increase Dw.

Triplet Loss Intuition:

No alt text provided for this image

For Triplet loss, the dataset structure will be as described below:

  • Input 1 - Base or Anchor data point
  • Input 2 - Datapoint belonging to the same class as Anchor data point
  • Input 3 - Datapoint belonging to a different class wrt Anchor data point

Instead of 2 inputs, our network will have 3 inputs. Using the 3 output embeddings based on the equation above, we try to reduce the distance between the anchor and positive and increase the distance between the anchor and negative such that the squared difference of both crosses the threshold alpha. Below is the illustration of the same:

No alt text provided for this image


I have created a notebook to get you started with Siamese Network, the notebook will help you in defining the network architecture. Hope this notebook proves to be a good starting point.

Github Link - https://github.com/amit-raj-repo/Siamese-Network

The concept covered above are explained in detail in the notes below. Hope this helps you in your next DL journey.

No alt text provided for this image
No alt text provided for this image
No alt text provided for this image
No alt text provided for this image
No alt text provided for this image
No alt text provided for this image
No alt text provided for this image
No alt text provided for this image
No alt text provided for this image


要查看或添加评论,请登录

Amit Raj的更多文章

  • Entropy Loss: The Fundamental metric of Classification Algorithms

    Entropy Loss: The Fundamental metric of Classification Algorithms

    Every DL/ML algorithm works towards reducing/minimizing the loss, our model's accuracy and relevance depends on this…

    3 条评论
  • The Other Side of K-NN

    The Other Side of K-NN

    K-NN is one of the most popular and easy to implement supervised learning methods used not only for classification and…

  • Automated Outlier Detection: Resolving Outliers in a Flash

    Automated Outlier Detection: Resolving Outliers in a Flash

    When we think of creating an ML or DL model, the first thing we want to get our hands on is clean data which best…

  • Iterative Imputer: Hidden Gem of sklearn

    Iterative Imputer: Hidden Gem of sklearn

    Missing value treatment is one of the most important stages of data preprocessing. Even before we plan/think about the…

    1 条评论
  • Handling Data Imbalance

    Handling Data Imbalance

    Data Imbalance is a situation we data scientists have faced on a daily basis. It is a scenario where the distribution…

    4 条评论

社区洞察

其他会员也浏览了