How to Handle Imbalanced Classes in Machine Learning

Imbalanced classes put “accuracy” out of business. This is a surprisingly common problem in machine learning (specifically in classification), occurring in datasets with a disproportionate ratio of observations in each class.

Standard accuracy no longer reliably measures performance, which makes model training much trickier.

Imbalanced classes appear in many domains, including:

  • Fraud detection
  • Spam filtering
  • Disease screening
  • SaaS subscription churn
  • Advertising click-throughs

In this guide, we’ll explore 5 effective ways to handle imbalanced classes.

How to Handle Imbalanced Classes in Machine Learning

Intuition: Disease Screening Example

Let's say your client is a leading research hospitals, and they've asked you to train a model for detecting a disease based on biological inputs collected from patients.

But here's the catch... the disease is relatively rare; it occurs in only 8% of patients who are screened.

Now, before you even start, do you see how the problem might break? Imagine if you didn't bother training a model at all. Instead, what if you just wrote a single line of code that always predicts 'No Disease?'

Well, guess what? Your "solution" would have 92% accuracy!

Unfortunately, that accuracy is misleading.

  • For patients who do not have the disease, you'd have 100% accuracy.
  • For patients who do have the disease, you'd have 0% accuracy.
  • Your overall accuracy would be high simply because most patients do not have the disease (not because your model is any good).

This is clearly a problem because many machine learning algorithms are designed to maximize overall accuracy. The rest of this guide will illustrate different tactics for handling imbalanced classes.

Important notes before we begin:

First, please note that we're not going to split out a separate test set, tune hyperparameters, or implement cross-validation. In other words, we're not going to follow best practices (which are covered in our 7-day crash course).

Instead, this tutorial is focused purely on addressing imbalanced classes.

In addition, not every technique below will work for every problem. However, 9 times out of 10, at least one of these techniques should do the trick.

Balance Scale Dataset

For this guide, we'll use a synthetic dataset called Balance Scale Data, which you can download from the UCI Machine Learning Repository here.

This dataset was originally generated to model psychological experiment results, but it's useful for us because it's a manageable size and has imbalanced classes.

Balance Scale Dataset

The dataset contains information about whether a scale is balanced or not, based on weights and distances of the two arms.

  • It has 1 target variable, which we've labeled balance .
  • It has 4 input features, which we've labeled var1  through var4 .
Image Scale Data

The target variable has 3 classes.

  • R for right-heavy, i.e. when var3 * var4 > var1 * var2
  • L for left-heavy, i.e. when var3 * var4 < var1 * var2
  • B for balanced, i.e. when var3 * var4 = var1 * var2

However, for this tutorial, we're going to turn this into a binary classification problem.

We're going to label each observation as 1 (positive class) if the scale is balanced or 0 (negative class) if the scale is not balanced:

As you can see, only about 8% of the observations were balanced. Therefore, if we were to always predict 0, we'd achieve an accuracy of 92%.

The Danger of Imbalanced Classes

Now that we have a dataset, we can really show the dangers of imbalanced classes.

First, let's import the Logistic Regression algorithm and the accuracy metric from Scikit-Learn.

Next, we'll fit a very simple model using default settings for everything.

As mentioned above, many machine learning algorithms are designed to maximize overall accuracy by default.

We can confirm this:

So our model has 92% overall accuracy, but is it because it's predicting only 1 class?

As you can see, this model is only predicting 0, which means it's completely ignoring the minority class in favor of the majority class.

Next, we'll look at the first technique for handling imbalanced classes: up-sampling the minority class.

1. Up-sample Minority Class

Up-sampling is the process of randomly duplicating observations from the minority class in order to reinforce its signal.

There are several heuristics for doing so, but the most common way is to simply resample with replacement.

First, we'll import the resampling module from Scikit-Learn:

Next, we'll create a new DataFrame with an up-sampled minority class. Here are the steps:

  1. First, we'll separate observations from each class into different DataFrames.
  2. Next, we'll resample the minority class with replacement, setting the number of samples to match that of the majority class.
  3. Finally, we'll combine the up-sampled minority class DataFrame with the original majority class DataFrame.

Here's the code:

As you can see, the new DataFrame has more observations than the original, and the ratio of the two classes is now 1:1.

Let's train another model using Logistic Regression, this time on the balanced dataset:

Great, now the model is no longer predicting just one class. While the accuracy also took a nosedive, it's now more meaningful as a performance metric.

2. Down-sample Majority Class

Down-sampling involves randomly removing observations from the majority class to prevent its signal from dominating the learning algorithm.

The most common heuristic for doing so is resampling without replacement.

The process is similar to that of up-sampling. Here are the steps:

  1. First, we'll separate observations from each class into different DataFrames.
  2. Next, we'll resample the majority class without replacement, setting the number of samples to match that of the minority class.
  3. Finally, we'll combine the down-sampled majority class DataFrame with the original minority class DataFrame.

Here's the code:

This time, the new DataFrame has fewer observations than the original, and the ratio of the two classes is now 1:1.

Again, let's train a model using Logistic Regression:

The model isn't predicting just one class, and the accuracy seems higher.

We'd still want to validate the model on an unseen test dataset, but the results are more encouraging.

3. Change Your Performance Metric

So far, we've looked at two ways of addressing imbalanced classes by resampling the dataset. Next, we'll look at using other performance metrics for evaluating the models.

Albert Einstein once said, "if you judge a fish on its ability to climb a tree, it will live its whole life believing that it is stupid." This quote really highlights the importance of choosing the right evaluation metric.

For a general-purpose metric for classification, we recommend Area Under ROC Curve (AUROC).

  • We won't dive into its details in this guide, but you can read more about it here.
  • Intuitively, AUROC represents the likelihood of your model distinguishing observations from two classes.
  • In other words, if you randomly select one observation from each class, what's the probability that your model will be able to "rank" them correctly?

We can import this metric from Scikit-Learn:

To calculate AUROC, you'll need predicted class probabilities instead of just the predicted classes. You can get them using the .predict_proba()  function like so:

So how did this model (trained on the down-sampled dataset) do in terms of AUROC?

Ok... and how does this compare to the original model trained on the imbalanced dataset?

Remember, our original model trained on the imbalanced dataset had an accuracy of 92%, which is much higher than the 58% accuracy of the model trained on the down-sampled dataset.

However, the latter model has an AUROC of 57%, which is higher than the 53% of the original model (but not by much).

Note: if you got an AUROC of 0.47, it just means you need to invert the predictions because Scikit-Learn is misinterpreting the positive class. AUROC should be >= 0.5.

4. Penalize Algorithms (Cost-Sensitive Training)

The next tactic is to use penalized learning algorithms that increase the cost of classification mistakes on the minority class.

A popular algorithm for this technique is Penalized-SVM:

During training, we can use the argument class_weight='balanced'  to penalize mistakes on the minority class by an amount proportional to how under-represented it is.

We also want to include the argument probability=True  if we want to enable probability estimates for SVM algorithms.

Let's train a model using Penalized-SVM on the original imbalanced dataset:

Again, our purpose here is only to illustrate this technique. To really determine which of these tactics works best for this problem, you'd want to evaluate the models on a hold-out test set.

5. Use Tree-Based Algorithms

The final tactic we'll consider is using tree-based algorithms. Decision trees often perform well on imbalanced datasets because their hierarchical structure allows them to learn signals from both classes.

In modern applied machine learning, tree ensembles (Random Forests, Gradient Boosted Trees, etc.) almost always outperform singular decision trees, so we'll jump right into those:

Now, let's train a model using a Random Forest on the original imbalanced dataset.

Wow! 97% accuracy and nearly 100% AUROC? Is this magic? A sleight of hand? Cheating? Too good to be true?

Well, tree ensembles have become very popular because they perform extremely well on many real-world problems. We certainly recommend them wholeheartedly.

However:

While these results are encouraging, the model could be overfit, so you should still evaluate your model on an unseen test set before making the final decision.

Note: your numbers may differ slightly due to the randomness in the algorithm. You can set a random seed for reproducible results.

Honorable Mentions

There were a few tactics that didn't make it into this tutorial:

Create Synthetic Samples (Data Augmentation)

Creating synthetic samples is a close cousin of up-sampling, and some people might categorize them together. For example, the SMOTE algorithm is a method of resampling from the minority class while slightly perturbing feature values, thereby creating "new" samples.

You can find an implementation of SMOTE in the imblearn library.

*Update: One of our readers, Marco, brought up a great point about the risks of using SMOTE without proper cross-validation. Check out the comments section for more details or read his blog post on the topic.

Combine Minority Classes

Combining minority classes of your target variable may be appropriate for some multi-class problems.

For example, let's say you wished to predict credit card fraud. In your dataset, each method of fraud may be labeled separately, but you might not care about distinguishing them. You could combine them all into a single 'Fraud' class and treat the problem as binary classification.

Reframe as Anomaly Detection

Anomaly detection, a.k.a. outlier detection, is for detecting outliers and rare events. Instead of building a classification model, you'd have a "profile" of a normal observation. If a new observation strays too far from that "normal profile," it would be flagged as an anomaly.

Conclusion & Next Steps

In this guide, we covered 5 tactics for handling imbalanced classes in machine learning:

  1. Up-sample the minority class
  2. Down-sample the majority class
  3. Change your performance metric
  4. Penalize algorithms (cost-sensitive training)
  5. Use tree-based algorithms

These tactics are subject to the No Free Lunch theorem, and you should try several of them and use the results from the test set to decide on the best solution for your problem.

If you enjoyed this guide, we invite you to sign up for our free 7-day crash course on applied ML. We share lessons that are not found on our blog, and we'll also notify you when we publish new tutorials like this one.

2 Comments

  • Marco

    July 6, 2017

    Great article, I understand cross-validation has been kept out for clarity, but I’ve seen horrible papers in the past doing upsampling with SMOTE before cross-validation, then obtaining near perfect accuracy on ‘impossible’ problems just because validation data was contaminated with training data, so maybe still worth a mention. I’ve covered the problem here: http://www.marcoaltini.com/blog/dealing-with-imbalanced-data-undersampling-oversampling-and-proper-cross-validation (data and code also available). Cheers.

    Marco

    • EliteDataScience

      July 6, 2017

      Hi Marco,

      You bring up a great point. Just updated the article with a note after the SMOTE section and a link to your blog post.

      Thanks!