Don't use PyTorch to Learn Deep Learning
It is a click-baity title if you read it as "don’t use PyTorch for deep learning". But in this post, I want to convince you to not use PyTorch in educational setting, to learn the concepts of deep learning. I’ll show you why it will hinder understanding and leave you feeling even more confused.
Simple PyTorch Training Loop
The code given below shows a part of a canonical PyTorch training loop from tutorials. The full code can be found at basic deep learning with PyTorch tutorials like this .
def train(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
model.train()
for batch, (X, y) in enumerate(dataloader):
X, y = X.to(device), y.to(device)
# Compute prediction error
pred = model(X)
loss = loss_fn(pred, y)
loss.backward()
optimizer.step()
optimizer.zero_grad()
Ask yourself some questions after looking at this code,
The core ideas of Deep Learning
Deep learning using neural networks is built on some fundamental ideas. Don’t worry if you don’t understand these yet. The whole point is to wrap your head around these principles.
Problem with PyTorch API in educational setting
PyTorch implements an object oriented API for deep learning, hiding a lot of state in objects (tensor). The gradient of the parameter (tensor) is a state, which is updated by loss and used by an optimizer. Learning this API is akin to learning to press buttons on a black box machine without understanding what it does. That’s a recipe for disaster. Only way to get better at understanding is by repeated exposure to the core ideas. PyTorch, unfortunately, hides these very ideas from you by giving you a simpler interface instead.
领英推荐
Functional programming enthusiasts will quickly recognize these as challenges in general with object oriented programming, but I’ll not digress on that line.
What can you use instead?
PyTorch has been around for sometime. Two other frameworks had the benefit of hindsight and made things explicit: JAX from Google and MLX from Apple. Since JAX works on wider variety of devices, let us look at a canonical JAX training loop. Again, I'm reproducing the core part of the code below, without the imports and setup etc.
def predict(params, x):
activations = x
for w, b in params[:-1]:
outputs = jnp.dot(w, activations) + b
activations = relu(outputs)
final_w, final_b = params[-1]
logits = jnp.dot(final_w, activations) + final_b
return logits
def cross_entropy_loss(params, x, y):
logits = predict(params, x)
return -jnp.mean(jnp.sum(y * jax.nn.log_softmax(logits), axis=-1))
#create a function that computes the gradient of the loss with # respect to the parameters
grad_fn = grad(cross_entropy_loss)
def update(params, x, y, learning_rate=0.01):
grads = grad_fn(params, x, y)
# Update the parameters using the computed gradients
return [(w - learning_rate * dw, b - learning_rate * db) for (w, b), (dw, db) in zip(params, grads)]
Code looks different than PyTorch and immediately highlights few things,
You have to set this up every time you write a small network, instead of loss.backward and optimizer.step magic. It's like setting everything up from first principles always. Tedious when you do this multiple times in the long run, but invaluable as a learner.
Summary
You, like most newcomers, would struggle with the maths of #deeplearning anyway because school taught us #calculus in 1 dimension for a long time. Idea of directions and gradient is introduced by the time most people have given up on “understanding” maths.
Don’t make it harder for yourself by using a production ready but magical black box like #PyTorch on top. It will double up the confusion in your head. Instead, use simpler #JAX or #MLX early in your journey. It will expose you to core principles repeatedly and help you internalise them!