BxD Primer Series: Attention Mechanism
Hey there ??
Welcome to BxD Primer Series where we are covering topics such as Machine learning models, Neural Nets, GPT, Ensemble models, Hyper-automation in ‘one-post-one-topic’ format. Today’s post is on?Attention Neural Networks. Let’s get started:
Introduction to Attention Mechanism:
Attention mechanism is a key component in many modern deep learning models, particularly in natural language processing, image and video frame processing. In these tasks, such as machine translation, text summarization, question answering, image captioning etc., the input data often contains long sequences or dependencies that require modeling context and understanding.
Traditional sequence-to-sequence models, like?RNN,?LSTM?or?GRU, tend to struggle with long-range dependencies and relevant information become diluted or is lost over time. Attention mechanism addresses this issue by allowing the model to weigh importance of different parts of input dynamically.
Attention mechanism works by?creating a context vector for each step or token in input sequence. This context vector is a weighted sum of input elements, with their weights indicating the importance or relevance of each element. High level steps (more detail in later parts):
Attention Neural Networks:
Attention Neural Networks (ANNs) are designed to enhance the ability of a model to?selectively focus on specific parts of an input?sequence or image. They have been successfully applied in a natural language, image, and speech related tasks.
Basic idea is that the model learns to assign varying degrees of importance or "attention" to different parts of input.
In traditional neural networks, the input is processed in a?fixed order, and?each element of input?is treated equally. But in most tasks, some parts of input is more relevant than others. For example, in language translation, certain words or phrases may be more important for translating a particular sentence accurately.
Attention Neural Networks are designed to address this by allowing the model to?selectively attend?to most relevant parts of input. It is achieved by using a learnable weighting scheme that assigns varying degrees of attention to different parts of input. Weights represent the degree of relevance or importance of each input feature, and are learned during training process.
There are two main types of Attention:
One popular type of ANN is the Transformer model, which uses a self-attention mechanism that allows the model to?attend to different parts of input sequence at different times.
In this edition we will cover the basics of Attention Mechanism and move on to transformer models in next edition.
Soft v/s Hard Attention:
Soft Attention calculates a weighted sum of input features, where weights represent the degree of attention given to each input feature. Soft Attention is called "soft" because it produces a continuous distribution of weights over input features. Continuous distribution allows the model to learn to attend to multiple parts of input simultaneously.
On the other hand, Hard Attention selects a subset of input features to attend to, and discards the rest. This selection can be done using either a learned or a fixed rule. Hard Attention is called "hard" because it produces a discrete distribution of weights over input features. Discrete distribution allows the model to attend to only one part of input at a time.
Local v/s Global Attention:
Local attention mechanisms are used when input data is sequential. It focuses on a limited set of neighboring inputs, rather than the entire input sequence. For example, in machine translation, attention mechanism can focus on a few words before and after current word being translated. It is faster than global attention, as it only needs to focus on a small subset of input at any given time.
Global attention mechanisms can attend to any part of input sequence. It is used when input data is not sequential, as in image recognition tasks. Global attention provide more comprehensive information to model because it can access the entire input sequence.
There is also a hybrid approach called ‘local-global attention’, that combines both local and global attention. In this mechanism, the model attend to both a small subset of neighboring inputs and the entire input sequence, allowing it to balance computation efficiency with comprehensive information access.
The How:
Generating attention based context vectors from input typically involves below steps:
? Input Encoding: Assume we have an input sequence consisting of tokens or elements denoted as (x1, x2, …., xn).
This input sequence is encoded into a set of representations using an encoding function?E(·): h1, h2, …., hn = E(x1, x2, …., xn)
? Query, Key, and Value: Encoded representations are split into query (Q), key (K), and value (V) vectors through linear transformations:
Q = W_q × h_i?- captures the current context or information that needs attention.
K = W_k × h_j?- represents the encoded information or features of the token.
V = W_v × h_j?- contains the values or actual information associated with the token.
Where?W_q,?W_k, and?W_v?are learnable weight matrices.
? Similarity Scores: Compute similarity scores between query and each key using a similarity function:
Where, s_{ij} represents similarity score between query (Q) and key (K) at positions (i) and (j). Similarity function could be dot product, scaled dot product, or concatenation etc.
? Attention Weights: Apply a softmax function to similarity scores to obtain attention weights. Softmax function ensures that weights sum up to 1 and represent a valid probability distribution:
? Weighted Sum: Attention weights are used to calculate a weighted sum of value vectors. This weighted sum is known as?context vector:
领英推荐
??Context Vector and Output:?Context vector (C) is concatenated with output of previous step and passed through a linear transformation and an activation function to generate output?y_i?at current step:
y_i = Activation(W_o × [C, h_i])
where?W_o?is a learnable weight matrix.
? Repeat?all above steps for each step or token in input sequence, allowing the model to attend to different parts of input at each step.
Common Similarity Functions:
Similarity functions are used to calculate compatibility between query and key vectors. Common similarity functions:
? Dot Product?measures the similarity in terms of magnitude and direction of the vectors:
??Scaled Dot Product?is similar to dot product, but also incorporates a scaling factor to control the magnitude of similarity scores. This helps in stabilizing the gradients during training.
where d is the dimensionality of query and key vectors.
??Concatenation?similarity function concatenates query and key vector, followed by a linear transformation using a weight matrix. This approach allows the model to capture more complex interactions between query and key vectors.
??General?similarity function is a flexible approach that allows for linear transformation of key vectors before computing similarity scores. This aligns the dimensions of query and key vectors.
Self-Attention:
Self-Attention is a type of Attention mechanism that allows the model to attend to different parts of input sequence to produce a better representation of input. It is called "Self-Attention" because the input sequence is the same sequence that the model is attending to.
Self-Attention operates by computing a?weighted sum of input sequence, where the weights are learned by model. These?weights represent the importance or attention?given to each element of input sequence, relative to other elements. This way, the model learns to weigh each element of input sequence based on its relevance to current context.
Attention weights are learnable parameters, which can be learned through back-propagation to minimize a loss function. They are used to calculate the?dot product?between?input?sequence and a set of?query,?key, and?value?vectors.
Self-Attention mechanism is commonly used in natural language processing tasks, where it allows the model to attend to different parts of input sentence to produce accurate representation of context.
Attention is a dynamic operation that can adaptively attend to any part of input, based on its relevance.
Multi-Head Attention:
Multi-head attention is a variant of attention mechanism that enables the model to attend to different parts of input sequence?simultaneously. This approach enhances model's ability to capture different types of dependencies and provides more expressive representations. This is slightly different than basic (single head) attention:
Note: Number of attention heads, "h", is a hyper-parameter that can be tuned based on specific task and dataset. Increasing "h" provides more capacity to capture diverse patterns but also increases computational complexity of model.
The Why:
Reasons for using Attention Neural Networks:
The Why Not:
Reasons for not using Attention Neural Networks:
Time for you to support:
In next edition, we will cover Stable Diffusion Models.
Let us know your feedback!
Until then,
Have a great time! ??