Fine-Tuning Image-to-Text algorithms with?LORA
Daniel Puente Viejo
Generative AI Engineer II @ NTT Data | Data Science | NLP | Deep Learning | Deep Knowledge Graphs | Machine Learning Engineer | Microsoft Azure | Amazon Web Services (AWS)
Simple notebook for Fine-Tuning Image-to-Text conversion algorithms using?LORA
The goal of this article is to cover a simple notebook example of how to apply LORA to Fine-Tune Image-to-Text algorithms. The notebook will be developed using Hugging Face and Peft libraries.
Let’s dive in!
1. What is?LORA?
In the field of large language models, the challenge of fine-tuning has long perplexed researchers. Microsoft, however, has unveiled an innovative solution called Low-Rank Adaptation (LoRA). With the emergence of behemoth models like GPT-3 boasting billions of parameters, the cost of fine-tuning them for specific tasks or domains has become exorbitant.
LoRA offers a groundbreaking approach by freezing the weights of pre-trained models and introducing trainable layers known as rank-decomposition matrices in each transformer block. This ingenious technique significantly reduces the number of trainable parameters and minimizes GPU memory requirements, as gradients no longer need to be computed for the majority of model weights.?
2. Libraries
To start with the notebook, the libraries to be used are presented. For the LORA application, it is recommended to have GPU enabled. This is achieved with the last line of the code.
import torch
from torch.utils.data import DataLoader, Dataset
from PIL import Image
from transformers import ViTFeatureExtractor, VisionEncoderDecoderModel, AutoProcessor
from datasets import load_dataset
from peft import LoraConfig, get_peft_model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
3. Dataset and?model
The model to be used during this document is vit-gpt2-coco-en by ydshieh. It is a ViT model with fast operation and good results.
loc = "ydshieh/vit-gpt2-coco-en"
feature_extractor = ViTFeatureExtractor.from_pretrained(loc)
model = VisionEncoderDecoderModel.from_pretrained(loc)
processor = AutoProcessor.from_pretrained(loc)
model = model.to(device)
In contrast, the ybelkada/football-dataset emerges as a valuable resource for experimentation. Despite its modest size (6 images), this dataset proves incredibly useful for conducting rapid tests and initially validating the model’s functionality.?
It serves as a convenient tool to assess the model’s accuracy and ascertain its proper functioning with ease. While small in scale, the ybelkada/football-dataset offers an efficient means to gauge the model’s performance before delving into more extensive and intricate datasets.
dataset = load_dataset("ybelkada/football-dataset", split="train")
4. Adapt the?Dataset
The implementation of LORA relies on the Peft library. However, when confronted with Image-to-Text problems, Peft is not ideally suited for this specific task. Consequently, a series of adjustments must be implemented to ensure seamless compatibility. These modifications are necessary to enhance the framework’s effectiveness in handling the unique challenges posed by Image-to-Text problems.
class ImageCaptioningDataset(Dataset):
? ? def __init__(self, dataset, processor, feature_extractor):
? ? ? ? self.dataset = dataset
? ? ? ? self.processor = processor
? ? ? ? self.feature_extractor = feature_extractor
? ? def __len__(self):
? ? ? ? return len(self.dataset)
? ? def __getitem__(self, idx):
? ? ? ? item = self.dataset[idx]
? ? ? ? encoding = self.feature_extractor(images=item["image"], return_tensors="pt")
? ? ? ? encoding = {k: v.squeeze() for k, v in encoding.items()}
? ? ? ? encoding["text"] = item["prompt"]
? ? ? ? return encoding
train_dataset = ImageCaptioningDataset(dataset, processor, feature_extractor)
In this implementation, we leverage the “Dataset” extension provided by the PyTorch library. By utilizing PyTorch Datasets, we adhere to a best practice for training models, benefiting from their efficiency, flexibility, and reusability. The process is straightforward: as we iterate through the dataset, we retrieve the relevant images, apply the “feature_extractor” to encode them and store the encoded features alongside the corresponding prompts in a dictionary.
领英推荐
5. Adapt the?model
The subsequent step in our process involves adapting our model for LORA fine-tuning. This crucial adjustment allows our model to effectively harness the power of LORA’s capabilities. For this, we need to instantiate the LORA configuration and determine the corresponding parameters. Subsequently, we list and describe each of them in order to adapt it to each use case.
Furthermore, a handy function has been incorporated that provides insights into the percentage of weights slated for modification.?This feature offers transparency and a clear understanding of the extent to which the model’s parameters will be altered during the LORA fine-tuning process.?
# Lora configuration
config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
target_modules=["query","value"],
)
model = get_peft_model(model, config)
## Trainable Params
def print_trainable_parameters(model):
"""
Prints the number of trainable parameters in the model.
"""
trainable_params, all_param = 0, 0
for _, param in model.named_parameters():
all_param += param.numel()
if param.requires_grad: trainable_params += param.numel()
print(f"trainable params: {trainable_params} || all params: {all_param} || trainable %: {100 * trainable_params / all_param}")
print_trainable_parameters(model)
## ------------------------------------------------------------------------------------
## trainable params: 589824 || all params: 239785728 || trainable %: 0.2459796105963404
## ------------------------------------------------------------------------------------
Next, some of the main hyperparameters that can be used within the LORA configuration are described.
6. Dataloader
Before starting the training, it is necessary to make a few small adjustments. First, the “collect” function allows us to adapt the previously adapted data set for training. Here you set parameters such as batch_size or shuffle among others. The function includes things such as tokenized text or attention masks. With the “processor” the prompts are converted to their tokenized version.
def collator(batch):
processed_batch = {}
for key in batch[0].keys():
if key != "text":
processed_batch[key] = torch.stack([example[key] for example in batch])
else:
text_inputs = processor([example["text"] for example in batch], padding=True, return_tensors="pt")
processed_batch["input_ids"] = text_inputs["input_ids"]
processed_batch["attention_mask"] = text_inputs["attention_mask"]
return processed_batch
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=3, collate_fn=collator)
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
epochs = 100
In this case, a reduced batch_size has been applied due to the small number of images. The selected optimizer is AdamW with a total of 100 epochs. The training should take about 1 minute.
7. Train
Lastly, the model undergoes training, recovering the loss at the end of each epoch. By leveraging this training methodology, we can effectively monitor and optimize the model’s progress, ultimately leading to enhanced performance and more accurate predictions.
loss_list = []
model.train()
for epoch in range(1, epochs+1):
for idx, batch in enumerate(train_dataloader):
input_ids = batch.pop("input_ids").to(device)
pixel_values = batch.pop("pixel_values").to(device, torch.float16)
outputs = model(pixel_values=pixel_values, labels=input_ids)
loss = outputs.loss
loss.backward()
optimizer.step()
optimizer.zero_grad()
if epoch % 10 == 0: print(f"Epoch {epoch} done!, Loss: {loss.item()}")
loss_list.append(loss.item())
With this last piece of code, the notebook is finished. It is important to remember that the code should be tailored to suit each specific use case. Additionally, include validation and test files to ensure the correct functioning of the model.
It is a simple notebook that can be used in many contexts. If you want to learn in more detail how to apply LORA to other algorithms, don’t hesitate to visit the Peft repository.
Thanks for?Reading!
Thank you very much for reading the article. If you liked it, don’t hesitate to follow me on Linkedin.
Data Science Engineer @DraftKings | MSc AI & Data Analytics @ Northeastern London | Generative AI Enthusiast
10 个月Great post! This provided some much-needed additional context for my studies on LoRA training with Stable Diffusion.
Data Scientist - Sidenor
1 年Interesting post ??