Contextual AI LIME Text ExplainerFactory with Keras

This tutorial is similar to lime_text_explainer.ipynb, but instead of a Naive Bayes model we attempt to generate explanations with a neural network implemented with Keras.

The neural network is a simple multi-layer CNN with GloVe embeddings. This tutorial requires you to download the pre-trained word embeddings from this link (caution - this link initiates a 822MB download).

The modelling/text processing portions of this tutorial are heavily borrowed from this Keras blog.

Like with other Contextual AI tutorials, the main steps for generating explanations are:

  1. Get an explainer via the ExplainerFactory class

  2. Build the text explainer

  3. Call explain_instance

Step 1: Import libraries

[1]:
# Some auxiliary imports for the tutorial
import os
import sys
import random
import math
import numpy as np
from pprint import pprint
from sklearn import datasets
from sklearn.model_selection import train_test_split

import keras
from keras import backend as K
from keras.models import Model, Sequential
from keras.layers import Input, Dense, Embedding, Layer, Activation, \
Conv1D, MaxPooling1D, Convolution1D, Dropout, BatchNormalization, Conv1D, Concatenate, Flatten
from keras.optimizers import Adam
from keras.initializers import Constant
from keras.preprocessing.text import Tokenizer
from keras.preprocessing.sequence import pad_sequences
from keras.utils import to_categorical
from tensorflow.contrib.learn import preprocessing
from sklearn.model_selection import train_test_split
from sklearn.utils import class_weight

# Set seed for reproducibility
np.random.seed(123456)

# Set the path so that we can import the ExplainerFactory
sys.path.append('../../')

# Main Contextual AI imports
import xai
from xai.explainer import ExplainerFactory

###################################################
# Set the directory to the GloVe embeddings here! #
###################################################
GLOVE_DIR = ''
MAX_SEQUENCE_LENGTH = 1000
MAX_NUM_WORDS = 20000
EMBEDDING_DIM = 100
VALIDATION_SPLIT = 0.2
HIDDEN_UNITS = 128
Using TensorFlow backend.

Step 2: Load dataset and train a model

In this tutorial, we rely on the 20newsgroups text dataset, which can be loaded via sklearn’s dataset utility. Documentation on the dataset itself can be found here. To keep things simple, we will extract data for 3 topics - baseball, Christianity, and medicine.

Our target model is a CNN which ingest pre-trained word embeddings.

[2]:
# Train on a subset of categories

categories = [
    'rec.sport.baseball',
    'soc.religion.christian',
    'sci.med'
]

raw_train = datasets.fetch_20newsgroups(subset='train', categories=categories)
print(list(raw_train.keys()))
print(raw_train.target_names)
print(raw_train.target[:10])
raw_test = datasets.fetch_20newsgroups(subset='test', categories=categories)

# Turn text into lowercase
raw_train_text = [doc.lower() for doc in raw_train.data]
y_train = raw_train.target
raw_test_text = [doc.lower() for doc in raw_test.data]
y_test = raw_test.target

# Tokenizer
tokenizer = Tokenizer(num_words=None, char_level=True, oov_token='UNK')
tokenizer.fit_on_texts(raw_train_text)
vocab_size = len(tokenizer.word_index)
word_index = tokenizer.word_index

# Convert string to index
train_sequences = tokenizer.texts_to_sequences(raw_train_text)
test_texts = tokenizer.texts_to_sequences(raw_test_text)

# Padding
train_data = pad_sequences(train_sequences, maxlen=MAX_SEQUENCE_LENGTH, padding='post')
test_data = pad_sequences(test_texts, maxlen=MAX_SEQUENCE_LENGTH, padding='post')

# Convert to numpy array
X_train = np.array(train_data, dtype='float32')
X_test = np.array(test_data, dtype='float32')

X_train, X_valid, y_train, y_valid = train_test_split(X_train, y_train, test_size=0.2)

y_train_onehot = to_categorical(y_train, num_classes=3)
y_valid_onehot = to_categorical(y_valid, num_classes=3)
y_test_onehot = to_categorical(y_test, num_classes=3)
['DESCR', 'target', 'data', 'filenames', 'target_names']
['rec.sport.baseball', 'sci.med', 'soc.religion.christian']
[1 0 2 2 0 2 0 0 0 1]

Prepare the embedding matrix

[3]:
# Prepare the embedding matrix
# Code comes from https://keras.io/examples/pretrained_word_embeddings/
embeddings_index = {}
with open(os.path.join(GLOVE_DIR, 'glove.6B.100d.txt'), encoding='utf-8') as f:
    for line in f:
        values = line.split()
        word = values[0]
        coefs = np.asarray(values[1:], dtype='float32')
        embeddings_index[word] = coefs

# prepare embedding matrix
num_words = min(MAX_NUM_WORDS, len(word_index)) + 1
embedding_matrix = np.zeros((num_words, EMBEDDING_DIM))
for word, i in word_index.items():
    if i > MAX_NUM_WORDS:
        continue
    embedding_vector = embeddings_index.get(word)
    if embedding_vector is not None:
        # words not found in embedding index will be all-zeros.
        embedding_matrix[i] = embedding_vector

# load pre-trained word embeddings into an Embedding layer
# note that we set trainable = False so as to keep the embeddings fixed
embedding_layer = Embedding(num_words,
                            EMBEDDING_DIM,
                            embeddings_initializer=Constant(embedding_matrix),
                            input_length=MAX_SEQUENCE_LENGTH,
                            trainable=False)

Define the model

[4]:
# Prepare the model

sequence_input = Input(shape=(MAX_SEQUENCE_LENGTH,), dtype='int32')
embedded_sequences = embedding_layer(sequence_input)
x = Conv1D(128, 5, activation='relu')(embedded_sequences)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Dropout(0.2)(x)
x = MaxPooling1D(4)(x)
x = Conv1D(128, 5)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Dropout(0.2)(x)
x = MaxPooling1D(4)(x)
x = Conv1D(128, 5)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Dropout(0.2)(x)
x = Flatten()(x)

preds = Dense(y_train_onehot.shape[1], activation='softmax')(x)

model = Model(sequence_input, preds)
model.summary()
model.compile(loss='categorical_crossentropy',
              optimizer='rmsprop',
              metrics=['acc'])
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
input_1 (InputLayer)         (None, 1000)              0
_________________________________________________________________
embedding_1 (Embedding)      (None, 1000, 100)         8100
_________________________________________________________________
conv1d_1 (Conv1D)            (None, 996, 128)          64128
_________________________________________________________________
batch_normalization_1 (Batch (None, 996, 128)          512
_________________________________________________________________
activation_1 (Activation)    (None, 996, 128)          0
_________________________________________________________________
dropout_1 (Dropout)          (None, 996, 128)          0
_________________________________________________________________
max_pooling1d_1 (MaxPooling1 (None, 249, 128)          0
_________________________________________________________________
conv1d_2 (Conv1D)            (None, 245, 128)          82048
_________________________________________________________________
batch_normalization_2 (Batch (None, 245, 128)          512
_________________________________________________________________
activation_2 (Activation)    (None, 245, 128)          0
_________________________________________________________________
dropout_2 (Dropout)          (None, 245, 128)          0
_________________________________________________________________
max_pooling1d_2 (MaxPooling1 (None, 61, 128)           0
_________________________________________________________________
conv1d_3 (Conv1D)            (None, 57, 128)           82048
_________________________________________________________________
batch_normalization_3 (Batch (None, 57, 128)           512
_________________________________________________________________
activation_3 (Activation)    (None, 57, 128)           0
_________________________________________________________________
dropout_3 (Dropout)          (None, 57, 128)           0
_________________________________________________________________
flatten_1 (Flatten)          (None, 7296)              0
_________________________________________________________________
dense_1 (Dense)              (None, 3)                 21891
=================================================================
Total params: 259,751
Trainable params: 250,883
Non-trainable params: 8,868
_________________________________________________________________

Train the model

[5]:
model.fit([X_train], y_train_onehot, epochs=100, batch_size=50,
          validation_data=([X_valid], y_valid_onehot))
Train on 1432 samples, validate on 358 samples
Epoch 1/100
1432/1432 [==============================] - 5s 3ms/step - loss: 2.1773 - acc: 0.3897 - val_loss: 2.4187 - val_acc: 0.3743
Epoch 2/100
1432/1432 [==============================] - 0s 264us/step - loss: 1.1169 - acc: 0.5740 - val_loss: 1.2449 - val_acc: 0.4749
Epoch 3/100
1432/1432 [==============================] - 0s 219us/step - loss: 0.7730 - acc: 0.6899 - val_loss: 0.7144 - val_acc: 0.6732
Epoch 4/100
1432/1432 [==============================] - 0s 217us/step - loss: 0.4197 - acc: 0.8156 - val_loss: 3.1116 - val_acc: 0.3575
Epoch 5/100
1432/1432 [==============================] - 0s 220us/step - loss: 0.2677 - acc: 0.8897 - val_loss: 0.5152 - val_acc: 0.7821
Epoch 6/100
1432/1432 [==============================] - 0s 218us/step - loss: 0.2428 - acc: 0.9064 - val_loss: 0.9468 - val_acc: 0.6732
Epoch 7/100
1432/1432 [==============================] - 0s 216us/step - loss: 0.1793 - acc: 0.9385 - val_loss: 3.4665 - val_acc: 0.5754
Epoch 8/100
1432/1432 [==============================] - 0s 219us/step - loss: 0.1346 - acc: 0.9497 - val_loss: 3.2727 - val_acc: 0.3994
Epoch 9/100
1432/1432 [==============================] - 0s 220us/step - loss: 0.0958 - acc: 0.9728 - val_loss: 6.0987 - val_acc: 0.3855
Epoch 10/100
1432/1432 [==============================] - 0s 219us/step - loss: 0.0677 - acc: 0.9777 - val_loss: 4.9628 - val_acc: 0.3939
Epoch 11/100
1432/1432 [==============================] - 0s 217us/step - loss: 0.0732 - acc: 0.9742 - val_loss: 3.5400 - val_acc: 0.4302
Epoch 12/100
1432/1432 [==============================] - 0s 218us/step - loss: 0.0802 - acc: 0.9728 - val_loss: 1.5637 - val_acc: 0.6061
Epoch 13/100
1432/1432 [==============================] - 0s 219us/step - loss: 0.1385 - acc: 0.9567 - val_loss: 8.2588 - val_acc: 0.3575
Epoch 14/100
1432/1432 [==============================] - 0s 217us/step - loss: 0.0578 - acc: 0.9818 - val_loss: 4.9207 - val_acc: 0.6117
Epoch 15/100
1432/1432 [==============================] - 0s 218us/step - loss: 0.0652 - acc: 0.9825 - val_loss: 3.7316 - val_acc: 0.5615
Epoch 16/100
1432/1432 [==============================] - 0s 219us/step - loss: 0.0436 - acc: 0.9853 - val_loss: 1.5448 - val_acc: 0.6508
Epoch 17/100
1432/1432 [==============================] - 0s 218us/step - loss: 0.0824 - acc: 0.9804 - val_loss: 5.0774 - val_acc: 0.6006
Epoch 18/100
1432/1432 [==============================] - 0s 220us/step - loss: 0.0448 - acc: 0.9853 - val_loss: 7.6370 - val_acc: 0.3855
Epoch 19/100
1432/1432 [==============================] - 0s 217us/step - loss: 0.0578 - acc: 0.9797 - val_loss: 4.4715 - val_acc: 0.4385
Epoch 20/100
1432/1432 [==============================] - 0s 220us/step - loss: 0.0439 - acc: 0.9846 - val_loss: 5.1295 - val_acc: 0.5084
Epoch 21/100
1432/1432 [==============================] - 0s 220us/step - loss: 0.0285 - acc: 0.9881 - val_loss: 7.6370 - val_acc: 0.3715
Epoch 22/100
1432/1432 [==============================] - 0s 218us/step - loss: 0.0509 - acc: 0.9846 - val_loss: 5.2909 - val_acc: 0.4972
Epoch 23/100
1432/1432 [==============================] - 0s 219us/step - loss: 0.0324 - acc: 0.9860 - val_loss: 3.9168 - val_acc: 0.5475
Epoch 24/100
1432/1432 [==============================] - 0s 217us/step - loss: 0.0345 - acc: 0.9909 - val_loss: 5.7225 - val_acc: 0.3966
Epoch 25/100
1432/1432 [==============================] - 0s 218us/step - loss: 0.0425 - acc: 0.9867 - val_loss: 2.1274 - val_acc: 0.6899
Epoch 26/100
1432/1432 [==============================] - 0s 216us/step - loss: 0.0276 - acc: 0.9881 - val_loss: 1.1043 - val_acc: 0.7514
Epoch 27/100
1432/1432 [==============================] - 0s 215us/step - loss: 0.0339 - acc: 0.9874 - val_loss: 2.7648 - val_acc: 0.5866
Epoch 28/100
1432/1432 [==============================] - 0s 216us/step - loss: 0.0335 - acc: 0.9881 - val_loss: 0.9610 - val_acc: 0.7905
Epoch 29/100
1432/1432 [==============================] - 0s 220us/step - loss: 0.0227 - acc: 0.9930 - val_loss: 5.7261 - val_acc: 0.5782
Epoch 30/100
1432/1432 [==============================] - 0s 218us/step - loss: 0.0163 - acc: 0.9944 - val_loss: 10.2775 - val_acc: 0.3436
Epoch 31/100
1432/1432 [==============================] - 0s 216us/step - loss: 0.0206 - acc: 0.9930 - val_loss: 1.2157 - val_acc: 0.7542
Epoch 32/100
1432/1432 [==============================] - 0s 217us/step - loss: 0.0239 - acc: 0.9923 - val_loss: 5.4944 - val_acc: 0.4385
Epoch 33/100
1432/1432 [==============================] - 0s 216us/step - loss: 0.0108 - acc: 0.9951 - val_loss: 1.3424 - val_acc: 0.7905
Epoch 34/100
1432/1432 [==============================] - 0s 216us/step - loss: 0.0397 - acc: 0.9881 - val_loss: 4.4665 - val_acc: 0.5112
Epoch 35/100
1432/1432 [==============================] - 0s 217us/step - loss: 0.0212 - acc: 0.9923 - val_loss: 2.5467 - val_acc: 0.6788
Epoch 36/100
1432/1432 [==============================] - 0s 215us/step - loss: 0.0455 - acc: 0.9860 - val_loss: 5.6098 - val_acc: 0.4860
Epoch 37/100
1432/1432 [==============================] - 0s 219us/step - loss: 0.0227 - acc: 0.9930 - val_loss: 5.3960 - val_acc: 0.4553
Epoch 38/100
1432/1432 [==============================] - 0s 217us/step - loss: 0.0435 - acc: 0.9888 - val_loss: 7.5935 - val_acc: 0.3687
Epoch 39/100
1432/1432 [==============================] - 0s 216us/step - loss: 0.0309 - acc: 0.9902 - val_loss: 1.2254 - val_acc: 0.7514
Epoch 40/100
1432/1432 [==============================] - 0s 217us/step - loss: 0.0143 - acc: 0.9937 - val_loss: 1.7116 - val_acc: 0.6983
Epoch 41/100
1432/1432 [==============================] - 0s 217us/step - loss: 0.0425 - acc: 0.9895 - val_loss: 0.3884 - val_acc: 0.8883
Epoch 42/100
1432/1432 [==============================] - 0s 217us/step - loss: 0.0233 - acc: 0.9923 - val_loss: 1.0631 - val_acc: 0.7626
Epoch 43/100
1432/1432 [==============================] - 0s 218us/step - loss: 0.0224 - acc: 0.9944 - val_loss: 0.7163 - val_acc: 0.8575
Epoch 44/100
1432/1432 [==============================] - 0s 217us/step - loss: 0.0209 - acc: 0.9937 - val_loss: 2.8861 - val_acc: 0.6564
Epoch 45/100
1432/1432 [==============================] - 0s 215us/step - loss: 0.0717 - acc: 0.9818 - val_loss: 2.7609 - val_acc: 0.6955
Epoch 46/100
1432/1432 [==============================] - 0s 218us/step - loss: 0.0056 - acc: 0.9972 - val_loss: 0.9425 - val_acc: 0.8128
Epoch 47/100
1432/1432 [==============================] - 0s 217us/step - loss: 0.0185 - acc: 0.9958 - val_loss: 0.3938 - val_acc: 0.8966
Epoch 48/100
1432/1432 [==============================] - 0s 217us/step - loss: 0.0162 - acc: 0.9944 - val_loss: 0.8557 - val_acc: 0.8296
Epoch 49/100
1432/1432 [==============================] - 0s 215us/step - loss: 0.0204 - acc: 0.9916 - val_loss: 2.1519 - val_acc: 0.7123
Epoch 50/100
1432/1432 [==============================] - 0s 217us/step - loss: 0.0042 - acc: 0.9993 - val_loss: 3.7780 - val_acc: 0.6313
Epoch 51/100
1432/1432 [==============================] - 0s 221us/step - loss: 0.0214 - acc: 0.9930 - val_loss: 3.6579 - val_acc: 0.6313
Epoch 52/100
1432/1432 [==============================] - 0s 219us/step - loss: 0.0263 - acc: 0.9916 - val_loss: 0.7616 - val_acc: 0.8268
Epoch 53/100
1432/1432 [==============================] - 0s 221us/step - loss: 0.0056 - acc: 0.9993 - val_loss: 4.4331 - val_acc: 0.5615
Epoch 54/100
1432/1432 [==============================] - 0s 215us/step - loss: 0.0109 - acc: 0.9965 - val_loss: 3.3715 - val_acc: 0.6173
Epoch 55/100
1432/1432 [==============================] - 0s 218us/step - loss: 0.0337 - acc: 0.9888 - val_loss: 3.3990 - val_acc: 0.6480
Epoch 56/100
1432/1432 [==============================] - 0s 216us/step - loss: 0.0045 - acc: 0.9979 - val_loss: 7.0738 - val_acc: 0.4413
Epoch 57/100
1432/1432 [==============================] - 0s 219us/step - loss: 0.0191 - acc: 0.9951 - val_loss: 2.2512 - val_acc: 0.7151
Epoch 58/100
1432/1432 [==============================] - 0s 215us/step - loss: 0.0170 - acc: 0.9930 - val_loss: 2.9241 - val_acc: 0.6480
Epoch 59/100
1432/1432 [==============================] - 0s 217us/step - loss: 0.0213 - acc: 0.9916 - val_loss: 5.3108 - val_acc: 0.6089
Epoch 60/100
1432/1432 [==============================] - 0s 217us/step - loss: 0.0032 - acc: 0.9993 - val_loss: 1.3068 - val_acc: 0.7961
Epoch 61/100
1432/1432 [==============================] - 0s 217us/step - loss: 0.0233 - acc: 0.9909 - val_loss: 1.2623 - val_acc: 0.7207
Epoch 62/100
1432/1432 [==============================] - 0s 216us/step - loss: 0.0146 - acc: 0.9944 - val_loss: 1.4777 - val_acc: 0.7402
Epoch 63/100
1432/1432 [==============================] - 0s 218us/step - loss: 0.0070 - acc: 0.9972 - val_loss: 4.0222 - val_acc: 0.5196
Epoch 64/100
1432/1432 [==============================] - 0s 217us/step - loss: 0.0075 - acc: 0.9979 - val_loss: 1.5516 - val_acc: 0.7709
Epoch 65/100
1432/1432 [==============================] - 0s 217us/step - loss: 0.0134 - acc: 0.9930 - val_loss: 1.0965 - val_acc: 0.7514
Epoch 66/100
1432/1432 [==============================] - 0s 217us/step - loss: 0.0249 - acc: 0.9909 - val_loss: 3.6485 - val_acc: 0.6480
Epoch 67/100
1432/1432 [==============================] - 0s 219us/step - loss: 0.0281 - acc: 0.9923 - val_loss: 2.7969 - val_acc: 0.5838
Epoch 68/100
1432/1432 [==============================] - 0s 218us/step - loss: 0.0053 - acc: 0.9986 - val_loss: 5.0894 - val_acc: 0.5503
Epoch 69/100
1432/1432 [==============================] - 0s 216us/step - loss: 0.0095 - acc: 0.9972 - val_loss: 3.8370 - val_acc: 0.5782
Epoch 70/100
1432/1432 [==============================] - 0s 217us/step - loss: 0.0107 - acc: 0.9972 - val_loss: 1.3765 - val_acc: 0.7877
Epoch 71/100
1432/1432 [==============================] - 0s 216us/step - loss: 0.0075 - acc: 0.9965 - val_loss: 2.2304 - val_acc: 0.7346
Epoch 72/100
1432/1432 [==============================] - 0s 217us/step - loss: 0.0134 - acc: 0.9979 - val_loss: 1.8926 - val_acc: 0.7179
Epoch 73/100
1432/1432 [==============================] - 0s 222us/step - loss: 0.0222 - acc: 0.9944 - val_loss: 3.0531 - val_acc: 0.6732
Epoch 74/100
1432/1432 [==============================] - 0s 218us/step - loss: 0.0099 - acc: 0.9958 - val_loss: 0.7451 - val_acc: 0.8324
Epoch 75/100
1432/1432 [==============================] - 0s 217us/step - loss: 0.0103 - acc: 0.9951 - val_loss: 4.3561 - val_acc: 0.5615
Epoch 76/100
1432/1432 [==============================] - 0s 217us/step - loss: 0.0070 - acc: 0.9965 - val_loss: 3.4871 - val_acc: 0.6844
Epoch 77/100
1432/1432 [==============================] - 0s 221us/step - loss: 0.0027 - acc: 0.9993 - val_loss: 2.5512 - val_acc: 0.7291
Epoch 78/100
1432/1432 [==============================] - 0s 218us/step - loss: 0.0108 - acc: 0.9958 - val_loss: 1.6322 - val_acc: 0.7207
Epoch 79/100
1432/1432 [==============================] - 0s 218us/step - loss: 0.0177 - acc: 0.9937 - val_loss: 1.0522 - val_acc: 0.8101
Epoch 80/100
1432/1432 [==============================] - 0s 219us/step - loss: 0.0197 - acc: 0.9958 - val_loss: 1.4012 - val_acc: 0.7793
Epoch 81/100
1432/1432 [==============================] - 0s 219us/step - loss: 0.0280 - acc: 0.9916 - val_loss: 1.6188 - val_acc: 0.7486
Epoch 82/100
1432/1432 [==============================] - 0s 221us/step - loss: 0.0038 - acc: 0.9986 - val_loss: 0.7318 - val_acc: 0.8464
Epoch 83/100
1432/1432 [==============================] - 0s 217us/step - loss: 0.0033 - acc: 0.9986 - val_loss: 1.8894 - val_acc: 0.7514
Epoch 84/100
1432/1432 [==============================] - 0s 216us/step - loss: 0.0074 - acc: 0.9972 - val_loss: 0.8990 - val_acc: 0.8603
Epoch 85/100
1432/1432 [==============================] - 0s 219us/step - loss: 0.0202 - acc: 0.9951 - val_loss: 5.4534 - val_acc: 0.6117
Epoch 86/100
1432/1432 [==============================] - 0s 217us/step - loss: 0.0212 - acc: 0.9937 - val_loss: 4.7464 - val_acc: 0.6145
Epoch 87/100
1432/1432 [==============================] - 0s 218us/step - loss: 0.0072 - acc: 0.9979 - val_loss: 6.7823 - val_acc: 0.4581
Epoch 88/100
1432/1432 [==============================] - 0s 216us/step - loss: 0.0101 - acc: 0.9965 - val_loss: 9.6006 - val_acc: 0.3603
Epoch 89/100
1432/1432 [==============================] - 0s 217us/step - loss: 0.0094 - acc: 0.9986 - val_loss: 3.7251 - val_acc: 0.6816
Epoch 90/100
1432/1432 [==============================] - 0s 214us/step - loss: 0.0062 - acc: 0.9979 - val_loss: 3.1987 - val_acc: 0.6844
Epoch 91/100
1432/1432 [==============================] - 0s 217us/step - loss: 0.0092 - acc: 0.9979 - val_loss: 3.2621 - val_acc: 0.6760
Epoch 92/100
1432/1432 [==============================] - 0s 217us/step - loss: 0.0285 - acc: 0.9916 - val_loss: 7.5566 - val_acc: 0.3939
Epoch 93/100
1432/1432 [==============================] - 0s 216us/step - loss: 0.0016 - acc: 1.0000 - val_loss: 3.0771 - val_acc: 0.5894
Epoch 94/100
1432/1432 [==============================] - 0s 217us/step - loss: 0.0166 - acc: 0.9965 - val_loss: 2.3383 - val_acc: 0.7011
Epoch 95/100
1432/1432 [==============================] - 0s 218us/step - loss: 0.0101 - acc: 0.9965 - val_loss: 0.5835 - val_acc: 0.8715
Epoch 96/100
1432/1432 [==============================] - 0s 216us/step - loss: 0.0020 - acc: 0.9993 - val_loss: 1.6642 - val_acc: 0.7849
Epoch 97/100
1432/1432 [==============================] - 0s 216us/step - loss: 0.0074 - acc: 0.9986 - val_loss: 5.3298 - val_acc: 0.4665
Epoch 98/100
1432/1432 [==============================] - 0s 216us/step - loss: 0.0143 - acc: 0.9951 - val_loss: 2.5091 - val_acc: 0.6760
Epoch 99/100
1432/1432 [==============================] - 0s 217us/step - loss: 0.0180 - acc: 0.9937 - val_loss: 2.2953 - val_acc: 0.7095
Epoch 100/100
1432/1432 [==============================] - 0s 217us/step - loss: 0.0145 - acc: 0.9972 - val_loss: 0.4745 - val_acc: 0.8911
[5]:
<keras.callbacks.History at 0x7fb629d771d0>

Step 3: Instantiate the explainer

Here, we will use the LIME Text ExplainerFactory.

[6]:
explainer = ExplainerFactory.get_explainer(domain=xai.DOMAIN.TEXT)

Step 4: Build the explainer

This initializes the underlying explainer object. We provide the explain_instance method below with the raw text - LIME’s text explainer algorithm will conduct its own preprocessing in order to generate interpretable representations of the data. Hence we must define a custom predict_fn which takes a raw piece of text, vectorizes it using the trained tokenizer, and passes the vector into the Keras model to generate class probabilities. LIME uses predict_fn to query our neural network order to learn its behavior around the provided data instance.

[7]:
def predict_fn(instance):
    # Convert string to index
    sequence = tokenizer.texts_to_sequences(instance)

    # Padding
    data = pad_sequences(sequence, maxlen=MAX_SEQUENCE_LENGTH, padding='post')

    # Convert to numpy array
    arr = np.array(data, dtype='float32')

    return model.predict([arr])

explainer.build_explainer(predict_fn)

Step 5: Generate some explanations

[8]:
exp = explainer.explain_instance(
    labels=[0, 1, 2],
    instance=raw_train.data[100],
    num_features=10
)

print('Label', raw_train.target_names[raw_train.target[100]])
pprint(exp)
/experiments/venv_codebook_defenses/lib/python3.5/re.py:203: FutureWarning: split() requires a non-empty pattern match.
  return _compile(pattern, flags).split(string, maxsplit)
Label rec.sport.baseball
{0: {'confidence': 0.9999999,
     'explanation': [{'feature': 'game', 'score': 0.2595229194688601},
                     {'feature': 'again', 'score': 0.18668745076997575},
                     {'feature': 'Yankees', 'score': 0.1860972771358493},
                     {'feature': 'pitches', 'score': 0.15736066007125038},
                     {'feature': 'Liberalizer', 'score': 0.1347915665789044},
                     {'feature': 'can', 'score': 0.12968895498952704},
                     {'feature': 'think', 'score': 0.11919896484535476},
                     {'feature': 'am', 'score': 0.11277455237479057},
                     {'feature': 'I', 'score': -0.05480655587778587},
                     {'feature': 'believe', 'score': -0.043605727115050126}]},
 1: {'confidence': 2.9914535e-09,
     'explanation': [{'feature': 'game', 'score': -0.001329492360436989},
                     {'feature': 'going', 'score': -0.0012736316324735816},
                     {'feature': 'the', 'score': -0.0011308983744577278},
                     {'feature': 'this', 'score': -0.0009705440718428874},
                     {'feature': 'Allegheny', 'score': 0.0008826584173278725},
                     {'feature': 'College', 'score': -0.0008800347780071557},
                     {'feature': 'Yankees', 'score': 0.0008542046802342881},
                     {'feature': 'can', 'score': 0.0007478587703458676},
                     {'feature': 'it', 'score': 0.0007078390948858462},
                     {'feature': 'believe', 'score': -0.0006601853762163212}]},
 2: {'confidence': 7.7463795e-08,
     'explanation': [{'feature': 'game', 'score': -0.2580333837550864},
                     {'feature': 'Yankees', 'score': -0.18679854511492092},
                     {'feature': 'again', 'score': -0.1860578505206049},
                     {'feature': 'pitches', 'score': -0.15752556251321137},
                     {'feature': 'Liberalizer', 'score': -0.13402499305687204},
                     {'feature': 'can', 'score': -0.13028440845443395},
                     {'feature': 'think', 'score': -0.11928233701465708},
                     {'feature': 'am', 'score': -0.11278246486555822},
                     {'feature': 'I', 'score': 0.054717479221668884},
                     {'feature': 'believe', 'score': 0.04448656555204294}]}}

Step 6: Save and load the explainer

Like with the LIME tabular explainer, we can save and load the explainer via load_explainer and save_explainer respectively.

[9]:
# Save the explainer somewhere

explainer.save_explainer('artefacts/lime_text_keras.pkl')
[10]:
# Load the saved explainer in a new ExplainerFactory instance

new_explainer = ExplainerFactory.get_explainer(domain=xai.DOMAIN.TEXT, algorithm=xai.ALG.LIME)
new_explainer.load_explainer('artefacts/lime_text_keras.pkl')

exp = new_explainer.explain_instance(
    instance=raw_train.data[20],
    labels=[0, 1, 2],
    num_features=5
)

print('Label', raw_train.target_names[raw_train.target[20]])
pprint(exp)
/experiments/venv_codebook_defenses/lib/python3.5/re.py:203: FutureWarning: split() requires a non-empty pattern match.
  return _compile(pattern, flags).split(string, maxsplit)
Label rec.sport.baseball
{0: {'confidence': 1.0,
     'explanation': [{'feature': 'baseball', 'score': 0.20855838996714443},
                     {'feature': 'stadium', 'score': 0.13433699819432926},
                     {'feature': 'football', 'score': 0.07024692251378775},
                     {'feature': 'in', 'score': -0.031260284138860533},
                     {'feature': 'with', 'score': -0.03063213505227814}]},
 1: {'confidence': 3.4470056e-13,
     'explanation': [{'feature': 'baseball', 'score': -0.009481833034551843},
                     {'feature': 'the', 'score': -0.007167239192790489},
                     {'feature': 'multipurpose', 'score': 0.004503424932271883},
                     {'feature': 'It', 'score': 0.004496400507397244},
                     {'feature': 'let', 'score': 0.004477692674225241}]},
 2: {'confidence': 3.5067134e-14,
     'explanation': [{'feature': 'baseball', 'score': -0.1942792635465446},
                     {'feature': 'stadium', 'score': -0.12629712306511712},
                     {'feature': 'football', 'score': -0.06171824737116527},
                     {'feature': 'with', 'score': 0.0384714317863913},
                     {'feature': 'play', 'score': -0.030882082530022472}]}}