Training-Free Long-Context Scaling of Large Language Models

Training-Free Long-Context Scaling of Large Language Models

Introduction

The ability of Large Language Models (LLMs) to process and generate coherent text diminishes when input tokens exceed their pretraining length. Fine-tuning models for longer sequences is costly. Thus, Dual Chunk Attention (DCA) was introduced to extend LLAMA2 70B’s context window to over 100k tokens without continuous training. DCA decomposes attention computation into chunk-based modules, effectively managing both intra-chunk and inter-chunk positional information and integrating seamlessly with Flash Attention. This approach maintains performance comparable to finetuned models and offers an open-source alternative achieving 94% of GPT-3.5-16k’s performance.

Background

Positional Encoding

Positional embeddings, critical for transformers, map absolute position indices into a feature space incorporated into the input layer. A prominent method, Rotary Positional Encoding (RoPE), embeds positional information directly into the attention layer, ensuring that relative positional information is captured during computations.

Extrapolation of RoPE

RoPE lacks robust length extrapolation capabilities, leading to performance degradation on longer sequences. Methods like Position Interpolation (PI) and NTK-Aware RoPE aim to redesign the relative position matrix but often fall short when input lengths exceed training lengths.

Method

Dual Chunk Attention (DCA)

DCA introduces a novel approach to extend the context window of LLMs without training. It leverages efficient chunk-based attention patterns to segment self-attention computations for long sequences into smaller, manageable chunks. DCA comprises three components:

  1. Intra-Chunk Attention: Processes tokens within the same chunk.
  2. Inter-Chunk Attention: Manages attention between distinct chunks.
  3. Successive-Chunk Attention: Ensures locality by focusing on adjacent chunks.

Intra-Chunk Attention

This mechanism calculates the inner product of queries and keys within the same chunk, ensuring position indices are within chunk size limits, effectively managing short-range dependencies.

Inter-Chunk Attention

This component aggregates information from other chunks. Position indices for queries are assigned a large value to reflect the left-to-right information flow, enabling efficient handling of long-range dependencies.

Successive-Chunk Attention

Maintains the precise relative position between neighboring tokens in successive chunks, ensuring high retrieval accuracy and low perplexity.

Experiments

Experimental Setup

DCA is implemented via a monkey patch to replace the original LlamaAttention’s inference code, integrated with Flash Attention for efficiency. Evaluations are conducted on Llama2 models (7B, 13B, 70B) and popular long-context models like Together-32k and CodeLlama.

Long-Sequence Language Modeling

CHUNKLLAMA2 exhibits superior performance in handling extended context windows compared to other training-free methods, maintaining low PPL and high retrieval accuracy.

Practical Tasks

DCA-enhanced models demonstrate competitive performance in few-shot and zero-shot settings, achieving results comparable to finetuned baselines on benchmarks like NarrativeQA, QMSum, and QuALITY.

Analysis

Efficiency

The efficiency of the Dual Chunk Attention (DCA) framework is crucial for its practical adoption in real-world applications. By decomposing the attention mechanism into chunk-based modules, DCA achieves significant computational efficiency, which is particularly critical for handling long sequences. The integration with Flash Attention ensures that the computational overhead remains minimal.

To evaluate the efficiency of DCA, inference time and GPU memory usage were measured across various prompt lengths using the Llama2 7B model on a single NVIDIA A100 GPU with 80GB of memory. The input prompt was derived from the NarrativeQA dataset, known for its extensive context lengths. Each experiment was conducted 20 times, and the average performance was recorded for accuracy.

The results show that without Flash Attention, the maximum input length manageable by a single GPU is roughly between 12k and 16k tokens. In contrast, DCA maintains similar GPU memory consumption and inference speed to the original Flash Attention, without introducing significant overhead. This efficiency is crucial for practical deployment, as it allows handling extensive context windows without requiring additional hardware resources.

Ablation Study

An ablation study was conducted to validate the contributions of the three attention mechanisms proposed in DCA: intra-chunk attention, inter-chunk attention, and successive-chunk attention. The study focused on language modeling and passkey retrieval tasks.

  1. Intra-Chunk Attention: This mechanism calculates the inner product of queries and keys within the same chunk, ensuring position indices are within chunk size limits. It maintains a very low perplexity (PPL) by processing tokens within the same chunk but hinders the model’s ability to retrieve passkeys from other chunks as it discards information from previous chunks.
  2. Inter-Chunk Attention: This component aggregates information from other chunks by incorporating attention calculations between different chunks. While it improves passkey retrieval performance at longer input lengths, the loss of locality results in a significant increase in PPL.
  3. Successive-Chunk Attention: This mechanism ensures the locality of attention by focusing on adjacent chunks, specifically addressing the relative position between neighboring tokens. By maintaining the precise relative positions for adjacent chunks, it achieves a balance between low PPL and high retrieval accuracy.

The ablation study, illustrated in Figure 4, highlights the importance of combining all three attention mechanisms to achieve optimal performance. Intra-chunk attention ensures efficient handling of short-range dependencies, inter-chunk attention captures long-range dependencies, and successive-chunk attention maintains locality, collectively contributing to the overall effectiveness of DCA. The study demonstrates that using intra-chunk attention alone results in low PPL but poor passkey retrieval accuracy. Adding inter-chunk attention improves retrieval accuracy but increases PPL. Integrating all three mechanisms achieves both low PPL and high retrieval accuracy, highlighting the necessity of the comprehensive approach taken by DCA.

Conclusion

Dual Chunk Attention (DCA) offers an efficient, training-free solution to extend the context window of LLMs, maintaining performance within the training length and excelling beyond it. This approach is compatible with existing long-context finetuning methods, presenting a cost-effective alternative for managing long-context scenarios in LLM applications.

  • Figure 3: Inference time and GPU memory comparison between original self-attention, Flash Attention, and DCA (provide image from the paper with the caption "Figure 3: Inference time and GPU memory comparison").

Impact Statement

DCA’s compatibility with Flash Attention and minimal training requirements provide a substantial cost-effective solution for long-context scenarios in LLM applications, potentially revolutionizing the industry’s approach to handling extensive text sequences.

Acknowledgements

We thank Yukang Chen and Hang Yan for their helpful comments and open-source code. This research was supported by the joint research scheme of the NSFC and the RGC.

Reference: https://arxiv.org/html/2402.17463v2

code: https://github.com/HKUNLP/ChunkLlama

Bhaumik Vyas

IIM Ahmedabad | Data Scientist | IIITB | LJMU | Data Analyst | Consulting

9 个月

Thanks for sharing Ashish Patel ????

要查看或添加评论,请登录

Ashish Patel ????的更多文章

社区洞察