Paper Review: Training Language Models to Self-Correct via Reinforcement Learning
Andrey Lukyanenko
Senior Data Scientist @ Careem. Kaggle Competition Master, Notebooks Top-1.
SCoRe - a new approach to improving LLMs self-correction ability through multi-turn online RL. Existing methods rely on multiple models or external supervision, but SCoRe uses only self-generated data. The authors found that previous supervised fine-tuning methods are insufficient due to distribution mismatches and ineffective correction behavior. SCoRe addresses this by training the model on its own correction traces, using regularization to enhance self-correction. Applied to Gemini 1.0 Pro and 1.5 Flash models, SCoRe improves self-correction performance by 15.6% and 9.1% on MATH and HumanEval benchmarks.
Preliminaries and Problem Setup
The goal is to train LLMs to improve their own predictions by relying solely on self-generated data. In this self-correction setting, models try to correct their responses without external feedback. The approach uses a dataset of problems and oracle responses, training a policy that generates multiple attempts to solve a problem. A reward function evaluates the correctness of the responses. The model learns to detect and correct mistakes without access to the reward function during testing. The objective is to optimize a policy over multiple turns, using policy gradient reinforcement learning with a KL-divergence penalty to ensure gradual improvements. Key performance metrics include accuracy on the first and second attempts, the improvement in accuracy between attempts, the fraction of initially incorrect problems that become correct after self-correction, and the fraction of initially correct responses that become incorrect.
Supervised Fine-Tuning on Self-Generated Data is Insufficient for Self-Correction
An empirical study was conducted to evaluate whether SFT approaches can improve large language models’ self-correction abilities. Two methods, STaR and an approach with training only one model, were tested. Although these methods improve the base model’s self-correction performance, they still fail to achieve a positive self-correction rate, often producing worse second attempts. The failures stem from SFT amplifying the initial biases of the base model, resulting in only minor changes in responses. Adjusting the distribution of initial responses helps but does not fully solve the problem, as learning remains hindered by distribution shifts or bias amplification.
These methods were tested for improving self-correction in large language models using the MATH dataset. The STaR approach filters model-generated traces to retain only those where incorrect responses were successfully revised, then applies SFT on this filtered dataset. The second method, called Pair-SFT, pairs incorrect responses with correct ones to create “synthetic” repair traces without training a separate corrector model or using multi-turn traces. The datasets for each method were constructed from Gemini 1.5 Flash’s outputs, and SFT was applied: STaR underwent three iterations of data collection and fine-tuning, while Pair-SFT had only one epoch due to the large dataset size.
Pair-SFT showed a small 1.8% gain in self-correction compared to the base model, mainly by reducing the number of correct responses that were mistakenly revised to incorrect ones. However, it did not significantly improve the correction of incorrect first attempts. STaR, in contrast, did not reduce the number of incorrect revisions, indicating a lack of understanding of when to make changes. The discrepancy is attributed to differences in data distributions: Pair-SFT’s random pairing covered a broader range of revision scenarios, while STaR had a narrower focus. Adding “correct-to-correct” data improved STaR slightly but still yielded minimal self-correction. For SFT, however, adding this data overly biased the model, causing it to avoid making changes entirely.
The authors analyzed the self-correction behavior of the models by measuring the edit distance ratio, which quantifies how much models modify their first-attempt responses. The results showed that while the base model sometimes made substantial edits, the fine-tuned models were overly conservative and often made no edits at all. STaR performed similarly on both training and validation data, whereas Pair-SFT showed discrepancies in edit distance ratios between training and validation, indicating poor generalization. Furthermore, while Pair-SFT optimized correction accuracy on training data and maintained accuracy on validation problems, its self-correction accuracy degraded when tested on self-generated responses with more training. SCoRe, on the other hand, avoids the bias of minimal edits without explicit training for controlling the degree of response changes.
That takeaway is that there are two key failures of SFT methods for self-correction. STaR was too focused on a single correction strategy that made only minor changes, while Pair-SFT, despite covering a broader range of data, suffered from a degradation in self-correction performance due to distribution shift. These findings highlight two important criteria for an effective approach:
SCoRe: Self-Correction via Multi-Turn Reinforcement Learning
To develop an effective method for teaching LLMs to self-correct using only self-generated data, SCoRe leverages on-policy RL to address distribution mismatch and prevent mode collapse. A key challenge is that multi-turn RL, while addressing distribution shift, often leads to models that don’t self-correct, instead opting to maintain their first attempt responses. This happens because both producing the best first attempt or improving on it appear equally optimal during training.
领英推荐
SCoRe addresses this challenge in two stages:
Experiments
In the MATH benchmark, SCoRe achieves a 4.4% self-correction gain with overall Accuracy@t2 increasing by 23.0%, outperforming the base model by 15.6% and surpassing Pair-SFT by 10.2%. SCoRe also improves the rate of fixing incorrect responses and reduces the number of correct answers changed to incorrect.
For the code generation task, SCoRe boosts performance from 47.3% to 60.6% on MBPP-R, a gap similar to that between GPT-3.5 and GPT-4. It also generalizes well to the HumanEval dataset, achieving a 12.2% self-correction delta, outperforming the base model by 9%. While Pair-SFT performs well in static repair tasks, it degrades self-correction performance, highlighting the importance of on-policy sampling.
Additionally, SCoRe is effective when combined with inference-time compute scaling strategies like self-consistency decoding (majority voting). Combining parallel sampling with sequential self-correction yields a 10.5% improvement, compared to 7.4% from parallel sampling alone.
Ablation studies :
Startup Growth | Generative AI | Process | CS/X | Strategy ????
2 个月Thanks as always Andrey Lukyanenko.