From Overfit to Optimal: Refine Smart Models with Regularization.
Dushan Jalath
3rd Year AI Undergraduate and a Problem Solver | Passionate about shaping the future of technology
Overfitting is a problem in many machine learning models. I've noticed that many university students’ ML projects often suffer from overfitting. In this article, we’ll take a closer look at what overfitting is and how we can prevent it.
What is Overfitting?
Overfitting occurs when a machine learning model learns not only the underlying patterns in the training data but also the noise and random fluctuations. This causes the model to perform very well on the training data but poorly on new, unseen data, as it has become too specialized in the specific details of the training set. In other words, overfitting happens when a model is too complex, capturing not just the useful patterns but also irrelevant details or noise, which reduces its ability to generalize to new data.
The primary sign of an overfitted model is high accuracy on training data but poor performance on test data. Complex models with too many parameters are especially prone to overfitting.
How to Prevent Overfitting
There are several techniques to prevent overfitting. Let’s explore one of the most common methods: regularization.
Regularization
If you suspect your ML model is overfitting, one of the first techniques you should try is regularization.
In regularization, we add a term (λ/2m) * ||W||^2 to the loss function.
Here, ||W||^2 = Σ(W_i^2) , where i = 1 to n and n is the size of training set.
or ||W||^2 = W . W^T .
The parameter λ is the regularization parameter, a hyperparameter that you can tune using a validation (or development) set. This technique is known as L2 regularization.
领英推荐
How L2 Regularization Affects Backpropagation
During backpropagation, the model computes the gradient of the loss function with respect to each weight W_i , then updates the weights in the direction that minimizes the loss. The weight update rule is:
Wi = Wi - learning_rate * dW_i
Here, dW_i is gradient of loss function with respect to W_i.
With L2 regularization, we add the gradient of the penalty term (λ/2m) * ||W||^2 to the original gradient. The gradient of the L2 term with respect to λ / m * W_i. Therefore, the total gradient becomes:
Wnew = (1 - α λ / m) Wold - α * dW
Here, α is the learning rate.
Why L2 Regularization Works
L2 regularization assumes that a model with smaller weights is simpler than one with larger weights. By penalizing the squared values of the weights in the cost function, L2 regularization encourages smaller weights. This helps to distribute the weights more evenly, reducing the model's complexity and preventing overfitting. As a result, the model becomes smoother, with outputs that change more gradually as the inputs change.
Key Points to Remember When Implementing L2 Regularization
Conclusion
Regularization is a major method for reducing overfitting in machine learning models. In this article, we focused on L2 regularization. In the next article, we’ll dive into more techniques to prevent overfitting.