QKV and Multi-head Attention in LLM
In the realm of Natural Language Processing (NLP), Large Language Models (LLMs) like GPT-3 and BERT have revolutionized how machines understand and generate human language. At the heart of these models lies the concept that QKV and Multi-Head Attention are the key. It sounds cryptic to me at the very beginning and takes me a few weeks to figure it out.
The following is what has been explained in the paper.
Query (Q):
Key (K):
Value (V):
Well, it is still a bit hard to digest. Let’s look at an example below.
A simple sentence like “Tom is going to fish at the river bank ” is easy for us to understand. To let computers understand it, we need to encode every word into numbers, which is called Word embedding. Assuming in a simple six-dimensional space, the word “River” can be represented as a word embedding of [-0.9, 0.9,-0.2, 0.4, 0.2, 0.6]. Those words with a higher “similarity” will be close to each other. For example, Group 1) River, Fish, and Fishman. Group 2) Hospital, PostOffice, and Restaurant. It becomes interesting when we try to figure out where to put the word “Bank”. It is a polysemy that can be interpreted differently based on the context of the sentence in which it is. Should it be closer to Group 1 or Group 2?
Now, let’s look at the sentence again,
Tom is going to fish at the river bank.
When we read it, we know “bank” can not be the place where you can draw money. Why? Well, the presence of the words “River” and “Fish” contribute more to our understanding of the context, compared to the rest. Therefore, they should have high attention scores and be closer to the “bank”.
领英推荐
How does a computer determine that it should pay more attention to “River” and “Fish” and not the others? This is where the Q (Query) and K (Key) come in. They are two linear transformations that help answer the question: within this sentence, what are the similarity scores among the words?
Firstly, the input of both is the same input embeddings( let’s put the positional embedding aside first), assuming 6 dimensions, illustrated below.
Applying the linear transformation of K and Q to the input embedding,
The output goes through the steps of MatMul, Scale, Mask, and SoftMax to get the attention weights, and MatMul with V. We then have the final output, a weighted sum of the values, where the weights are determined by how well each key matches the query. So, the new embedding, compared to the original one, captures more contextual relationships.
For example, the word “bank” has the highest attention score with “bank”, “river” and “fish”. So the model will focus more on these input words.
Why do we have to go through this complicated QKV transformation?
If we are asked to describe what is in a picture, rather than scanning from the top left corner, pixel by pixel, our brain will immediately focus on the most prominent elements, like a boy in the scene. This process is highly efficient and effective, demonstrating the power of attention.
If you consider QKV as one set of linear projections, representing a so-called attention head, then multi-headed attention is simply having multiple sets of QKV and concatenating the outputs. The benefit of having multi-heads is to allow us to find different aspects of similarity. For instance, one head can focus on the nearby nouns, while another might look at the verb-objective relation. Back to the picture above, one “head” might detect the boy, and another sees the ball.
This is an intuitive explanation of QKV and multi-head attention. If you want to know the mathematical part of it, the original paper “Attention Is All You Need” is a good place to start. Have fun!
AI Enthusiast | Double Master's | Deep Learning | LLMs | Computer Vision | NLP | Deployment | Gen AI | Researcher
7 个月So concise and clear