What is Regularization: Towards Deep Learning
What is Regularization?
Regularization is a technique used in machine learning to prevent models from overfitting. It helps the model focus on the essential patterns in the data by discouraging overly complex models that memorize noise or irrelevant details.
In simple terms:
Why Regularization?
Imagine you're trying to guess someone’s favorite ice cream flavor based on clues like:
Without Regularization:
With Regularization:
Real-Life Examples
Example 1: Packing for a Trip
No Regularization (Overpacking): You pack EVERYTHING—10 outfits, 5 pairs of shoes, a coffee maker, and books for every possible scenario.
Problem: Your bag becomes too heavy, and you don’t even use most of the stuff.
With Regularization (Smart Packing): You prioritize essentials like clothes, toiletries, and a charger.
Outcome: Your bag is lighter, and you have everything you need, without unnecessary clutter.
Example 2: Study Notes
Types of Regularization
领英推荐
L1 Regularization (Lasso)
L2 Regularization (Ridge)
Practical Machine Learning Example
Without Regularization (Overfitting Example)
The model learns unnecessary details, like the "color of shirts," instead of focusing on relevant patterns.
With Regularization
The model learns only the essential patterns, like "age" and "weather."
Python Code Example
import numpy as np
from sklearn.linear_model import LinearRegression, Ridge, Lasso
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
# Generate some data
np.random.seed(42)
X = np.random.rand(100, 5) # 5 features
y = 3 * X[:, 0] - 2 * X[:, 1] + np.random.randn(100) * 0.1 # Only first two features matter
# Split into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# No Regularization
model = LinearRegression()
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
print("No Regularization MSE:", mean_squared_error(y_test, y_pred))
# L2 Regularization (Ridge)
ridge_model = Ridge(alpha=1.0)
ridge_model.fit(X_train, y_train)
ridge_pred = ridge_model.predict(X_test)
print("Ridge Regularization MSE:", mean_squared_error(y_test, ridge_pred))
# L1 Regularization (Lasso)
lasso_model = Lasso(alpha=0.1)
lasso_model.fit(X_train, y_train)
lasso_pred = lasso_model.predict(X_test)
print("Lasso Regularization MSE:", mean_squared_error(y_test, lasso_pred))
Visualizing Regularization
When plotted:
Summary Table
Key Takeaway
Regularization is like teaching someone to focus on the big picture and not get lost in the tiny, irrelevant details. In machine learning, it ensures that your model learns what’s truly important, making it accurate, reliable, and able to handle new data well. ??