Learning to distill ML models
Ivan Isaev
ML tech-lead and senior engineer | Ex-Head of ML & DS | Ex-Head of Engineering | Kaggle Competitions Master
I’m investigating the topic of ML models distillation and learning to do that. These are my takeaways with the links to the code.
Approach was introduced first in the paper "Distilling the Knowledge in a Neural Network" by Geoffrey Hinton, Oriol Vinyals, Jeff Dean in Mar 2015.
Also there is a term “data distillation” but it relates to compressing (right way to say is creating synthetic data with much less sample number instead of original one) training data to less number of samples and isn't related to model distillation.?
Bellow three interesting cases with the code where knowledge distillation was used.
1. Distill the knowledge from a LGBM teacher to a neural network
First introductory example of model distillation is the distillation the knowledge from a LGBM teacher to a neural network in Santander Customer Transaction Prediction competition. In my opinion this is the good example to start from and understand approach. I created a Pytorch template for this notebook, but I haven't tested it yet (you can start from this if you want). I pushed it into my Github.
Approach is to use predictions of student_model twice. As I got it: the second one output in Experiment 2 (y_pred_student) mimicking the teacher model the best way and first one output (y_pred) tries to learn from ground-through labels as best as possible.
I think this notebook demonstrates the idea of distillation the best way: we want a student model mimicking the behaviour of the teacher model but with less effort compared to the teacher model.?
2. Knowledge distillation from trained LGBM to transformer model
This case of Distilling LGBM model to Transformer model in American Express - Default Prediction competition illustrates other approach for training. Authir uses knowledge distillation from trained LGBM before fine tuning with the train targets. Furthermore, both train and test data are used for knowledge distillation which helps the Transformer learn the test data distribution.
When using knowledge distillation, it is possible to train a deeper transformer successfully.
Training is done using 4 cycles of cosine learning schedule. In the first cold start cosine cycle, we pretrain (i.e. Knowldege Distillation) the Transformer using concatenated rows of both LGBM OOF preds and LGBM test preds and leave probabilities between 0 and 1 (i.e. soft labels). During the second cosine cycle, we use a warm start, reduce the learning rate and train with the hard (0 or 1) train targets. For the third and fourth cycle, we repeat cycles one and two.
领英推荐
You can try to reproduce gold medal solution?by pre-training it using instruction from this discussion with LGBM oofs, removing skip connection in transformer and adding two more transformer blocks in this public notebook of transformer.
3. Distillation of Qwen-72b and Llama-32b using Gemma-9b
Distill llama3 70b and qwen2 72b to the gemma2-9b.
Code is there and is really great.
Steps include:?
1. Post-Pretrain? of three models (llama3-70b, qwen2-72b and gemma2-9b) using UT data.
2. Fine-Tune Models for 5-Fold Results and distill to the 9b model with logits.
After obtaining the logits distribution, load the 9b model for fine-tuning and incorporate the distillation loss during the fine-tuning process.
Each fold includes training the Llama3 and Qwen2 models, predicting to obtain the probability distribution for the training set, and finally fine-tuning the Gemma-2b model.?
3. Merge LoRA and Quantize. Here, the LoRA layers of the 5-fold Gemma2-9b models are merged and then quantized to 8-bit using GPTQ.
Regarding distillation, the losses are as follows:
loss_fun = nn.CrossEntropyLoss()
divergence_loss_fn = nn.KLDivLoss(reduction='batchmean')
cos_loss_fn = nn.CosineEmbeddingLoss()
outputs = model(batch['input_ids'], use_cache=False) # predict gemma2
logits = outputs.logits
grads = batch['grads']
grads1 = batch['grads'][:, :3] # qwen2
grads2 = batch['grads'][:, 3:] # llama3
labels = batch['labels']
loss_ce = loss_fun(logits, labels)
loss_grad1 = divergence_loss_fn(
F.log_softmax(logits / T, dim=1),
F.softmax(grads1 / T, dim=1)
)
cos_loss1 = cos_loss_fn(F.softmax(grads1 / T, dim=1), F.softmax(logits / T, dim=1),
torch.ones(logits.size()[0]).to(logits.device))
loss_grad2 = divergence_loss_fn(
F.log_softmax(logits / T, dim=1),
F.softmax(grads2 / T, dim=1)
)
cos_loss2 = cos_loss_fn(F.softmax(grads2 / T, dim=1), F.softmax(logits / T, dim=1),
torch.ones(logits.size()[0]).to(logits.device))
loss = (loss_ce + loss_grad1 + cos_loss1 + loss_grad2 + cos_loss2) / 5