Advanced Attention Mechanisms-I

Advanced Attention Mechanisms-I


I would recommend you go through this blog first to develop the intuition behind the infamous attention mechanism.


ATTENTION

|| paper ||

(quick recap)

It helped solve the bottleneck of having to encode the entire context of the text in a fixed-length vector. Each of the decoder hidden states acts as a query, and the encoder states act as values.

Each decoder state uses direct connections to all encoder states to focus on a particular part of the source sentence, which is used to take dot product with every encoder state.

Attention mechanism overview (NMT?: Hindi -> English) (image by author)

A crucial part of the attention mechanism is that no additional parameters are used. In statistical machine translation models, we had to use a different model for alignment, which is not the case with attention.

THE MATH

Let Encoder hidden states?: H_1, H_2,... H_n ∈ R_H Decoder hidden state (at t)?: S_t ∈ R_H

At time t,

Attention Score, e_t at time t (image by author)

Next up, we take the softmax to get the attention distribution,

Attention Distribution (image by author)

This can be used to get the weighted sum of the encoder's hidden states, i.e. “the infamous attention mechanism”,

Attention Output (image by author)

Finally, we concatenate the attention output, a_t, with the decoder hidden state S_t.

FOR HACKERS

Here’s a high-level overview of the vanilla attention mechanism (courtesy of “sooftware’s repo”).

class Attention : 
"""
Implementation of the attention mechanism proposed in the paper, "Neural Machine Translation 
by Jointly Learning to Align and Translate". 

Inputs : Query, Key
Outputs : Context, Attention vector
Reference : https://github.com/sooftware/attentions/blob/master/attentions.py
"""
    def __init__(self, hidden_dim: int) -> None:
        super(AdditiveAttention, self).__init__()
        self.query_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.key_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.bias = nn.Parameter(torch.rand(hidden_dim).uniform_(-0.1, 0.1))
        self.score_proj = nn.Linear(hidden_dim, 1)

    def forward(self, query: Tensor, key: Tensor, value: Tensor) -> Tuple[Tensor, Tensor]:
        score = self.score_proj(torch.tanh(self.key_proj(key) + self.query_proj(query) + self.bias)).squeeze(-1)
        attn = F.softmax(score, dim=-1)
        context = torch.bmm(attn.unsqueeze(1), value)
        return context, attn        

SELF ATTENTION

|| paper ||

All seemed well with attention, but there was an issue. We couldn’t do the computation for multiple decoder states in parallel. We had to wait for the previous states to get computed first. (Note in the attention diagram that the recurrent connections are still present.)

This was hampering—long-range dependencies, parallel computing, & total computational complexity per layer. So, we remove the recurrent connections.

Now, self-attention connects all positions with a constant number of sequentially executed operations, compared to a recurrent layer requiring O(n) sequential operations.

Self-attention (image by author)

This is a single attention layer. Stack multiple such layers (6 in the original paper) to capture further intricate patterns and relations in the text.

FOR HACKERS

class ScaledDotProductAttention(nn.Module):
"""
Scaled Dot-Product Attention proposed in "Attention Is All You Need"
Computes the dot products of the query with all keys, divides each by sqrt(dim),
and applies a softmax function to obtain the weights on the values

Inputs : Query, Key, Value, Mask
Output : Context, Attention vector
Reference : https://github.com/sooftware/attentions/blob/master/attentions.py
"""
    def __init__(self, dim: int):
        super(ScaledDotProductAttention, self).__init__()
        self.sqrt_dim = np.sqrt(dim)

    def forward(self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
        score = torch.bmm(query, key.transpose(1, 2)) / self.sqrt_dim

        if mask is not None:
            score.masked_fill_(mask.view(score.size()), -float('Inf'))

        attn = F.softmax(score, -1)
        context = torch.bmm(attn, value)
        return context, attn        

Note how the decoder stack uses a mask to prevent positions from attending subsequent positions.? How do we do it? The upper triangle entries of matrix QK^T are made -ve infinity, effectively a mask.

Why so? Because in the next softmax step...

softmax of masked entries (image by author)

... those entries become zero.


KV CACHE ATTENTION

"Those who can't remember the past are condemned to repeat it." 
                                                         ~ George Santayana        

Folks accustomed to the concept of dynamic programming will find this technique very intuitive.

Unnecessary previous token calculation at each step at T = 4 (image by author)

Note how at every timestep we are calculating the already predicted tokens again, clearly at the cost of repeated computation. At T = 4, we are only interested in calculating the last row. Effectively, at each generation step we are recalculating the same previous token attention, while we just want to calculate the attention for the new token.

Repeated Computation visualization

What can we do in such a situation intuitively?

Simply store the previous outputs and retrieve them instead of recalculating them. This is how a certain “KV cache” can be utilized. By caching the previous keys and values, we can focus on only calculating the attention for the new token.

KV caching

The only downside is increased VRAM requirements for storing the vectors in the cache memory.


SLIDING WINDOW ATTENTION

|| paper ||

Now we have a working mechanism that enables each decoder state to attend to the most relevant encoder states, and we can do the computation or all decoder states simultaneously. But again, researchers found scope to optimize it further.?

The self-attention operation scales up quadratically with the sequence length.

Behold the LONGFORMER!!

LONGFORMER

It scales linearly with sequence length (or so they claim) by combining windowed local-context self-attention with task-motivated global attention. Unlike previous approaches, it can process long sequences without truncating or chunking, allowing a much simpler approach that concatenates the available context and processes it in a single pass.

Sliding Window Attention

Longformer’s attention pattern is a sliding window of attention surrounding each token. Given a fixed size window of 2w, each token attends to w tokens on each side. One can use multiple such stacked layers to create a “large receptive field where top layers have access to all input locations and have the capacity to build representations that incorporate information across the entire input.”

So, in a transformer with l layers, the receptive field size at the top layer is: l x w (w is the window size).

Sliding Window attention

Dilated Sliding Window Attention

Often, the window can be dilated to further broaden the receptive field by having gaps of size dilation, d. The receptive field becomes: l x w x d.

A good way to get the intuition is to understand it’s role in multi-head attention. There can be certain heads without dilation that focus on capturing the local context. While the heads with dilation focus on longer context, simply because they span over longer context.

Dilated Sliding Window attn.
The Math

effective window size without dilation = w

effective window size with dilation = w + d(w-1)

effective increase in window size = d(w-1)

Clearly, even for small values of d, we’ll be having quite large effective windows. So, they use small window sizes for the lower layers and increase window sizes as they move to higher layers. This allows the top layers to learn a higher-level representation of the entire sequence while having the lower layers capture local information. It provides a balance between efficiency & performance.

Global Attention

The authors feel windowed and dilated attentions are not flexible enough to learn task-specific representations, so they propose a certain “global attention.

Linear Projections for Global Attention

They mention using two sets of projections in the paper?—?Qs, Ks, Vs to compute attention scores for sliding window attention & Qg, Kg, Vg to compute attention scores for global attention.

Global + Sliding Window attention

The implementation requires a form of “banded matrix multiplication," which is not supported by conventional PyTorch / TensorFlow, which is why they have come up with 3 custom implementation techniques. Refer to the paper for more details on the same.

And yes, we can again use KV-cache for faster retrieval of previous outputs.

Sliding Window attention KV cache (image by author)

MHA (Multi Head Attention)

|| paper ||

So far, so good?! But what if two tokens, t_i & t_j, are correlated in more than one way? Single-head attention won’t be able to capture that.

We go back to the self-attention paper by Google Brain (NeurIPS, 2017). They proposed linearly projecting the queries, keys, and values h times with different learned projections to their respective dimensions. The attention function is then performed parallel on all these projected versions of the vectors, following which they are concatenated and once again projected to get the final values.

Mathematically,?

MHA math
MHA visualization

It’s advantage is that it performs attention across multiple heads independently, taking into account all kinds of associations between t_i & t_j.

FOR HACKERS

class MultiHeadAttention(nn.Module):
    """
    Multi-Head Attention proposed in "Attention Is All You Need"
    Multi-head attention allows the model to jointly attend to information from different representation
    subspaces at different positions.

    MultiHead(Q, K, V) = Concat(head_1, ..., head_h) · W_o
        where head_i = Attention(Q · W_q, K · W_k, V · W_v)

    Inputs : Query, Key, Value, Mask
        - Query (batch, q_len, d_model): In transformer, three different ways:
            Case 1: come from previoys decoder layer
            Case 2: come from the input embedding
            Case 3: come from the output embedding (masked)

        - Key (batch, k_len, d_model): In transformer, three different ways:
            Case 1: come from the output of the encoder
            Case 2: come from the input embeddings
            Case 3: come from the output embedding (masked)

        - Value (batch, v_len, d_model): In transformer, three different ways:
            Case 1: come from the output of the encoder
            Case 2: come from the input embeddings
            Case 3: come from the output embedding (masked)

        - Mask (-): tensor containing indices to be masked

    Outputs : Output, Attention
        - Output (batch, output_len, dimensions): tensor containing the attended output features.
        - Attention (batch * num_heads, v_len): tensor containing the attention (alignment) from the encoder outputs.

    Reference : https://github.com/sooftware/attentions/blob/master/attentions.py
    """
    def __init__(self, d_model: int = 512, num_heads: int = 8):
        super(MultiHeadAttention, self).__init__()

        assert d_model % num_heads == 0, "d_model % num_heads should be zero."

        self.d_head = int(d_model / num_heads)
        self.num_heads = num_heads
        self.scaled_dot_attn = ScaledDotProductAttention(self.d_head)
        self.query_proj = nn.Linear(d_model, self.d_head * num_heads)
        self.key_proj = nn.Linear(d_model, self.d_head * num_heads)
        self.value_proj = nn.Linear(d_model, self.d_head * num_heads)

    def forward(
            self,
            query: Tensor,
            key: Tensor,
            value: Tensor,
            mask: Optional[Tensor] = None
    ) -> Tuple[Tensor, Tensor]:
        batch_size = value.size(0)

        query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head)  # BxQ_LENxNxD
        key = self.key_proj(key).view(batch_size, -1, self.num_heads, self.d_head)      # BxK_LENxNxD
        value = self.value_proj(value).view(batch_size, -1, self.num_heads, self.d_head)  # BxV_LENxNxD

        query = query.permute(2, 0, 1, 3).contiguous().view(batch_size * self.num_heads, -1, self.d_head)  # BNxQ_LENxD
        key = key.permute(2, 0, 1, 3).contiguous().view(batch_size * self.num_heads, -1, self.d_head)      # BNxK_LENxD
        value = value.permute(2, 0, 1, 3).contiguous().view(batch_size * self.num_heads, -1, self.d_head)  # BNxV_LENxD

        if mask is not None:
            mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1)  # BxNxQ_LENxK_LEN

        context, attn = self.scaled_dot_attn(query, key, value, mask)

        context = context.view(self.num_heads, batch_size, -1, self.d_head)
        context = context.permute(1, 2, 0, 3).contiguous().view(batch_size, -1, self.num_heads * self.d_head)  # BxTxND

        return context, attn        

MQA (Multi Query Attention)

|| paper ||

It’s similar to MHA; just that all the different query heads share the same key and value tensors. The paper claims it greatly reduces the size of these tensors and hence the memory bandwidth requirements of “incremental decoding.”

The KV cache size is reduced by a factor of h, simply because the “key” & “values” heads are getting reused.

MQA

FOR HACKERS

class  MultiQueryAttention(Attention):
    """
    Multi-Query Attention is similar to Multi-Head Attention but it reduces KV cache memory
    as all the different heads use the same "key" & "value" tensors.

    Inputs : Query 
                  (batch, q_len, d_model): This can come from multiple sources within the transformer architecture:
                  
                  Case 1: From the previous decoder layer (in the decoder stack).
                  Case 2: From the input embedding (in the encoder stack).
                  Case 3: From the output embedding (masked for autoregressive decoding).

             Key 
                  (batch, k_len, d_model): In MQA, the keys share the same projection across all heads and can come from:
                  
                  Case 1: The output of the encoder.
                  Case 2: The input embeddings (self-attention).
                  Case 3: The output embedding (masked for autoregressive decoding).
            
             Value 
                  (batch,v_len,d_model): Similar to keys, values are shared across all heads, with sources being:
                  
                  Case 1: The output of the encoder.
                  Case 2: The input embeddings.
                  Case 3: The output embedding (masked).
            
            
     Outputs : 
            Output 
                  (batch,output_len,d_model): A tensor containing the attended output features, where the attention mechanism has aggregated relevant information.
            
            Attention 
                  (batch×num_heads,v_len): A tensor containing the attention weights (alignment) across the shared key-value pairs from the encoder outputs, representing the focus of each query across the input sequence. 
    
      Reference : https://github.com/knotgrass/attention/blob/main/attn/attention.py
    """
    def __init__(self, word_size: int = 512, embed_dim: int = 64, n_query:int=8) -> None:
        super().__init__(word_size, embed_dim)
        self.n_query = n_query
        self.proj = nn.Linear(in_features=embed_dim * n_query,
                              out_features=embed_dim, bias=False)
        delattr(self, 'query')
        self.querys = nn.ModuleList([
            nn.Linear(in_features=word_size, out_features=embed_dim, bias=True)
            for _ in range(n_query)
        ])
        self.key = nn.Linear(in_features=word_size, out_features=embed_dim, bias=True)
        self.value = nn.Linear(in_features=word_size, out_features=embed_dim, bias=True)

    def forward(self, x: Tensor, mask:Optional[BoolTensor]=None) -> Tensor:
        K = self.key(x)
        V = self.value(x)
        Z_s = torch.cat([
            self.self_attention(query(x), K, V, mask) for query in self.querys
        ], dim=1)
        Z = self.proj(Z_s)
        return Z        

UPTRAINING

|| paper ||

The paper claims “uptraining” works better than selecting a single key and value head or randomly initializing new key or value heads from scratch. They show that language model checkpoints can be uptrained to use MQA with a small fraction of the original training compute. This is a cost-effective method that helps obtain fast multi-query and high-quality MHA checkpoints.

There are 2 steps?:?

  1. First they convert the checkpoint, i.e., they mean pool (average out) the projection matrices for keys and values into single projection matrices.
  2. Further pre-training to allow the weights to update to the new structure.

Mean Pooling key and value heads
Performance vs Time comparison

Note how uptrained MQA takes less inference time even though it performs better than MHA-large. Refer to the paper for further details on uptraining steps and results discussion.


GQA (Grouped Query Attention)

|| paper ||

This idea is nothing fancy. It is an adaptation from MHA & MQA by Google Research. Instead of all query vectors tending to single key and value heads, here they are divided into (let’s say) G groups.?

Each group shares a single key and value heads, effectively making it a MQA. So, GQA is a collection of MQA.

By intuition, while converting a multi-head checkpoint to a GQA checkpoint, we construct each group key and value head by mean pooling all the original heads within that group.

GQA
Performance vs Time comparison

We can observe that the performance of GQA is on par with MHA-XXL, even though it takes way less inference time. Refer to the paper for more details on the experimental results.


So far, we've been able to reduce the need for storing a large amount of KV cache, but quality degradation still remains.

In the 2nd part of the blog, we’ll look into ways where we can optimize without performance degradation.


Acknowledgement & References

It is worth noting that the techniques mentioned are pretty recent and don't have multiple resources present as of today, so considering any unprecedented error is greatly appreciated. Given the pace of innovation, it is worth considering the possibility that these concepts might well go obsolete in a few months. All the concepts mentioned above have been sourced from authentic resources to the best of the author's knowledge. It is recommended not to use the provided code snippets in production; they have been adapted from various sources, and their sole purpose is to provide practical relevance. Thank you.

Laboratory for Computational Social Systems (LCS2)

Sourish Dasgupta

Google Research

#genai #llms #attention #transformers #datascience #nlp #research


Akash Mishra

BTech(3rd year) at Indian Institute of Information Technology Ranchi

3 周

Congratulations Arion bro ??

Fascinating attention mechanisms. Optimizing computational cost versus quality is intriguing. Expands thinking on algorithm trade-offs.

Godwin Josh

Co-Founder of Altrosyn and DIrector at CDTECH | Inventor | Manufacturer

3 周

The pursuit of efficient attention mechanisms is crucial as we move towards deploying increasingly complex language models on resource-constrained devices. Recent breakthroughs in neuromorphic computing suggest a paradigm shift, enabling models to process information more akin to the human brain. Will these advancements allow us to design attention mechanisms that are not only computationally efficient but also exhibit greater cognitive flexibility?

要查看或添加评论,请登录

社区洞察

其他会员也浏览了