Learning to Learn

Learning to Learn

Intro

As we enter the Age of Information, a new resource has arisen: Data. Every website you visited, everything you bought, and every video you have watched is recorded. You, along with everyone else who has used the internet, has a digital footprint. But that's not even the whole story -- data is becoming even more relevant now that tons of it is required to train Neural Networks. This raises a big issue: for any sufficiently large neural network you will need a corresponding large dataset if you were to train it normally. This data might be easy to find, like how ChatGPT was just trained on text you could find literally everywhere on the internet. On the other hand, it could be extremely hard, expensive, or downright impossible to collect enough data to train a good model.?

Fine-Tuning

One of the most obvious ways of mitigating the amount of data needed is probably fine-tuning a larger model. For ChatGPT, I only need around 100 high-quality samples to get a good result in fine tuning. Although ChatGPT was trained on billions of samples, for my more specific tasks it only took me 100. It doesn’t technically require less samples, since GPT is already pretrained, but it drastically reduces the amount one would need to make many models off of one pre-trained one.?

As previously mentioned this isn’t the greatest method in the world. This assumes you already have a huge pretrained network that you can easily specify for future tasks, which isn’t true for most cases. For example, in the medical field a model for diagnosis will have more or less data depending on how rare the disease is and how many people choose to partake in a study or give their data. It isn’t feasible to create a huge model on every single disease, and then fine-tune it on that specific one either.

Transfer Learning

Another method is called transfer learning. Similar to fine-tuning, it requires an already trained model. Unlike fine-tuning, it can completely change the domain in which the network is built on. To understand this, let’s take a look at an example.

Let’s say you have a simple classifier model that was trained on pictures of cats, dogs,?flamingos, and humans. It had 4 classes for each type of animal.

Now, we want to change up our model a little bit. Instead of classifying 4 different classes of species, we want to classify mammals versus birds.?

Just a quick side note: Yes this is a trivial test case since technically you could just do some post-processing on the existing network, and just say if the output is a dog or cat, it's a mammal, otherwise it's a bird. But, for the sake of the example let’s say we want to change the model itself from 4 classes to 2.

In order to do this, we first need new data for mammals vs birds. Then, we would need to train the model. However, we don’t actually need to retrain the entire net. This is where transfer learning comes in. Hopefully, the network should have already learned general features in the first few layers, and only specifically classified them in its last layers. This means we can reuse the first few layers (freeze their weights) and only train the last or second to last layer in the network. This would reduce the amount of training samples needed and the amount of time needed to train in the first place.

This graphic illustrates how the important lines and shapes learned by the old classifier can be reused. The new classifier for mammals will likely learn that 4 legs and a larger facial structure is enough to determine that it is a mammal. However, the old one needed to get more specific details of the face to differentiate between cats and dogs, which we didn’t need for the new one.

In general, there isn’t really a good method that requires less data, so as you have seen with these first examples, we need to learn on somewhat relevant data so that we can use less data to learn our actual task.?

IMAML

One of the more involved methods is called MAML: Model-Agnostic Meta-Learning. With fine-tuning and transfer learning, you required a general model to be trained on a large amount of data in order to train on another task. However, in many cases gaining a large amount of data for the general model can be infeasible, like in medicine for example.?

So, instead of training on a large amount of data, we take data across a variety of tasks. The idea is to train a set of general model parameters that are optimal for fine-tuning on these tasks.

To more formally show it, we are going to introduce some symbols in order to keep track of everything. Let’s say you had a set of tasks T = {t0, t1 … ti}. For each of these tasks, you want to find the optimal parameters Φi such that the loss of model Loss(Φi) is minimized. However, each of these tasks don’t have much data associated with them. Instead, we want to train a so-called “meta-learner” that learns the best starting parameters θ so that Loss(Φi) is as small as possible.

The above image summarizes the MAML procedure. First, a few tasks (ideally with the most data out of all the tasks) are used for training. For each of the training tasks the Stochastic Gradient Descent algorithm is initialized with the initial parameters θ. Then, SGD runs some amount of steps or until the desired minimum loss is reached. Then, each of the losses across all tasks are fed into the optimizer for the meta-learner. This will then adjust the initial parameters θ to get a better overall loss across all of the tasks. Once we have finished training, we can then apply our new initialization parameters to a new task. This should then require less data and time in order to achieve an optimal loss.?

Of course, there are some pitfalls with this. Notice how the title of the section was iMAML, but we only talked about MAML so far. That’s because regular MAML has a big drawback. When it optimizes its initial parameters, it has to backpropagate through the entire SGD in order to figure out which direction to tweak θ. This is very slow, and to get any feasibly fast speeds, the amount of SGD steps need to be capped. This is where iMAML comes in, or implicit model-agnostic meta-learning. I won't go through the math here, but it gives a way to estimate the direction needed to tweak θ without backpropagating through SGD at all. In fact, we don’t even need to use SGD for iMAML -- we can use whatever we want!?

Don’t worry if you didn’t quite understand the last paragraph, all you need to know is that iMAML is superior to MAML in almost every way. iMAML has proven to be very useful in cases where multiple related tasks all have insufficient data. A great example of this is cancer studies. As you can imagine, it's difficult to find a lot of data about cancer patients. It's expensive, there are privacy concerns, you need a doctor to run tests, etc. There are also many different cancer types, which are our different tasks in iMAML. By training a meta-learner on a few cancer types that have sufficient data, you may be able to train a model that works well on less common cancer types as well.

The Code

The actual algorithm behind iMAML is a little more complicated than what was described. It uses several optimization techniques, namely using the conjugate gradient algorithm to avoid having to calculate a Hessian matrix. However, the algorithm is pretty simple apart from that.

To start, let’s first define our task. We are going to use the popular Omniglot dataset for training. It is a computer vision task that contains over a 1000 different characters across many different languages. Our goal will be to be able to quickly train a new classifier network that can distinguish between 5 different characters after only seeing each of them once. The lingo for this is 5 way, 1 shot. Let’s define the convolutional neural network.

class Flatten(nn.Module):
    def forward(self, x):
        x = x.view(x.size()[0], -1)
        return x
    
def make_conv_network(in_channels, out_dim, filter_size=32):
    model = nn.Sequential()

    num_filters = 64
    conv_stride = 2

    model.add_module('conv1', nn.Conv2d(in_channels=in_channels, out_channels=num_filters, kernel_size=3, stride=conv_stride, padding=1))
    model.add_module('BN1', nn.BatchNorm2d(
        num_filters, track_running_stats=False))
    model.add_module('relu1', nn.ReLU())
    model.add_module('conv2', nn.Conv2d(in_channels=num_filters, out_channels=num_filters, kernel_size=3, stride=conv_stride, padding=1))
    model.add_module('BN2', nn.BatchNorm2d(
        num_filters, track_running_stats=False))
    model.add_module('relu2', nn.ReLU())
    model.add_module('pad2', nn.ZeroPad2d((0, 1, 0, 1)))
    model.add_module('conv3', nn.Conv2d(in_channels=num_filters, out_channels=num_filters, kernel_size=3, stride=conv_stride, padding=1))
    model.add_module('BN3', nn.BatchNorm2d(
        num_filters, track_running_stats=False))
    model.add_module('relu3', nn.ReLU())
    model.add_module('conv4', nn.Conv2d(in_channels=num_filters, out_channels=num_filters, kernel_size=3, stride=conv_stride, padding=1))
    model.add_module('BN4', nn.BatchNorm2d(
        num_filters, track_running_stats=False))
    model.add_module('relu4', nn.ReLU())
    model.add_module('flatten', Flatten())
    model.add_module('fc1', nn.Linear(2*2*num_filters, out_dim))

    for layer in [model.conv1, model.conv2, model.conv3, model.conv4, model.fc1]:
        nn.init.xavier_uniform_(layer.weight, gain=1.73)
        try:
            nn.init.uniform_(layer.bias, a=0.0, b=0.05)
        except:
            print("Bias layer not detected for layer:", layer)
            pass

    return model
        

Let's also really quickly define our args for this network. You will see how they are used later.

class Args:
    def __init__(self):
        self.N_way = 5
        self.K_shot = 1
        self.inner_lr = 0.1
        self.outer_lr = 0.001
        self.outer_train_steps = 500
        self.inner_train_steps = 16
        self.task_mb_size = 16
        self.num_tasks = 50
        self.lamb = 100
        self.cg_iters = 5
        self.save_dir = '.'
        self.use_gpu = True        

Now, we can start defining our iMAML algorithm. To start, there will be two networks at any given time: The “outer network,” or meta-learner, and the “inner network”, the actual learner.

class iMAML:
    def __init__(self, args, dataset, model, loss_function):
        self.args = args

        self.outer_network = model()
        self.inner_network = model()

        self.outer_network.train()
        self.outer_network.cuda()
        self.inner_network.train()
        self.inner_network.cuda()

        # TODO: add other optimizers
        self.inner_optimizer = torch.optim.Adam(self.inner_network.parameters(), lr=self.args.inner_lr, betas=(0.0, 0.9))
        self.outer_optimizer = None

        self.lr_scheduler = torch.optim.lr_scheduler.StepLR(self.outer_optimizer, step_size=10, gamma=.5)

        self.dataset = dataset
        self.loss_function  = loss_function        

Now, to train the network we will need two nested loops. One to go through each train step for the meta learner, and another one to go through each train step for the actual learner. For each outstep, we will randomly select a number of tasks that we will use to train our classifier.

Each time we enter into the actual learner’s loop, we want to copy over the parameters from our meta learner as well.

def train(self):
    lam = torch.tensor(self.args.lamb)
    lam = lam.to("cuda")

    losses = []

    for outstep in range(self.args.outer_train_steps):
        task_mb = np.random.choice(self.args.num_tasks, size=self.args.task_mb_size)
        # flatten the paramters of the outer network
        w_k = torch.cat([param.data.view(-1) for param in self.outer_network.parameters()], 0).clone()

        meta_grad = 0
        losses.append(0)

        for idx in task_mb:
            # Train the inner network
            # Copy over parameters
            offset = 0
            for param in self.inner_network.parameters():
                param.data.copy_(w_k[offset:offset + param.nelement()].view(param.size()))
                offset += param.nelement()        

Then, we can define our inner network’s optimizer and train it on the data

self.inner_optimizer = torch.optim.Adam(self.inner_network.parameters(), lr=self.args.inner_lr, betas=(0.0, 0.9)) # create a new optimizer
task = self.dataset.__getitem__(idx) 
self.train_inner(task)
self.inner_optimizer.zero_grad()
reg_loss = self.regularization_loss(w_k, lam)
reg_loss.backward()
self.inner_optimizer.step()        
def get_inner_loss(self, x, y):
    yhat = self.inner_network.forward(x)
    loss = self.loss_function(yhat, y)
    return loss
    
@torch.enable_grad()
def train_inner(self, task):
    x, y = task['x_train'], task['y_train']

    for _ in range(self.args.inner_train_steps):
        self.inner_optimizer.zero_grad()
        loss = self.get_inner_loss(x, y)
        loss.backward()
        self.inner_optimizer.step()

def regularization_loss(self, w_0, lam=0):
    offset = 0
    reg_loss = 0
    for param in self.inner_network.parameters():
        # offset is required since w_0 has been flattened
        delta = param.view(-1) - w_0[offset:offset + param.nelement()].view(-1)
        reg_loss += 0.5 * lam * torch.sum(torch.pow(delta, 2))
        offset += param.nelement()
    return reg_loss        

The regularization loss is there in order to keep a relation between the loss of the inner network and the outer network’s parameters. This is important for how the iMAML algorithm works since it allows for some neat math tricks in order to estimate the gradient. You can take a look at how that works in their paper, since the math is a little beyond the scope of this article.

Now that we have a trained inner network, we want to use the loss of its validation set in order to train the outer network.

# Compute the implicit meta gradients via CG solver
valid_loss = self.get_inner_loss(task['x_val'], task['y_val'])
valid_grad = torch.autograd.grad(valid_loss, self.inner_network.parameters(), create_graph=True)
flat_grad = torch.cat([g.contiguous().view(-1) for g in valid_grad])
outer_grad = self.cg_solve(flat_grad).detach()

# take the average of the meta gradients
meta_grad += outer_grad / self.args.task_mb_size
losses[outstep] += valid_loss.item() / self.args.task_mb_size        

Here is where the conjugate-gradient algorithm comes in. Basically, the iMAML algorithm ends up with a nice equation that represents the gradient of the outer network. However, its computationally expensive to compute. Luckily, this algorithm helps us bypass the computationally heavy part and allows iMAML to be very quick and efficient.

@torch.enable_grad()
def hv_prod(self, grad, v):
    hv1 = torch.autograd.grad(torch.sum(grad * v), self.inner_network.parameters(), retain_graph=True)
    hv =  torch.cat([g.contiguous().view(-1) for g in hv1])
    return hv / (self.args.lamb) + v * 2

@torch.no_grad()
def cg_solve(self, grad, residual_tol=1e-10):
    x = torch.zeros(grad.shape[0]).to(grad.device)
    r = grad - self.hv_prod(grad, x)
    p = r.clone()

    # Code adapted from the official IMAML paper code
    for i in range(self.args.cg_iters):
        rdotr = r.dot(r)
        Ap = self.hv_prod(grad, p)
        alpha = rdotr/(p.dot(Ap))
        x = x + alpha * p
        r = r - alpha * Ap
        newrdotr = r.dot(r)
        beta = newrdotr/rdotr
        p = r + beta * p

        if newrdotr < residual_tol:
            # print("Early CG termination because the residual was small")
            break
    return x        

Finally, we can now update the outer network with our meta grad. The dummy loss is just there to ensure the network knows that it's about to have its parameters changed. I’m not entirely sure why, but it seems to not work without it. Note that this part is now OUTSIDE the inner loop.

dummy_loss = self.regularization_loss(torch.cat([param.data.view(-1) for param in self.outer_network.parameters()], 0).clone())
dummy_loss.backward()

meta_grad = meta_grad.to("cuda")
offset = 0
for p in self.outer_network.parameters():
    # Again, since meta_grad has been flattened
    # Basically this is just setting the gradients of the outer network to the meta gradients
    this_grad = meta_grad[offset:offset + p.nelement()].view(p.size())
    if p.grad is None:
        p.grad = this_grad
    p.grad.copy_(this_grad)
    offset += p.nelement()

print(f"Step {outstep}")
print(f"Average Loss", losses[outstep])        

And that's pretty much it! For the dataset, I just used the same dataset helper class that the official iMAML code used. We can test the training out like so:

args = Args()

# this part is from the iMAML code
task_defs = [OmniglotTask(list(range(1623)), num_cls=args.N_way, num_inst=args.K_shot) for _ in tqdm(range(args.num_tasks))]
dataset = OmniglotFewShotDataset(task_defs=task_defs, GPU=args.use_gpu)

maml = iMAML(args, dataset, lambda: make_conv_network(in_channels=1, out_dim=args.N_way), nn.CrossEntropyLoss())

maml.train()        

I was able to get it to learn pretty well for what little classes and training steps I gave it. The actual paper used much more data than I did, but unfortunately I don’t currently have a powerful enough machine to actually try that out.

I also want to point out that there are many different algorithms out there that can achieve similar tasks, such as Few Shot learning with global class representations or Decoder choice network for meta-learning, but I found iMAML to be the simplest yet most versatile algorithm of all of these. However, if these algorithms fit your task it may be worth a try, since these algorithms can get superior scores.

Conclusion

While “learning to learn,” is hard to do, it becomes increasingly important as many crucial tasks don’t have enough data to create a good model. As it stands, most of the state of the art image and text models are only possible because of how much text and how many images are available on the internet. However, most tasks, especially the niche ones, won’t have nearly enough data. These methods as well as hopefully new ones in the future are taking the first steps to remedy this. In fact, these methods may even be the first steps into Artificial General Intelligence itself. As humans, we are able to rapidly learn new topics based on everything we have experienced before. If I saw a new animal today, I would be able to easily distinguish it from every other one I have seen thus far. That’s what AGI is missing right now -- the ability to learn something with minimal data but with all of its past experiences behind it. And that is only possible if it can first learn how to learn.


Puneet Bhardwaj

Group Chief Data Officer @ Zurich Insurance

1 年

Meta learning is the missing link between GenAI and its large scale commercial deployment. Very good explanation in this article.

要查看或添加评论,请登录

Mannan Bhardwaj的更多文章

  • TinyR1: Recreating DeepSeek R1 at Home!

    TinyR1: Recreating DeepSeek R1 at Home!

    OpenAI O1 pushed the frontier of what is possible with LLMs by tuning an LLM to create chains of reasoning using…

    2 条评论
  • Man VS Machine—A Battle Of Intelligence

    Man VS Machine—A Battle Of Intelligence

    Invented by Warren McCulloch and Walter Pitts, the MucColluoch-Pitts Neuron, more popularly known as the Perceptron…

    3 条评论
  • Agentic AI

    Agentic AI

    One of the most interesting use cases for LLMs is its use in autonomous agents. LLMs by themselves are great for…

    1 条评论
  • ChatGPT is obsolete

    ChatGPT is obsolete

    Whether its for general use, autonomous agents, or creating fine-tuned chatbots, OpenAI has been at the forefront of…

  • New Open-source LLM: Google Gemma

    New Open-source LLM: Google Gemma

    Intro Google has recently released their newest open-source AI models: Gemma 2b and 7b. These are competitors to other…

    1 条评论
  • Mixture Of Experts: The Future of LLMs

    Mixture Of Experts: The Future of LLMs

    Intro What made GPT3.5 and GPT4 completely destroy all the competition? Since “Open”AI’s closed source models make it…

    4 条评论
  • Virtual Cloning

    Virtual Cloning

    Intro Throughout one’s life, a person can create a significant impact on the internet. Every like, every post, every…

    1 条评论
  • A Journey Through Neural Compression

    A Journey Through Neural Compression

    Introduction At this point of the game, I feel like neural networks have been sort of black boxed, and for good reason.…

  • Prompting LLMs with LLMs

    Prompting LLMs with LLMs

    Introduction Prompt engineering is the process of creating a prompt such that the LLM knows exactly what to do and how…

  • Unlocking the true power of LLMs with Vector Embeddings

    Unlocking the true power of LLMs with Vector Embeddings

    What is a Vector Embedding and why is it important? In order to understand the power of vector embeddings, you need to…

    3 条评论

社区洞察

其他会员也浏览了