Multi-Headed Attention and Backpropagation in Transformer Model

Multi-Headed Attention and Backpropagation in Transformer Model

The key to understanding an act is its effect, not just the act itself. We often get stuck in the process and lose the bigger picture. I’ll simplify the effect of each step in the Transformer training process.

In the last article, we saw how a single attention head is built. Now, let’s move to multi-headed attention and see how it adds more depth and context.

A Quick Recap on Attention:

What we did was calculate the row-wise connection strength between Q and K matrices. Then, we scaled it down using the square root of dimensions. At this stage, the matrix doesn’t hold token embeddings. Instead, it represents connections or context between rows (tokens).

SoftMax is applied row-wise to produce a probability distribution across the connections (rows sum to 1). This transforms the context matrix into a probability-based representation of each token's connection to other tokens.

We multiply the SoftMax-encoded probability distribution matrix by the V (value) matrix. The V (value) matrix is the input embeddings multiplied by learned weights (initially random).

This multiplication combines connection information (from the Softmax matrix) with the positional and contextual information in the V(value) matrix. The result is a matrix where connection strengths between tokens are distributed and aggregated across all dimensions of the token embeddings.

The result is a matrix holding positional, connection strength and contextual weights for the input embedding. It refines and shows how strong the connection is between two tokens (words).

Multi-Headed Attention:

I’ll explain this in two ways: one for imagination to make it simpler and the other for how it actually works.

Let’s imagine a large embedding matrix of tokens is split into smaller parts. It’s like breaking a problem into smaller sections for easier processing. In multi-headed attention, this is like slicing a cube in 3D space. Each slice represents a smaller dimension of the original embedding.

Each attention head processes a unique slice of the input embedding, allowing the model to capture diverse relationships between tokens. Think of it like someone wanting to understand society better by interacting with multiple social groups rather than sticking to just one. This person gains a broader and deeper understanding by engaging with different groups.

Each attention head generates its own Q, K, and V matrices with unique weights. Each head computes attention independently, like in a single head. The outputs from all heads are concatenated into one matrix. A final linear layer refines this combined matrix into a single output. This allows multiple heads to focus on different aspects of the input.

Before Multi-Head: The input embedding (1×512) is split into 8 heads, each of size 1×64.

After Multi-Head: Each head computes attention, producing 8 outputs (1×64), which are concatenated back into 1×512. This combined output is then multiplied by a weight matrix of the same dimension to unify the inputs from all heads.

In reality, these operations happen in place. Splitting is conceptual but imagining it this way helps in understanding.

We have the final attention head matrix. What does it represent? We have the final attention head matrix, which represents the deep contextual connections between input embeddings derived from the tokens, just like in a single head.

Now, we refine it further to focus on what matters. We scale down what doesn’t work by setting negatives to zero with ReLU. This refinement helps us move closer to the desired output direction.

FFN (Feed-Forward Network):

FFN refines the attention head output using transformations. One of these is ReLU. As I mentioned before, ReLU updates values from 0 to infinity. It keeps what works and sets to 0 what doesn’t.

You can Imagine this process as editing a picture. You zoom into the pixels, enhance important ones, and downscale others. Once done, you scale the image back to normal size. FFN works the same way. It projects data to a higher dimension, makes updates, and scales it back to the original dimension.

Scaling is done by multiplying the input matrix with a weight matrix of larger size (e.g., 2048 in this case).

Biases are then added to the scaled matrix, followed by application of ReLU to refine it further.

Once you have the ReLU output, you scale it back to the original size by taking a dot product with a weight matrix of the original dimension. The resulting matrix is more refined and contextual. You add this refined output to the attention head, enriching it further. Finally, you normalize everything to balance the values. In pictures, it’s like superimposing the refined image on the original to make it sharper and clearer.

Normalized values are again multiplied with additional weight matrix with addition of bias term to get the final logits (Z function in neural network.)

What about the loss?

Logits calculated above are passed through the SoftMax function to produce a probability distribution. This distribution is used to calculate the loss for each element in the matrix.

SoftMax probability output is then inserted into the loss function to do the comparison against expected output (True label). The loss is calculated element-wise between the predicted output and the expected output. These individual losses are summed and then averaged to compute the final combined loss.

Here is how the loss function looks like:

It’s the same as what we discussed in the Cross-Entropy article. The function calculates the loss for each element in the matrix, sums them up, and averages them. The process involves collecting loss row-wise first, followed by column-wise.

We now have the measure of combined loss, how do transmit it back to update all the weights and biases?

Gradients are calculated from the combined loss, and weights and biases are updated layer by layer.

This includes the FFN, attention heads, and Q, K, and V matrices. Updated weights are used in the next forward pass, while input embeddings and positional encodings remain constant. Updated weights refine the context and improve predictions.

This process repeats until the overall loss is significantly reduced. By using multi-headed attention, FFN refinement and backpropagation, the Transformer model keeps learning and gets better at making predictions over time.

Jyoti Sharma

Senior Engineering Manager at Cohesity

2 个月

Interesting

回复

要查看或添加评论,请登录

Himanshu S.的更多文章

社区洞察

其他会员也浏览了