How to use Explainable Machine Learning with Python
 Interpret with Permutation Feature Importance, PDP, SHAP, LIME

Lianne & Justin

Lianne & Justin

explainable machine learning crystal clear
Source: Pixabay

In this guide for beginners, you’ll learn about explainable machine learning and use it to interpret models in Python.

Once training a machine learning model, it is critical to interpret its prediction results. It helps us improve the model performance and sell the model to others. By following this tutorial, you’ll learn:

  • What is explainable machine learning?
  • How to apply popular explainable ML methods in Python:
    • Permutation Feature Importance
    • Partial Dependence Plots (PDP)
    • SHapley Additive exPlanations (SHAP)
    • Local Interpretable Model-agnostic Explanations (LIME)
  • Plus some tips on using these methods!

We’ll fit an XGBoost model on a real-world dataset as an example throughout the guide.

If you want to interpret your complicated machine learning models, this quick and practical tutorial with examples will get you started!

Let’s jump in!



Intro to Explainable Machine Learning

Background

Before diving into explainable machine learning, let’s review some background knowledge first.

This tutorial focuses on the supervised machine learning models, where machines learn based on existing labels to make predictions. If you need a quick recap of ML, check out Machine Learning for Beginners: Overview of Algorithm Types. Then, based on explainability, we can divide all the supervised machine learning models into two main categories:

  • glass box: interpretable by domain experts
  • black box: hard to interpret by humans

The ‘glass box’ models include decision trees, linear regression, logistic regression, etc. The simplicity of these models makes them possible to understand. For example, if we build a decision tree model, we can clearly see which features lead to the final prediction. But on the other hand, these ‘glass box’ models might not provide ideal prediction results.

Then there come the ‘black box’ models, including random forests, gradient boosting models, neural networks, etc. They are often the winners of competitions since they tend to provide the best predictions. But due to their complicated structures, no human can understand how they make a prediction decision.

These ‘black box’ models are the focus of this tutorial. Their lack of explainability is a significant challenge.

Imagine working in the marketing team, and you’ve built a ‘black-box’ model to predict customer churn. What else do you need to consider?

  • you would like to interpret the model to ensure it’s working correctly. This also helps to debug and improve the model performance
  • before deploying the model into production, you need to explain the model to others, including the management, for approval
  • also, understanding the features contributing to each customer’s churn provides valuable insights for the business

All of the above scenarios require model explainability. We need methods to explain how the ‘black box’ models make decisions to trust them. This is why we need explainable machine learning.

What is Explainable Machine Learning?

Explainable machine learning (or Explainable Artificial Intelligence) includes methods to extract and interpret information from ‘black box’ models, in a humanly understandable way so that we can explain how the model makes predictions.

With the help of these explainable ML methods, we can answer questions such as:

  • what are the critical features for the predictions?
  • how much did this particular feature contribute to the prediction?
  • why was this specific instance classified as positive?

The explainable machine learning methods can be of two main categories:

  • Summary-based: explain the average behavior of the model
  • Instance-based: explain individual instance’s prediction

In general, when we try to understand the overall picture or debug the model, the summary-based methods are more appropriate. In contrast, instance-based methods help us focus on one prediction at a time.

So in this tutorial, we’ll cover four explainable machine learning methods of both types:

  • Permutation Feature Importance and PDP: summary-based
  • SHAP: both summary and instance-based
  • LIME: instance-based

Also, all of the above are so-called model-agnostic methods. This means they are generally independent of machine learning models, so we can apply them to explain various types of models.

Before jumping into these methods, we must fit a model to explain. So in the next section, let’s look at our example dataset and build a model.

Example dataset and model

We’ll explain an XGBoost model that predicts students’ scores throughout the tutorial. We’ll use a dataset of student performance from UCI. Since our goal is only to explain the model, I’ve only done ‘casual’ modeling. Please use the below Python code to transform the original dataset.

In the end, you should have df['score'] being the target and df_features being the 15 features of students.

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 395 entries, 0 to 394
Data columns (total 15 columns):
 #   Column                  Non-Null Count  Dtype
---  ------                  --------------  -----
 0   travel_time_school      395 non-null    int64
 1   study_time              395 non-null    int64
 2   past_failures           395 non-null    int64
 3   family_relationships    395 non-null    int64
 4   free_time               395 non-null    int64
 5   number_absences         395 non-null    int64
 6   sex_M                   395 non-null    uint8
 7   family_size_LE3         395 non-null    uint8
 8   parent_cohab_status_T   395 non-null    uint8
 9   extra_support_yes       395 non-null    uint8
 10  extra_paid_classes_yes  395 non-null    uint8
 11  extra_curricular_yes    395 non-null    uint8
 12  nursery_school_yes      395 non-null    uint8
 13  want_higher_edu_yes     395 non-null    uint8
 14  internet_access_yes     395 non-null    uint8
dtypes: int64(6), uint8(9)
memory usage: 22.1 KB

Then, we’ll build an XGBoost model to predict score based on df_features. We chose XGBoost for the following reasons:

  • popularity: it is often the winner of competitions since the model provides excellent prediction results
  • complexity: it is an ensemble of decision trees, which makes the model a ‘black box’

So XGBoost is a great example to learn explainable machine learning.

Let’s use the XGBRegressor. Again, our goal is to fit a model that we can explain instead of a good one. So we omitted processes like train-test split and hyperparameter tuning.

Now we have the model called model. We’re ready to explain it.


Explainable ML method #1: Permutation Feature Importance

The permutation feature importance method provides us with a summary of the importance of each feature to a particular model. It measures the feature importance by calculating the changes of a model score after permuting such a feature. Here are the basic steps:

  1. based on the original dataset, calculate the score of the model such as R2 or accuracy
  2. for each feature or column in the dataset:
    • randomly shuffle/permute its value. This breaks the relationship between the feature and the target
    • calculate the new score based on the permuted sample
    • you may repeat this process
  3. compare the changes of scores between the original dataset and the permuted datasets: the larger the changes for a feature, the more the model relies on the feature for prediction, the more important such feature is to the model

Let’s apply the permutation feature importance method to our XGBoost model. We use the permutation_importance function from sklearn. Besides the basic settings, we also include in the arguments:

  • scoring: sets the score function(s) for importance measurement. We are feeding two scorers stored in the variable scoring, which will produce two sets of importances
    • the feature importances based on different scorers could be different
  • n_repeats: sets the number of times to permute a feature. The feature is usually shuffled multiple times
{'r2': {'importances_mean': array([0.17602671, 0.2593226 , 0.45623922, 0.15198917, 0.31820088,
         0.70361245, 0.11848649, 0.09711127, 0.0450286 , 0.07920763,
         0.07026565, 0.10263182, 0.10923853, 0.06280063, 0.02629473]),
  'importances_std': array([0.019029  , 0.03318512, 0.03342523, 0.01361083, 0.02545227,
         0.06263065, 0.01626824, 0.01737502, 0.01477137, 0.00668625,
         0.00539641, 0.01185956, 0.01408107, 0.00618802, 0.0019774 ]),
  'importances': array([[0.14546738, 0.19518244, 0.18289745, 0.1632562 , 0.19333007],
         [0.20183863, 0.24620682, 0.28595884, 0.26800123, 0.29460751],
         [0.46917819, 0.48623832, 0.47561145, 0.39190052, 0.4582676 ],
         [0.16847948, 0.16397574, 0.13456689, 0.13791707, 0.15500664],
         [0.28528702, 0.29134219, 0.32917372, 0.33474997, 0.3504515 ],
         [0.71663196, 0.8176794 , 0.63674347, 0.66491067, 0.68209674],
         [0.13105339, 0.11758379, 0.11118837, 0.13977409, 0.0928328 ],
         [0.09562589, 0.06970246, 0.08979126, 0.12005011, 0.11038661],
         [0.04299708, 0.02512653, 0.04538924, 0.07094652, 0.04068361],
         [0.08924325, 0.08122129, 0.07265676, 0.08196891, 0.07094794],
         [0.07915642, 0.06867696, 0.07251366, 0.06807712, 0.06290406],
         [0.11102472, 0.11071997, 0.08007307, 0.11014394, 0.10119742],
         [0.12114419, 0.09811054, 0.10505943, 0.12968165, 0.09219684],
         [0.07129185, 0.0631908 , 0.05744222, 0.05454735, 0.06753092],
         [0.02623967, 0.02693123, 0.02334381, 0.02943233, 0.02552661]])},
 'neg_mean_squared_error': ......}

As you can see, perm_importance returns results for each scorer (r2 and neg_mean_squared_error: only printed the results for r2 since it’s long):

  • importances_mean: means of importances over n_repeats of the 15 features
  • importances_std: standard deviations over n_repeats
  • importances: raw importances
    • the importance for our example is defined to be the decrease in a model score when a single feature value is permuted

Among those above, we’ll compare the importances_mean of each feature. We can extract its data for the scoring metric r2 and make a bar chart showing the features’ ranked by their importance.

This is a nice summary of feature importance. You can see which features are more important for the model.

explainable machine learning Permutation Feature Importance Python
Permutation Feature Importance chart

Use with caution: when you have two correlated features, both features’ importances will show up as lower. This is because when one feature is permuted, the model can still access the other correlated feature, resulting in such feature appearing to be less important, and vice versa.

Explainable ML method #2: Partial Dependence Plots (PDP)

Partial Dependence Plots (PDP) visualize the effect of one or two features of interest on the prediction results while marginalizing the rest of the other features. More specifically, partial dependence is the expected target response as a function of those one or two features of interest. Then we form a plot of such relationships to observe features’ importance. When PDP is flat, such features are less important.

PDP is a summary-based explainable machine learning method because it includes all instances and summarizes the relationship of the features of interest with the model’s predictions.

You might wonder, why don’t we look at more features at a time?

This is due to our constraints: we, as humans, can’t process higher dimensional plots. When we have one feature of interest, we have a 2D plot. When we have two features of interest, we have a 3D plot. And that’s the limit of what we can interpret.

We’ll use the partial_dependence function from shap. Yes, we’ll cover the shap package more in the SHAP method. But it also includes functions to plot PDPs. We’ll test out a couple of examples, considering one feature at a time.

First, let’s look at the PDP for the feature want_higher_edu_yes.

want_higher_edu_yes is a dummy variable with values of 0 and 1. The grey bars on the plot indicate its data distribution. Looking at the blue line, we can see the expected score is higher when this feature has a value of 1 versus 0. This makes sense since students wanting higher education could be related to them having higher marks.

explainable machine learning Partial Dependence Plots (PDP) Python
PDP for want_higher_edu_yes

Next, let’s plot PDP for the feature past_failures.

We can see that controlling for all other features, the more past failures, the lower the expected scores. And it also makes sense.

explainable machine learning Partial Dependence Plots (PDP) Python
PDP for past_failures

The last two features both show a clear monotonic relationship with the score. Let’s look at a complex one, number_absences.

The plot shows that controlling for other features, the expected scores are higher when there are a couple of absences. And there’s a local peak when the absences are over 20, although that could be due to a lack of data.

explainable machine learning Partial Dependence Plots (PDP) Python
PDP for number_absences

While all the above features show different relationships with the target, they are all important for the model since the PDP varied instead of being flat.

Use with caution: first, this method assumes the features of interest should be independent of the rest of the features, which is often not true. Also, PDP only shows the average marginal effects, which could hide a heterogeneous relationship related to interactions. In that case, it is better to look at individual instances through the Individual Conditional Expectation (ICE) plot.

Explainable ML method #3: SHapley Additive exPlanations (SHAP)

SHapley Additive exPlanations (SHAP) is a practical method based on Shapley values. So let’s start with the question: what are Shapley values?

Shapley values is an instance-based explaining method derived from coalitional game theory.

Assume playing a game of predicting the score of a student. Each feature value of such student ‘plays’ the game by contributing to the model. How do we know the contribution of each feature to the prediction result? We can calculate the Shapley values to distribute the ‘payout’ among the feature values fairly.

To be more specific, let’s consider an example.

Assume the model predicts this particular student to have a score of 12. While the average predicted score of all students is 10. Why is there a difference of 2 between the average (10) and this particular prediction (12)? Shapley values can help us quantify the contribution of each feature value of such a student. For example, perhaps study_time contributed +3, past_failures contributed -4, and so on. All the features contributions/Shapley values add up to the difference of 2. The larger the absolute value of Shapley value, the more critical the feature.

We won’t cover the details of the process. But the exact Shapley value is calculated as the average marginal contribution of one feature value across all possible coalitions of feature values. It is the only explanation method in this tutorial with a solid theory. You can read more about it here.

As you can imagine, as the number of features increases, the number of possible coalitions increases exponentially, resulting in a computation increase. So we usually approximate the Shapley values, rather than applying the exact calculations.

And that’s why we will apply SHAP. SHAP includes an estimation approach of Shapley values, but more than that. Besides being an instance-based method to explain one instance, SHAP also contains methods of combining the Shapley values of all instances to summarize the model predictions.

Let’s use the TreeExplainer function from shap and make a summary plot. We picked this function since it uses the Tree SHAP algorithms, which is appropriate on our XGBoost model – an ensemble of trees. If you are using different models, you can find other SHAP explainers here.

So on the below SHAP beeswarm summary plot, we can see both feature importances and the effects on the predictions. Each point marks a SHAP value for a feature and a student:

  • Along the y-axis, the features are sorted from top to down by the sum of SHAP value magnitudes of all instances
  • Along the x-axis, for each feature, you can also see the distribution of the impacts each feature has on the model’s predictions
    • the color of dots represents the values of the features: red high and blue low
explainable machine learning SHAP values summary plot Python
SHAP summary plot

For example, we can see that past_failures is the second most important feature for our model. The higher values of past_failures (red dots) tend to contribute negatively to the prediction. In comparison, the lower values (blue dots) have positive contributions. This makes sense since the more past failures students have, the more likely their scores will be lower.

Besides this summary plot of all features and all instances, we can also focus on one feature’s effect across the entire dataset. The below scatter plot shows the SHAP values for the feature past_failures.

The grey bars represent the distribution of the feature, while the dots show the SHAP values.

The vertical dispersion at a single value of past_failures shows interaction effects with other features. Using the color=shap_values argument, the scatter plot picked the best feature to color by to reveal the interactions. In our example, it picked number_absences. We can see that the feature past_failures has less impact on scores (lower SHAPs) with higher number_absences values.

explainable machine learning SHAP values scatter plot Python
SHAP scatter plot

So far, we’ve been looking at SHAP as a summary-based method. We can also use it as an instance-based method, i.e., look at one student at a time.

For example, we can display a waterfall plot for the first explanation.

On this plot, let’s find two values first. At the bottom of the plot, E(f(x)), the average predicted score of the dataset is 10.448. At the top of the plot, f(x), the predicted score for this student, is 6. Between these two values, the waterfall plots how each feature contributes to the changes of prediction from E(f(x)) to f(x). For example, the extra_support_yes had the most impact. It pulled the prediction down, while the past_failures value increased the prediction.

explainable machine learning SHAP values waterfall plot Python
SHAP waterfall plot

Another instance’s waterfall plot could look completely different.

SHAP values waterfall plot Python
SHAP waterfall plot

Great!

As you can see, SHAP can be both a summary and instance-based approach to explaining our machine learning models. There are also other convenient plots in the shap package, please explore if you need them.

Use with caution: SHAP is my personal favorite explainable ML method. But it may not fit all your specific needs, so please only use it when it answers your business questions.

Explainable ML method #4: Local Interpretable Model-agnostic Explanations (LIME)

Local Interpretable Model-Agnostic Explanations (LIME) is another popular method to explain one instance. Unlike SHAP, LIME suggests learning interpretable local surrogate models around the prediction to estimate features’ effects.

Suppose you want to explain how a ‘black box’ model makes a specific prediction on one instance. Here are the general steps of LIME:

  • perturb the dataset, and get the ‘black box’ model predictions for the new points
  • weight the new samples based on their proximity to the instance of interest
  • train a weighted, interpretable model on the dataset with the variations, i.e., learn a local surrogate model
    • this local surrogate model should approximate the ‘black box’ model’s prediction locally
    • the common local surrogate models include linear regression, decision tree
  • interpret the local model to explain the prediction

We’ll use the lime_tabular module from lime since our data is tabular. You can also explore other modules for image or text data as needed.

Then we’ll generate explanations for one instance. We’ve also set the maximum number of features to include in the explanation to be 5.

Intercept 5.7634568235854395
Prediction_local [10.11461869]
Right: 8.998626

Here is how we can show the results in JupyterLab.

We can see that the predicted score for this student is 8.998626. The features and their contributions (blue being negative, orange being positive) to this prediction are shown, as well as their feature values for this student. For example, the feature past_failures with a value of 0, had the strongest positive effect on the prediction.

explainable machine learning LIME Python
LIME plot

Use with caution: when using the LIME method, we need to define the neighborhood of the instance and other settings. This needs different trials to get a good and relatively stable result.

That’s it! This is all the explainable machine learning methods covered in this tutorial.


How to choose Explainable Machine Learning methods?

In the last section of this tutorial, I would like to provide a general procedure of when to use and how to choose explainable machine learning methods:

  1. if your business demands high model explainability, you should try ‘glass box’ models first
  2. if you find out that only the ‘black box’ models offer satisfying performance, you can try interpreting them with explainable machine learning methods
  3. during this process, you need to consider what business questions you are trying to answer, for example:
    • do I need to explain the overall prediction or individual predictions?
    • do I want to look at the detailed effect of each feature on the prediction?

You’ve learned four popular explainable ML methods in this tutorial, and they can be your starting point. Each of them has its advantages and disadvantages. Please don’t hesitate to apply each of them and also dive deeper.

If none of them satisfy your needs, you can always explore other methods!


In this guide, you’ve learned about explainable machine learning and its four popular methods in Python.

Hope now you are ready to apply them to interpret your machine learning models.

We’d love to hear from you. Leave a comment for any questions you may have or anything else.

Twitter
LinkedIn
Facebook
Email
Lianne & Justin

Lianne & Justin

Leave a Comment

Your email address will not be published. Required fields are marked *

More recent articles

Scroll to Top

Learn Python for Data Analysis

with a practical online course

lectures + projects

based on real-world datasets

We use cookies to ensure you get the best experience on our website.  Learn more.