Exploring the Reasons for Unexpected Prediction Distributions in Machine Learning Models
Ilia Ekhlakov
Senior Data Scientist @ Wrike | B2B SaaS | Revenue Strategy & Ops | MSc in Physics | 9 YoE
When investigating unexpected model behavior, many Data Scientists I know start by analyzing distribution drifts in the most important predictors. This approach can help, but it carries significant risks: you might draw the wrong conclusions or fail to identify the root cause altogether.
Let’s consider a classic example, the California House Price prediction task. Imagine scoring new properties in production using a model trained on this dataset. Normally, we expect predictor distributions to stay relatively consistent over weeks or months. Now, suppose one period shows a clear upward shift in predicted house prices.
This raises two key questions:
The second question, arguably more critical, boils down to: Is the model's performance still acceptable? There’s a wealth of excellent materials addressing this, including resources from NannyML or this wonderful article by Samuele Mazzanti , so I’ll refer you to those rather than summarizing them here.
Instead, I’d like to focus on the first question: understanding what caused the drift.
In our example, the Median Income in Block Group predictor is typically the most important feature in models trained on this dataset.
For simplicity, let’s assume that during previous inference periods with expected predictions, the median income followed a uniform distribution within the range of 0–15. While this assumption may not realistically reflect the dataset, our primary goal here is to explore how the model’s response function behaves across the range of predictor values.
Now, let’s say univariate metrics like the Population Stability Index (PSI) or Wasserstein Distance indicate a shift in the median income distribution. Is this enough to explain the unexpected predictions? Not necessarily.
To see why, we can examine the model’s sensitivity to median income values, expressed, for instance, through Accumulated Local Effects (ALE). Such a plot might reveal zones where the model’s predictions are either highly sensitive or nearly insensitive to changes in median income:
Additionally, univariate drift analysis doesn’t account for interactions between predictors, which your model likely considers. Nor does it help detect cumulative effects from small changes across multiple key predictors.
A Better Starting Point
I recommend starting your investigation by comparing the distribution of predictor contributions between the reference and drifted periods. The most straightforward option is to use SHAP values, but other tools can work too. For instance, LightGBM allows you to retrieve a matrix or tensor of feature contributions without external libraries by using the predict or predict_proba methods with predict_contrib=True.
The image below shows the SHAP Summary plots for the reference and examined periods. The examined period's data was generated from the reference data using the following modifications:
Comparing the distributions of feature contributions provides a correct interpretation of the reasons behind the shift in prediction distribution:
If feature interactions play a significant role in driving the issue, uncovering an explanation may still require analyzing multiple plots that illustrate how these interactions impact predictions, such as Partial Dependence Plots or Pairwise Feature Interaction Plots. However, comparing the distribution of predictor impacts in the reference and examined periods will provide a clear direction and rationale for your investigation.
Lead Data Scientist @ BURN | Decision Science | Applied AI & ML | 6yrs experience in Data
4 个月Very informative
Data Scientist | ML Engineer | MScIT
4 个月Thanks for the revealing ideas, Ilia! I had a question while I was reading. Please clarify the given example regarding SHAP values and the previous explanation around shifts in median income. There are shifts in Lat and Long in SHAP plots, but almost no shifts around MedInc are displayed. Why would you recommend using this technique to determine shifts in MedInc?