Why LLMs are obsessed with "Attention"?

Why LLMs are obsessed with "Attention"?

It's been a long time that I have released a new edition of this newsletter. All I can say is that life comes in the way and things happen, and we lose focus and attention and keep slagging.

Attention is important because it keeps us focused and stay on course with our calling.

It is quite honestly one thing that I have been struggling with in recent months. But I remember my father saying to me when I was a child to write and remember important things. Other times he would ask me to keep repeating the same task over and over again to master it. Now I find that he was indirectly telling me to pay attention.

Attention is very important skill to learn something and do your best work. In one way you can say that it is an obsession over something.

You see, attention is quite mentally and cognitively expensive and very taxing. Attention requires time and a great deal of commitment and discipline to build. It is not something that is given but earned, and Large Language Models know about it.


Sequential information, as the name suggests, contains information in a linearly ordered manner, which is also time-dependent. Examples of such types of data include text, weather forecast data, time-series data, protein sequences, and material sequences.

Such sequential information has a unique characteristic that other data types don't: its inherent temporal or spatial order which allows it to encode patterns, dependencies, and context. This temporal behavior is crucial for understanding and predicting future behavior or states. It is this property that makes modeling them mathematically challenging.

Modeling time-dependent data is not easy. In September 2014, Cho et al. proposed the RNN Encoder-Decoder neural network architecture for sequence modeling. In this approach, the input is transformed into a fixed-size vector using a recurrent neural network and again transformed into output.

A workflow of RNN |

In the same year, in December, Sutskever et al. introduced a similar architecture to handle large sequences. They utilized LSTM networks to encode and decode sequences of variable length. In this model, the encoder processes the input sequence and produces a final hidden state. This hidden state contains combined information from the entire input sequence from the past. The decoder then uses this hidden state to generate the output sequence. You can say that the hidden state contains the information of the entire block of text.

The key difference between RNNs and LSTMs is RNNs can process the data one step at a time which makes them inherently sequential and slower as well, especially for long sequences. LSTMs, on the other hand, are designed to handle long-term dependencies. This means that they retain and use information from earlier time steps in a sequence when making predictions or generating output at later time steps.

But they are still not capable enough.

The innovation of LSTM brought a very significant approach as it opened the door to compress information and use it to generate likeable text. But it was still lacking the ability to process text with longer context.

In parallel, Bahdanau and his team understood the need for a longer contextual window for language models.

In 2014, they came up with a model that would allow it to focus on different parts of the input sequence rather than relying on a fixed-size context vector.

This was the attention mechanism.

Attention

Attention is all about looking at the most important information in a given text and storing them for latter use.

How does it work?

Attention helps a computer focus on the most important words in a sentence. Imagine reading a newspaper or an article and only paying attention to the important details that help you understand the context. The computer does the same thing, giving more importance to the keywords. This way, it can understand the story or information much more clearly and accurately.

Illustration of how attention works processes the input during training.

Mathematically, it can be defined as a scoring function to compute a weight for each element in the input data. The score is calculated using the dot product of the input query (Q) and the input key (V) as shown in the image above. Here, the dot product aims to capture the similarity between Q and V.

Here is a general workflow,

Starting with calculating the dot-product attention where Q (queries) and K (keys) which are projections of the input sequence. This is attention.

#using linear layers to produce initialize the projection layer for Q,V, and K
qkv_proj = Linear(n_embd, 3 * n_embd)

#passing the input and creating Q, K, and V
q, k, v = self.qkv_proj(x).chunk(3, dim=-1)

#applying dot product multiplication
dot_product(q,  k.t)        

Next, we apply a softmax function to the scores to obtain a set of weights -- or values that measure the importance of each and every work.

#calculating softmax
a = softmax(att, dim=-1)        

Lastly, we compute the context vector as a weighted sum of the values V.

a @ v        

Attention became an integral part of processing sequential information but it was still coupled with Recurrent Models. While attention mechanisms showed promise for processing sequential data, their integration with recurrent neural networks limited the ability to exploit parallelization for faster computations and better scalability fully.

A simple flowchart of the attention mechanism. |

In 2016, Vaswani and his team came up with a solution to address the limitation of the recurrent networks. They decided to ditch the entire recurrent network and develop a new model called the "Transformers". This model depends only on the attention mechanism.

The code illustrates how the attention mechanism is implemented.

To make the attention more effective they increase the width of the mechanism by creating multiple copies. This in turn made it process the input parallelly.

A flowchart of the attention mechanism |
The code above captures the single attention layer with additional operations such as a Linear layer and a couple of normalization layers. This network is named the Transformer block.
The Transformer architecture is essentially increasing the width or creating more layers of Transformer block.

Attention via the transformer architecture has now become a cornerstone of modern LLMs such as the GPT series, Gemini series, LLaMA series, Claude series, Mistral series, and many more. In these models, the attention mechanism (particularly self-attention) allows the model to weigh the importance of different words in a sentence relative to each other, enabling a deep understanding of context and semantics.

I created an attention mechanism and I mapped the sentence "Hi, how are you doing, hope everything is well?" in a heat map to find similarities between the input characters.

Here is the result.

Heat map of the attention weights produced by the network.

The heatmap reveals how the model calculated the similarity and generated the attention weights between the input given.

In my example, I explain how attention is calculated at a character level. In the real world, LLMs like GPT-4 don't employ this method. They instead rely on the token level attention, where the words are broken into subwords and the similarity is calculated.

Tokens are the basic units or words that a model uses for processing text.
Visualization of subtokens.
Heap map illustration of similarity score via attention weights

As you saw attention is quite impressive by the fact that it can compress such massive information and retrieve relevant information if trained properly. But there is a major drawback. It is computationally expensive.

Here are several reasons:

  • Quadratic Complexity: The standard attention mechanism compares each element of the input sequence with every other element. This results in quadratic complexity, meaning the operations required grow with the square of the number of elements. This can be extremely taxing for long sequences.
  • Large Matrix Multiplications: These mechanisms rely on large matrix multiplications to compute attention scores and weighted sums. Such operations are computationally intensive, particularly as the length of the input sequence increases.
  • Memory Usage: Self-attention requires storing large intermediate matrices, including the query, key, and value matrices, along with attention scores. This can lead to high memory consumption, straining hardware resources, especially with long sequences or large batch sizes.
  • Multi-Head Attention: Multi-head attention involves running multiple sets of attention mechanisms in parallel to capture different aspects of the input. This further increases both computational and memory demands, as each head requires its own matrix multiplications and transformations.
  • Backpropagation: Training attention mechanisms requires calculating gradients for all operations, adding to the computational load. Backpropagation through these mechanisms involves complex derivatives, further intensifying the resource requirements.

But, despite the computational expense, attention is used because:

  • Effectiveness: Attention mechanisms significantly improve model performance on various tasks by providing dynamic, context-dependent representations.
  • Scalability: While computationally intensive, attention mechanisms can be efficiently scaled with parallel processing on modern hardware (e.g., GPUs and TPUs).
  • Versatility: Attention is applicable across a wide range of tasks, from natural language processing to image recognition.

So, why LLMs are obsessed with "Attention"?

Simply put, to have larger and better contextual information.

Final thoughts

I hope this edition of the newsletter was informative. I have covered the basics of the attention mechanism not dwelling on many details. But attention itself is a vast subject and as LLMs are growing powerful research on effective attention mechanisms also grows.

If you have come this far then please share your thoughts and the things that you have learned.

PS: As I was writing this article I felt like documenting the different types of attention mechanisms out there and benchmarking them based on the similarity performance.

Faith+Hope+Love



AKASH J DARWIN

Systems Engineer | MS-900 Certified | Red Hat Enterprise Linux, Data Engineering on AWS and Microsoft Azure, Data Science specialized with Python, Containers and Google Cloud

8 个月

Nilesh Barla very well written. Few questions for you.. In what ways might the attention mechanism in computers differ from human attention, and what are the potential implications of these differences? How does a computer system quantify the importance-of words in a sentence, and what metrics or algorithms are typically used to assign weightings to different words?

回复

要查看或添加评论,请登录

Nilesh Barla的更多文章

社区洞察

其他会员也浏览了