Adaptive LLM Transformer2
I came across an interesting paper titled "TRANSFORMER-SQUARED: SELF-ADAPTIVE LLMS" (link). It was published by SakanaAI, a Japanese AI company. (I have previously written an article about their method for enhancing model capabilities without training—feel free to check it out if you're interested.) SakanaAI has Lion Jones on their team, who is also one of the authors of Attention Is All You Need. This paper continues their tradition of focusing on algorithmic innovations rather than relying on extensive computing power (as of mid-2024, they reportedly acquired their first 8x H100 GPUs). Their approach is highly creative.
The paper primarily introduces a novel fine-tuning method called SVF (Singular Value Fine-tuning) to address challenges in traditional Supervised Fine-tuning (SFT), particularly those based on LoRA (Low-Rank Adaptation). The main issue with traditional SFT, including LoRA, is its inability to clearly distinguish downstream tasks. Additionally, when injecting new knowledge into the model, modifications to the original weight matrix can inadvertently affect performance on other tasks.
To mitigate these issues, the paper proposes using Singular Value Decomposition (SVD). SVD decomposes a matrix into three smaller matrices such that:
W=UΣV?W = UΣV^?W=UΣV?
where:
The purpose of this decomposition is to facilitate SVF (Singular Value Fine-tuning).
What is SVF?
Instead of directly modifying the weight matrix W, SVF learns a vector z∈Rrz \in R^rz∈Rr, which is then used to adjust the singular values of W, modifying its behavior in a structured manner.
For each weight matrix WWW, SVF learns a vector zzz, which independently adjusts each singular component, producing a new weight matrix:
W′=UΣ′V?W' = UΣ'V^?W′=UΣ′V?
where:
Σ′=Σ?diag(z)Σ' = Σ ? \text{diag}(z)Σ′=Σ?diag(z)
领英推荐
Here, diag(zzz) is a diagonal matrix where the diagonal elements are given by zzz.
Why is this approach effective?
Rather than directly altering W, SVF enables fine-grained control by scaling singular values. This technique allows optimization via reinforcement learning (RL), tuning parameters based on task performance without requiring massive datasets with explicit task explanations.
Intuition Behind SVF:
Simply put, SVF "splits" the weight matrix WWW into finer components. Within W, certain values may control mathematical reasoning, others may handle language understanding, and some may be responsible for historical knowledge.
During training, SVF learns a set of vectors zzz, where each downstream task corresponds to a specific zzz vector. Since ΣΣΣ can be computed from zzz, it essentially acts as a signal amplifier. For instance:
SVF utilizes reinforcement learning (RL) to learn these zzz vectors across a predefined set of downstream tasks.
How does it work in inference?
Once trained, the inference process proceeds as follows:
Final Thoughts
This is an ingenious idea, and its effectiveness scales with model size—the larger the model, the better the results.