Exploring the Reasons for Unexpected Prediction Distributions in Machine Learning Models

Exploring the Reasons for Unexpected Prediction Distributions in Machine Learning Models

When investigating unexpected model behavior, many Data Scientists I know start by analyzing distribution drifts in the most important predictors. This approach can help, but it carries significant risks: you might draw the wrong conclusions or fail to identify the root cause altogether.

Let’s consider a classic example, the California House Price prediction task. Imagine scoring new properties in production using a model trained on this dataset. Normally, we expect predictor distributions to stay relatively consistent over weeks or months. Now, suppose one period shows a clear upward shift in predicted house prices.

This raises two key questions:

  1. What caused the shift?
  2. Can we trust these predictions?

The second question, arguably more critical, boils down to: Is the model's performance still acceptable? There’s a wealth of excellent materials addressing this, including resources from NannyML or this wonderful article by Samuele Mazzanti , so I’ll refer you to those rather than summarizing them here.

Instead, I’d like to focus on the first question: understanding what caused the drift.

In our example, the Median Income in Block Group predictor is typically the most important feature in models trained on this dataset.

For simplicity, let’s assume that during previous inference periods with expected predictions, the median income followed a uniform distribution within the range of 0–15. While this assumption may not realistically reflect the dataset, our primary goal here is to explore how the model’s response function behaves across the range of predictor values.

Now, let’s say univariate metrics like the Population Stability Index (PSI) or Wasserstein Distance indicate a shift in the median income distribution. Is this enough to explain the unexpected predictions? Not necessarily.

To see why, we can examine the model’s sensitivity to median income values, expressed, for instance, through Accumulated Local Effects (ALE). Such a plot might reveal zones where the model’s predictions are either highly sensitive or nearly insensitive to changes in median income:

  • Below ~2 and above ~9.25: the model’s response flattens, showing little or no sensitivity to shifts. A drift within these zones wouldn’t explain prediction changes. In this specific case, the plateau at high values is primarily due to the very low density of samples with high Median income values, causing them to fall into a terminal node. Again, our main interest here is in the shape of the sensitivity function, which, in another task, might have a plateau with a solid physical or business explanation behind it.
  • Around ~5.55 to ~5.95: this region may show high sensitivity, where even small shifts, undetected by, for example, PSI if they occur within one bucket, could drastically impact predictions.

Additionally, univariate drift analysis doesn’t account for interactions between predictors, which your model likely considers. Nor does it help detect cumulative effects from small changes across multiple key predictors.

A Better Starting Point

I recommend starting your investigation by comparing the distribution of predictor contributions between the reference and drifted periods. The most straightforward option is to use SHAP values, but other tools can work too. For instance, LightGBM allows you to retrieve a matrix or tensor of feature contributions without external libraries by using the predict or predict_proba methods with predict_contrib=True.

The image below shows the SHAP Summary plots for the reference and examined periods. The examined period's data was generated from the reference data using the following modifications:

  • For Median Income, all values above 9.5 were replaced with random values from a uniform distribution between 14 and 15.
  • Random values from a normal distribution with a mean of -2 and a standard deviation of 1 were added to both Latitude and Longitude.

Comparing the distributions of feature contributions provides a correct interpretation of the reasons behind the shift in prediction distribution:

  • Despite the shift in the distribution of Median Income, it does not impact the predictions distribution.
  • The primary causes of the shift are Latitude (as it is a more important feature for the model under equal shift in distribution) and, to a bit lesser extent, Longitude.

If feature interactions play a significant role in driving the issue, uncovering an explanation may still require analyzing multiple plots that illustrate how these interactions impact predictions, such as Partial Dependence Plots or Pairwise Feature Interaction Plots. However, comparing the distribution of predictor impacts in the reference and examined periods will provide a clear direction and rationale for your investigation.

Francis Gichere

Lead Data Scientist @ BURN | Decision Science | Applied AI & ML | 6yrs experience in Data

4 个月

Very informative

Denis Sidorenko

Data Scientist | ML Engineer | MScIT

4 个月

Thanks for the revealing ideas, Ilia! I had a question while I was reading. Please clarify the given example regarding SHAP values and the previous explanation around shifts in median income. There are shifts in Lat and Long in SHAP plots, but almost no shifts around MedInc are displayed. Why would you recommend using this technique to determine shifts in MedInc?

要查看或添加评论,请登录

Ilia Ekhlakov的更多文章