Skip to content

Explanations

mercury.explainability.explanations

anchors

AnchorsWithImportanceExplanation(explain_data, explanations, categorical={})

Bases: object

Extended Anchors Explanations

Parameters:

Name Type Description Default
explain_data DataFrame

A pandas DataFrame containing the observations for which an explanation has to be found.

required
explanations List

A list containing the results of computing the explanations for explain_data.

required
categorical dict

A dictionary containing as key the features that are categorical and as value, the possible categorical values.

{}
Source code in mercury/explainability/explanations/anchors.py
21
22
23
24
25
26
27
28
29
def __init__(
        self,
        explain_data: pd.DataFrame,
        explanations: TP.List,
        categorical: dict = {}
    ) -> None:
    self.explain_data = explain_data
    self.explanations = explanations
    self.categorical = categorical
interpret_explanations(n_important_features)

This method prints a report of the important features obtaiend.

Parameters:

Name Type Description Default
n_important_features int

The number of imporant features that will appear in the report. Defaults to 3.

required
Source code in mercury/explainability/explanations/anchors.py
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
def interpret_explanations(self, n_important_features: int) -> str:
    """
    This method prints a report of the important features obtaiend.

    Args:
        n_important_features:
            The number of imporant features that will appear in the report.
            Defaults to 3.
    """
    names = []
    explanations_found = [explan for explan in self.explanations if not isinstance(explan, str)]
    for expl in explanations_found:
        for name in expl.data['anchor']:
            # split without an argument splits by spaces, and in every item in expl['names']
            # the first word refers to the feature name.
            if (
                (' = ' in name) or 
                ((len(self.categorical) > 0) and (name in [item for sublist in list(self.categorical.values()) for item in sublist]))
                ):
                names.append(name)
            else:
                names.append(' '.join(name[::-1].split('.', 1)[1][::-1].split()[:-1]))

    unique_names, count_names = np.unique(names, return_counts=True)
    top_feats = heapq.nlargest(n_important_features, count_names)
    print_values = ['The ', str(n_important_features), ' most common features are: ']
    unique_names_ordered = sorted(unique_names.tolist(), key=lambda x: count_names[unique_names.tolist().index(x)], reverse=True)
    count_names_ordered = sorted(count_names.tolist(), reverse=True)
    n_explanations = 0
    for unique_name, count_name in zip(unique_names_ordered[:n_important_features], count_names_ordered[:n_important_features]):
        if n_explanations == 0:
            print_values.append([unique_name, ' with a frequency of ', 
                str(count_name), ' (', str(100 * count_name / len(explanations_found)), '%) '])
        elif n_explanations == n_important_features - 1:
            print_values.append([' and ', unique_name, ' with a frequency of ', 
                str(count_name), ' (', str(100 * count_name / len(explanations_found)), '%) '])
        else:
            print_values.append([', ',unique_name, ' with a frequency of ', 
                str(count_name), ' (', str(100 * count_name / len(explanations_found)), '%) '])
    n_explanations += 1
    interptretation = ''.join(list(itertools.chain(*print_values)))
    print(interptretation)
    return interptretation

clustering_tree_explanation

ClusteringTreeExplanation(tree, feature_names=None)

Explanation for ClusteringTreeExplainer. Represents a Decision Tree for the explanation of a clustering algorithm. Using the plot method generates a visualization of the decision tree (requires graphviz package)

Parameters:

Name Type Description Default
tree Node

the fitted decision tree

required
feature_names List

the feature names used in the decision tree

None
Source code in mercury/explainability/explanations/clustering_tree_explanation.py
24
25
26
27
28
29
30
def __init__(
    self,
    tree: "Node",  # noqa: F821
    feature_names: List = None,
):
    self.tree = tree
    self.feature_names = feature_names
plot(filename='tree_explanation', feature_names=None, scalers=None)

Generates a graphviz.Source object representing the decision tree, which can be visualized in a notebook or saved in a file.

Parameters:

Name Type Description Default
filename str

filename to save if render() method is called over the returned object

'tree_explanation'
feature_names List

the feature names to use. If not specified, the feature names specified in the constructor are used.

None
scalers dict

dictionary of scalers. If passed, the tree will show the denormalized value in the split instead of the normalized value. The key is the feature name and the scaler must have the inverse_transform method

None

Returns:

Type Description
Source

object representing the decision tree.

Source code in mercury/explainability/explanations/clustering_tree_explanation.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
def plot(self, filename: str = "tree_explanation", feature_names: List = None, scalers: dict = None):

    """
    Generates a graphviz.Source object representing the decision tree, which can be visualized in a notebook
    or saved in a file.

    Args:
        filename: filename to save if render() method is called over the returned object
        feature_names: the feature names to use. If not specified, the feature names specified in the constructor
            are used.
        scalers: dictionary of scalers. If passed, the tree will show the denormalized value in the split instead
            of the normalized value. The key is the feature name and the scaler must have the `inverse_transform`
            method

    Returns:
        (graphviz.Source): object representing the decision tree.
    """

    feature_names = self.feature_names if feature_names is None else feature_names
    scalers = {} if scalers is None else scalers

    if not graphviz_available:
        raise Exception("Required package is missing. Please install graphviz")

    if self.tree is not None:
        dot_str = ["digraph ClusteringTree {\n"]
        queue = [self.tree]
        nodes = []
        edges = []
        id = 0
        while len(queue) > 0:
            curr = queue.pop(0)
            if curr.is_leaf():
                label = "%s\nsamples=\%d\nmistakes=\%d" % (str(self._get_node_split_value(curr)), curr.samples, curr.mistakes) # noqa
            else:
                feature_name = curr.feature if feature_names is None else feature_names[curr.feature]
                condition = "%s <= %.3f" % (feature_name, self._get_node_split_value(curr, feature_name, scalers))
                label = "%s\nsamples=\%d" % (condition, curr.samples) # noqa
                queue.append(curr.left)
                queue.append(curr.right)
                edges.append((id, id + len(queue) - 1))
                edges.append((id, id + len(queue)))
            nodes.append({"id": id,
                          "label": label,
                          "node": curr})
            id += 1
        for node in nodes:
            dot_str.append("n_%d [label=\"%s\"];\n" % (node["id"], node["label"]))
        for edge in edges:
            dot_str.append("n_%d -> n_%d;\n" % (edge[0], edge[1]))
        dot_str.append("}")
        dot_str = "".join(dot_str)
        s = Source(dot_str, filename=filename + '.gv', format="png")
        return s

counter_factual

CounterfactualBasicExplanation(from_, to_, p, path, path_ps, bounds, explored=np.array([]), explored_ps=np.array([]), labels=[])

Bases: object

A Panallet explanation.

Parameters:

Name Type Description Default
from_ ndarray

Starting point.

required
to_ ndarray

Found solution.

required
p float

Probability of found solution.

required
path ndarray

Path followed to get to the found solution.

required
path_ps ndarray

Probabilities of each path step.

required
bounds ndarray

Feature bounds used when exploring the probability space.

required
explored ndarray

Points explored but not visited (available only when backtracking strategy is used, empty for Simulated Annealing)

array([])
explored_ps ndarray

Probabilities of explored points (available only when backtracking strategy is used, empty for Simulated Annealing)

array([])
labels Optional[List[str]]

Labels to be used for each point dimension (used when plotting).

[]

Raises:

Type Description
AssertionError

if from_ shape != to_.shape

AssertionError

if dim(from_) != 1

AssertionError

if not 0 <= p <= 1

AssertionError

if path.shape[0] != path_ps.shape[0]

AssertionError

if bounds.shape[0] != from_.shape[0]

AssertionError

if explored.shape[0] != explored_ps.shape[0]

AssertionError

if len(labels) > 0 and len(labels) != bounds.shape[0]

Source code in mercury/explainability/explanations/counter_factual.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
def __init__(self,
             from_: 'np.ndarray',
             to_: 'np.ndarray',
             p: float,
             path: 'np.ndarray',
             path_ps: 'np.ndarray',
             bounds: 'np.ndarray',
             explored: 'np.ndarray' = np.array([]),
             explored_ps: 'np.ndarray' = np.array([]),
             labels: TP.Optional[TP.List[str]] = []) -> None:
    # Initial/end points
    assert from_.shape == to_.shape and from_.ndim == 1, 'Invalid dimensions'
    self.from_ = from_
    self.to_ = to_

    # Found solution probability
    assert p >= 0 and p <= 1, 'Invalid probability'
    self.p = p

    # Path followed till solution is found
    assert path.shape[0] == path_ps.shape[0], \
        'Invalid shape for path probabilities, got {} but expected {}'.format(path.shape[0], path_ps.shape[0])
    self.path = path
    self.path_ps = path_ps

    # Used bounds in the solution
    assert bounds.shape[0] == self.from_.shape[0], 'Invalid bounds shape'
    self.bounds = bounds

    assert explored.shape[0] == explored_ps.shape[0], \
        'Invalid shape for explored probabilities, got {} but expected {}'.format(explored.shape[0],
                                                                                  explored_ps.shape[0])
    self.explored = explored

    if labels is not None and len(labels) > 0:
        assert len(labels) == self.bounds.shape[0], 'Invalid number of labels'
    self.labels = labels
__verbose()

Internal debug information.

Source code in mercury/explainability/explanations/counter_factual.py
236
237
238
239
240
241
242
243
244
245
246
247
248
249
def __verbose(self):  # pragma: no cover
    """ Internal debug information. """

    print('Used bounds:')
    for i in range(self.bounds.shape[0]):
        label = self.labels[i] if self.labels else ''
        print('\t[{}] {}: [{}, {}]'.format(i, label, self.bounds[i][0], self.bounds[i][1]))
    print('Starting point: {}'.format(self.from_))
    print('Found solution: {} with probability {}'.format(self.to_, self.p))
    print('Changes:')
    for i in range(self.from_.shape[0]):
        if self.from_[i] != self.to_[i]:
            label = self.labels[i] if self.labels else ''
            print('\t[{}] {}: {} -> {}'.format(i, label, self.from_[i], self.to_[i]))
get_changes(relative=True)

Returns relative/absolute changes between initial and ending point.

Parameters:

Name Type Description Default
relative bool

True for relative changes, False for absolute changes.

True

Returns:

Type Description
ndarray

(np.ndarray) Relative or absolute changes for each feature.

Source code in mercury/explainability/explanations/counter_factual.py
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
def get_changes(self, relative=True) -> 'np.ndarray':
    """
    Returns relative/absolute changes between initial and ending point.

    Args:
        relative (bool):
            True for relative changes, False for absolute changes.

    Returns:
        (np.ndarray) Relative or absolute changes for each feature.
    """

    if relative:
        # Avoid divs by zero
        aux = self.from_.copy()
        aux[aux == 0.] = 1.
        return (self.to_.squeeze() - self.from_.squeeze()) * 100 / (np.sign(aux) * aux.squeeze())
    else:
        return self.to_.squeeze() - self.from_.squeeze()
show(figsize=(12, 6), debug=False, path=None, backend='matplotlib')

Creates a plot with the explanation.

Parameters:

Name Type Description Default
figsize tuple

Width and height of the figure (inches if matplotlib backend is used, pixels for bokeh backend).

(12, 6)
debug bool

Display verbose information (debug mode).

False
Source code in mercury/explainability/explanations/counter_factual.py
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
def show(self, figsize: TP.Tuple[int, int] = (12, 6), debug: bool = False,
         path: TP.Optional[str] = None, backend='matplotlib') -> None:  # pragma: no cover
    """
    Creates a plot with the explanation.

    Args:
        figsize (tuple):
            Width and height of the figure (inches if matplotlib backend is used,
            pixels for bokeh backend).
        debug (bool):
            Display verbose information (debug mode).
    """

    def _show(from_: 'np.ndarray', to_: 'np.ndarray', backend='matplotlib',
              path: TP.Optional[str] = None, debug: bool = False) -> None:
        """ Backend specific show method. """

        if backend == 'matplotlib':
            # It seems we can't decouple figure from axes
            fig = plt.figure(figsize=figsize)

            # LIME-like hbars showing relative differences
            ax = plt.subplot2grid((2, 5), (0, 1))
            CounterfactualBasicExplanation.plot_butterfly(
                self.get_changes(relative=False), self.labels, ax,
                title='Absolute delta')

            ax = plt.subplot2grid((2, 5), (0, 3))
            CounterfactualBasicExplanation.plot_butterfly(
                self.get_changes(relative=True), self.labels, ax,
                title='Relative delta')

            # Probabilities
            ax = plt.subplot2grid((2, 5), (1, 0), colspan=5)
            xs = np.arange(len(self.path_ps))
            ys = self.path_ps
            cax = ax.scatter(xs, ys, c=ys)
            fig.colorbar(cax)
            ax.plot(xs, ys, '--', c='k', linewidth=.2, alpha=.3)
            ax.grid()
            ax.set_title('Visited itinerary')
            ax.set_xlabel('# Iteration')
            ax.set_ylabel('probability')
            plt.tight_layout()

            if path is not None:
                plt.savefig(path, output='pdf')
            else:
                plt.show()

        elif backend == 'bokeh':
            # LIME-like hbars showing relative differences
            values = self.get_changes()
            fig1 = BP.figure(plot_width=400, plot_height=300, y_range=self.labels,
                             x_range=(min(values), max(values)), x_axis_label='Relative change')
            colors = np.where(values <= 0, '#ff0000', '#00ff00')
            fig1.hbar(y=self.labels, height=0.75, right=values, fill_color=colors)

            # LIME-like hbars showing absolute differences
            values = self.get_changes(relative=False)
            fig2 = BP.figure(plot_width=400, plot_height=300, y_range=self.labels,
                             x_range=(min(values), max(values)), x_axis_label='Absolute change')
            colors = np.where(values <= 0, '#ff0000', '#00ff00')
            fig2.hbar(y=self.labels, height=0.75, right=values, fill_color=colors, line_color=None)

            # Probabilities
            fig3 = BP.figure(plot_width=800, plot_height=200, x_axis_label='Step', y_axis_label='p')
            xs = np.arange(self.path_ps.size)
            ys = self.path_ps
            color_mapper = LinearColorMapper(palette='Viridis256', low=min(self.path_ps),
                                             high=max(self.path_ps))
            color_bar = ColorBar(color_mapper=color_mapper, location=(0, 0))
            fig3.circle(xs, ys, size=5, fill_color={'field': 'y', 'transform': color_mapper},
                        fill_alpha=.3, line_color=None)
            fig3.add_layout(color_bar, 'left')
            fig3.line(xs, ys, line_dash='dashed', line_alpha=.3, line_width=.2)

            row1 = row([fig1, fig2])
            row2 = row([fig3])
            lyt = column([row1, row2])

            if path is not None:
                BPIO.export_png(lyt, filename=path)
            else:
                BPIO.output_notebook(hide_banner=True)
                BP.show(lyt)

        else:
            raise ValueError('Unsupported backend')

        if debug:
            self.__verbose()

    _show(self.from_, self.to_, debug=debug, path=path, backend=backend)

CounterfactualWithImportanceExplanation(explain_data, counterfactuals, importances, count_diffs, count_diffs_norm)

Bases: object

Extended Counterfactual Explanations

Parameters:

Name Type Description Default
explain_data DataFrame

A pandas DataFrame containing the observations for which an explanation has to be found.

required
explanations

A list containing the results of computing the explanations for explain_data.

required
categorical

A dictionary containing as key the features that are categorical and as value, the possible categorical values.

required
Source code in mercury/explainability/explanations/counter_factual.py
266
267
268
269
270
271
272
273
274
275
276
277
278
def __init__(
        self,
        explain_data: pd.DataFrame,
        counterfactuals: TP.List[dict],
        importances: TP.List[TP.Tuple],
        count_diffs: dict,
        count_diffs_norm: dict
    ) -> None:
    self.explain_data = explain_data
    self.counterfactuals = counterfactuals
    self.importances = importances
    self.count_diffs = count_diffs
    self.count_diffs_norm = count_diffs_norm
interpret_explanations(n_important_features=3)

This method prints a report of the important features obtaiend.

Parameters:

Name Type Description Default
n_important_features int

The number of imporant features that will appear in the report. Defaults to 3.

3
Source code in mercury/explainability/explanations/counter_factual.py
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
    def interpret_explanations(self, n_important_features: int = 3) -> str:
        """
        This method prints a report of the important features obtaiend.

        Args:
            n_important_features:
                The number of imporant features that will appear in the report.
                Defaults to 3.
        """

        importances_str = []
        for n in range(n_important_features):
            importance_str = [imp if isinstance(imp, str) else '{:.2f}'.format(imp) for imp in self.importances[n]]
            importances_str.append(importance_str)

        count_diffs_norm_str = []
        for n in range(n_important_features):
            count_diffs_i = list(self.count_diffs_norm.items())[n]
            count_diff_norm_str = '{} {:.2f}'.format(count_diffs_i[0], count_diffs_i[1])
            count_diffs_norm_str.append(count_diff_norm_str)

        interptretation = """The {} most important features and their importance values according to the first metric (amount features change) are: 
    {}.

According to the second metric (times features change), these importances are: 
    {}""".format(
            n_important_features, 
            ' AND '.join([' '.join(imp_str) for imp_str in importances_str]),
            ' AND '.join(count_diffs_norm_str)
        )
        print(interptretation)
        return interptretation

partial_dependence

PartialDependenceExplanation(data)

This class holds the result of a Partial Dependence explanation and provides functionality for plotting those results via Partial Dependence Plots.

Parameters:

Name Type Description Default
data dict

Contains the result of the PartialDependenceExplainer. It must be in the form of: :: { 'feature_name': {'values': [...], 'preds': [...], 'lower_quantile': [...], 'upper_quantile': [...]}, 'feature_name2': {'values': [...], 'preds': [...], 'lower_quantile': [...], 'upper_quantile': [...]}, ... }

required
Source code in mercury/explainability/explanations/partial_dependence.py
23
24
def __init__(self, data):
    self.data = data
__getitem__(key)

Gets the dependence data of the desired feature.

Parameters:

Name Type Description Default
key str

Name of the feature.

required
Source code in mercury/explainability/explanations/partial_dependence.py
185
186
187
188
189
190
191
192
193
def __getitem__(self, key:str):
    """
    Gets the dependence data of the desired feature.

    Args:
        key (str):
            Name of the feature.
    """
    return self.data[key]['values'], self.data[key]['preds']
plot(ncols=1, figsize=(15, 15), quantiles=False, filter_classes=None, **kwargs)

Plots a summary of all the partial dependences.

Parameters:

Name Type Description Default
ncols int

Number of columns of the summary. 1 as default.

1
quantiles bool or list

Whether to also plot the quantiles and a shaded area between them. Useful to check whether the predictions have high or low dispersion. If this is a list of booleans, quantiles will be plotted filtered by class (i.e. quantiles[0] = class number 0).

False
filter_clases list

List of bool with the classes to plot. If None, all classes will be plotted. Ignored if the target variable is not categorical.

required
figsize tuple

Size of the plotted figure

(15, 15)
Source code in mercury/explainability/explanations/partial_dependence.py
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
def plot(self, ncols:int = 1, figsize:tuple = (15,15), quantiles:TP.Union[bool, list] = False, filter_classes:list = None, **kwargs):
    """
    Plots a summary of all the partial dependences.

    Args:
        ncols (int):
            Number of columns of the summary. 1 as default.
        quantiles (bool or list):
            Whether to also plot the quantiles and a shaded area between them. Useful to check whether the predictions
            have high or low dispersion. If this is a list of booleans, quantiles
            will be plotted filtered by class (i.e. `quantiles[0]` = `class number 0`).
        filter_clases (list):
            List of bool with the classes to plot. If None, all classes will be plotted. Ignored if the target variable
            is not categorical.
        figsize (tuple):
            Size of the plotted figure
    """
    features = list(self.data.keys())

    fig, ax = plt.subplots(ceil(len(features) / ncols), ncols, figsize=figsize)

    for i, feat_name in enumerate(features):
        sbplt = ax[i] if ncols==1 or ncols==len(features) else ax[i // ncols, i % ncols]
        self.plot_single(feat_name, sbplt, quantiles=quantiles, filter_classes=filter_classes, **kwargs)
plot_single(var_name, ax=None, quantiles=False, filter_classes=None, **kwargs)

Plots the partial dependence of a single variable.

Parameters:

Name Type Description Default
var_name str

Name of the desired variable to plot.

required
quantiles bool or list[bool]

Whether to also plot the quantiles and a shaded area between them. Useful to check whether the predictions have high or low dispersion. If data doesn't contain the quantiles this parameter will be ignored.

False
filter_clases list

List of bool with the classes to plot. If None, all classes will be plotted. Ignored if the target variable is not categorical.

required
ax AxesSubplot

Axes object on which the data will be plotted.

None
Source code in mercury/explainability/explanations/partial_dependence.py
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
def plot_single(self, var_name: str, ax=None, quantiles:TP.Union[bool, list] = False, filter_classes:list = None, **kwargs):
    """
    Plots the partial dependence of a single variable.

    Args:
        var_name (str):
            Name of the desired variable to plot.
        quantiles (bool or list[bool]):
            Whether to also plot the quantiles and a shaded area between them. Useful to check whether the predictions
            have high or low dispersion. If data doesn't contain the quantiles this parameter will be ignored.
        filter_clases (list):
            List of bool with the classes to plot. If None, all classes will be plotted. Ignored if the target variable
            is not categorical.
        ax (matplotlib.axes._subplots.AxesSubplot):
            Axes object on which the data will be plotted.
    """
    # If user pass a single bool and prediction data is a multinomial, we conver the
    # single boolean to a mask array to only plot the quantile range over the selected
    # classes.
    if len(self.data[var_name]['preds'].shape)>=2:
        if type(quantiles) == list and len(quantiles) != self.data[var_name]['preds'].shape[1]:
            raise ValueError("len(quantiles) must be equal to the number of classes.")
        if type(quantiles) == bool:
            quantiles = [quantiles for i in range(self.data[var_name]['preds'].shape[1])]
    elif type(quantiles) == list and len(self.data[var_name]['preds'].shape)==1:
        quantiles = quantiles[0]

    if filter_classes is not None:
        filter_classes = np.where(filter_classes)[0].tolist()
    else:
        filter_classes = np.arange(self.data[var_name]['preds'].shape[-1]).tolist()
        if len(self.data[var_name]['preds'].shape) < 2:
            filter_classes = None

    ax = ax if ax else plt.gca()

    ax.set_title(var_name)
    ax.set_xlabel(f"{var_name} value")
    ax.set_ylabel("Avg model prediction")

    vals = np.array(self.data[var_name]['values'])
    int_locations = np.arange(len(vals))

    non_numerical_values = False
    # Check if variable is categorical. If so, plot bars
    if self.data[var_name]['categorical'] and not type(vals[0]) == float:
        bar_width = .2
        class_nb = 0 if not filter_classes else len(filter_classes)

        if type(vals[0]) == float or type(vals[0]) == int:
            ax.set_xticks(self.data[var_name]['values'])
        else:
            non_numerical_values = True
            bar_offsets = np.linspace(-bar_width, bar_width, num=class_nb) / class_nb
            ax.set_xticks(int_locations)
            ax.set_xticklabels(self.data[var_name]['values'])

        if class_nb == 0:
            # If prediction is a single scalar
            if non_numerical_values:
                ax.bar(int_locations, self.data[var_name]['preds'], width=bar_width, label='Prediction',**kwargs)
            else:
                ax.bar(vals, self.data[var_name]['preds'], width=bar_width, label='Prediction', **kwargs)

            if quantiles:
                ax.errorbar(
                    int_locations,
                    self.data[var_name]['preds'],
                    yerr=np.vstack([self.data[var_name]['lower_quantile'],
                                    self.data[var_name]['upper_quantile']]),
                    fmt='ko',
                    label='Quantiles',
                    **kwargs
                )

        else:
            # If prediction is multiclass
            for i in range(class_nb):
                if i in filter_classes:
                    if non_numerical_values:
                        ax.bar(int_locations + bar_offsets[i], self.data[var_name]['preds'][:,i],
                                width=bar_width / class_nb, label=f'Class {i}',**kwargs)
                    else:
                        ax.bar(vals, self.data[var_name]['preds'][:,i], width=bar_width / class_nb, label=f'Class {i}', **kwargs)

                if quantiles[i]:
                    ax.errorbar(
                        int_locations + bar_offsets[i],
                        self.data[var_name]['preds'][:, i],
                        yerr=np.vstack([self.data[var_name]['lower_quantile'][:,i],
                                        self.data[var_name]['upper_quantile'][:,i]]),
                        fmt='ko',
                        label=f'Quantiles {i}',
                        **kwargs
                    )

        if class_nb > 0:
            ax.legend()

    else:  # Variable is continuous

        # Check whether prediction data is multinomial
        if filter_classes:
            objs = ax.plot(vals, self.data[var_name]['preds'][:, filter_classes], **kwargs)
        else:
            objs = ax.plot(vals, self.data[var_name]['preds'], **kwargs)
        if len(self.data[var_name]['preds'].shape)>=2:
            labels = [f"Class: {i}" for i in range(self.data[var_name]['preds'].shape[1])]
            # Filter labels
            labels = [l for i, l in enumerate(labels) if i in filter_classes]
            # Show labels
            ax.legend(iter(objs), labels)
            for i in range(self.data[var_name]['preds'].shape[1]):
                if quantiles[i] and len(self.data[var_name]['lower_quantile']) > 0:
                    # Plot quantiles and a shaded band between them

                    # We will need the color assigned to each one of the lines so the
                    # shaded area also has that color. Since filtering can be done, we
                    # extract the line index as the minimum between the current class
                    # index and the maximum amount of lines on the canvas.
                    obj_index = min(i, len(objs) - 1)

                    # Actually plot the shaded area
                    ax.plot(vals, self.data[var_name]['lower_quantile'][:,i], ls='--', color=objs[obj_index].get_color(),**kwargs)
                    ax.plot(vals, self.data[var_name]['upper_quantile'][:,i], ls='--', color=objs[obj_index].get_color(), **kwargs)
                    ax.fill_between(vals,
                            self.data[var_name]['lower_quantile'][:,i], self.data[var_name]['upper_quantile'][:,i], alpha=.05)
        else:  # If target is not multinomial
            if quantiles and len(self.data[var_name]['lower_quantile']) > 0:
                # Plot quantiles and a shaded band between them
                ax.plot(vals, self.data[var_name]['lower_quantile'], ls='--', color=objs[0].get_color(),**kwargs)
                ax.plot(vals, self.data[var_name]['upper_quantile'], ls='--', color=objs[0].get_color(), **kwargs)
                ax.fill_between(vals, self.data[var_name]['lower_quantile'], self.data[var_name]['upper_quantile'], alpha=.05)

shuffle_importance

FeatureImportanceExplanation(data, reverse=False)

This class holds the data related to the importance a given feature has for a model.

Parameters:

Name Type Description Default
data dict

Contains the result of the PartialDependenceExplainer. It must be in the form of: :: { 'feature_name': 1.0, 'feature_name2': 2.3, ... }

required
reverse bool

Whether to reverse sort the features by increasing order (i.e. Worst performance (latest) = Smallest value). Default False (decreasing order).

False
Source code in mercury/explainability/explanations/shuffle_importance.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
def __init__(self, data:dict, reverse:bool = False):
    """
    This class holds the data related to the importance a given
    feature has for a model.

    Args:
        data (dict):
            Contains the result of the PartialDependenceExplainer. It must be in the
            form of: ::
                {
                    'feature_name': 1.0,
                    'feature_name2': 2.3, ...
                }

        reverse (bool):
            Whether to reverse sort the features by increasing order (i.e. Worst
            performance (latest) = Smallest value). Default False (decreasing order).
    """
    self.data = data
    self._sorted_features = sorted(list(data.items()), key=lambda i: i[1],
                                   reverse=not reverse)
__getitem__(key)

Gets the feature importance of the desired feature.

Parameters:

Name Type Description Default
key str

Name of the feature.

required
Source code in mercury/explainability/explanations/shuffle_importance.py
48
49
50
51
52
53
54
55
def __getitem__(self, key:str)->float:
    """
    Gets the feature importance of the desired feature.

    Args:
        key (str): Name of the feature.
    """
    return self.data[key]
get_importances()

Returns a list of tuples (feature, importance) sorted by importances.

Source code in mercury/explainability/explanations/shuffle_importance.py
57
58
59
60
def get_importances(self)->list:
    """ Returns a list of tuples (feature, importance) sorted by importances.
    """
    return self._sorted_features
plot(ax=None, figsize=(15, 15), limit_axis_x=False, **kwargs)

Plots a summary of the importances for each feature

Parameters:

Name Type Description Default
figsize tuple

Size of the plotted figure

(15, 15)
limit_axis_x bool

Whether to adjust axis x to limit between the minimum and maximum feature values

False
Source code in mercury/explainability/explanations/shuffle_importance.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
def plot(self, ax: "matplotlib.axes.Axes" = None,  # noqa:F821
         figsize: tuple = (15, 15), limit_axis_x=False, **kwargs) -> "matplotlib.axes.Axes":  # noqa:F821
    """
    Plots a summary of the importances for each feature

    Args:
        figsize (tuple): Size of the plotted figure
        limit_axis_x (bool): Whether to adjust axis x to limit between the minimum and maximum feature values
    """
    ax = ax if ax else plt.gca()

    feature_names = [i[0] for i in self._sorted_features]
    feature_values = [i[1] for i in self._sorted_features]
    ax.barh(feature_names, feature_values)

    if limit_axis_x:
        ax.set_xlim(min(feature_values), max(feature_values))

    return ax