Types of gradient descent

Types of gradient descent

Hi guys, I’d like to give an intro into the different types of gradient descent used. 

Gradient descent (GD) is an iterative algorithm to find the minima of a cost function. If this cost function is strictly convex ( example : a sum of squared differences error as in linear regression ) the minima found is global. For more complex cost functions, there is a possibility of getting stuck in local minima.  There are 3 main flavors of GD -

Batch gradient descent

All the training data samples are used to make an update of the model parameters. The gradient is averaged over all the samples, and then a step is taken towards the minima.

Pros -

  • Smooth descent down the cost function space.  

Cons - 

  • Takes a very long time for large datasets. As all examples are looked over for each update this method is not very time efficient. 
  • If the entire dataset can not fit into memory ( which happens quite often for real life datasets), batch GD can’t be used. 

Stochastic gradient descent

An update is made after each sample. The gradient is calculated for each sample and a step is taken towards the minima

Pros - 

  • Faster than batch gradient descent.
  • Less prone to get stuck in a local minima. This is because as a step is taken for each training example it incorporates some randomness and therefore less likely to jump straight into a local minima.  

Cons - 

  • Noisy path to the minima.
  • If there are a large number of model parameters, updating the parameters after each example may not be time efficient.
  • Doesn’t fully exploit vectorized operations (and therefore the benefits of parallelization) as it uses only one sample at a time. 

Mini batch gradient descent

Why not the best of both worlds ? In this ‘k’ examples are used where 1<k<n and ‘n’ is the total number of examples in the dataset.

Pros -

  • Time efficient as compared to the others
  • Fits in memory
  • Path to minima is between more stable than stochastic GD while also incorporating the randomness of the update steps.
  • Exploits vectorized operations

Cons - 

  • Introduces an additional batch size hyperparameter which requires tuning

Note : if ‘k’ = ‘n’ it ends up being batch gradient descent, and if ‘k’ = 1 its stochastic gradient descent. Thanks. Hope you guys enjoyed the article. Please let me know if you have anything to add on.


Wei-Hsiang (Shawn) Lin

Experienced ML Engineer with a passion for scaling products to serve hundreds of millions of users

5 å¹´

Really nice explanation!!

赞
回复
Yashvardhan Das

Senior Data Scientist at HBO Max | Machine Learning | Product & Marketing Analytics | A/B Testing | AI/ML Product Management

5 å¹´

Really lucid explanation of an otherwise complex mathematical topic to get into. Nicely structured!

赞
回复

要查看或添加评论,请登录

Rahul Suresh的更多文章

  • NLP : The basics

    NLP : The basics

    Natural language processing (NLP) is aimed at the processing and understanding of human language by machines. Since…

    1 条评论

社区洞察

其他会员也浏览了