Types of gradient descent
Hi guys, I’d like to give an intro into the different types of gradient descent used.
Gradient descent (GD) is an iterative algorithm to find the minima of a cost function. If this cost function is strictly convex ( example : a sum of squared differences error as in linear regression ) the minima found is global. For more complex cost functions, there is a possibility of getting stuck in local minima. There are 3 main flavors of GD -
Batch gradient descent
All the training data samples are used to make an update of the model parameters. The gradient is averaged over all the samples, and then a step is taken towards the minima.
Pros -
- Smooth descent down the cost function space.
Cons -
- Takes a very long time for large datasets. As all examples are looked over for each update this method is not very time efficient.
- If the entire dataset can not fit into memory ( which happens quite often for real life datasets), batch GD can’t be used.
Stochastic gradient descent
An update is made after each sample. The gradient is calculated for each sample and a step is taken towards the minima
Pros -
- Faster than batch gradient descent.
- Less prone to get stuck in a local minima. This is because as a step is taken for each training example it incorporates some randomness and therefore less likely to jump straight into a local minima.
Cons -
- Noisy path to the minima.
- If there are a large number of model parameters, updating the parameters after each example may not be time efficient.
- Doesn’t fully exploit vectorized operations (and therefore the benefits of parallelization) as it uses only one sample at a time.
Mini batch gradient descent
Why not the best of both worlds ? In this ‘k’ examples are used where 1<k<n and ‘n’ is the total number of examples in the dataset.
Pros -
- Time efficient as compared to the others
- Fits in memory
- Path to minima is between more stable than stochastic GD while also incorporating the randomness of the update steps.
- Exploits vectorized operations
Cons -
- Introduces an additional batch size hyperparameter which requires tuning
Note : if ‘k’ = ‘n’ it ends up being batch gradient descent, and if ‘k’ = 1 its stochastic gradient descent. Thanks. Hope you guys enjoyed the article. Please let me know if you have anything to add on.
Experienced ML Engineer with a passion for scaling products to serve hundreds of millions of users
5 å¹´Really nice explanation!!
Senior Data Scientist at HBO Max | Machine Learning | Product & Marketing Analytics | A/B Testing | AI/ML Product Management
5 å¹´Really lucid explanation of an otherwise complex mathematical topic to get into. Nicely structured!