Why LLMs are obsessed with "Attention"?
It's been a long time that I have released a new edition of this newsletter. All I can say is that life comes in the way and things happen, and we lose focus and attention and keep slagging.
Attention is important because it keeps us focused and stay on course with our calling.
It is quite honestly one thing that I have been struggling with in recent months. But I remember my father saying to me when I was a child to write and remember important things. Other times he would ask me to keep repeating the same task over and over again to master it. Now I find that he was indirectly telling me to pay attention.
Attention is very important skill to learn something and do your best work. In one way you can say that it is an obsession over something.
You see, attention is quite mentally and cognitively expensive and very taxing. Attention requires time and a great deal of commitment and discipline to build. It is not something that is given but earned, and Large Language Models know about it.
Sequential information, as the name suggests, contains information in a linearly ordered manner, which is also time-dependent. Examples of such types of data include text, weather forecast data, time-series data, protein sequences, and material sequences.
Such sequential information has a unique characteristic that other data types don't: its inherent temporal or spatial order which allows it to encode patterns, dependencies, and context. This temporal behavior is crucial for understanding and predicting future behavior or states. It is this property that makes modeling them mathematically challenging.
Modeling time-dependent data is not easy. In September 2014, Cho et al. proposed the RNN Encoder-Decoder neural network architecture for sequence modeling. In this approach, the input is transformed into a fixed-size vector using a recurrent neural network and again transformed into output.
In the same year, in December, Sutskever et al. introduced a similar architecture to handle large sequences. They utilized LSTM networks to encode and decode sequences of variable length. In this model, the encoder processes the input sequence and produces a final hidden state. This hidden state contains combined information from the entire input sequence from the past. The decoder then uses this hidden state to generate the output sequence. You can say that the hidden state contains the information of the entire block of text.
The key difference between RNNs and LSTMs is RNNs can process the data one step at a time which makes them inherently sequential and slower as well, especially for long sequences. LSTMs, on the other hand, are designed to handle long-term dependencies. This means that they retain and use information from earlier time steps in a sequence when making predictions or generating output at later time steps.
But they are still not capable enough.
The innovation of LSTM brought a very significant approach as it opened the door to compress information and use it to generate likeable text. But it was still lacking the ability to process text with longer context.
In parallel, Bahdanau and his team understood the need for a longer contextual window for language models.
In 2014, they came up with a model that would allow it to focus on different parts of the input sequence rather than relying on a fixed-size context vector.
This was the attention mechanism.
Attention
Attention is all about looking at the most important information in a given text and storing them for latter use.
How does it work?
Attention helps a computer focus on the most important words in a sentence. Imagine reading a newspaper or an article and only paying attention to the important details that help you understand the context. The computer does the same thing, giving more importance to the keywords. This way, it can understand the story or information much more clearly and accurately.
Mathematically, it can be defined as a scoring function to compute a weight for each element in the input data. The score is calculated using the dot product of the input query (Q) and the input key (V) as shown in the image above. Here, the dot product aims to capture the similarity between Q and V.
Here is a general workflow,
Starting with calculating the dot-product attention where Q (queries) and K (keys) which are projections of the input sequence. This is attention.
#using linear layers to produce initialize the projection layer for Q,V, and K
qkv_proj = Linear(n_embd, 3 * n_embd)
#passing the input and creating Q, K, and V
q, k, v = self.qkv_proj(x).chunk(3, dim=-1)
#applying dot product multiplication
dot_product(q, k.t)
Next, we apply a softmax function to the scores to obtain a set of weights -- or values that measure the importance of each and every work.
#calculating softmax
a = softmax(att, dim=-1)
Lastly, we compute the context vector as a weighted sum of the values V.
领英推荐
a @ v
Attention became an integral part of processing sequential information but it was still coupled with Recurrent Models. While attention mechanisms showed promise for processing sequential data, their integration with recurrent neural networks limited the ability to exploit parallelization for faster computations and better scalability fully.
In 2016, Vaswani and his team came up with a solution to address the limitation of the recurrent networks. They decided to ditch the entire recurrent network and develop a new model called the "Transformers". This model depends only on the attention mechanism.
To make the attention more effective they increase the width of the mechanism by creating multiple copies. This in turn made it process the input parallelly.
Attention via the transformer architecture has now become a cornerstone of modern LLMs such as the GPT series, Gemini series, LLaMA series, Claude series, Mistral series, and many more. In these models, the attention mechanism (particularly self-attention) allows the model to weigh the importance of different words in a sentence relative to each other, enabling a deep understanding of context and semantics.
I created an attention mechanism and I mapped the sentence "Hi, how are you doing, hope everything is well?" in a heat map to find similarities between the input characters.
Here is the result.
The heatmap reveals how the model calculated the similarity and generated the attention weights between the input given.
In my example, I explain how attention is calculated at a character level. In the real world, LLMs like GPT-4 don't employ this method. They instead rely on the token level attention, where the words are broken into subwords and the similarity is calculated.
Tokens are the basic units or words that a model uses for processing text.
As you saw attention is quite impressive by the fact that it can compress such massive information and retrieve relevant information if trained properly. But there is a major drawback. It is computationally expensive.
Here are several reasons:
But, despite the computational expense, attention is used because:
So, why LLMs are obsessed with "Attention"?
Simply put, to have larger and better contextual information.
Final thoughts
I hope this edition of the newsletter was informative. I have covered the basics of the attention mechanism not dwelling on many details. But attention itself is a vast subject and as LLMs are growing powerful research on effective attention mechanisms also grows.
If you have come this far then please share your thoughts and the things that you have learned.
PS: As I was writing this article I felt like documenting the different types of attention mechanisms out there and benchmarking them based on the similarity performance.
Faith+Hope+Love
Systems Engineer | MS-900 Certified | Red Hat Enterprise Linux, Data Engineering on AWS and Microsoft Azure, Data Science specialized with Python, Containers and Google Cloud
8 个月Nilesh Barla very well written. Few questions for you.. In what ways might the attention mechanism in computers differ from human attention, and what are the potential implications of these differences? How does a computer system quantify the importance-of words in a sentence, and what metrics or algorithms are typically used to assign weightings to different words?