Proxy-Tuning: Efficient and Customizable Adaptation of Language Models
Introduction :
Proxy-tuning addresses the challenge of efficiently customizing large pretrained language models (LMs) to better suit specific needs and tasks. Despite the impressive capabilities of these large LMs, further adaptation or tuning is often necessary to achieve desired behaviors, such as improved performance on specific tasks or better alignment with domain-specific requirements.
However, traditional methods of tuning these models have become resource-intensive, particularly when the model weights are private and inaccessible. This poses a significant obstacle for organizations seeking to tailor these models to their needs without compromising privacy or proprietary information.
Proxy-tuning introduces a lightweight decoding-time algorithm that operates on top of black-box LMs. It aims to achieve the results of direct model tuning without accessing the internal weights of the model, relying only on the model's predictive distributions over the output vocabulary. By leveraging a smaller pretrained LM as a proxy, proxy-tuning contrasts the predictions of the tuned and untuned versions of the smaller model to guide the larger base model. This process allows the base model to be steered in the direction of tuning while retaining the benefits of its larger-scale pretraining.
Proxy-Tuning Algorithm:
The proxy-tuning algorithm is a decoding-time method that steers a large pretrained language model (LM) to behave like a tuned model without directly accessing or modifying its internal parameters.
Here is how it works:
1. Tuning a Smaller LM (Expert and Anti-Expert):
- A smaller LM, often called the expert (M+), is tuned on a specific task or domain using standard fine-tuning techniques.
- An anti-expert (M-) is created by taking the original, untuned version of the smaller LM.
2. Decoding-Time Logit Offset:
- At each decoding step, the proxy-tuning algorithm calculates a logit offset based on the difference between the expert's (M+) and anti-expert's (M-) predictions.
- This logit offset is added to the logits of the large pretrained LM (M) to guide its predictions.
3. Shifting Base Model Predictions:
- The logit offset effectively shifts the base model's predictions towards the expert's predictions.
- This guidance encourages the base model to generate text that aligns with the desired behavior or knowledge learned by the expert.
4. Preserving Pretraining Benefits:
- Proxy-tuning retains the benefits of large-scale pretraining by leveraging the base model's extensive knowledge and language capabilities.
- It combines the advantages of tuning with the efficiency and scalability of decoding-time guidance.
Decoding-Time Experts (DEXPERTS) :
Decoding-Time Experts (DEXPERTS) is a technique that allows a language model to incorporate knowledge or guidance from an external source during the decoding process.
It involves combining the predictions of two models: an "expert" model and an "anti-expert" model.
The expert model is assumed to have knowledge or expertise in a specific domain or task, while the anti-expert model is the same model without that expertise.
In proxy-tuning, DEXPERTS is used to steer a large pretrained language model (the base model) towards the behavior of a tuned model (the expert model) without directly modifying the base model's weights. The anti-expert model is the base model itself, which has not been tuned for the specific task or domain.
领英推è
During decoding, the base model generates a distribution over the next word or token to be generated. The expert and anti-expert models also generate their own distributions over the next token. The DEXPERTS equation is then applied to combine these distributions. Specifically, a logit offset is calculated based on the difference between the expert and anti-expert model's logits. This offset is added to the base model's logits, shifting the distribution towards the expert model's predictions.
By applying DEXPERTS, proxy-tuning enables the base model to benefit from the knowledge or expertise of the expert model during decoding, without the need for direct tuning of the base model's weights. This approach allows for efficient and effective customization of large language models to specific tasks or domains, while preserving the benefits of the larger-scale pretraining.
Advantages of Proxy-Tuning:
- Efficient: Proxy-tuning is a lightweight decoding-time algorithm that operates on top of black-box LMs, making it more efficient than directly tuning the model's parameters.
- No Access to Model Weights: Proxy-tuning does not require access to the model's internal weights, only its predictive distributions over the output vocabulary. This is particularly useful when the model weights are private or proprietary.
- Customizable: Proxy-tuning allows users to customize large pre-trained LLMs for specific tasks or domains without having to retrain the entire model. This enables efficient adaptation to diverse user needs and applications.
- Preserves Pre-training Benefits: Proxy-tuning combines the benefits of tuning with larger pre-training scale, as it retains the knowledge and capabilities learned during pre-training while adapting the model to new tasks or domains.
- Controllable: Proxy-tuning provides an optional hyperparameter that allows users to control the amount of guidance exerted at runtime, enabling a trade-off between different desired attributes of generations.
Proxy-tuning offers a lightweight, efficient, and customizable approach to adapting large pretrained LMs without compromising their pretraining benefits.
Experimental Results:
Proxy-tuning demonstrates promising results across various experiments, showcasing its effectiveness in customizing large language models (LMs) for specific tasks or domains.
Limitations and Drawbacks :
While proxy-tuning offers several advantages, it also comes with limitations and drawbacks.
Firstly, since proxy-tuning operates based on the predictions of smaller expert and anti-expert models, its effectiveness heavily relies on the quality and relevance of these models to the task or domain of interest. Choosing appropriate expert and anti-expert models can be challenging and may require additional resources for model selection and fine-tuning.
Proxy-tuning may not fully capture nuanced domain-specific knowledge or behaviors, especially in complex tasks where fine-grained adjustments are necessary. Furthermore, the effectiveness of proxy-tuning may vary across different tasks, datasets, and languages, requiring careful evaluation and tuning for optimal performance.
Lastly, while proxy-tuning preserves the benefits of pre-training, it may not fully overcome the limitations inherent in the original pre-trained model, such as biases or ethical concerns encoded in the model's parameters. Therefore, careful consideration and evaluation are essential when applying proxy-tuning to ensure its suitability and effectiveness for specific use cases.
Code Example :
def generate_proxy_tuning(model_base, model_tuned, model_target, tokenizer, input_text, max_length):
# Get the device of each model
device_base = next(model_base.parameters()).device
device_tuned = next(model_tuned.parameters()).device
device_target = next(model_target.parameters()).device
# Encode the input text using the tokenizer and move it to the respective devices
input_ids = tokenizer.encode(input_text, return_tensors='pt').to(device_base)
input_ids_tuned = input_ids.to(device_tuned)
input_ids_target = input_ids.to(device_target)
generated_tokens = []
with torch.no_grad():
# Generate tokens iteratively until reaching the maximum length or an end-of-sequence token
for _ in range(max_length):
# Move input IDs to the respective devices
input_ids_tuned = input_ids.to(device_tuned)
input_ids_target = input_ids.to(device_target)
# Proxy-tuning:
# Get logits from each model
logits_base = model_base(input_ids).logits
logits_tuned = model_tuned(input_ids_tuned).logits
logits_target = model_target(input_ids_target).logits
# Combine logits to steer the base model using the difference between tuned and target models
logits = (
logits_target.to(device_base)
+ (logits_tuned.to(device_base) - logits_base.to(device_base))
)
# Get the softmax probabilities for the next token
predictions = torch.softmax(logits[:, -1, :], dim=-1)
# Select the token with the highest probability
next_token_id = torch.argmax(predictions).unsqueeze(0)
# Append the selected token to the list of generated tokens
generated_tokens.append(next_token_id.item())
# Append the new token to the input sequence for the next iteration
input_ids = torch.cat([input_ids, next_token_id.unsqueeze(0)], dim=1)
# Check if the generated token is an end-of-sequence token
if next_token_id.item() == tokenizer.eos_token_id:
break
# Decode the generated tokens into text, skipping special tokens
generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
return generated_text
Sources :
By Kirouane Ayoub