Difference between Casual and Flash Attention
In the past two weeks, my work largely revolved around the attention mechanism. In the last edition of my newsletter, I wrote about "Why LLMs are obsessed with 'Attention'?" And I concluded that it is all about learning the contextual relationship between the words.
I also wrote an extensive article explaining the performance of three different attention mechanisms -- casual attention, flash attention, and sparse attention.
In this edition of the newsletter, I thought of teaching you how to code two attention mechanisms and tell you the difference between the two.
But before learning about the attention mechanism it is worth learning how the input data is processed before feeding it into the attention.
Preprocessing Input Data
It is important to understand that the input data or sentences are converted into integers. Essentially, each word in the data is given a unique integer value. This value is then passed into an embedding layer that creates a dense matrix for every word.
In one of my posts, I explained the importance of matrix in modern engineering. You can read about it here.
For instance, you can define an embedding layer using the following code:
vocab_size = 64
n_embd = 512
emb = nn.Embedding(vocab_size, n_embd)
Here, the vocab_size refers to the number of unique words in a given document and n_embd refers to the embedding space. The embedding space is very important, a higher dimension corresponds to better learning capabilities but at a cost of higher computational demands.
Embedding space refers to the concept of representing an integer (individual word) with an array of vectors.
Assuming, I have data with 64 unique words and I have to represent each word X as a vector of 512 elements then I would:
y = emb(X)
print(y)
>>> tensor([ 1.6937e+00, -4.5502e-01, 7.9823e-01, 1.3988e+00, 9.5857e-01,
1.6602e+00, -2.2901e-01, 1.1099e+00, -7.8383e-01, 2.2650e-01,
-9.3003e-01, -1.0562e+00, -6.6863e-01, 1.7226e+00, -1.2740e+00,
...
3.6767e-01, 4.0440e-01], grad_fn=<EmbeddingBackward0>)
print(y.shape)
>>> torch.Size([1, 512])
Embedding layers allows us to generate lower-dimensional dense vectors. This makes the data more manageable and computationally efficient. It also helps us to capture the semantic similarity between words.
Apart from that it allows us to capture contextual information, meaning that the same word can have different embeddings depending on its context in the sentence.
Casual Attention
Now that we have embeddings ready we can feed into the Casual attention layers. But there is a way to feed these embeddings.
Let's start with the casual attention mechanism. It is one of the most simple attention mechanisms. It starts with defining a Linear layer that would take the output of the embedding layers as the input.
First of all, you must understand that the input sequences must have three duplicates corresponding to query, key, and value this allows the attention mechanism to attend each word with every other word. You can write as:
qkv = nn.Linear(n_embd, 3*n_embd)
As you can see the output of the linear is multiplied by 3 to retrieve the output for query, key, and value.
Now, when we pass the output of the embedding layer into the first layer of attention we get:
project_qkv = qkv(y)
print(project_qkv.shape)
>>> torch.Size([1, 1536])
After that, we can split the projection into the individual Q, K, and V.
A simple way to split is to use the .chunk method. However, this method can raise errors when scaling up. An alternative method would be you reshape the projection into the following shape using the:
Lastly, we will reorder the qkv projection before splitting it into individual projections. This can be done using the .permute(2, 0, 3, 1, 4) method.
The entire step looks like this:
B = 1
T, C = y.size()
n_head = 8
d_k = n_embd // n_head
project_qkv = project_qkv.reshape(B, T, 3, n_head, d_k)
print(project_qkv.shape)
>>> torch.Size([1, 1, 3, 8, 64])
project_qkv = project_qkv.permute(2, 0, 3, 1, 4)
print(print(project_qkv.shape))
>>> torch.Size([3, 1, 8, 1, 64])
q, k, v = project_qkv[0], project_qkv[1], project_qkv[2]
Now, that we have a q, k, and v, we start calculating the attention scores. The easiest way is to follow the diagram given below.
The first step to finding attention is writing a scaled-dot product operation between q and k.
attn_scores = (q @ k.transpose(-2, -1))
scale = d_k ** -0.5
attn_scores = attn_scores* scale
Next, we will apply the function softmax to calculate the probability attention scores :
attn_probs = F.softmax(attn_scores, dim=-1)
Lastly, we will write the final dot product operation between attention probability scores and v.
attn_output = (attn_probs @ v)
print(attn_output.shape)
>>> torch.Size([1, 8, 1, 64])
Before moving ahead, we need to rearrange the output shape back to the input shape.
attn_output = attn_output.transpose(1, 2)
print(attn_output.shape)
>>> torch.Size([1, 1, 8, 64])
attn_output = attn_output.reshape(B, T, C)
print(attn_output.shape)
>>> torch.Size([1, 1, 512])
Below is the overall code for the casual attention model.
class CausalSelfAttention(nn.Module):
def __init__(self, config):
super().__init__()
assert config.n_embd % config.n_head == 0
self.n_head = n_head
self.d_k = n_embd // config.n_head
self.scale = self.d_k ** -0.5
self.qkv_proj = nn.Linear(n_embd, 3 * n_embd, bias=bias)
self.out_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
self.attn_dropout = nn.Dropout(config.dropout)
self.resid_dropout = nn.Dropout(config.dropout)
self.register_buffer("mask", torch.tril(torch.ones(config.block_size, config.block_size)).view(1, 1, config.block_size, config.block_size))
def forward(self, x):
B, T, C = x.size()
qkv = self.qkv_proj(x).reshape(B, T, 3, self.n_head, self.d_k).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
attn_scores = (q @ k.transpose(-2, -1)) * self.scale
attn_scores = attn_scores.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf'))
attn_probs = F.softmax(attn_scores, dim=-1)
attn_probs = self.attn_dropout(attn_probs)
attn_output = (attn_probs @ v).transpose(1, 2).reshape(B, T, C)
attn_output = self.resid_dropout(self.out_proj(attn_output))
return attn_output
Now let's move on to Flash Attention.
领英推荐
Flash Attention
The Flash Attention is a more advanced attention mechanism designed for efficiency and performance. Most of the parts will be similar to casual attention with a few exception. The initial part of representing embeddings into q, k, and v remains the same. So assuming that we have attained q, v, and k we now move ahead to the remaining parts.
We will create a tensor of the same shape and type as q, but filled with zeros. This will allow us to store the accumulation of results across multiple blocks in the attention mechanism
output = torch.zeros_like(q)
Initializing the output tensor with zeros also provides a neutral starting point for adding each block's results without interference from pre-existing values.
Additionally, defining output tensors helps in efficient memory management. It avoids the overhead of dynamic resizing during computations. This approach enables faster and more efficient processing.
Now we enter the most important part of flash attention -- Block processing.
Block Processing - Outer Loop
We start by processing the input in blocks for the queries. We create an outer for loop that iterates over the sequence length in steps of block_size.
The block_size refers to the size of the smaller segments or "blocks" into which the input sequence is divided for processing. Essentially, it is a context window.
For instance, if you have a sequence length of 512 then you can select a block size or context window of 64.
We use this outer loop to extract the current block of queries into q_block.
for i in range(0, x.size(1), block_size):
i_end = min(i + block_size, x.size(1))
q_block = q[:, i:i_end]
m = torch.full((q.shape[0], i_end - i), float('-inf'), device=q.device)
l = torch.zeros((q.shape[0], i_end - i), device=q.device)
Next, we initialize m and l for numerical stability and accumulation. m is filled with negative infinity to ensure stability in subsequent calculations, and l is filled with zeros to accumulate the exponentials of the attention scores.
Inner Block
We then process the input in the same context window for the keys and values using the inner loop.
The inner loop iterates over the sequence length in steps of block_size. For each iteration, j_end defines the end index for the current block of keys and values, which are extracted into k_block and v_block.
for j in range(0, x.size(1), block_size):
j_end = min(j + block_size, x.size(1))
k_block = k[:, j:j_end]
v_block = v[:, j:j_end]
Next, we compute the attention scores.
The interesting thing about this approach is that the one q_block attends all the k_blocks inside the inner loop.
attn_block = torch.matmul(q_block, k_block.transpose(-2, -1)) * scale
We now compute the exponential of the attention scores to stabilize them.
If you remember in the casual attention we used softmax. Here, we will use maximum to first compute the element-wise maximum of m and attn_block. Followed by calculating the exponential element-wise difference between attn_block and m_new.
m_new updates m with the maximum value from the attention scores for numerical stability, and exp_attn computes the exponential of the adjusted attention scores.
m_new = torch.maximum(m, attn_block.max(dim=-1)[0])
exp_attn = torch.exp(attn_block - m_new.unsqueeze(-1))
Next, we accumulate the exponentials of the attention scores.
l_new updates l to accumulate the sum of the exponentials, and output_block calculates the output block by multiplying the exponential attention scores with the value block. This allows to keep track of the maximum values and row sums.
l_new = l * torch.exp(m - m_new) + exp_attn.sum(dim=-1)
output_block = torch.matmul(exp_attn, v_block)
We update the output tensor with the new calculated values. This involves adding the new output block to the relevant section of the output tensor, normalized by the accumulated exponentials. Finally, we update m and l for the next iteration.
output[:, i:i_end] += (output_block - output[:, i:i_end]) * (l / l_new).unsqueeze(-1)
output[:, i:i_end] /= l.unsqueeze(-1)
m, l = m_new, l_new
Finally, we project the output back to the original embedding dimension using another linear layer.
output = rearrange(output, '(b h) t d -> b t (h d)', h=n_head)
proj = nn.Linear(n_embd, n_embd)
output = proj(output)
So that was flash attention. I have written the overall code below.
class SimpleFlashAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.n_embd = config.n_embd
self.n_head = config.n_head
self.head_dim = self.n_embd // self.n_head
self.scale = self.head_dim ** -0.5
self.qkv = nn.Linear(self.n_embd, 3 * self.n_embd)
self.proj = nn.Linear(self.n_embd, self.n_embd)
self.dropout_p = config.dropout
self.causal = True # assuming causal for GPT-like model
self.block_size = config.block_size # size of blocks for tiling, now configurable
def forward(self, x):
b, t, c = x.size()
qkv = self.qkv(x).view(b, t, 3, self.n_head, self.head_dim)
q, k, v = qkv.unbind(2)
q, k, v = [rearrange(x, 'b t h d -> (b h) t d') for x in (q, k, v)]
output = torch.zeros_like(q)
for i in range(0, t, self.block_size):
i_end = min(i + self.block_size, t)
q_block = q[:, i:i_end]
m = torch.full((q.shape[0], i_end - i), float('-inf'), device=q.device)
l = torch.zeros((q.shape[0], i_end - i), device=q.device)
for j in range(0, t, self.block_size):
j_end = min(j + self.block_size, t)
k_block = k[:, j:j_end]
v_block = v[:, j:j_end]
attn_block = torch.matmul(q_block, k_block.transpose(-2, -1)) * self.scale
if self.causal and j > i:
attn_block.fill_(float('-inf'))
elif self.causal:
causal_mask = torch.triu(torch.ones(i_end - i, j_end - j, dtype=torch.bool, device=attn_block.device), diagonal=j - i + 1)
attn_block.masked_fill_(causal_mask, float('-inf'))
m_new = torch.maximum(m, attn_block.max(dim=-1)[0])
exp_attn = torch.exp(attn_block - m_new.unsqueeze(-1))
l_new = l * torch.exp(m - m_new) + exp_attn.sum(dim=-1)
output_block = torch.matmul(exp_attn, v_block)
output[:, i:i_end] += (output_block - output[:, i:i_end]) * (l / l_new).unsqueeze(-1)
m, l = m_new, l_new
output[:, i:i_end] /= l.unsqueeze(-1)
output = rearrange(output, '(b h) t d -> b t (h d)', h=self.n_head)
return self.proj(output)
Keep in mind that this implementation is a simpler version of Flash Attention. I wrote a basic and minimal version of Flash Attention to teach you the core concepts.
The core idea behind flash attention
What is the advantage of using flash attention over casual attention?
Let's say you have a long document and you need to summarize it. Instead of reading the entire document at once, you read and summarize it paragraph by paragraph. This is similar to how Flash Attention works – it breaks down a long sequence into smaller blocks (paragraphs), processes each block independently, and then combines the results. This approach is more efficient and manageable, especially for very long documents (or sequences).
Conclusion
Flash Attention exemplifies the cutting-edge in efficiency and performance, particularly for handling long sequences. By breaking down these sequences into smaller, manageable blocks and leveraging parallel processing, Flash Attention ensures both numerical stability and computational efficiency. This approach is akin to summarizing a lengthy document paragraph by paragraph, providing a more scalable and practical solution for real-world applications.
In contrast, while Casual Attention lays the foundation by introducing the basic concepts and operations, Flash Attention takes a step further by optimizing and refining these processes. Both mechanisms have their unique advantages, and understanding their differences equips you with the knowledge to choose the right tool for your specific needs.
If you have come this far then stay tuned for the next edition of Peceive where I simplify complex AI concepts.