Source code for model.interpreter.explanation_aggregator

#!/usr/bin/python
# -*- coding: utf-8 -*-

# Copyright 2019 SAP SE or an SAP affiliate company. All rights reserved
# ============================================================================

import operator
from collections import defaultdict, Counter
from typing import Dict, Tuple, List, Set

from xai.explainer.constants import OUTPUT
from xai.model.interpreter.exceptions import InvalidExplanationFormat, \
    MutipleScoresFoundForSameFeature, UnsupportedMethodType, \
    InvalidArgumentError


################################################################################
### Explanation Aggregator
################################################################################
[docs]class ExplanationAggregator: """ Class for explanation aggregator. It aggregates the explanations based on classes, feature and scores. """ def __init__(self, confidence_threshold=0.8): self._explanation_list = defaultdict(list) self._total_count = 0 self._class_counter = defaultdict(int) self._confidence_threshold = confidence_threshold
[docs] def get_feature_names(self, list_explanations: List[Dict]) -> Set: """ Get feature names for an explanation, plus schema validation Args: list_explanations (list): List of explanations Returns: (set) feature names """ feature_names = set() for item in list_explanations: if type(item) != dict: raise InvalidExplanationFormat(item) if OUTPUT.FEATURE not in item.keys(): raise InvalidExplanationFormat(item) if type(item[OUTPUT.FEATURE]) != str: raise InvalidExplanationFormat(item) if item[OUTPUT.FEATURE] in feature_names: raise MutipleScoresFoundForSameFeature(item[OUTPUT.FEATURE], list_explanations) else: feature_names.add(item[OUTPUT.FEATURE]) if OUTPUT.SCORE not in item.keys(): raise InvalidExplanationFormat(item) if type(item[OUTPUT.SCORE]) != float: raise InvalidExplanationFormat(item) return feature_names
[docs] def feed(self, explanation: Dict): """ Feed explanation into the aggregator for further analysis Args: explanation: dict, the pre-defined format as the output in `xai.explainer.utils.explanation_to_json` """ if OUTPUT.EXPLANATION in explanation: # Regression schema _ = self.get_feature_names(explanation[OUTPUT.EXPLANATION]) else: # Classification schema for _label, _exp in explanation.items(): if type(_exp) != dict: raise InvalidExplanationFormat(_exp) if OUTPUT.EXPLANATION not in _exp.keys(): raise InvalidExplanationFormat(_exp) if type(_exp[OUTPUT.EXPLANATION]) != list: raise InvalidExplanationFormat(_exp) _ = self.get_feature_names(_exp[OUTPUT.EXPLANATION]) if OUTPUT.EXPLANATION in explanation: # Regression schema # To follow downstream schema, we set the "label" of regression prediction to 'NA' _label = 'NA' if explanation[OUTPUT.PREDICTION] > self._confidence_threshold: self._explanation_list[_label].append( {item[OUTPUT.FEATURE]: item[OUTPUT.SCORE] for item in explanation[OUTPUT.EXPLANATION]}) self._class_counter[_label] += 1 else: # Classification schema for _label, _exp in explanation.items(): if _exp[OUTPUT.PREDICTION] > self._confidence_threshold: self._explanation_list[_label].append( {item[OUTPUT.FEATURE]: item[OUTPUT.SCORE] for item in _exp[OUTPUT.EXPLANATION]}) self._class_counter[_label] += 1 self._total_count += 1
[docs] def get_statistics(self, stats_type: str = 'top_k', k: int = 5) -> Tuple[Dict[int, Dict], int]: """ return statistics of explanations in the aggregator based on the type Args: stats_type: str, not None. The pre-defined types of statistics. For now, it supports 3 types: - top_k: how often a feature appears in the top K features in the explanation - average_score: average score for each feature in the explanation - average_ranking: average ranking for each feature in the explanation Default type is `top_k`. k: int, not None. the k value for `top_k` method and `average_ranking`. It will be ignored if the stats type are not `top_k` or `average_ranking`. Default value of k is 5. Returns: A dictionary maps the label to its aggregated statistics. An integer to indicate the total number of explanations to generate the statistics. """ if stats_type not in ['top_k', 'average_score', 'average_ranking']: raise UnsupportedMethodType(stats_type) if type(k) != int: raise InvalidArgumentError('k', '<int>') if self._total_count == 0: return dict(), 0 label_counter = defaultdict(Counter) for _label, _exp_list in self._explanation_list.items(): for _exp in _exp_list: if stats_type == 'top_k': top_k_list = sorted(_exp.items(), key=operator.itemgetter(1), reverse=True)[:k] _exp_counter = Counter({feature_name: 1 for feature_name, _ in top_k_list}) elif stats_type == 'average_score': _exp_counter = Counter(_exp) elif stats_type == 'average_ranking': top_k_list = sorted(_exp.items(), key=operator.itemgetter(1), reverse=True)[:k] _exp_counter = Counter( {feature_name: k - idx for idx, (feature_name, _) in enumerate(top_k_list)}) label_counter[_label].update(_exp_counter) stats = dict() for _label, _counter in label_counter.items(): stats[_label] = {name: (score / self._class_counter[_label]) for name, score in list(sorted(_counter.items(), key=operator.itemgetter(1), reverse=True))} return stats, self._total_count