How to adapt autoregressive (AR) models to diffusion model?
TuringPost
Newsletter about AI and ML. ?? Sign up for free to get your list of essential AI resources ??
How to adapt autoregressive (AR) models to diffusion model?
Diffusion models allow flexible, parallel text generation but they are smaller and less trained than AR models.
So researchers proposed special training approach to easily covert AR models into diffusion models??
?? AR models generate text sequentially, predicting one token at a time based on the previous tokens, proceeding from left to right
?? Diffusion models add noise to data and "denoise" it to retrieve the original data, predicting multiple tokens at once
2. Aligning the model objectives:
You need to train the AR model to predict masked tokens in their original positions, not just the next token. The key here is a reweighted objective function: the objective guides the model to focus on early denoising steps with fewer noisy tokens, moving to clearer text.
3. Attention mask annealing:
AR models use causal masks, which restrict them from "seeing" future tokens. In diffusion, you gradually reduce this masking so the model eventually sees the full context. This process is called attention mask annealing.
4. Shift operation:
In AR models, each token predicts the next one. Add a shift operation so that each token in the sequence is shifted one position, making it easier to align the prediction targets in the diffusion process.
领英推荐
5. No time embedding:
Diffusion models often track how much noise has been added at each step, but AR models don’t need this. So, you can skip time embeddings to keep the model simpler and faster.
6. Train with noise and masking:
- Add different levels of noise to the text sequence, which teaches the model to restore or predict the original text accurately.
- Mask tokens to create gaps in the sequence to train the model to fill in these gaps based on context.
7. To generate text with the adapted model:
? Start with a noisy (masked) sequence.
? Denoise iteratively by predicting tokens at each step, reducing noise until the full sequence is generated.
? Use top-k or nucleus sampling for coherent and relevant text generation.
These steps let an AR model work like a diffusion model, generating text in any order and predicting original token positions. This approach combines strengths of both types of models.
Using this approach, researchers trained diffusion versions of GPT-2 and LLaMA2 on large datasets, creating models like DiffuGPT and DiffuLLaMA that rival traditional AR models.
Original paper: https://arxiv.org/pdf/2410.17891