Model Interpreter via Contextual AI

In this tutorial, we use xai to analyze a trained model and its features. The tutorial contains 3 parts: - feature distribution analysis - trained model feature importance analysis - model interpretation via explanation aggregation

Prerequisites : Import libraries

xai.model.interpreter is the main package that users of Contextual AI interact with.

[1]:
# Some auxiliary imports for the tutorial
import sys
import random
import numpy as np

from pprint import pprint
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier

# Set seed for reproducibility
np.random.seed(123456)

# Set the path so that we can import ExplainerFactory
sys.path.append('../../')
[2]:
# Main Contextual AI imports
import xai
from xai.model.interpreter.model_interpreter import ModelInterpreter

In this tutorial, we train a sample RandomForestClassifier model on the Wisconsin breast cancer dataset, a sample binary classification problem that is provided by scikit-learn (details can be found here). We will use APIs in xai.model_interpreter package to conduct feature analysis, feature ranking and model interpretation.

[3]:
# Load the dataset and prepare training and test sets
raw_data = datasets.load_breast_cancer()
X, y = raw_data['data'], raw_data['target']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

1. Feature Analysis

Key parameters used in feature analysis: - feature_names: a list of str, names of each feature. - feature_types: a list of str, pre-defined type for each feature. - train_x: numpy.dnarray, training data. Each row is a training sample, each column is a feature. - labels: a list of str/int, the class label for each training sample. - trained_model: model object, the trained model object.

Feature Distribution

To analyze the feature distribution, we need to import xai.model_interpreter.FeatureInterpreter to initialize a FeatureInterpreter and call the get_feature_distribution() function.

[4]:
# Define the feature data type and feature names
from xai.data.constants import DATATYPE
feature_names = raw_data['feature_names'].tolist()
feature_types = [DATATYPE.NUMBER]*len(feature_names)

from xai.model.interpreter.feature_interpreter import FeatureInterpreter
feature_interpreter = FeatureInterpreter(feature_names=feature_names)

stats = feature_interpreter.get_feature_distribution(feature_types=feature_types,train_x=X_train,labels=y_train)

We use our plot helper class NotebookPlots to plot the feature distribution results for the first 3 features. Details for each type of stats can be found in xai.data.explorer under different data type packages.

[5]:
from xai.plots.data_stats_notebook_plots import NotebookPlots
plotter = NotebookPlots()

sample_features = feature_names[:3]
for feature_name in sample_features:
    (label_feature_stats, all_feature_stats) = stats[feature_name]

    plotter.plot_labelled_numerical_stats(labelled_stats=label_feature_stats, # stats for each class
                                          all_stats=all_feature_stats, # stats for all classes
                                          label_column='Class', # column name for label
                                          feature_column=feature_name) # column name for feature

mean radius

../../_images/tutorials_model_interpreter_tutorial_model_interpreter_9_1.png
class min max mean median sd total_count
0 1 6.981 16.84 12.095455 1.700759 16.84 286
1 0 11.420 28.11 17.468462 3.158721 28.11 169
2 all 6.981 28.11 14.091143 3.502028 28.11 455

mean texture

../../_images/tutorials_model_interpreter_tutorial_model_interpreter_9_4.png
class min max mean median sd total_count
0 1 9.71 33.81 18.066643 4.013245 33.81 286
1 0 10.38 39.28 21.600118 3.738992 39.28 169
2 all 9.71 39.28 19.379077 4.269827 39.28 455

mean perimeter

../../_images/tutorials_model_interpreter_tutorial_model_interpreter_9_7.png
class min max mean median sd total_count
0 1 43.79 108.4 77.730105 11.292902 108.4 286
1 0 75.00 188.5 115.335503 21.653977 188.5 169
2 all 43.79 188.5 91.697824 24.176163 188.5 455

Feature correlation

To analyze the feature correlction, we need to import xai.model_interpreter.FeatureInterpreter to initialize a FeatureInterpreter and call the get_feature_correlation() function. Same FeatureInterpreter object can be reused here for correlation analysis.

In this sample, as all the features are numerical features, correlation are calculated using Pearson’s testing as a default setup. And we use plot helper function to plot the heatmap for the correlation between all features.

[6]:
types, values = feature_interpreter.get_feature_correlation(feature_types=feature_types,train_x=X_train)
plotter.plot_correlation_heatmap(types,values)
../../xai/model/interpreter/feature_interpreter.py:138: SettingWithCopyWarning:
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  correlation_types[col1][col2] = method
/Users/i309943/opt/anaconda3/envs/xai/lib/python3.6/site-packages/pandas/core/indexing.py:202: SettingWithCopyWarning:
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  self._setitem_with_indexer(indexer, value)

Correlation Type: pearson

../../_images/tutorials_model_interpreter_tutorial_model_interpreter_11_2.png

2. Feature Importance

As feature importance is associated with a model, we need to firstly train a sample model first. In this tutorial, we train a sample RandomForestClassifier model on the the dataset.

[7]:
# Instantiate a classifier, train, and evaluate on test set
clf = RandomForestClassifier()
clf.fit(X_train, y_train)
clf.score(X_test, y_test)
/Users/i309943/opt/anaconda3/envs/xai/lib/python3.6/site-packages/sklearn/ensemble/forest.py:245: FutureWarning: The default value of n_estimators will change from 10 in version 0.20 to 100 in 0.22.
  "10 in version 0.20 to 100 in 0.22.", FutureWarning)
[7]:
0.956140350877193

To analysze the feature importance ranking, we need to import xai.model_interpreter.FeatureInterpreter to initialize a FeatureInterpreter and call the get_feature_ranking() function. Same FeatureInterpreter object can be reused here for feature importance ranking.

The code below shows the feature importance used the default method provided by the model itself.

[8]:
feature_importance_ranking = feature_interpreter.get_feature_importance_ranking(trained_model=clf,
                                                                                train_x=X_train,
                                                                                method='default')
plotter.plot_feature_importance_ranking(feature_importance_ranking)
../../_images/tutorials_model_interpreter_tutorial_model_interpreter_15_0.png

By changing the method, we can get feature importance based on different criterion. The code below shows the feature importance calculated by shap value.

[9]:
feature_importance_ranking = feature_interpreter.get_feature_importance_ranking(trained_model=clf,
                                                                                train_x=X_train,
                                                                                method='shap')
plotter.plot_feature_importance_ranking(feature_importance_ranking)
../../_images/tutorials_model_interpreter_tutorial_model_interpreter_17_0.png

We can also plot the shap values in a summary plot to show individual sample shap values for all the features.

[10]:
feature_shap_values = feature_interpreter.get_feature_shap_values(trained_model=clf,
                                                                  train_x=X_train)

plotter.plot_feature_shap_values(feature_shap_values,class_id = 1, X_train=X_train)
../../_images/tutorials_model_interpreter_tutorial_model_interpreter_19_0.png

3. Model Interpretation by aggregate explanations

One way to interpret the model is by aggregating the individual explanations which tries to explain the model locally

Step 0. Import the ModelInterpreter

[11]:
from xai.model.interpreter.model_interpreter import ModelInterpreter

Step 1. Define domain and algorithm

As model interpreter is using a model-agnostic explainer, domain and algorithm is dependent on xai.explainer package. See details in xai.explainer.config

[12]:
from xai.explainer.config import DOMAIN, ALG
model_interpreter = ModelInterpreter(domain=DOMAIN.TABULAR, algorithm=ALG.LIME)

Step 2. Build interpreter

Based on the domain and algorithm chosen, build the explainer in the interpreter by passing in the required parameter.

Required parameters include: - training data - training labels - model prediction functions

See details in xai.explainer.

[13]:
model_interpreter.build_interpreter(
    training_data=X_train,
    training_labels=y_train,
    mode=xai.MODE.CLASSIFICATION,
    predict_fn=clf.predict_proba,
    feature_names=raw_data['feature_names'],
    class_names=list(raw_data['target_names'])
)

Step 3. Interpreter the model with training data

Model Interpretation

The interpreter explains the model by aggregate explainations based on predicted classes. By calling interpret_model() with training data, the explainer will explain each sample on a local manner and aggregate the local explanations on each class to provide a global interpretation statistically. For now, we support 3 types of statistical aggregation: - top_k: how often a feature appears in the top K features in the explanation - average_score: average score for each feature in the explanation - average_ranking: average ranking for each feature in the explanation

Default type is top_k.

[14]:
stats = model_interpreter.interpret_model(samples=X_train, stats_type='top_k',k=5)
../../xai/model/interpreter/model_interpreter.py:83: UserWarning: Interpret 100/455 samples
  idx + 1, len(samples)))
../../xai/model/interpreter/model_interpreter.py:83: UserWarning: Interpret 200/455 samples
  idx + 1, len(samples)))
../../xai/model/interpreter/model_interpreter.py:83: UserWarning: Interpret 300/455 samples
  idx + 1, len(samples)))
../../xai/model/interpreter/model_interpreter.py:83: UserWarning: Interpret 400/455 samples
  idx + 1, len(samples)))
[15]:
class_stats, total_count = stats
num_of_top_explanation = 15
for _class,_explanation_ranking in class_stats.items():
    print('Interpretation for Class %s'%_class)
    plotter.plot_feature_importance_ranking([(key,value) for key,value in _explanation_ranking.items()]
                                            [:num_of_top_explanation])
Interpretation for Class 1
../../_images/tutorials_model_interpreter_tutorial_model_interpreter_29_1.png
Interpretation for Class 0
../../_images/tutorials_model_interpreter_tutorial_model_interpreter_29_3.png

Error Analaysis

Error analysis helps to aggregate explanations on samples that are wrongly classified in the validation data set.

By calling function error_analysis(), it returns a stats of top explanations for wrongly classified samples.

[16]:
stats = model_interpreter.error_analysis(class_num=2, valid_x=X_test, valid_y=y_test, stats_type='average_score', k=5)
../../xai/model/interpreter/model_interpreter.py:128: UserWarning: Analyze 10/114 samples
  idx + 1, len(valid_x)))
../../xai/model/interpreter/model_interpreter.py:128: UserWarning: Analyze 20/114 samples
  idx + 1, len(valid_x)))
../../xai/model/interpreter/model_interpreter.py:128: UserWarning: Analyze 30/114 samples
  idx + 1, len(valid_x)))
../../xai/model/interpreter/model_interpreter.py:128: UserWarning: Analyze 40/114 samples
  idx + 1, len(valid_x)))
../../xai/model/interpreter/model_interpreter.py:128: UserWarning: Analyze 50/114 samples
  idx + 1, len(valid_x)))
../../xai/model/interpreter/model_interpreter.py:128: UserWarning: Analyze 60/114 samples
  idx + 1, len(valid_x)))
../../xai/model/interpreter/model_interpreter.py:128: UserWarning: Analyze 70/114 samples
  idx + 1, len(valid_x)))
../../xai/model/interpreter/model_interpreter.py:128: UserWarning: Analyze 80/114 samples
  idx + 1, len(valid_x)))
../../xai/model/interpreter/model_interpreter.py:128: UserWarning: Analyze 90/114 samples
  idx + 1, len(valid_x)))
../../xai/model/interpreter/model_interpreter.py:128: UserWarning: Analyze 100/114 samples
  idx + 1, len(valid_x)))
../../xai/model/interpreter/model_interpreter.py:128: UserWarning: Analyze 110/114 samples
  idx + 1, len(valid_x)))
[17]:
num_of_top_explanation = 10
for (gt_class,predict_class),(_explanation_dict,num_sample) in stats.items():
    print('%s sample from class [%s] is wrongly classified as class[%s]'%(num_sample,gt_class,predict_class))
    print(' - Top reasons that they are predicted as class[%s]'%predict_class)
    plotter.plot_feature_importance_ranking([(key,value) for key,value in _explanation_dict[predict_class].items()]
                                            [:num_of_top_explanation])
3 sample from class [1] is wrongly classified as class[0]
 - Top reasons that they are predicted as class[0]
../../_images/tutorials_model_interpreter_tutorial_model_interpreter_32_1.png
2 sample from class [0] is wrongly classified as class[1]
 - Top reasons that they are predicted as class[1]
../../_images/tutorials_model_interpreter_tutorial_model_interpreter_32_3.png