Tackling Quadratic Attention Complexity: Methods to Optimize Attention in Transformers. Part 1
In this post I provide a brief review of papers dedicated to reducing the quadratic attention complexity in transformers, specifically referring to the quadratic dependency on sequence length. This limitation becomes significant when working with extended contexts, so there are plenty of research papers addressing this issue. Some focus on reducing time complexity, some on memory, and some tackle both. Before reading, it is recommended to understand the principle of the classic transformer, a description of which can be found in a beautiful article by Jay Alammar:
Linformer: Let's recall the root of the quadratic complexity — it arises from multiplying Q by K^T. The Linformer's approach is to project matrices K and V into a smaller space with constant dimensions. This means the multiplication is performed on a matrix independent of context length, moving away from quadratic complexity. The result? Linear dependency in both time and memory.
AFT (Attention Free Transformer): Here, the authors move away from the traditional concept of attention, where scores between Q and K are computed. They achieve this by introducing a matrix w of size TxT (T being the maximum context size). The elements of this matrix are added to the K matrix. As the authors describe, w represents the learned pair-wise position biases. Essentially, an element in the matrix at position (1,3) indicates the weight with which position 1 should pay attention to position 3. This representation serves as a gating mechanism, highlighting which positions deserve more attention. Here we have linear memory dependency (no need to store a large attention scores matrix) but quadratic time complexity. However, modifications exist to achieve linear time complexity, like AFT-local.
Reformer: This paper employs the well-known LSH hashing algorithm. By setting Q=K, LSH is used to identify clusters of similar embeddings. Consequently, attention scores only need to be calculated within these clusters. On the one hand, complexity is reduced to log-linear (n * logn), but in practice, there is a constant of 128^2 within the big O notation. Interestingly, a similar method is used in one of the ANN algorithms.
领英推荐
Performer: The authors demonstrate that the attention result (softmax(Q x K^T) can be approximated by multiplying two matrices. The model learns this approximation. The proof is intricate, but the idea revolves around kernel methods. With this approximation (softmax(Q x K^T) ~ Q1 x K1^T), we can first multiply K and V, and then their result by Q. This changes the complexity due to the order of these multiplications. And we get linear dependency in both time and memory.
It's essential to remember that attention complexity depends not only on context length but also on an embedding size (more precisely, from the head embedding size, which is usually under 128, and the computations are paralleled over heads). I haven't provided metrics; they can be found in the papers if interested. This post doesn't cover all the works on this topic — there are others like MEGA, Sparse Transformers, Longformer, MQA, GQA, and more. Perhaps they'll be topics for future posts if you find it useful.
NLP Engineer & Python Instructor ?? LLM Enthusiast ?? 7+ Years of AI & Data Science Experience ?? Open to Opportunities
1 年Your review on reducing quadratic attention complexity in transformers is a must-read. I appreciate the way you broke down Linformer's approach??
NLP Engineer & Python Instructor ?? LLM Enthusiast ?? 7+ Years of AI & Data Science Experience ?? Open to Opportunities
1 年I think it's worth noting that there are some hardware-specific solutions like TPUs that are being developed to handle large-scale transformers. Have you considered how hardware advancements like TPUs could potentially mitigate the issues of attention complexity?