How to Handle Imbalanced Classes in Machine Learning
Intuition: Disease?Screening Example
Let’s say your client is a?leading research hospital, and they’ve asked you to train a model for?detecting a disease based on biological inputs collected from patients.
But here’s the catch… the disease?is relatively rare; it occurs in only 8% of patients who are screened.
Now, before?you even start, do you see how the problem might break? Imagine if you didn’t bother training a model at all. Instead, what?if you just wrote a single line of code that always predicts ‘No Disease?’
Well, guess what? Your “solution” would have 92% accuracy!
Unfortunately, that accuracy is misleading.
This is clearly a problem because many machine learning algorithms are designed?to maximize overall accuracy. The rest of this guide will illustrate?different tactics for handling?imbalanced classes.
Important notes before we begin:
First, please note that we’re not going to split out a separate test set, tune hyperparameters, or implement cross-validation. In other words, we’re not necessarily going to follow best practices.
Instead, this tutorial is focused purely on addressing?imbalanced classes.
In addition, not every technique below will work for every problem. However,?9 times out of 10, at least one of these techniques should?do the trick.
Balance Scale Dataset
For?this guide, we’ll use a synthetic dataset called Balance Scale Data, which you can download from the UCI Machine Learning Repository .
This dataset was originally generated to model psychological experiment results, but it’s useful for us because it’s a manageable size and has imbalanced classes.
The dataset contains information?about whether a scale is balanced or not, based on weights and distances?of the two arms.
The?target variable has 3 classes.
However, for this tutorial, we’re going to turn this into a binary classification problem.
We’re going to label each observation as?1 (positive class) if the scale is balanced or?0 (negative class) if the scale is not balanced:
Next,?we’ll fit a very simple model using default settings for everything.
As mentioned above, many machine learning algorithms are designed?to maximize overall accuracy by default.
We can confirm this:
So our model has 92% overall accuracy, but is it because it’s predicting only 1 class?
As you can see, this?model is only predicting?0, which means it’s completely ignoring the minority class in favor of the majority class.
Next, we’ll look at the first technique for?handling imbalanced classes: up-sampling the minority class.
1. Up-sample Minority Class
Up-sampling is?the process of randomly duplicating observations from the minority class in order to reinforce its?signal.
There are several heuristics for doing so, but the most common way is to simply resample with replacement.
First, we’ll import the resampling module from Scikit-Learn:
Next, we’ll create a new DataFrame with an up-sampled minority class. Here are the steps:
Here’s the code:
As you can see, the new DataFrame has more observations than the original, and the ratio of the two classes is now 1:1.
Let’s train another model using Logistic Regression, this time on the balanced dataset:
Great, now the model is no longer predicting just one class. While?the accuracy also took a nosedive, it’s now more meaningful as a performance metric.
2. Down-sample Majority Class
Down-sampling involves randomly removing observations from the majority class to prevent?its signal from dominating the learning algorithm.
The most common heuristic for doing so is resampling without replacement.
The process is similar to that of up-sampling. Here are the steps:
Here’s the code:
This time, the new DataFrame has fewer?observations than the original, and the ratio of the two classes is now 1:1.
Again, let’s train a model using Logistic Regression:
The model isn’t?predicting just one class, and the accuracy seems higher.
We’d still want to validate the model on an unseen test dataset, but the results are more encouraging.
3. Change Your Performance Metric
So far, we’ve looked at two ways?of addressing imbalanced classes by resampling the dataset. Next, we’ll look at using other performance metrics for evaluating the models.
Albert Einstein once said, “if you judge a fish on its ability to climb a tree, it will live its whole life believing that it is stupid.” This quote really highlights the importance of?choosing the right evaluation metric.
For a general-purpose metric for classification, we recommend?Area Under ROC Curve (AUROC).
We can import this metric from Scikit-Learn:
To calculate AUROC, you’ll need predicted class probabilities instead of just the predicted classes. You can get them using the .predict_proba()??function like so:
So how did this?model (trained on the down-sampled dataset) do in terms of AUROC?
Ok… and how does this compare to the original model trained on the imbalanced dataset?
Remember, our original?model trained on the imbalanced dataset had an accuracy of 92%, which is much higher than the 58% accuracy of the model trained on the down-sampled dataset.
However, the latter model has an AUROC of 56.5%, which is higher than the 52.4% of the original model (but not by much).
Note: if you got an AUROC of 0.476 instead, it just means you need to invert the predictions because Scikit-Learn is misinterpreting the positive class. AUROC should always be >= 0.5, so the actual AUROC is simply 1 – 0.476 = 0.524.
4. Penalize Algorithms (Cost-Sensitive Training)
The next tactic is to use penalized learning algorithms that increase the cost of classification mistakes on the minority class.
A?popular algorithm for this technique is Penalized-SVM:
During training, we can use the argument class_weight='balanced'? to penalize mistakes on the minority class by an amount proportional to how under-represented it is.
We also want to include the argument probability=True? if we want to enable probability estimates for SVM algorithms.
Let’s train a model using Penalized-SVM on the original imbalanced dataset:
Again, our purpose here is only to illustrate this technique. To really determine which of these tactics works best for this problem, you’d want to evaluate the models on a hold-out test set.
5. Use Tree-Based Algorithms
The final tactic we’ll consider is using tree-based algorithms. Decision trees often perform well on imbalanced datasets because their hierarchical structure allows them to learn signals?from both classes.
In modern applied machine learning, tree ensembles (Random Forests, Gradient Boosted Trees, etc.)?almost always outperform singular decision trees, so we’ll jump right into those:
Now, let’s train a model using a Random Forest on the original imbalanced dataset.
Wow! 100% accuracy and 100% AUROC? Is this magic? A sleight of hand? Cheating? Too good to be true?
Well, tree ensembles have become very popular because they perform extremely well on many real-world problems. We certainly recommend them wholeheartedly.
While these results?are encouraging, the model?could be overfit, so you should still?evaluate your model on an unseen test set before making the final decision.
Note: If your numbers differ slightly, it is due to the randomness in the algorithm. Remember to use random_state=123 (or any number you desire) set a random seed for reproducible results.