Flash Attention 2 in Large Language Models

Flash Attention 2 in Large Language Models

Introduction

Large Language Models (LLMs) such as GPT3/4, Falcon, and LLama are rapidly advancing in tackling human-centric tasks[1,1b]. However, deploying these models in real-world tasks remains challenging due to their extensive memory demands and the need to manage very long input sequences[1].

The Need for Flash Attention

To tackle these challenges, a variation of the attention algorithm called Flash Attention was introduced[1]. Flash Attention provides a more memory-efficient approach and increases efficiency due to optimized GPU memory utilization[1,1b].

Flash Attention 2: An Evolution

Flash Attention 2 is an evolution of the original Flash Attention. It exploits the asymmetric GPU memory hierarchy to bring significant memory saving (linear instead of quadratic) and runtime speedup (2-4× compared to optimized baselines) with no approximation[2].

How Flash Attention 2 Works

Flash Attention 2 reorders attention computation and harnesses classical techniques like tiling and recomputation to achieve a remarkable boost in speed and a substantial reduction in memory usage[3, 3b]. It moves from a quadratic to a linear memory footprint about sequence length[3].

Flash Attention 2 adopts classical tiling techniques for every attention head to minimize memory reads and writes. It shuttles query, key, and value blocks from the GPU's HBM (main memory) to its speedy SRAM (fast cache)[3, 3b].

Limitations and Future Directions

While Flash Attention 2 does well in most scenarios, it wasn't fine-tuned for exceptionally lengthy sequences, where parallelism is lacking[3]. Future research may focus on optimizing Flash Attention 2 for these scenarios.

Difference between Flash Attention and Flash Attention 2

Flash Attention and Flash Attention 2 are advancements in attention mechanisms specifically designed to enhance the efficiency and speed of Large Language Models (LLMs). Here are the key differences between them:

  1. Flash Attention: Introduced in 2022 by researchers at Stanford University[5], Flash Attention leverages IO-awareness to produce fast and memory-efficient 'exact attention' [5]. It significantly improved over the standard attention mechanism but still had some limitations.
  2. Flash Attention 2: An evolution of Flash Attention, Flash Attention 2 exploits the asymmetric GPU memory hierarchy to bring significant memory saving and runtime speedup[5-6]. It reorders attention computation and harnesses classical techniques like tiling and recomputation to achieve a remarkable boost in speed and a substantial reduction in memory usage[1-2, 3b]. Flash Attention 2 is reported to be 2x faster than Flash Attention[7-8], which means we can train models with more extended context for the same price as previously training a shorter context model[7-9].

In summary, while both Flash Attention and Flash Attention 2 aim to improve the efficiency of attention mechanisms in LLMs, Flash Attention 2 provides further enhancements in speed and memory usage.

Using Flash Attention 2 with MISTRAL 7B

Mistral[10] 7B is a Large Language Model developed by Mistral AI[11]. It uses techniques like Sliding Window Attention and Grouped Query Attention (GQA) for efficient inference[11].

To use Flash Attention 2 with Mistral 7B, you must ensure you have the latest version of Flash Attention 2 installed[11]. Here's an example of how to load and run Mistral 7B with Flash Attention 2:

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

device = "cuda" # the device to load the model onto
model = AutoModelForCausalLM.from_pretrained(

    "mistralai/Mistral-7B-v0.1", 

    torch_dtype=torch.float16, 

    attn_implementation="flash_attention_2"

)
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
prompt = "My favourite condiment is"
model_inputs = tokenizer([prompt], return_tensors="pt").to(device)
model.to(device)
generated_ids = model.generate(**model_inputs, max_new_tokens=100, do_sample=True)
tokenizer.batch_decode(generated_ids)[0] # The expected output         

This script loads the Mistral 7B model with Flash Attention 2 and generates a response to the given prompt. So, as you know, you must have compatible hardware to use Flash Attention 2[11]. Also, load your model in half-precision (e.g., torch.float16)[11].

A large notebook using Mistral 7B model with Flash Attention 2 in[10].

Conclusion

Flash Attention[4] and Flash Attention 2 are two fundamental techniques used to scale the context of LLMs[3]. They represent one of the most significant research breakthroughs in this area and are influencing new methods that can help increase the capacity of LLMs[3,3b].

References

1.-?Optimizing your LLM in production (huggingface.co)

1b.-?Optimizing your LLM in production (vuink.com)

2.-?Flash Attention 2 · MinWoo Park (dsdanielpark.github.io)

3.-?Understanding Flash-Attention and Flash-Attention-2: The Path to… – Towards AI

3b.-?Understanding Flash-Attention and Flash-Attention-2: The Path to… – Towards AI

4.-?FlashAttention: An Advancement in GPU Acceleration for Training LLMs-Part 2 | by Sachin Kalsi | Medium

5.-?FlashAttention vs FlashAttention-2 - an Analysis. (e2enetworks.com)

6.-?Understanding Flash-Attention and Flash-Attention-2: The Path to Scale The Context Lenght of Language Models | by Jesus Rodriguez | Towards AI

7.-?Stanford CRFM?FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning

8.-?FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning | Princeton NLP Group (princeton-nlp.github.io)

9.-?FlashAttention with PyTorch Compile - Benchmarking FlashAttention and FlashAttention-2 on a Consumer GPU | Just Stir It Some More (benjaminwarner.dev)

10.-?Mistral LLM: A New Era in Language Models | by Frank Morales Aguilera | Feb, 2024 | Medium

11.-?Mistral (huggingface.co)

要查看或添加评论,请登录

Frank Morales Aguilera, BEng, MEng, SMIEEE的更多文章

社区洞察

其他会员也浏览了