LIME Tabular Explainer via XAI¶
This tutorial demonstrates how to generate explanations using LIME’s tabular explainer implemented by the Contextual AI library.
At a high level, explanations can be obtained from any Contextual AI explanation algorithm in 3 steps:
Create an explainer via the
ExplainerFactory
class, which serves as the primary interface between the user and all Contextual AI-supported explanation algorithmsBuild the explainer by calling the
build_explainer
method (which is implemented by any Contextual AI explanation algorithm) and providing arguments that are specific to that algorithmGet explanations for some data instance by calling the
explain_instance
method (which is also common among all algorithms) and provoding arguments that are specific to that algorithm
Step 1: Import libraries¶
xai.explainer.ExplainerFactory
is the main class that users of Contextual AI interact with. xai
contains some constants that are used to instantiate an AbstractExplainer
object.
[1]:
# Some auxiliary imports for the tutorial
import sys
import random
import numpy as np
from pprint import pprint
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
# Set seed for reproducibility
np.random.seed(123456)
# Set the path so that we can import ExplainerFactory
sys.path.append('../../')
# Main Contextual AI imports
import xai
from xai.explainer import ExplainerFactory
Step 2: Train a model on a sample dataset¶
We train a sample RandomForestClassifier
model on the Wisconsin breast cancer dataset, a sample binary classification problem that is provided by scikit-learn (details can be found here).
[2]:
# Load the dataset and prepare training and test sets
raw_data = datasets.load_breast_cancer()
X, y = raw_data['data'], raw_data['target']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
# Instantiate a classifier, train, and evaluate on test set
clf = RandomForestClassifier()
clf.fit(X_train, y_train)
clf.score(X_test, y_test)
/Users/i330688/venv_xai/lib/python3.6/site-packages/sklearn/ensemble/forest.py:245: FutureWarning: The default value of n_estimators will change from 10 in version 0.20 to 100 in 0.22.
"10 in version 0.20 to 100 in 0.22.", FutureWarning)
[2]:
0.956140350877193
Step 3: Instantiate the explainer¶
This is where we instantiate the Contextual AI explainer. This ExplainerFactory
class is in charge of loading a particular explanation algorithm. The user is required to provide one argument - the domain
, which indicates the domain of the training data (e.g. tabular
or text
). The available domains can be found in xai.DOMAIN
. Users can also select a particular explainer algorithm by providing the algorithm’s name (registered in xai.ALG
) to the algorithm
parameter. If
this argument is not provided, the ExplainerFactory.get_explainer
method defaults to a pre-set algorithm for that domain which can be found in xai/explainer/config.py
.
We want to load the LimeTabularExplainer
, so we provide xai.DOMAIN.TABULAR
as the argument to domain
and xai.ALG.LIME
as the argument to algorithm
. Note that xai.ALG.LIME
is the default tabular explanation algorithm; hence this also works:
explainer = ExplainerFactory.get_explainer(domain=xai.DOMAIN.TABULAR)
[3]:
# Instantiate LimeTabularExplainer via the Explainer interface
explainer = ExplainerFactory.get_explainer(domain=xai.DOMAIN.TABULAR, algorithm=xai.ALG.LIME)
Step 4: Build the explainer¶
build_explainer
calls the explanation algorithms initialization routine, which can include things like setting parameters or a pre-training loop. The LimeTabularExplainer
requires the following parameters:
training_data (np.ndarray): 2d Numpy array representing the training data (or some representative subset) (required)
mode (str): Whether the problem is ‘classification’ or ‘regression’ (required)
predict_fn (function): A function that wraps the target model’s prediction function - it takes in a 1D numpy array and outputs a vector of probabilities which should sum to 1 (required)
Here are some other optional parameters: * training_labels (list): Training labels, which can be used by the continuous feature discretizer * feature_names (list): The names of the columns of the training data * categorical_features (list): Integer list indicating the indices of categorical features * dict_categorical_mapping (dict): Mapping of integer index of categorical feature (same as from categorical_features) to a list of values for that column. So dict_categorical_mapping[x][y] is the yth value of column x. * kernel_width (float): Width of the exponential kernel used in the LIME loss function * verbose (bool): Control verbosity. If true, local prediction values of the LIME model are printed * class_names (list): Class names (positional index corresponding to class index) * feature_selection (str): Feature selection method. See original docs for more details * discretize_continuous (True): Whether to discretize non-categorical features * discretizer (str): Type of discretization. See original docs for more details * sample_around_instance (True): if True, will sample continuous features in perturbed samples from a normal centered at the instance being explained. Otherwise, the normal is centered on the mean of the feature data. * random_state (int): The random seed to generate random numbers during training
In this particular example, we pass the RandomForestClassifier
’s predict_proba
function to predict_fn
and get explanations for the two classes.
[4]:
explainer.build_explainer(
training_data=X_train,
training_labels=y_train,
mode=xai.MODE.CLASSIFICATION,
predict_fn=clf.predict_proba,
feature_names=raw_data['feature_names'],
class_names=list(raw_data['target_names'])
)
Step 5: Generate some explanations¶
Once we build the explainer, we can start generating some explanations via the explain_instance
method. The LimeTabularExplainer
expects several things, like: * instance (np.ndarray): A 1D numpy array corresponding to a row/single example (required)
You can also pass the following:
labels (list): The list of class indexes to produce explanations for
top_labels (int): If not None, this overwrites labels and the explainer instead produces explanations for the top k classes
num_features (int): Number of features to include in an explanation
num_samples (int): The number of perturbed samples to train the LIME model with
distance_metric (str): The distance metric to use for weighting the loss function
We restrict explanations to 10 features (meaning only 10 features will have scores attached to them). The output of explain_instance
is a dictionary that maps each class to two things - the confidence of model and a list of explanations.
[5]:
exp = explainer.explain_instance(
instance=X_test[0],
top_labels=2,
num_features=5)
pprint(exp)
{0: {'explanation': [{'feature': 'worst perimeter <= 83.79',
'score': -0.10193695487658752},
{'feature': 'worst area <= 509.25',
'score': -0.09601666088375639},
{'feature': 'worst radius <= 12.93',
'score': -0.06025582708518221},
{'feature': 'mean area <= 419.25',
'score': -0.056302517885391166},
{'feature': 'worst texture <= 21.41',
'score': -0.05509499962470648}],
'prediction': 0.0},
1: {'explanation': [{'feature': 'worst perimeter <= 83.79',
'score': 0.10193695487658752},
{'feature': 'worst area <= 509.25',
'score': 0.0960166608837564},
{'feature': 'worst radius <= 12.93',
'score': 0.06025582708518222},
{'feature': 'mean area <= 419.25',
'score': 0.05630251788539119},
{'feature': 'worst texture <= 21.41',
'score': 0.05509499962470641}],
'prediction': 1.0}}
Step 6: Save and load the explainer¶
Finally, every Contextual AI explainer supports saving and loading functions.
[6]:
# Save the explainer somewhere
explainer.save_explainer('artefacts/lime_tabular.pkl')
[7]:
# Load the saved explainer in a new Explainer instance
new_explainer = ExplainerFactory.get_explainer(domain=xai.DOMAIN.TABULAR, algorithm=xai.ALG.LIME)
new_explainer.load_explainer('artefacts/lime_tabular.pkl')
exp = new_explainer.explain_instance(
instance=X_test[0],
top_labels=2,
num_features=5)
pprint(exp)
{0: {'explanation': [{'feature': 'worst perimeter <= 83.79',
'score': -0.09985606175737251},
{'feature': 'worst area <= 509.25',
'score': -0.08623375147255567},
{'feature': 'mean area <= 419.25',
'score': -0.07671371631709668},
{'feature': 'worst radius <= 12.93',
'score': -0.06861610584095608},
{'feature': 'worst texture <= 21.41',
'score': -0.05078617133441289}],
'prediction': 0.0},
1: {'explanation': [{'feature': 'worst perimeter <= 83.79',
'score': 0.09985606175737251},
{'feature': 'worst area <= 509.25',
'score': 0.08623375147255567},
{'feature': 'mean area <= 419.25',
'score': 0.0767137163170967},
{'feature': 'worst radius <= 12.93',
'score': 0.0686161058409561},
{'feature': 'worst texture <= 21.41',
'score': 0.05078617133441288}],
'prediction': 1.0}}