Flash Attention 2 in Large Language Models
Frank Morales Aguilera, BEng, MEng, SMIEEE
Boeing Associate Technical Fellow /Engineer /Scientist /Inventor /Cloud Solution Architect /Software Developer /@ Boeing Global Services
Introduction
Large Language Models (LLMs) such as GPT3/4, Falcon, and LLama are rapidly advancing in tackling human-centric tasks[1,1b]. However, deploying these models in real-world tasks remains challenging due to their extensive memory demands and the need to manage very long input sequences[1].
The Need for Flash Attention
To tackle these challenges, a variation of the attention algorithm called Flash Attention was introduced[1]. Flash Attention provides a more memory-efficient approach and increases efficiency due to optimized GPU memory utilization[1,1b].
Flash Attention 2: An Evolution
Flash Attention 2 is an evolution of the original Flash Attention. It exploits the asymmetric GPU memory hierarchy to bring significant memory saving (linear instead of quadratic) and runtime speedup (2-4× compared to optimized baselines) with no approximation[2].
How Flash Attention 2 Works
Flash Attention 2 reorders attention computation and harnesses classical techniques like tiling and recomputation to achieve a remarkable boost in speed and a substantial reduction in memory usage[3, 3b]. It moves from a quadratic to a linear memory footprint about sequence length[3].
Flash Attention 2 adopts classical tiling techniques for every attention head to minimize memory reads and writes. It shuttles query, key, and value blocks from the GPU's HBM (main memory) to its speedy SRAM (fast cache)[3, 3b].
Limitations and Future Directions
While Flash Attention 2 does well in most scenarios, it wasn't fine-tuned for exceptionally lengthy sequences, where parallelism is lacking[3]. Future research may focus on optimizing Flash Attention 2 for these scenarios.
Difference between Flash Attention and Flash Attention 2
Flash Attention and Flash Attention 2 are advancements in attention mechanisms specifically designed to enhance the efficiency and speed of Large Language Models (LLMs). Here are the key differences between them:
In summary, while both Flash Attention and Flash Attention 2 aim to improve the efficiency of attention mechanisms in LLMs, Flash Attention 2 provides further enhancements in speed and memory usage.
Using Flash Attention 2 with MISTRAL 7B
Mistral[10] 7B is a Large Language Model developed by Mistral AI[11]. It uses techniques like Sliding Window Attention and Grouped Query Attention (GQA) for efficient inference[11].
To use Flash Attention 2 with Mistral 7B, you must ensure you have the latest version of Flash Attention 2 installed[11]. Here's an example of how to load and run Mistral 7B with Flash Attention 2:
领英推荐
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
device = "cuda" # the device to load the model onto
model = AutoModelForCausalLM.from_pretrained(
"mistralai/Mistral-7B-v0.1",
torch_dtype=torch.float16,
attn_implementation="flash_attention_2"
)
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
prompt = "My favourite condiment is"
model_inputs = tokenizer([prompt], return_tensors="pt").to(device)
model.to(device)
generated_ids = model.generate(**model_inputs, max_new_tokens=100, do_sample=True)
tokenizer.batch_decode(generated_ids)[0] # The expected output
This script loads the Mistral 7B model with Flash Attention 2 and generates a response to the given prompt. So, as you know, you must have compatible hardware to use Flash Attention 2[11]. Also, load your model in half-precision (e.g., torch.float16)[11].
A large notebook using Mistral 7B model with Flash Attention 2 in[10].
Conclusion
Flash Attention[4] and Flash Attention 2 are two fundamental techniques used to scale the context of LLMs[3]. They represent one of the most significant research breakthroughs in this area and are influencing new methods that can help increase the capacity of LLMs[3,3b].
References
7.-?Stanford CRFM?FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning
8.-?FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning | Princeton NLP Group (princeton-nlp.github.io)