Don't use PyTorch to Learn Deep Learning
Generated with OpenAI Dall-E3

Don't use PyTorch to Learn Deep Learning

It is a click-baity title if you read it as "don’t use PyTorch for deep learning". But in this post, I want to convince you to not use PyTorch in educational setting, to learn the concepts of deep learning. I’ll show you why it will hinder understanding and leave you feeling even more confused.

Simple PyTorch Training Loop

The code given below shows a part of a canonical PyTorch training loop from tutorials. The full code can be found at basic deep learning with PyTorch tutorials like this .

def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()        

Ask yourself some questions after looking at this code,

  • Exactly how is the neural network getting trained?
  • Where’s gradient descent? Exactly what does it compute? What is the shape of its input and output?
  • Why do you have to do `optimizer.zero_grad`. Does the placement of that line in this code matter? Why or why not?

The core ideas of Deep Learning

Deep learning using neural networks is built on some fundamental ideas. Don’t worry if you don’t understand these yet. The whole point is to wrap your head around these principles.

  • What are neural networks? (lego-like stacking of layers like parameterised linear functions and non-linearities like ReLU)
  • How do neural networks learn? (find optimal values of parameters that make up the lego blocks)
  • How do we find the optimal values of the parameters? (by using gradient based methods)
  • What is the data type of gradient? (gradient is an operator or higher order function, it eats one function and spits out another function)
  • Gradient of what? How? (gradient of a cost function that represents how well our model fits data, gradient is evaluated at the current parameters of the lego blocks)

Problem with PyTorch API in educational setting

PyTorch implements an object oriented API for deep learning, hiding a lot of state in objects (tensor). The gradient of the parameter (tensor) is a state, which is updated by loss and used by an optimizer. Learning this API is akin to learning to press buttons on a black box machine without understanding what it does. That’s a recipe for disaster. Only way to get better at understanding is by repeated exposure to the core ideas. PyTorch, unfortunately, hides these very ideas from you by giving you a simpler interface instead.

  • The gradients are calculated when loss.backward is called. It is somewhat opaque how .backward on the loss is connecting with the parameters of the model.
  • The idea of zeroing out the currently calculated gradient is needed because this is a stateful calculation. It is easy to miss this step or misplace it!

Functional programming enthusiasts will quickly recognize these as challenges in general with object oriented programming, but I’ll not digress on that line.

What can you use instead?

PyTorch has been around for sometime. Two other frameworks had the benefit of hindsight and made things explicit: JAX from Google and MLX from Apple. Since JAX works on wider variety of devices, let us look at a canonical JAX training loop. Again, I'm reproducing the core part of the code below, without the imports and setup etc.

def predict(params, x):
    activations = x
    for w, b in params[:-1]:
        outputs = jnp.dot(w, activations) + b
        activations = relu(outputs)
    
    final_w, final_b = params[-1]
    logits = jnp.dot(final_w, activations) + final_b
    return logits

def cross_entropy_loss(params, x, y):
    logits = predict(params, x)
    return -jnp.mean(jnp.sum(y * jax.nn.log_softmax(logits), axis=-1))

#create a function that computes the gradient of the loss with # respect to the parameters

grad_fn = grad(cross_entropy_loss)

def update(params, x, y, learning_rate=0.01):

    grads = grad_fn(params, x, y)
    
    # Update the parameters using the computed gradients
    return [(w - learning_rate * dw, b - learning_rate * db) for (w, b), (dw, db) in zip(params, grads)]        

Code looks different than PyTorch and immediately highlights few things,

  1. We’re computing the gradient of the loss function with respect to the params of the network.
  2. Loss function will have scalar output (`jnp.mean`).
  3. Taking gradient involves first eating the loss function and to get back another function (`grad_fn`). You evaluate the gradient for the current value of parameters (`grad`).
  4. Using model params and their gradients, optimizers update to the better value of params.

You have to set this up every time you write a small network, instead of loss.backward and optimizer.step magic. It's like setting everything up from first principles always. Tedious when you do this multiple times in the long run, but invaluable as a learner.

Summary

You, like most newcomers, would struggle with the maths of #deeplearning anyway because school taught us #calculus in 1 dimension for a long time. Idea of directions and gradient is introduced by the time most people have given up on “understanding” maths.

Don’t make it harder for yourself by using a production ready but magical black box like #PyTorch on top. It will double up the confusion in your head. Instead, use simpler #JAX or #MLX early in your journey. It will expose you to core principles repeatedly and help you internalise them!


要查看或添加评论,请登录

社区洞察

其他会员也浏览了