Multi-task learning(MTL) with Multi-Layer Perceptron (MLP) and Deep Learning Techniques
Image by Peace, Love, Happiness from Pixabay

Multi-task learning(MTL) with Multi-Layer Perceptron (MLP) and Deep Learning Techniques

Multi-task learning is a Method in Machine Learning where Multiple related tasks are learned simultaneously, leveraging shared information among them to improve performance. Instead of training a separate model for each task, MTL trains a single model to handle multiple tasks. We are making the model to learn different tasks at the same network. By giving one record/ vector (Independent Variables) we get multiple outputs (Targets or dependent Variables).

We'll start by exploring concepts behind MTL, its benefits, and its drawbacks. Then, we'll look at how it works, using architecture, code, and visual workflow examples. I took a Kaggle dataset to help illustrate the concept, we will explore all this concept with this example. By the end of this article, you'll have a solid understanding of MTL and how it can be applied in real-world scenarios.

Concepts Behind Multi-Task Learning (MTL):

In MTL, some layers or parameters are shared across tasks, allowing the model to learn common features that benefit all tasks. The model is trained on different tasks simultaneously, and the parameters are updated based on the combined loss from all tasks.

In addition to shared layers, MTL models typically have task-specific layers that handle the unique aspects of each task. The final output layer of the model provides the desired output for each task.

So, what are the advantages and disadvantages of MTL?

On the plus side, MTL can improve the performance of individual tasks when they are related. It can also act as a regularizer, preventing the model from overfitting on a single task. Additionally, MTL can be seen as a form of transfer learning.

However, there are some drawbacks to consider. For example, conflicting gradients from different tasks can affect the learning process, making it challenging to balance the learning across tasks. Furthermore, as the number of tasks increases, the complexity and computational cost of MTL can grow significantly.


Architecture, Code, and Visual Workflow

We are Going to explore Multi-Task Learning from a real-world use case, I take a Kaggle dataset (Heart Disease Dataset) to predict two targets. It has 12 independent variables or features like age, sex, chest pain type, and resting blood pressure. Two Target Variables (Dependent Variables) are thal (thalassemia) and heart disease.

Two Tasks:

Task 1: Predicting Thalassemia (Multi-Class Classification) First task is to predict the type of thalassemia a patient has, if any. This is a multi-class classification problem, where we need to predict one of three outcomes: reversed thalassemia, fixed thalassemia, or normal (no thalassemia).

Task 2: Predicting Heart Disease (Binary Classification) Second task is to determine whether a patient has heart disease or not. This is a binary classification problem, where we need to predict one of two outcomes: yes or no.


Multi-Task Learning with MLP


Image 1: MTL Neural Network for Two Tasks. Image by the author


Let's take a closer look at the neural network architecture we're using for our Multi-Task Learning (MTL) tasks. As shown in Image 1, our model has two hidden layers that act as a shared representation, learning jointly for both tasks. Each task then has its own separate hidden layer. The output layers are determined by the target of each task, with one layer for binary classification (heart disease) and another for multi-class classification (thalassemia).


Multi-Task Learning Code

Now, let's take a look at the code that brings this architecture to life. The code snippet below replicates the architecture we saw in Image 1. If you're interested in exploring further, I've also included a reference to my Kaggle notebook where you can see the code in action.[Reference Section]. Please make sure to thoroughly review this code to gain a complete understanding.

class MultiTaskNet(nn.Module):
    def __init__(self):
        super(MultiTaskNet, self).__init__()
        
        # Two Shared Hidden Layer (Parameters in this layer learns general nature of the input and its relationship with the output)
        self.shared_fc1 = nn.Linear(12, 32) 
        self.shared_fc2 = nn.Linear(32, 64)
        
        self.thal_fc1 = nn.Linear(64, 32)
        self.thal_fc2 = nn.Linear(32, 3)  # 3 classes for thalassemia
        
        self.heart_fc1 = nn.Linear(64, 16)
        self.heart_fc2 = nn.Linear(16, 1)  # 1 output for heart disease
    
    def forward(self, x):
        x = F.relu(self.shared_fc1(x))
        x = F.relu(self.shared_fc2(x))
        
        thal_out = F.relu(self.thal_fc1(x))
        thal_out = self.thal_fc2(thal_out)  # Task 1: Predicting thalassemia
        
        heart_out = F.relu(self.heart_fc1(x))
        heart_out = torch.sigmoid(self.heart_fc2(heart_out)) # Task 2: Predicting heart disease
        
        return thal_out, heart_out

model = MultiTaskNet()

-----------------------------------------------------------------------------------------------

# Cost function
criterion_thal = nn.CrossEntropyLoss()  # Multi Class- Softmax activation
criterion_heart = nn.BCELoss()    # Binary Loss- Sigmoid activation

#Optimizers
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training Loop
num_epochs = 50

for epoch in range(num_epochs):
    model.train()
    running_loss_thal = 0.0
    running_loss_heart = 0.0
    
    for inputs, labels_thal, labels_heart in train_loader:
        optimizer.zero_grad()  # Making the optimizer has no slope (zero_grade)
        
        outputs_thal, outputs_heart = model(inputs)
        
        loss_thal = criterion_thal(outputs_thal, labels_thal)
        loss_heart = criterion_heart(outputs_heart.squeeze(), labels_heart)
        
        loss = loss_thal + loss_heart
        loss.backward() #Calculates the slope or gradients
        optimizer.step() # Updating gradients
        
        running_loss_thal += loss_thal.item()
        running_loss_heart += loss_heart.item()
        
    if epoch%10==0:
        print(f'Epoch {epoch+1}/{num_epochs}, Loss Thal: {running_loss_thal/len(train_loader)}, Loss Heart: {running_loss_heart/len(train_loader)}')        


Multi-Task Learning Visual Workflow

Before we dive into the workflow of our Multi-Task Learning project, I want to clarify an important point about how neural networks operate. You may have learned about Multi-Layer Perceptrons (MLPs) or Artificial Neural Networks (ANNs) from a neuron-centric perspective, where each neuron performs a series of operations on the input data, such as multiplying it by weights and adding bias.

However, in my articles, I've presented this operation in a different way - one that I believe is more accurate and intuitive. The truth is, a neuron's operation can be thought of as a simple matrix multiplication of the input vectors and weights, followed by the addition of a bias vector. This perspective can help simplify the complex workings of neural networks and make them easier to understand. If you are not familiar, It is okay, this article will give you an in-depth understanding of how really neural network works.

Important Note to Understand the Linear Layer (Matrix Multiplication): Weight_Matrix( Current Layer No.of Neurons, Previous Layer or Input Layer No.of Neurons). We also Initialize the Linear Layer with input size and output size (nn.linear (input vector size (12), output vector size (32))). For example, If we have 12 (Input vector size) input units in the MLP/FFN and 32 Neurons in the hidden layer, the Weight Matrix size will be W(32,12). This Matrix (Transposed) Multiplied with Input vectors, gives a new vector. Then the bias vector is added to the Linearly transformed vector. This is the Linear Transformation of vectors. Vectors transformed from one dimension (32) to other dimensions (12) or the same dimensions.

Image 2: Linear Layer operation. Source: pytorch.org

Also, I want to include one more evidence, why it is a Matrix multiplication. From Image 2 You can understand the Linear layer operation. Where x is the input vector (12 independent variables in our case), A is the Weight matrix and b is the bias vector. For simplicity, I didn't use the bias vector in this workflow.

Note: The numbers and calculation in the images are for illustration purposes only. It will help you to understand the workflow.


Image 3: 1st Hidden layer operation. Image by author

I have taken 32 as batch size and we have 12 independent variables as shown in Image 3, the input data A is multiplied by 1st Hidden layer weight Matrix W1 (Transposed W1) with the shape of (32, 12) resulting in 1st Hidden layer output (O1) with the shape of (32, 32). The 1st Hidden layer has 32 Neurons, Which means each vector will have 32 features now (From 12 to 32). The ReLU activation was also applied.

"If you Look Closely the neuron operation and this matrix multiplication is the same."


Image 4: 2nd Hidden Layer Output. Image by author

The 1st Hidden layer output O1 (32,32) is then multiplied by the 2nd hidden layer weight matrix W2(Transposed W2) with the shape of (64,32) resulting in 2nd Hidden layer output O2 with the shape of (32, 64) (ReLU activation applied). In 2nd Layer we have 64 neurons, so 64 features for all 32 vectors.

Till now we have seen a Shared hidden layer. Here we will see a task-specific hidden layer.

Image 5: 1st Task specific hidden layer.

We have seen in architecture (Image 1) and Code that the thalassemia prediction task-specific hidden layer has 32 neurons (Width), and the Heart Disease prediction task-specific hidden layer has 16 neurons. So the output O2 is Multiplied with Two weight Matrices here.

The O2 is multiplied by the thalassemia prediction task hidden layer weight matrix W31 (Transposed W31) with the shape of (32,64) resulting in Output O31, with the shape of (32,32). The same O2 is multiplied by the Heart Disease Prediction task hidden layer weight matrix W32 (Transposed W32) resulting in Output O32 as shown in Image 2.


First, we will look into the Heart Disease (Task 2) output layer. Here we have only one Neuron in the output layer. The Output O32 is Multiplied by the Heart disease task output layer weight matrix W42 (Transposed W42) with the shape of (1,16). This results in Output Logits with the shape of (32,1). We have taken 32 as batch size, For these 32 records we got the logits score in the final layer as shown in image 6.

Image 6: Heart Disease output layer. Image by author

The Output Logits for 32 records are applied with sigmoid activation [0,1], Which converts the Logits into probability scores. Then these probabilities are rounded or using threshold values it converted into outputs as shown in image 7.


Image 7: Heart Disease Output. Image by author

For example, the First 3 patients have Heart Disease and 4th patient doesn't have heart disease as shown in Image 7. (Illustration Purpose only). Deep Learning is simple Guys, But to understand this you have to know the basics like optimizers, activation functions, gradients, Cost functions, Backward propagation, etc. As you know we have just explored the working mechanism, not fundamental concepts. Ok, now let's look at the task 1 (multi-class classification).

The Output O31 is multiplied by Task 1 output layer weight matrix W41(Transpose W41) with the shape of (3,32), resulting in a Multi-Class Output logits Matrix with the shape of (32,3).

Image 8: Thalassemia output layer. image by author

I didn't show the logits matrix here. Softmax activation (Each row) is then applied to these Logits that give the probability score for 3 classes to each record. From these probability scores, we can get the output, Whichever class has a high probability score will be the predicted output.

From this, you can understand how the Cost function and Loss Calculation works. We have seen many weight matrices. They all are parameters, and they get updated during training.

"By Understanding these DeepLearning Concepts very well we can use it based on our use case and requirements." To illustrate this point, I want to give you One article link, Where I have explained how we can use Image features and Text features to achieve our use case.

Link: From Pixels to Words: How Model Understands?


I hope this article has helped you gain a better understanding of one of the fundamental learning methods in deep learning - Multi-Task Learning. If you found this article informative and useful, follow JAIGANESAN N , for more articles on deep learning and AI. Additionally, if you have some spare time, I'd love for you to check out my Medium articles, where I share more insights and knowledge on Advanced GenAI, Deep Learning, and AI Concepts. You can find my Medium profile in the comments below. Thank you for reading, and I look forward to sharing more knowledge with you in the future!


References:

  1. Multi-Task Learning Kaggle Implementation Notebook
  2. Heart Disease Dataset
  3. Neural Network/ Multi-Layer Perceptron (MLP) Working (My article)



要查看或添加评论,请登录

社区洞察

其他会员也浏览了