Implementing L2SP Regularization in an Open Source Project
L2-SP Regularization Summarized; Created by: Dinesh Palli

Implementing L2SP Regularization in an Open Source Project

I've been spending my past 2 weekends working on an open-source project at Charité, and day before yesterday was a milestone—the branch I was contributing to was merged into the main codebase! This marked the first time I directly adjusted the loss function in a student-teacher model.

While I can't disclose all the details since the project isn't publicly available yet, the main goal I understood was to implement elastic net regularization. Initially, I had confused l2sp regularization with elastic net regularization, so let me clarify the difference.

Elastic net combines two types of regression losses: L1 (lasso) and L2 (ridge). It strikes a balance between them, controlled by a mix ratio (alpha) between 0 and 1. When alpha = 1, it's pure L1 loss, and when alpha = 0, it's pure L2 loss. Elastic net helps reduce overfitting (through L2/ridge) and enables feature selection by driving uncorrelated features to zero (through L1/lasso).

It also takes an input parameter for the reduction method (none, mean, or sum). The forward pass calculates alpha l1 + (1 - alpha) l2. You can find the Python implementation code here.

However, this wasn't exactly what was required for our project. Instead, we needed the "l2sp loss" (L2 regularization starting point).

Traditional L2 regularization adds a penalty term to the loss function, penalizing the sum of squared weights. In contrast, L2SP is commonly used in transfer learning tasks. It penalizes the model if the weights deviate too much from their initial values (the pre-trained model's weights) during fine-tuning, rather than penalizing deviation from zero. This was clearly explained in this medium article by Arjun Deshmukh , which helped me understand it better. Further reading in these articles 1, 2.

To implement this, I wrote a class with a function that loops over each parameter, calculates the squared difference between the current weight and the initial weight, and accumulates it in an l2_sp_loss variable. After the loop, it multiplies l2_sp_loss by the l2sp_alpha regularization coefficient, which determines the strength of the L2SP regularization. Finally, it returns the computed L2SP loss as a PyTorch tensor. You can find the code here. Currently, the L2SP loss is applied on student model, in every training step. The formula is given by:

Here, w? represents the parameter vector of the model pretrained on the source problem, serving as the starting point for fine-tuning. Using this initial vector as the reference in the L2 penalty, the formula ensures minimal deviation from the pretrained values.
Here, w? represents the parameter vector of the model pretrained on the source problem, serving as the starting point for fine-tuning. Using this initial vector as the reference in the L2 penalty, the formula ensures minimal deviation from the pretrained values.

The improvement in the model's performance was evaluated with F1 score [In machine learning, the F1 score is a metric that combines precision (the fraction of positive predictions that are correct) and recall (the fraction of actual positives that were correctly identified) into a single value. It is the harmonic mean of precision and recall, providing a balanced measure of a model's performance in classification tasks. A high F1 score indicates that the model has a good balance between correctly identifying positive instances while minimizing false positives].

Lastly, I wrote tests for the l2sp loss function using pytest. These tests cover various scenarios, such as verifying that the L2SP loss is positive when model parameters deviate from initial weights, ensuring it returns zero when alpha is zero (no regularization), checking that it applies the penalty correctly with a non-zero alpha, and ensuring it doesn't modify the model parameters during computation. The test file is available here.

The next task is to write a design doc and a data loader for this dataset - "Weakly Supervised Cell Segmentation in Multi-modality High-Resolution Microscopy Images".

It's been a rewarding experience contributing to this open-source project and diving deep into loss function implementations. Additionally, I contributed the implementation of multiple augmentation of the input images, and dataloader for LIVECell dataset etc., I'm excited to contribute more and see the project's public release and continue learning!

Krishna Sai Vootla

Sr. Data Scientist @ The Weather Company | ML, NLP, Tableau

10 个月

Great article Dinesh Palli ??

要查看或添加评论,请登录

社区洞察

其他会员也浏览了