Knowledge Distillation in Deep Learning: Part 1
Himanshu S.
Engineering Leader, ML & LLM / Gen AI enthusiast, Senior Engineering Manager @ Cohesity | Ex-Veritas | Ex-DDN
Key to learning is not remembering but understanding. What is understanding? We will get to that question a little later. First, let me share some real-life experiences I've had myself and with my son.
Recently, I was reading a paper that offered a mathematical perspective on learning. In one of the equations, it included an orthogonality component in it. My first instinct was to understand what it meant mathematically and why it was there. Next, I considered it philosophically to understand why the math was capturing something, and whether it was apparent in nature. Then I wondered if it could be explained through analogy. Following that, I sought to understand it in geometric terms and even created a dynamic system in my head to explain how the process evolved.
The final step in my learning journey was to seamlessly connect ideas from a probabilistic nature of things to mathematics to physical world through analogies to its geometry and finally its flow and implications in learning. I was able to jump back and forth between these layers, linking ideas together.
Another example I can link here is my son learning to skate. He learnt some really bent basics. Despite having good speed, he kept falling around turns. His school teacher continued to teach but could not find the fault. So, I found another instructor to check the basics. We both agreed that he skated with a bent left ankle. He wasn't using his upper body much. The instructor explained balance, weight, body aerodynamics, center of mass and the position of legs & hands around turns. The orthogonal component beyond the technical was his desire to go faster and faster, even around turns.
He then learnt to maintain good leg position while using an upper body twist outwards around turns. With controlled speed, this gave him the balance needed to turn effectively without falling. In the end with this specific learning arrangement, he won the race. He finished with the fastest time and left all other participants far behind.
Both the stories above illustrate how unique perspectives derived from diverse observations enhance learning. Those different observations / perspectives can come from different teachers as well.
Having narrated these stories, let us go back to the question I asked when we began. What is understanding? This process, this ability to connect diverse perspectives, connecting multiple layers of insight is what I call understanding.
Question thus arises is, if this is true in human form of learning, would it work with Deep learning networks too? Especially considering what Hinton proposed in his 2015 paper and the recently infamous Deep Seek distillation of OpenAI models.
What is Knowledge Distillation in Learning?
Knowledge distillation (KD) is a technique in which a student model is trained to replicate the behaviour of one or more teacher models. It can be viewed as a learning process where a single teacher or multiple teachers guides a student, each offering a unique perspective. This diversity in guidance helps the student absorb refined knowledge without becoming confused by conflicting information.
But this opens up many questions. For example:
Question: What if the student does not have the capacity to absorb all the knowledge the teacher is willing to impart?
Question: Can we train the student on the complete set of output probability distributions generated by the teacher, allowing it to independently learn its own weights?
Question: Can we train the student to mimic the teacher's final attention weights?
Question: What happens if the teacher provides well-calibrated outputs? Would that lead to true learning, or would we need to increase the temperature parameter to obtain a wider range of probability distributions?
Question: Should the student learn from the teacher on a feature-wise basis?
Question: Who should generate the questions: the teacher or the student?
Question: What happens if multiple teachers provide conflicting signals? How can we ensure that the student benefits from this diversity of knowledge without becoming confused?
I will address all these questions in a series of articles on distillation, culminating in a working prototype of the process itself.
For this introductory article, I will focus on the core process of distillation.
Brief Reintroduction to Training in Deep Learning:
I assume the audience is already familiar with training procedures in transformers; if not, please refer to my earlier articles on the subject. In a typical transformer model, after processing data through all the layers (self-attention, feed-forward networks, layer normalization, etc.), the network outputs a probability distribution. This distribution is compared against the hard labels (e.g., true/false). The model computes the loss by comparing the predicted distribution to the hard labels, and then backpropagates this loss to update the model weights, thereby improving the probability distribution.
How Distillation Works in Practice:
Consider an image recognition network as a simple example. First, you train a teacher model using a set of images. Then, you use the same images to train a student model. Before training the student on a given image, you let the teacher predict its probability distribution. This predicted distribution serves as a soft label for the student, instead of a hard label (such as a one-hot vector like [1, 0]). The student model uses this soft label as its target probability distribution.
After a round of training, when the student model produces its final probability distribution for an image, it compares its predictions for each class with the teacher’s corresponding probabilities. In effect, each class can be thought of as having its own prediction that the student must match. The student then adjusts its weights during backpropagation so that its output distribution aligns more closely with that of the teacher.
However, this may leave the impression that something is missing. In conventional training, we also use hard labels (e.g., one-hot vectors representing true/false). If we know the true label of an image from the beginning, we should use that information as well. In distillation, both the teacher's soft labels and the hard labels are used to train the student model.
The introduction of hard labelling creates a conflict regarding which signal to prioritize: the hard labels or the teacher's soft labels. Therefore, the loss function must be designed to incorporate both inputs. A practical approach is to allow users to determine the relative weighting of each component.
Furthermore, we want to understand the rationale behind the teacher’s outputs. Simply knowing the teacher’s answer is not sufficient; understanding the reasoning process is crucial. To achieve this, we introduce a temperature variable that softens the teacher’s probability distribution, thereby providing richer contextual information. Consequently, the student model must apply the same temperature scaling to align its outputs with the teacher’s.
To illustrate the process, imagine that the scalar output logits are arranged on an N-dimensional plane, forming a vector. You then divide this vector by a temperature parameter and apply the softmax function, which yields a probability distribution represented as another vector in the same N-dimensional space. When you perform this operation for both the teacher and the student, the softmax transformation softens the differences between the logits. This softening can be interpreted as a reduction in the “angular” difference between the two probability vectors compared to the raw logits.
By measuring the difference between these two probability distributions using the Kullback–Leibler (KL) divergence, you quantify how much the student's output deviates from the teacher's. To compute the overall loss, you combine the KL divergence loss (derived from the softened outputs) with the regular cross-entropy loss (derived from the hard labels), and you adjust their relative contributions based on your design preferences.
Let’s take a look at the loss function now.
Loss Function for Distillation
In knowledge distillation, the total loss combines two components: a standard cross-entropy loss with respect to the hard labels, and a KL divergence term that measures how closely the student’s softened probability distribution matches the teacher’s softened distribution. The hyperparameter α determines the relative weight given to each component.
A key detail is the T^2 term in the KL divergence portion. Because the logits are divided by the temperature T before Softmax, the derivative introduces a factor of 1/T^2. Multiplying by T^2 compensates for this effect, ensuring that the gradient magnitudes remain at an appropriate scale during backpropagation.
Once you compute this combined loss, you train the student model via gradient descent (backpropagation), iteratively reducing both the cross-entropy term (aligning with hard labels) and the KL divergence term (aligning with the teacher’s distribution). Over multiple training iterations, the student’s predictions converge toward both the teacher’s outputs and the ground truth labels. This procedure can be applied to any number of images or queries.
This overview sets the stage for a deeper exploration of distillation. In future articles, we’ll expand on these principles, address the remaining challenges, and finalize a practical teacher-student distillation model.
Senior Engineering Manager at Cohesity
15 小时前Very informative
Looking for Data science lead and data science developer for Navi mumbai
18 小时前Awaiting ...