Learning to distill ML models

Learning to distill ML models

I’m investigating the topic of ML models distillation and learning to do that. These are my takeaways with the links to the code.

Approach was introduced first in the paper "Distilling the Knowledge in a Neural Network" by Geoffrey Hinton, Oriol Vinyals, Jeff Dean in Mar 2015.

Also there is a term “data distillation” but it relates to compressing (right way to say is creating synthetic data with much less sample number instead of original one) training data to less number of samples and isn't related to model distillation.?

Bellow three interesting cases with the code where knowledge distillation was used.

1. Distill the knowledge from a LGBM teacher to a neural network

First introductory example of model distillation is the distillation the knowledge from a LGBM teacher to a neural network in Santander Customer Transaction Prediction competition. In my opinion this is the good example to start from and understand approach. I created a Pytorch template for this notebook, but I haven't tested it yet (you can start from this if you want). I pushed it into my Github.

Approach is to use predictions of student_model twice. As I got it: the second one output in Experiment 2 (y_pred_student) mimicking the teacher model the best way and first one output (y_pred) tries to learn from ground-through labels as best as possible.

I think this notebook demonstrates the idea of distillation the best way: we want a student model mimicking the behaviour of the teacher model but with less effort compared to the teacher model.?

2. Knowledge distillation from trained LGBM to transformer model

This case of Distilling LGBM model to Transformer model in American Express - Default Prediction competition illustrates other approach for training. Authir uses knowledge distillation from trained LGBM before fine tuning with the train targets. Furthermore, both train and test data are used for knowledge distillation which helps the Transformer learn the test data distribution.

When using knowledge distillation, it is possible to train a deeper transformer successfully.


Training is done using 4 cycles of cosine learning schedule. In the first cold start cosine cycle, we pretrain (i.e. Knowldege Distillation) the Transformer using concatenated rows of both LGBM OOF preds and LGBM test preds and leave probabilities between 0 and 1 (i.e. soft labels). During the second cosine cycle, we use a warm start, reduce the learning rate and train with the hard (0 or 1) train targets. For the third and fourth cycle, we repeat cycles one and two.

You can try to reproduce gold medal solution?by pre-training it using instruction from this discussion with LGBM oofs, removing skip connection in transformer and adding two more transformer blocks in this public notebook of transformer.

3. Distillation of Qwen-72b and Llama-32b using Gemma-9b

Amazing distillation of qwen72b and llama 32b using gemma-9b 1-st place gold medal in LMSYS - Chatbot Arena Human Preference Predictions competition!

Distill llama3 70b and qwen2 72b to the gemma2-9b.

Code is there and is really great.

Steps include:?

1. Post-Pretrain? of three models (llama3-70b, qwen2-72b and gemma2-9b) using UT data.

2. Fine-Tune Models for 5-Fold Results and distill to the 9b model with logits.

After obtaining the logits distribution, load the 9b model for fine-tuning and incorporate the distillation loss during the fine-tuning process.

Each fold includes training the Llama3 and Qwen2 models, predicting to obtain the probability distribution for the training set, and finally fine-tuning the Gemma-2b model.?

3. Merge LoRA and Quantize. Here, the LoRA layers of the 5-fold Gemma2-9b models are merged and then quantized to 8-bit using GPTQ.

Regarding distillation, the losses are as follows:

loss_fun = nn.CrossEntropyLoss()
divergence_loss_fn = nn.KLDivLoss(reduction='batchmean')
cos_loss_fn = nn.CosineEmbeddingLoss()
outputs = model(batch['input_ids'], use_cache=False) # predict gemma2
logits = outputs.logits
grads = batch['grads']
grads1 = batch['grads'][:, :3] # qwen2 
grads2 = batch['grads'][:, 3:] # llama3
labels = batch['labels']
loss_ce = loss_fun(logits, labels)
loss_grad1 = divergence_loss_fn(
    F.log_softmax(logits / T, dim=1),
    F.softmax(grads1 / T, dim=1)
)
cos_loss1 = cos_loss_fn(F.softmax(grads1 / T, dim=1), F.softmax(logits / T, dim=1),
                        torch.ones(logits.size()[0]).to(logits.device))
loss_grad2 = divergence_loss_fn(
    F.log_softmax(logits / T, dim=1),
    F.softmax(grads2 / T, dim=1)
)
cos_loss2 = cos_loss_fn(F.softmax(grads2 / T, dim=1), F.softmax(logits / T, dim=1),
                        torch.ones(logits.size()[0]).to(logits.device))
loss = (loss_ce + loss_grad1 + cos_loss1 + loss_grad2 + cos_loss2) / 5        


要查看或添加评论,请登录

Ivan Isaev的更多文章

  • Quatitative interview task: human approach vs AI approach

    Quatitative interview task: human approach vs AI approach

    It is interesting to comare human approach to solving tasks reqired knowleage of some theorems with how current AI…

  • Group-wise Precision Quantization with Test Time Adaptation (GPQT with TTA)

    Group-wise Precision Quantization with Test Time Adaptation (GPQT with TTA)

    What is Group-wise Precision Quantization with Test Time Adaptation (GPQT with TTA)? Group-wise Precision Quantization…

  • Pseudo Labeling

    Pseudo Labeling

    Pseudo Labeling (Lee 2013) assigns fake labels to unlabeled samples based on the maximum softmax probabilities…

  • Kaggle Santa 2024 and what do the puzzles have to do with it?

    Kaggle Santa 2024 and what do the puzzles have to do with it?

    Our team got 23-rd place in Santa 2024 with a silver medal. We were close to gold but not this time.

  • Qdrant and other vector DBs

    Qdrant and other vector DBs

    Issue with vector DB size There are plenty of vector DBs available including FAISS, OpenSearch, Milvous, Pinackle…

  • Chutes: did you try it?

    Chutes: did you try it?

    Hi there I found one thing and want to ask if you tried it. It named Chutes and could be found there https://chutes.

    3 条评论
  • InternVL2 test drive

    InternVL2 test drive

    Intern_vl2 Is one another vision language model I tried some time ago and I like it a lot. It is quite fast (10 times…

  • VITA multimodal LLM

    VITA multimodal LLM

    Lately, I've been working a lot with multimodal LLMs to generate video descriptions. This post is about the multimodal…

  • What are Diffusion Models?

    What are Diffusion Models?

    Diffusion models is one of the hottest topics now. This short post is just a reminder what is this and how they emerged…

  • 4 Neural Network Activation Functions you should keep in mind

    4 Neural Network Activation Functions you should keep in mind

    What is a Neural Network Activation Function (AF)? Why are deep neural networks hard to train? What is "rule of thumb"…

社区洞察

其他会员也浏览了