Advanced Attention Mechanisms-I
I would recommend you go through this blog first to develop the intuition behind the infamous attention mechanism.
ATTENTION
|| paper ||
(quick recap)
It helped solve the bottleneck of having to encode the entire context of the text in a fixed-length vector. Each of the decoder hidden states acts as a query, and the encoder states act as values.
Each decoder state uses direct connections to all encoder states to focus on a particular part of the source sentence, which is used to take dot product with every encoder state.
A crucial part of the attention mechanism is that no additional parameters are used. In statistical machine translation models, we had to use a different model for alignment, which is not the case with attention.
THE MATH
Let Encoder hidden states?: H_1, H_2,... H_n ∈ R_H Decoder hidden state (at t)?: S_t ∈ R_H
At time t,
Next up, we take the softmax to get the attention distribution,
This can be used to get the weighted sum of the encoder's hidden states, i.e. “the infamous attention mechanism”,
Finally, we concatenate the attention output, a_t, with the decoder hidden state S_t.
FOR HACKERS
Here’s a high-level overview of the vanilla attention mechanism (courtesy of “sooftware’s repo”).
class Attention :
"""
Implementation of the attention mechanism proposed in the paper, "Neural Machine Translation
by Jointly Learning to Align and Translate".
Inputs : Query, Key
Outputs : Context, Attention vector
Reference : https://github.com/sooftware/attentions/blob/master/attentions.py
"""
def __init__(self, hidden_dim: int) -> None:
super(AdditiveAttention, self).__init__()
self.query_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
self.key_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
self.bias = nn.Parameter(torch.rand(hidden_dim).uniform_(-0.1, 0.1))
self.score_proj = nn.Linear(hidden_dim, 1)
def forward(self, query: Tensor, key: Tensor, value: Tensor) -> Tuple[Tensor, Tensor]:
score = self.score_proj(torch.tanh(self.key_proj(key) + self.query_proj(query) + self.bias)).squeeze(-1)
attn = F.softmax(score, dim=-1)
context = torch.bmm(attn.unsqueeze(1), value)
return context, attn
SELF ATTENTION
|| paper ||
All seemed well with attention, but there was an issue. We couldn’t do the computation for multiple decoder states in parallel. We had to wait for the previous states to get computed first. (Note in the attention diagram that the recurrent connections are still present.)
This was hampering—long-range dependencies, parallel computing, & total computational complexity per layer. So, we remove the recurrent connections.
Now, self-attention connects all positions with a constant number of sequentially executed operations, compared to a recurrent layer requiring O(n) sequential operations.
This is a single attention layer. Stack multiple such layers (6 in the original paper) to capture further intricate patterns and relations in the text.
FOR HACKERS
class ScaledDotProductAttention(nn.Module):
"""
Scaled Dot-Product Attention proposed in "Attention Is All You Need"
Computes the dot products of the query with all keys, divides each by sqrt(dim),
and applies a softmax function to obtain the weights on the values
Inputs : Query, Key, Value, Mask
Output : Context, Attention vector
Reference : https://github.com/sooftware/attentions/blob/master/attentions.py
"""
def __init__(self, dim: int):
super(ScaledDotProductAttention, self).__init__()
self.sqrt_dim = np.sqrt(dim)
def forward(self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
score = torch.bmm(query, key.transpose(1, 2)) / self.sqrt_dim
if mask is not None:
score.masked_fill_(mask.view(score.size()), -float('Inf'))
attn = F.softmax(score, -1)
context = torch.bmm(attn, value)
return context, attn
Note how the decoder stack uses a mask to prevent positions from attending subsequent positions.? How do we do it? The upper triangle entries of matrix QK^T are made -ve infinity, effectively a mask.
Why so? Because in the next softmax step...
... those entries become zero.
KV CACHE ATTENTION
"Those who can't remember the past are condemned to repeat it."
~ George Santayana
Folks accustomed to the concept of dynamic programming will find this technique very intuitive.
Note how at every timestep we are calculating the already predicted tokens again, clearly at the cost of repeated computation. At T = 4, we are only interested in calculating the last row. Effectively, at each generation step we are recalculating the same previous token attention, while we just want to calculate the attention for the new token.
What can we do in such a situation intuitively?
Simply store the previous outputs and retrieve them instead of recalculating them. This is how a certain “KV cache” can be utilized. By caching the previous keys and values, we can focus on only calculating the attention for the new token.
The only downside is increased VRAM requirements for storing the vectors in the cache memory.
SLIDING WINDOW ATTENTION
|| paper ||
Now we have a working mechanism that enables each decoder state to attend to the most relevant encoder states, and we can do the computation or all decoder states simultaneously. But again, researchers found scope to optimize it further.?
The self-attention operation scales up quadratically with the sequence length.
Behold the LONGFORMER!!
LONGFORMER
It scales linearly with sequence length (or so they claim) by combining windowed local-context self-attention with task-motivated global attention. Unlike previous approaches, it can process long sequences without truncating or chunking, allowing a much simpler approach that concatenates the available context and processes it in a single pass.
Sliding Window Attention
Longformer’s attention pattern is a sliding window of attention surrounding each token. Given a fixed size window of 2w, each token attends to w tokens on each side. One can use multiple such stacked layers to create a “large receptive field where top layers have access to all input locations and have the capacity to build representations that incorporate information across the entire input.”
So, in a transformer with l layers, the receptive field size at the top layer is: l x w (w is the window size).
Dilated Sliding Window Attention
Often, the window can be dilated to further broaden the receptive field by having gaps of size dilation, d. The receptive field becomes: l x w x d.
A good way to get the intuition is to understand it’s role in multi-head attention. There can be certain heads without dilation that focus on capturing the local context. While the heads with dilation focus on longer context, simply because they span over longer context.
领英推荐
The Math
effective window size without dilation = w
effective window size with dilation = w + d(w-1)
effective increase in window size = d(w-1)
Clearly, even for small values of d, we’ll be having quite large effective windows. So, they use small window sizes for the lower layers and increase window sizes as they move to higher layers. This allows the top layers to learn a higher-level representation of the entire sequence while having the lower layers capture local information. It provides a balance between efficiency & performance.
Global Attention
The authors feel windowed and dilated attentions are not flexible enough to learn task-specific representations, so they propose a certain “global attention.”
Linear Projections for Global Attention
They mention using two sets of projections in the paper?—?Qs, Ks, Vs to compute attention scores for sliding window attention & Qg, Kg, Vg to compute attention scores for global attention.
The implementation requires a form of “banded matrix multiplication," which is not supported by conventional PyTorch / TensorFlow, which is why they have come up with 3 custom implementation techniques. Refer to the paper for more details on the same.
And yes, we can again use KV-cache for faster retrieval of previous outputs.
MHA (Multi Head Attention)
|| paper ||
So far, so good?! But what if two tokens, t_i & t_j, are correlated in more than one way? Single-head attention won’t be able to capture that.
We go back to the self-attention paper by Google Brain (NeurIPS, 2017). They proposed linearly projecting the queries, keys, and values h times with different learned projections to their respective dimensions. The attention function is then performed parallel on all these projected versions of the vectors, following which they are concatenated and once again projected to get the final values.
Mathematically,?
It’s advantage is that it performs attention across multiple heads independently, taking into account all kinds of associations between t_i & t_j.
FOR HACKERS
class MultiHeadAttention(nn.Module):
"""
Multi-Head Attention proposed in "Attention Is All You Need"
Multi-head attention allows the model to jointly attend to information from different representation
subspaces at different positions.
MultiHead(Q, K, V) = Concat(head_1, ..., head_h) · W_o
where head_i = Attention(Q · W_q, K · W_k, V · W_v)
Inputs : Query, Key, Value, Mask
- Query (batch, q_len, d_model): In transformer, three different ways:
Case 1: come from previoys decoder layer
Case 2: come from the input embedding
Case 3: come from the output embedding (masked)
- Key (batch, k_len, d_model): In transformer, three different ways:
Case 1: come from the output of the encoder
Case 2: come from the input embeddings
Case 3: come from the output embedding (masked)
- Value (batch, v_len, d_model): In transformer, three different ways:
Case 1: come from the output of the encoder
Case 2: come from the input embeddings
Case 3: come from the output embedding (masked)
- Mask (-): tensor containing indices to be masked
Outputs : Output, Attention
- Output (batch, output_len, dimensions): tensor containing the attended output features.
- Attention (batch * num_heads, v_len): tensor containing the attention (alignment) from the encoder outputs.
Reference : https://github.com/sooftware/attentions/blob/master/attentions.py
"""
def __init__(self, d_model: int = 512, num_heads: int = 8):
super(MultiHeadAttention, self).__init__()
assert d_model % num_heads == 0, "d_model % num_heads should be zero."
self.d_head = int(d_model / num_heads)
self.num_heads = num_heads
self.scaled_dot_attn = ScaledDotProductAttention(self.d_head)
self.query_proj = nn.Linear(d_model, self.d_head * num_heads)
self.key_proj = nn.Linear(d_model, self.d_head * num_heads)
self.value_proj = nn.Linear(d_model, self.d_head * num_heads)
def forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
mask: Optional[Tensor] = None
) -> Tuple[Tensor, Tensor]:
batch_size = value.size(0)
query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head) # BxQ_LENxNxD
key = self.key_proj(key).view(batch_size, -1, self.num_heads, self.d_head) # BxK_LENxNxD
value = self.value_proj(value).view(batch_size, -1, self.num_heads, self.d_head) # BxV_LENxNxD
query = query.permute(2, 0, 1, 3).contiguous().view(batch_size * self.num_heads, -1, self.d_head) # BNxQ_LENxD
key = key.permute(2, 0, 1, 3).contiguous().view(batch_size * self.num_heads, -1, self.d_head) # BNxK_LENxD
value = value.permute(2, 0, 1, 3).contiguous().view(batch_size * self.num_heads, -1, self.d_head) # BNxV_LENxD
if mask is not None:
mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1) # BxNxQ_LENxK_LEN
context, attn = self.scaled_dot_attn(query, key, value, mask)
context = context.view(self.num_heads, batch_size, -1, self.d_head)
context = context.permute(1, 2, 0, 3).contiguous().view(batch_size, -1, self.num_heads * self.d_head) # BxTxND
return context, attn
MQA (Multi Query Attention)
|| paper ||
It’s similar to MHA; just that all the different query heads share the same key and value tensors. The paper claims it greatly reduces the size of these tensors and hence the memory bandwidth requirements of “incremental decoding.”
The KV cache size is reduced by a factor of h, simply because the “key” & “values” heads are getting reused.
FOR HACKERS
class MultiQueryAttention(Attention):
"""
Multi-Query Attention is similar to Multi-Head Attention but it reduces KV cache memory
as all the different heads use the same "key" & "value" tensors.
Inputs : Query
(batch, q_len, d_model): This can come from multiple sources within the transformer architecture:
Case 1: From the previous decoder layer (in the decoder stack).
Case 2: From the input embedding (in the encoder stack).
Case 3: From the output embedding (masked for autoregressive decoding).
Key
(batch, k_len, d_model): In MQA, the keys share the same projection across all heads and can come from:
Case 1: The output of the encoder.
Case 2: The input embeddings (self-attention).
Case 3: The output embedding (masked for autoregressive decoding).
Value
(batch,v_len,d_model): Similar to keys, values are shared across all heads, with sources being:
Case 1: The output of the encoder.
Case 2: The input embeddings.
Case 3: The output embedding (masked).
Outputs :
Output
(batch,output_len,d_model): A tensor containing the attended output features, where the attention mechanism has aggregated relevant information.
Attention
(batch×num_heads,v_len): A tensor containing the attention weights (alignment) across the shared key-value pairs from the encoder outputs, representing the focus of each query across the input sequence.
Reference : https://github.com/knotgrass/attention/blob/main/attn/attention.py
"""
def __init__(self, word_size: int = 512, embed_dim: int = 64, n_query:int=8) -> None:
super().__init__(word_size, embed_dim)
self.n_query = n_query
self.proj = nn.Linear(in_features=embed_dim * n_query,
out_features=embed_dim, bias=False)
delattr(self, 'query')
self.querys = nn.ModuleList([
nn.Linear(in_features=word_size, out_features=embed_dim, bias=True)
for _ in range(n_query)
])
self.key = nn.Linear(in_features=word_size, out_features=embed_dim, bias=True)
self.value = nn.Linear(in_features=word_size, out_features=embed_dim, bias=True)
def forward(self, x: Tensor, mask:Optional[BoolTensor]=None) -> Tensor:
K = self.key(x)
V = self.value(x)
Z_s = torch.cat([
self.self_attention(query(x), K, V, mask) for query in self.querys
], dim=1)
Z = self.proj(Z_s)
return Z
UPTRAINING
|| paper ||
The paper claims “uptraining” works better than selecting a single key and value head or randomly initializing new key or value heads from scratch. They show that language model checkpoints can be uptrained to use MQA with a small fraction of the original training compute. This is a cost-effective method that helps obtain fast multi-query and high-quality MHA checkpoints.
There are 2 steps?:?
Note how uptrained MQA takes less inference time even though it performs better than MHA-large. Refer to the paper for further details on uptraining steps and results discussion.
GQA (Grouped Query Attention)
|| paper ||
This idea is nothing fancy. It is an adaptation from MHA & MQA by Google Research. Instead of all query vectors tending to single key and value heads, here they are divided into (let’s say) G groups.?
Each group shares a single key and value heads, effectively making it a MQA. So, GQA is a collection of MQA.
By intuition, while converting a multi-head checkpoint to a GQA checkpoint, we construct each group key and value head by mean pooling all the original heads within that group.
We can observe that the performance of GQA is on par with MHA-XXL, even though it takes way less inference time. Refer to the paper for more details on the experimental results.
So far, we've been able to reduce the need for storing a large amount of KV cache, but quality degradation still remains.
In the 2nd part of the blog, we’ll look into ways where we can optimize without performance degradation.
Acknowledgement & References
It is worth noting that the techniques mentioned are pretty recent and don't have multiple resources present as of today, so considering any unprecedented error is greatly appreciated. Given the pace of innovation, it is worth considering the possibility that these concepts might well go obsolete in a few months. All the concepts mentioned above have been sourced from authentic resources to the best of the author's knowledge. It is recommended not to use the provided code snippets in production; they have been adapted from various sources, and their sole purpose is to provide practical relevance. Thank you.
#genai #llms #attention #transformers #datascience #nlp #research
BTech(3rd year) at Indian Institute of Information Technology Ranchi
3 周Congratulations Arion bro ??
Fascinating attention mechanisms. Optimizing computational cost versus quality is intriguing. Expands thinking on algorithm trade-offs.
Co-Founder of Altrosyn and DIrector at CDTECH | Inventor | Manufacturer
3 周The pursuit of efficient attention mechanisms is crucial as we move towards deploying increasingly complex language models on resource-constrained devices. Recent breakthroughs in neuromorphic computing suggest a paradigm shift, enabling models to process information more akin to the human brain. Will these advancements allow us to design attention mechanisms that are not only computationally efficient but also exhibit greater cognitive flexibility?