Understanding LoRA (Low-Rank Adaptation) with simple example in Pytorch

Understanding LoRA (Low-Rank Adaptation) with simple example in Pytorch

In deep learning, fine-tuning pre-trained models for specific tasks has become a common practice. However, traditional fine-tuning methods come with significant drawbacks, such as high computational costs and storage requirements. This article explores LoRA, an innovative approach that addresses these challenges while enabling effective model adaptation.

What is LoRA?

LoRA, or Low-Rank Adaptation, is a technique that allows models to be fine-tuned with significantly fewer parameters than traditional methods. By freezing the pre-trained weights and adding trainable low-rank matrices, LoRA effectively reduces the computational burden and memory requirements of fine-tuning while maintaining performance.

Problems with Fine-Tuning

  1. Computational Expense: Fine-tuning large models can be computationally expensive. For instance, fine-tuning a BERT model with millions of parameters requires substantial GPU resources and time, especially if done multiple times for different tasks.
  2. Storage Requirements: Each fine-tuned model consumes storage space, leading to bloated model repositories. For instance, storing few different fine-tuned versions of a model can take up several gigabytes.
  3. Switching Between Models: When working on multiple tasks, switching between various fine-tuned models can be cumbersome. The process of loading and unloading models can be inefficient and memory demanding.

How LoRA Works

LoRA addresses these issues by leveraging the concept of frozen pretrained weights and low-rank updates. Here's a breakdown of the mechanism:

Frozen Pretrained Weights

In LoRA, we keep the pretrained weights W fixed and introduce two new low-rank matrices A and B. The adjustment to the model can be expressed mathematically as:

W′=W+A?B

Where:

  • W′ is the modified weight
  • A has dimensions d×r
  • B has dimensions r×k
  • r << min (d, k) meaning r is significantly smaller than the original dimensions d and k

This ensures that when multiplying these two matrices, the resulting dimensions match that of the frozen weight's matrix.

Maintaining Information

The choice of r being much smaller than d and k ensures that we do not lose significant information while creating the two new matrices. For example, if W is a 1024×1024 weight matrix, we might choose r=16:

  • A: 1024×16
  • B: 16×1024

This results in A?B producing a 1024×1024 while requiring only 16× (1024+1024) =32,768 parameters to train, compared to the 1,048,576 parameters of W.

Backpropagation Through New Matrices

During backpropagation, the gradients are only computed for the low-rank matrices A and B, minimizing computational load and speeding up the training process.

Benefits of LoRA

  1. Fewer Parameters to Train: As shown in the previous example, the number of parameters to train is drastically reduced, which makes LoRA particularly efficient for large models.
  2. Reduced Storage: Since we only need to store the low-rank matrices A and B, the overall storage requirement is much less than storing multiple fine-tuned models.
  3. Faster Backpropagation: The reduced number of parameters results in faster gradient computations, enabling quicker training cycles.
  4. Easier Switching Between Models: Switching tasks becomes straightforward as we only need to load different sets of A and B matrices instead of entire models.

Why Does LoRA Work?

According to the LoRA paper, it operates on the principles of low intrinsic dimension and low intrinsic rank during adaptation. By approximating the changes needed in the weights with low-rank matrices, LoRA effectively captures the essential features needed for specific tasks without requiring full model adjustments.

Single Value Decomposition (SVD)

SVD is a key mathematical technique used in LoRA, allowing for efficient representation of matrices. By decomposing a matrix M into three components:

M=U?S?V^T #Note we take V transpose

Where:

  • U and V are orthogonal matrices.
  • S is a diagonal matrix containing singular values.

This decomposition enables a rank-efficient representation, ensuring that crucial information is preserved even when dimensionality is reduced. For instance, if we can represent M accurately with a lower rank r (by keeping only the largest singular values), we can effectively reduce the complexity of our model without significant information loss.

Weights Update and Bias

It is important to note that LoRA focuses on updating weights while keeping biases unchanged. This strategy allows for retaining the learned knowledge embedded in the biases of the pretrained model, while still enabling significant adaptations via the low-rank matrices.

Conclusion

LoRA provides a powerful framework for adapting large models with minimal computational and storage costs. By leveraging low-rank updates while keeping the core pretrained weights fixed, it streamlines the fine-tuning process and facilitates easier transitions between tasks. The principles behind LoRA—such as low intrinsic dimension and SVD—further enhance its effectiveness, making it a valuable tool in the deep learning toolkit.

PyTorch Example: Note this is just for understanding purpose and not the actual implementation

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Define the LoRA layer
class LoRALayer(nn.Module):
    def __init__(self, input_dim, output_dim, rank):
        super(LoRALayer, self).__init__()
        self.A = nn.Parameter(torch.randn(input_dim, rank))  # Dimensions: (input_dim, rank)
        self.B = nn.Parameter(torch.randn(rank, output_dim))  # Dimensions: (rank, output_dim)
        
        # Print the dimensions of A and B
        print(f"LoRALayer initialized with A shape: {self.A.shape} and B shape: {self.B.shape}")

    def forward(self, x):
        if not hasattr(self, 'has_printed_shapes'):
            # Print input dimensions only once
            print(f"Input shape to LoRALayer: {x.shape}")
            self.has_printed_shapes = True
        
        result = x @ (self.A @ self.B)
        
        if not hasattr(self, 'has_printed_output_shape'):
            # Print output dimensions only once
            print(f"Output shape from LoRALayer: {result.shape}")
            self.has_printed_output_shape = True
            
        return result

# Simple feedforward network with LoRA
class SimpleNet(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, rank):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)  # First linear layer
        self.lora = LoRALayer(hidden_size, output_size, rank)  # LoRA layer

    def forward(self, x):
        x = torch.relu(self.fc1(x))  # Activation after the first layer
        return self.lora(x)  # Pass through LoRA layer

# Load MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

# Training loop
def train_model(model, train_loader, epochs=5):
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()
    
    for epoch in range(epochs):
        for batch_index, (data, target) in enumerate(train_loader):
            # Print initial matrix only once for the first batch
            if batch_index == 0 and epoch == 0:
                print(f"Initial data shape: {data.shape}")  # Print initial matrix shape
                #print(f"Initial data (first batch): {data[0]}")  # Print the first image data
            
            optimizer.zero_grad()
            output = model(data.view(data.size(0), -1))  # Flatten the input
            loss = criterion(output, target)  # Compute the loss
            loss.backward()  # Backpropagation
            optimizer.step()  # Update weights
        print(f'Epoch {epoch + 1}, Loss: {loss.item()}')

# Initialize and train the model
input_size = 784  # 28*28
hidden_size = 128
output_size = 10   # 10 classes (digits 0-9)
rank = 16          # Low rank

model = SimpleNet(input_size, hidden_size, output_size, rank)
train_model(model, train_loader)
        
Mandar Vairagkar

Director |Head of ServiceNow I | Automation | Digital & Cloud | Consulting | ServiceNow GTM | Pre-Sales/RFP Expert | Business Development | Innovation Lead

5 个月

Very helpful

赞
回复

要查看或添加评论,请登录

Zahir Shaikh的更多文章

社区洞察

其他会员也浏览了