Solving Class Imbalance: Techniques and Strategies
Hey everyone, in my previous post about - Unleashing the Power of Data, I talked about how data imbalance can severely impact the accuracy of our models. In this article, I'll dive into the technical side of things and share how I tackled this issue.
So, as we saw in the last post, the heavy-plastic class had four times more images than the no-image data. And when I trained my model using this set, it was biased towards the majority class, resulting in a measly 50% accuracy. But, I didn't give up!
After doing some research, I came across a combo of techniques that proved to be super effective - SubsetRandomSampler and Class Weights. Let me break it down for you.
SubsetRandomSampler is a PyTorch utility that lets you create a random subset of a dataset. It shuffles the list of indices and returns a subset of those indices. So, we can create a random subset of our imbalanced data with ease.
Next up, we calculate the class weights for our dataset. How? By counting the number of examples in each class, and then computing the inverse frequency of each class. These weights are then passed to the loss function using the 'weight' argument.
Now, we can pass our SubsetRandomSampler to the PyTorch DataLoader, which loads the data in batches during training. And since we have class weights in place, our model pays more attention to the minority class during training.
The class weights ensure that the model pays more attention to the minority class, while the SubsetRandomSampler ensures that the model sees a variety of examples during training.
So, what did all of this do for my model? Well, my accuracy skyrocketed from 50% to 80%! And, it even generalized well on unseen data. Pretty cool, huh?
Alright, that's all for now, folks. Stay tuned for my next post, where I'll be sharing more on model selection and building. Cheers!"