Difference between Casual and Flash Attention
An entire architecture of the Transformer model. | Source: Attention? Attention!

Difference between Casual and Flash Attention

In the past two weeks, my work largely revolved around the attention mechanism. In the last edition of my newsletter, I wrote about "Why LLMs are obsessed with 'Attention'?" And I concluded that it is all about learning the contextual relationship between the words.

I also wrote an extensive article explaining the performance of three different attention mechanisms -- casual attention, flash attention, and sparse attention.

In this edition of the newsletter, I thought of teaching you how to code two attention mechanisms and tell you the difference between the two.

But before learning about the attention mechanism it is worth learning how the input data is processed before feeding it into the attention.

Preprocessing Input Data

It is important to understand that the input data or sentences are converted into integers. Essentially, each word in the data is given a unique integer value. This value is then passed into an embedding layer that creates a dense matrix for every word.

In one of my posts, I explained the importance of matrix in modern engineering. You can read about it here.

For instance, you can define an embedding layer using the following code:

vocab_size = 64
n_embd = 512
emb = nn.Embedding(vocab_size, n_embd)        

Here, the vocab_size refers to the number of unique words in a given document and n_embd refers to the embedding space. The embedding space is very important, a higher dimension corresponds to better learning capabilities but at a cost of higher computational demands.

Embedding space refers to the concept of representing an integer (individual word) with an array of vectors.

Assuming, I have data with 64 unique words and I have to represent each word X as a vector of 512 elements then I would:

y = emb(X)

print(y)
>>> tensor([ 1.6937e+00, -4.5502e-01,  7.9823e-01,  1.3988e+00,  9.5857e-01,
         1.6602e+00, -2.2901e-01,  1.1099e+00, -7.8383e-01,  2.2650e-01,
        -9.3003e-01, -1.0562e+00, -6.6863e-01,  1.7226e+00, -1.2740e+00,
         ...
         3.6767e-01,  4.0440e-01], grad_fn=<EmbeddingBackward0>)

print(y.shape)
>>> torch.Size([1, 512])        

Embedding layers allows us to generate lower-dimensional dense vectors. This makes the data more manageable and computationally efficient. It also helps us to capture the semantic similarity between words.

Apart from that it allows us to capture contextual information, meaning that the same word can have different embeddings depending on its context in the sentence.

Casual Attention

Now that we have embeddings ready we can feed into the Casual attention layers. But there is a way to feed these embeddings.

Let's start with the casual attention mechanism. It is one of the most simple attention mechanisms. It starts with defining a Linear layer that would take the output of the embedding layers as the input.

First of all, you must understand that the input sequences must have three duplicates corresponding to query, key, and value this allows the attention mechanism to attend each word with every other word. You can write as:

qkv = nn.Linear(n_embd, 3*n_embd)        

As you can see the output of the linear is multiplied by 3 to retrieve the output for query, key, and value.

Now, when we pass the output of the embedding layer into the first layer of attention we get:

project_qkv = qkv(y)

print(project_qkv.shape)
>>> torch.Size([1, 1536])        

After that, we can split the projection into the individual Q, K, and V.

A simple way to split is to use the .chunk method. However, this method can raise errors when scaling up. An alternative method would be you reshape the projection into the following shape using the:

  • Batch b of the output from the embedding
  • Sequence length t is also referred to as the time dimension because sequences are time-dependent.
  • Number of attention heads n_head which is the number of times you will create a copy of the attention mechanism.
  • Embedding dimension d_k represents the dimensionality of each attention head’s subspace. This means that every attention copy will contain an embedding space of d_k instead of n_embd. By using d_k we are distributing the input for efficient parallel processing. This can be calculated using n_embd // n_head.

Lastly, we will reorder the qkv projection before splitting it into individual projections. This can be done using the .permute(2, 0, 3, 1, 4) method.

The entire step looks like this:

B = 1
T, C = y.size()
n_head = 8
d_k = n_embd // n_head

project_qkv = project_qkv.reshape(B, T, 3, n_head, d_k)
print(project_qkv.shape)
>>> torch.Size([1, 1, 3, 8, 64])

project_qkv = project_qkv.permute(2, 0, 3, 1, 4)
print(print(project_qkv.shape))
>>> torch.Size([3, 1, 8, 1, 64])

q, k, v = project_qkv[0], project_qkv[1], project_qkv[2]        

Now, that we have a q, k, and v, we start calculating the attention scores. The easiest way is to follow the diagram given below.

A simple flowchart of the attention mechanism. | Source:

The first step to finding attention is writing a scaled-dot product operation between q and k.

 attn_scores = (q @ k.transpose(-2, -1)) 
scale = d_k ** -0.5
attn_scores = attn_scores* scale        

Next, we will apply the function softmax to calculate the probability attention scores :

attn_probs = F.softmax(attn_scores, dim=-1)        

Lastly, we will write the final dot product operation between attention probability scores and v.

attn_output = (attn_probs @ v)
print(attn_output.shape)
>>> torch.Size([1, 8, 1, 64])        

Before moving ahead, we need to rearrange the output shape back to the input shape.

attn_output = attn_output.transpose(1, 2)
print(attn_output.shape)
>>> torch.Size([1, 1, 8, 64])

attn_output = attn_output.reshape(B, T, C)
print(attn_output.shape)
>>> torch.Size([1, 1, 512])        

Below is the overall code for the casual attention model.

class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0

        self.n_head = n_head
        self.d_k = n_embd // config.n_head
        self.scale = self.d_k ** -0.5

        self.qkv_proj = nn.Linear(n_embd, 3 * n_embd, bias=bias)
        self.out_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)

        self.register_buffer("mask", torch.tril(torch.ones(config.block_size, config.block_size)).view(1, 1, config.block_size, config.block_size))

    def forward(self, x):
        B, T, C = x.size()
        qkv = self.qkv_proj(x).reshape(B, T, 3, self.n_head, self.d_k).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn_scores = (q @ k.transpose(-2, -1)) * self.scale
        attn_scores = attn_scores.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf'))
        attn_probs = F.softmax(attn_scores, dim=-1)
        attn_probs = self.attn_dropout(attn_probs)

        attn_output = (attn_probs @ v).transpose(1, 2).reshape(B, T, C)
        attn_output = self.resid_dropout(self.out_proj(attn_output))

        return attn_output        

Now let's move on to Flash Attention.

Flash Attention

The Flash Attention is a more advanced attention mechanism designed for efficiency and performance. Most of the parts will be similar to casual attention with a few exception. The initial part of representing embeddings into q, k, and v remains the same. So assuming that we have attained q, v, and k we now move ahead to the remaining parts.

We will create a tensor of the same shape and type as q, but filled with zeros. This will allow us to store the accumulation of results across multiple blocks in the attention mechanism

output = torch.zeros_like(q)        

Initializing the output tensor with zeros also provides a neutral starting point for adding each block's results without interference from pre-existing values.

Additionally, defining output tensors helps in efficient memory management. It avoids the overhead of dynamic resizing during computations. This approach enables faster and more efficient processing.

Now we enter the most important part of flash attention -- Block processing.

Block Processing - Outer Loop

We start by processing the input in blocks for the queries. We create an outer for loop that iterates over the sequence length in steps of block_size.

The block_size refers to the size of the smaller segments or "blocks" into which the input sequence is divided for processing. Essentially, it is a context window.

For instance, if you have a sequence length of 512 then you can select a block size or context window of 64.

We use this outer loop to extract the current block of queries into q_block.

for i in range(0, x.size(1), block_size):
    i_end = min(i + block_size, x.size(1))  
    q_block = q[:, i:i_end]  
    m = torch.full((q.shape[0], i_end - i), float('-inf'), device=q.device)  
    l = torch.zeros((q.shape[0], i_end - i), device=q.device)          

Next, we initialize m and l for numerical stability and accumulation. m is filled with negative infinity to ensure stability in subsequent calculations, and l is filled with zeros to accumulate the exponentials of the attention scores.

Inner Block

We then process the input in the same context window for the keys and values using the inner loop.

The inner loop iterates over the sequence length in steps of block_size. For each iteration, j_end defines the end index for the current block of keys and values, which are extracted into k_block and v_block.

    for j in range(0, x.size(1), block_size):
        j_end = min(j + block_size, x.size(1))  
        k_block = k[:, j:j_end]  
        v_block = v[:, j:j_end]          

Next, we compute the attention scores.

The interesting thing about this approach is that the one q_block attends all the k_blocks inside the inner loop.

        attn_block = torch.matmul(q_block, k_block.transpose(-2, -1)) * scale         

We now compute the exponential of the attention scores to stabilize them.

If you remember in the casual attention we used softmax. Here, we will use maximum to first compute the element-wise maximum of m and attn_block. Followed by calculating the exponential element-wise difference between attn_block and m_new.

m_new updates m with the maximum value from the attention scores for numerical stability, and exp_attn computes the exponential of the adjusted attention scores.

        m_new = torch.maximum(m, attn_block.max(dim=-1)[0])  
        exp_attn = torch.exp(attn_block - m_new.unsqueeze(-1))          

Next, we accumulate the exponentials of the attention scores.

l_new updates l to accumulate the sum of the exponentials, and output_block calculates the output block by multiplying the exponential attention scores with the value block. This allows to keep track of the maximum values and row sums.

        l_new = l * torch.exp(m - m_new) + exp_attn.sum(dim=-1)  
        output_block = torch.matmul(exp_attn, v_block)          

We update the output tensor with the new calculated values. This involves adding the new output block to the relevant section of the output tensor, normalized by the accumulated exponentials. Finally, we update m and l for the next iteration.

        output[:, i:i_end] += (output_block - output[:, i:i_end]) * (l / l_new).unsqueeze(-1)  

       output[:, i:i_end] /= l.unsqueeze(-1)
        m, l = m_new, l_new          

Finally, we project the output back to the original embedding dimension using another linear layer.

output = rearrange(output, '(b h) t d -> b t (h d)', h=n_head)  

proj = nn.Linear(n_embd, n_embd)  
output = proj(output)          

So that was flash attention. I have written the overall code below.

class SimpleFlashAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_embd = config.n_embd
        self.n_head = config.n_head
        self.head_dim = self.n_embd // self.n_head
        self.scale = self.head_dim ** -0.5
        self.qkv = nn.Linear(self.n_embd, 3 * self.n_embd)
        self.proj = nn.Linear(self.n_embd, self.n_embd)
        self.dropout_p = config.dropout
        self.causal = True  # assuming causal for GPT-like model
        self.block_size = config.block_size  # size of blocks for tiling, now configurable

    def forward(self, x):
        b, t, c = x.size()
        qkv = self.qkv(x).view(b, t, 3, self.n_head, self.head_dim)
        q, k, v = qkv.unbind(2)
        q, k, v = [rearrange(x, 'b t h d -> (b h) t d') for x in (q, k, v)]

        output = torch.zeros_like(q)
        
        for i in range(0, t, self.block_size):
            i_end = min(i + self.block_size, t)
            q_block = q[:, i:i_end]
            
            m = torch.full((q.shape[0], i_end - i), float('-inf'), device=q.device)
            l = torch.zeros((q.shape[0], i_end - i), device=q.device)
            
            for j in range(0, t, self.block_size):
                j_end = min(j + self.block_size, t)
                k_block = k[:, j:j_end]
                v_block = v[:, j:j_end]
                
                attn_block = torch.matmul(q_block, k_block.transpose(-2, -1)) * self.scale
                
                if self.causal and j > i:
                    attn_block.fill_(float('-inf'))
                elif self.causal:
                    causal_mask = torch.triu(torch.ones(i_end - i, j_end - j, dtype=torch.bool, device=attn_block.device), diagonal=j - i + 1)
                    attn_block.masked_fill_(causal_mask, float('-inf'))
                
                m_new = torch.maximum(m, attn_block.max(dim=-1)[0])
                exp_attn = torch.exp(attn_block - m_new.unsqueeze(-1))
                
                l_new = l * torch.exp(m - m_new) + exp_attn.sum(dim=-1)
                output_block = torch.matmul(exp_attn, v_block)
                
                output[:, i:i_end] += (output_block - output[:, i:i_end]) * (l / l_new).unsqueeze(-1)
                
                m, l = m_new, l_new
            
            output[:, i:i_end] /= l.unsqueeze(-1)
        
        output = rearrange(output, '(b h) t d -> b t (h d)', h=self.n_head)
        return self.proj(output)        
Keep in mind that this implementation is a simpler version of Flash Attention. I wrote a basic and minimal version of Flash Attention to teach you the core concepts.

The core idea behind flash attention

  1. Memory Efficiency: The key idea behind FlashAttention is its innovative approach to handling memory. Instead of storing the full attention matrix, which can become enormous for long sequences, FlashAttention processes the input in small blocks and uses accumulator variables to keep track of the necessary information. This method drastically reduces memory usage and enhances performance.
  2. Numerical Stability: FlashAttention ensures numerical stability through two main accumulators: m: This accumulator tracks the maximum value in each row of the attention matrix. By using the "max trick," it maintains numerical stability during the softmax computation. l: This accumulator represents the sum of exponentials for each query, effectively acting as the denominator in the softmax function.
  3. Incremental Computation: These accumulators enable incremental computation of the attention output. As we process each block of keys and values, we can compute the attention output on the go without needing to wait until all keys and values have been processed. This approach significantly speeds up the computation.
  4. Handling Long Sequences: By leveraging these accumulators, FlashAttention can handle very long sequences that would otherwise be unmanageable if we attempted to compute the full attention matrix at once. This capability makes it particularly suited for applications requiring efficient processing of extensive data sequences.

What is the advantage of using flash attention over casual attention?

Let's say you have a long document and you need to summarize it. Instead of reading the entire document at once, you read and summarize it paragraph by paragraph. This is similar to how Flash Attention works – it breaks down a long sequence into smaller blocks (paragraphs), processes each block independently, and then combines the results. This approach is more efficient and manageable, especially for very long documents (or sequences).

Conclusion

Flash Attention exemplifies the cutting-edge in efficiency and performance, particularly for handling long sequences. By breaking down these sequences into smaller, manageable blocks and leveraging parallel processing, Flash Attention ensures both numerical stability and computational efficiency. This approach is akin to summarizing a lengthy document paragraph by paragraph, providing a more scalable and practical solution for real-world applications.

In contrast, while Casual Attention lays the foundation by introducing the basic concepts and operations, Flash Attention takes a step further by optimizing and refining these processes. Both mechanisms have their unique advantages, and understanding their differences equips you with the knowledge to choose the right tool for your specific needs.

If you have come this far then stay tuned for the next edition of Peceive where I simplify complex AI concepts.

要查看或添加评论,请登录

Nilesh Barla的更多文章

  • Why LLMs are obsessed with "Attention"?

    Why LLMs are obsessed with "Attention"?

    It's been a long time that I have released a new edition of this newsletter. All I can say is that life comes in the…

    2 条评论
  • Harmonizing Reinforcement Learning and Maximum Likelihood Estimation

    Harmonizing Reinforcement Learning and Maximum Likelihood Estimation

    A Journey into Intelligent Decision-Making: Part-1 Introduction In artificial intelligence, where machines learn to…

    1 条评论
  • Importance of Alignment in LLMs

    Importance of Alignment in LLMs

    In the rapidly evolving world of language models, one concept stands out as crucial: alignment. Alignment is one of the…

    9 条评论
  • Human Machine Intelligence

    Human Machine Intelligence

    Exploring the very elements of who we are and what we can do Content The Human Brain Thoughts Emotions Decision Making…

  • Artificial Intelligence over a cup of coffee — The Dawn

    Artificial Intelligence over a cup of coffee — The Dawn

    From inception to reality and beyond Every morning as soon as I get up, I make sure that I make a cup of coffee. How…

  • Elemental Knowledge of Data Science and the role of a Data Scientist

    Elemental Knowledge of Data Science and the role of a Data Scientist

    We certainly must have heard a buzz about Big Data and AI and of course Data science. I have read a lot of article…

    2 条评论
  • Food that destroys your day

    Food that destroys your day

    Last week as soon as I reached my office I started craving for breakfast. I knew there a fast food shop near my office…

社区洞察

其他会员也浏览了