LIME Text Explainer via XAI

This tutorial demonstrates how to generate explanations using LIME’s text explainer implemented by the Contextual AI library. Much of the tutorial overlaps with what is covered in the LIME tabular tutorial. To recap, the main steps for generating explanations are:

  1. Get an explainer via the ExplainerFactory class

  2. Build the text explainer

  3. Call explain_instance

Step 1: Import libraries

[1]:
# Some auxiliary imports for the tutorial
import sys
import random
import numpy as np
from pprint import pprint
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.naive_bayes import MultinomialNB
from sklearn.feature_extraction.text import TfidfVectorizer

# Set seed for reproducibility
np.random.seed(123456)

# Set the path so that we can import ExplainerFactory
sys.path.append('../../')

# Main Contextual AI imports
import xai
from xai.explainer import ExplainerFactory

Step 2: Load dataset and train a model

In this tutorial, we rely on the 20newsgroups text dataset, which can be loaded via sklearn’s dataset utility. Documentation on the dataset itself can be found here. To keep things simple, we will extract data for 3 topics - baseball, Christianity, and medicine.

Our target model is a multinomial Naive Bayes classifier, which we train using TF-IDF vectors.

[2]:
# Train on a subset of categories

categories = [
    'rec.sport.baseball',
    'soc.religion.christian',
    'sci.med'
]

raw_train = datasets.fetch_20newsgroups(subset='train', categories=categories)
print(list(raw_train.keys()))
print(raw_train.target_names)
print(raw_train.target[:10])
raw_test = datasets.fetch_20newsgroups(subset='test', categories=categories)

vectorizer = TfidfVectorizer()
X_train = vectorizer.fit_transform(raw_train.data)
y_train = raw_train.target

X_test = vectorizer.transform(raw_test.data)
y_test = raw_test.target

clf = MultinomialNB(alpha=0.1)
clf.fit(X_train, y_train)
clf.score(X_test, y_test)
['data', 'filenames', 'target_names', 'target', 'DESCR']
['rec.sport.baseball', 'sci.med', 'soc.religion.christian']
[1 0 2 2 0 2 0 0 0 1]
[2]:
0.9689336691855583

Step 3: Instantiate the explainer

Here, we will use the LIME Text Explainer.

[3]:
explainer = ExplainerFactory.get_explainer(domain=xai.DOMAIN.TEXT)

Step 4: Build the explainer

This initializes the underlying explainer object. We provide the explain_instance method below with the raw text - LIME’s text explainer algorithm will conduct its own preprocessing in order to generate interpretable representations of the data. Hence we must define a custom predict_fn which takes a raw piece of text, vectorizes it via a pre-trained TF-IDF vectorizer, and passes the vector into the trained Naive Bayes model to generate class probabilities. LIME uses predict_fn to query our Naive Bayes model in order to learn its behavior around the provided data instance.

[4]:
def predict_fn(instance):
    vec = vectorizer.transform(instance)
    return clf.predict_proba(vec)


explainer.build_explainer(predict_fn)

Step 5: Generate some explanations

[5]:
exp = explainer.explain_instance(
    labels=[0, 1, 2],
    instance=raw_test.data[0],
    num_features=10
)

print('Label', raw_train.target_names[raw_test.target[0]], raw_test.target[0])
pprint(exp)
/Users/i330688/venv_xai/lib/python3.6/re.py:212: FutureWarning: split() requires a non-empty pattern match.
  return _compile(pattern, flags).split(string, maxsplit)
Label rec.sport.baseball 0
{0: {'confidence': 0.9604247937223921,
     'explanation': [{'feature': 'Mattingly', 'score': 0.1630204569197586},
                     {'feature': 'njit', 'score': -0.05400846560032084},
                     {'feature': 'Yankee', 'score': 0.047128435532711524},
                     {'feature': 'Lurie', 'score': 0.0459271027896729},
                     {'feature': 'PLAYERS', 'score': 0.045541508852427214},
                     {'feature': 'tesla', 'score': -0.04552783302602691},
                     {'feature': 'Allegheny', 'score': 0.0440014710417496},
                     {'feature': 'luriem', 'score': 0.04385267215867704},
                     {'feature': 'Liberalizer', 'score': 0.042445765884872304},
                     {'feature': 'Don', 'score': -0.030393475108189762}]},
 1: {'confidence': 0.015984823571617023,
     'explanation': [{'feature': 'Mattingly', 'score': -0.05443408204951863},
                     {'feature': 'alleg', 'score': -0.023071281337399444},
                     {'feature': 'Yankee', 'score': -0.0204790656431549},
                     {'feature': 'Allegheny', 'score': -0.019319586624860205},
                     {'feature': 'game', 'score': -0.019075909341883114},
                     {'feature': 'Lurie', 'score': -0.01823234234170473},
                     {'feature': 'tesla', 'score': 0.016795268385738336},
                     {'feature': '1993Apr21', 'score': 0.012968253445169269},
                     {'feature': 'Don', 'score': 0.011628382538193854},
                     {'feature': 'njit', 'score': 0.011514128262241137}]},
 2: {'confidence': 0.02359038270598772,
     'explanation': [{'feature': 'Mattingly', 'score': -0.11653224143274481},
                     {'feature': 'njit', 'score': 0.036924001047734877},
                     {'feature': 'PLAYERS', 'score': -0.034825273342100574},
                     {'feature': 'Lurie', 'score': -0.03388829751366763},
                     {'feature': 'Yankee', 'score': -0.033169483357673},
                     {'feature': 'luriem', 'score': -0.032474210174719534},
                     {'feature': 'Liberalizer', 'score': -0.03079880516104312},
                     {'feature': 'Jesus', 'score': 0.027367258249072796},
                     {'feature': 'tesla', 'score': 0.022433618170210407},
                     {'feature': 'christ', 'score': 0.02164039206739657}]}}

Just like with the LIME tabular explainer, the output of explain_instance is a JSON-compatible object where each class index maps to the target model’s confidence and the corresponding explanations generated by LIME. For text, each feature is a token.

[6]:
exp = explainer.explain_instance(
    instance=raw_test.data[7],
    labels=[0, 1, 2],
    num_features=5
)

print('Label', raw_train.target_names[raw_test.target[7]], raw_test.target[7])
pprint(exp)
/Users/i330688/venv_xai/lib/python3.6/re.py:212: FutureWarning: split() requires a non-empty pattern match.
  return _compile(pattern, flags).split(string, maxsplit)
Label sci.med 1
{0: {'confidence': 0.006374625691451515,
     'explanation': [{'feature': 'pain', 'score': -0.027402611439935602},
                     {'feature': 'sr', 'score': 0.026176880833875864},
                     {'feature': 'ai', 'score': -0.023919836440025922},
                     {'feature': 'Covington', 'score': -0.02087504251506631},
                     {'feature': 'mcovingt', 'score': -0.02069997767962776}]},
 1: {'confidence': 0.8824748491424798,
     'explanation': [{'feature': 'hp', 'score': 0.06962985800565995},
                     {'feature': 'doctor', 'score': 0.06779310792572511},
                     {'feature': 'pain', 'score': 0.0668010276930299},
                     {'feature': 'kidney', 'score': 0.0549079057920354},
                     {'feature': 'Kidney', 'score': 0.05326854053175146}]},
 2: {'confidence': 0.11115052516607107,
     'explanation': [{'feature': 'hp', 'score': -0.0799997479251323},
                     {'feature': 'doctor', 'score': -0.04754155417624489},
                     {'feature': 'pain', 'score': -0.041227319748901106},
                     {'feature': 'kidney', 'score': -0.03950550045278837},
                     {'feature': 'Kidney', 'score': -0.03753117417614439}]}}
[7]:
exp = explainer.explain_instance(
    instance=raw_test.data[9],
    labels=[0, 1, 2],
    num_features=5
)

print('Label', raw_train.target_names[raw_test.target[9]], raw_test.target[9])
pprint(exp)
/Users/i330688/venv_xai/lib/python3.6/re.py:212: FutureWarning: split() requires a non-empty pattern match.
  return _compile(pattern, flags).split(string, maxsplit)
Label soc.religion.christian 2
{0: {'confidence': 6.798212345437472e-05,
     'explanation': [{'feature': 'Bible', 'score': -0.0023500809763485468},
                     {'feature': 'Scripture', 'score': -0.0014344577715211986},
                     {'feature': 'Heaven', 'score': -0.001381196356886895},
                     {'feature': 'Sin', 'score': -0.0013723724408794883},
                     {'feature': 'specific', 'score': -0.0013611914394935848}]},
 1: {'confidence': 0.00044272540371258136,
     'explanation': [{'feature': 'Bible', 'score': -0.007407412195931125},
                     {'feature': 'Scripture', 'score': -0.003658367757678809},
                     {'feature': 'Heaven', 'score': -0.003652181996607397},
                     {'feature': 'immoral', 'score': -0.003469502264458387},
                     {'feature': 'Sin', 'score': -0.003246609821338066}]},
 2: {'confidence': 0.9994892924728337,
     'explanation': [{'feature': 'Bible', 'score': 0.009736539971486623},
                     {'feature': 'Scripture', 'score': 0.005124375636024145},
                     {'feature': 'Heaven', 'score': 0.005053514624616295},
                     {'feature': 'immoral', 'score': 0.004781252244149238},
                     {'feature': 'Sin', 'score': 0.004596128058053568}]}}

Step 6: Save and load the explainer

Like with the LIME tabular explainer, we can save and load the explainer via load_explainer and save_explainer respectively.

[8]:
# Save the explainer somewhere

explainer.save_explainer('artefacts/lime_text.pkl')
[9]:
# Load the saved explainer in a new Explainer instance

new_explainer = ExplainerFactory.get_explainer(domain=xai.DOMAIN.TEXT, algorithm=xai.ALG.LIME)
new_explainer.load_explainer('artefacts/lime_text.pkl')

exp = new_explainer.explain_instance(
    instance=raw_test.data[9],
    labels=[0, 1, 2],
    num_features=5
)

print('Label', raw_train.target_names[raw_test.target[9]], raw_test.target[9])
pprint(exp)
/Users/i330688/venv_xai/lib/python3.6/re.py:212: FutureWarning: split() requires a non-empty pattern match.
  return _compile(pattern, flags).split(string, maxsplit)
Label soc.religion.christian 2
{0: {'confidence': 6.798212345437472e-05,
     'explanation': [{'feature': 'Bible', 'score': -0.002291036085092343},
                     {'feature': 'Heaven', 'score': -0.001386727909779096},
                     {'feature': 'babies', 'score': -0.0013482141842248723},
                     {'feature': 'Scripture', 'score': -0.0012967367558917526},
                     {'feature': 'infants', 'score': -0.0012887203369136644}]},
 1: {'confidence': 0.00044272540371258136,
     'explanation': [{'feature': 'Bible', 'score': -0.007441841401906927},
                     {'feature': 'Heaven', 'score': -0.003699731572404996},
                     {'feature': 'Scripture', 'score': -0.003493032440072657},
                     {'feature': 'God', 'score': -0.0030701936621817727},
                     {'feature': 'doctrine', 'score': -0.003026287136219051}]},
 2: {'confidence': 0.9994892924728337,
     'explanation': [{'feature': 'Bible', 'score': 0.009764693171821786},
                     {'feature': 'Heaven', 'score': 0.0051058553475867505},
                     {'feature': 'Scripture', 'score': 0.00481801754635917},
                     {'feature': 'God', 'score': 0.004325649143393945},
                     {'feature': 'doctrine', 'score': 0.00424415351934624}]}}
[ ]: