This notebook shows some different types of analysis that can be conducted with Quantus, going from qualitative analysis to quantitative analysis and sensitivity analysis.
For this purpose, we use a pre-trained PyTorch MobileNet V3 model and ImageNet dataset.
from IPython.display import clear_output
!pip install torch torchvision captum quantus
clear_output()
import pathlib
import random
import copy
import gc
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torchvision
from captum.attr import *
import quantus
sns.set()
# Enable GPU.
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
We have prepared a small subset of ImageNet images which can be downloaded at: https://github.com/understandable-machine-intelligence-lab/Quantus/tree/main/tutorials/assets/imagenet_samples/. Please make sure to download the contents of the folder, that are, the inputs x_batch.pt
, outputs y_batch.pt
and segmentation masks s_batch.pt
. (A description for how to download the full dataset can be found here: https://image-net.org/download.php.)
If you use Google Colab, you might want to add the following:
# Mount Google Drive.
from google.colab import drive
drive.mount('/content/drive', force_remount=True)
# Mount Google Drive.
from google.colab import drive
drive.mount('/content/drive', force_remount=True)
Mounted at /content/drive
# Adjust this path.
path_to_files = "drive/MyDrive/Projects/quantus/tutorials/assets/imagenet_samples"
# Load test data and make loaders.
x_batch = torch.load(f'{path_to_files}/x_batch.pt')
y_batch = torch.load(f'{path_to_files}/y_batch.pt')
s_batch = torch.load(f'{path_to_files}/s_batch.pt')
x_batch, s_batch, y_batch = x_batch.to(device), s_batch.to(device), y_batch.to(device)
print(f"{len(x_batch)} matches found.")
17 matches found.
# Plot some inputs!
nr_images = 5
fig, axes = plt.subplots(nrows=1, ncols=nr_images, figsize=(nr_images*3, int(nr_images*2/3)))
for i in range(nr_images):
axes[i].imshow((np.moveaxis(quantus.normalise_func.denormalise(x_batch[i].cpu().numpy(), mean=np.array([0.485, 0.456, 0.406]), std=np.array([0.229, 0.224, 0.225])), 0, -1) * 255).astype(np.uint8), vmin=0.0, vmax=1.0, cmap="gray")
axes[i].title.set_text(f"ImageNet class - {y_batch[i].item()}")
axes[i].axis("off")
plt.show()
In this example we load a pre-trained MobileNet V3 model but it goes without saying that any model works.
def evaluate_model(model, data, device):
"""Evaluate accuracy of torch model."""
model.eval()
logits = torch.Tensor().to(device)
targets = torch.LongTensor().to(device)
with torch.no_grad():
for images, _, labels in data:
images, labels = images.to(device), labels.to(device)
logits = torch.cat([logits, model(images)])
targets = torch.cat([targets, labels])
return np.mean(np.argmax(logits.cpu().numpy(), axis=1) == targets.cpu().numpy())
# Load pre-trained MobileNet V3 model.
model = torchvision.models.mobilenet_v3_small(pretrained=True)
model = model.to(device)
/usr/local/lib/python3.7/dist-packages/torchvision/models/_utils.py:209: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and will be removed in 0.15, please use 'weights' instead. f"The parameter '{pretrained_param}' is deprecated since 0.13 and will be removed in 0.15, " /usr/local/lib/python3.7/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and will be removed in 0.15. The current behavior is equivalent to passing `weights=MobileNet_V3_Small_Weights.IMAGENET1K_V1`. You can also use `weights=MobileNet_V3_Small_Weights.DEFAULT` to get the most up-to-date weights. warnings.warn(msg)
There exist multiple ways to generate explanations for neural network models e.g., using captum
or innvestigate
libraries. In this example, we rely completely customised explainer functions. All necessary source code is pasted below.
def explainer_wrapper(**kwargs):
"""Wrapper for explainer functions."""
if kwargs["method"] == "Saliency":
return saliency_explainer(**kwargs)
elif kwargs["method"] == "IntegratedGradients":
return intgrad_explainer(**kwargs)
elif kwargs["method"] == "FusionGrad":
return fusiongrad_explainer(**kwargs)
elif kwargs["method"] == "GradientShap":
return gradshap_explainer(**kwargs)
else:
raise ValueError("Pick an explaination function that exists.")
def saliency_explainer(
model, inputs, targets, abs=False, normalise=False, *args, **kwargs
) -> np.array:
"""Wrapper aorund captum's Saliency implementation."""
gc.collect()
torch.cuda.empty_cache()
# Set model in evaluate mode.
model.to(kwargs.get("device", None))
model.eval()
if not isinstance(inputs, torch.Tensor):
inputs = (
torch.Tensor(inputs)
.reshape(
-1,
kwargs.get("nr_channels", 3),
kwargs.get("img_size", 224),
kwargs.get("img_size", 224),
)
.to(kwargs.get("device", None))
)
if not isinstance(targets, torch.Tensor):
targets = (
torch.as_tensor(targets).long().to(kwargs.get("device", None))
) # inputs = inputs.reshape(-1, 3, 224, 224)
assert (
len(np.shape(inputs)) == 4
), "Inputs should be shaped (nr_samples, nr_channels, img_size, img_size) e.g., (1, 3, 224, 224)."
explanation = (
Saliency(model)
.attribute(inputs, targets, abs=abs)
.sum(axis=1)
.reshape(-1, kwargs.get("img_size", 224), kwargs.get("img_size", 224))
.cpu()
.data
)
gc.collect()
torch.cuda.empty_cache()
if normalise:
explanation = quantus.normalise_func.normalise_by_negative(explanation)
if isinstance(explanation, torch.Tensor):
if explanation.requires_grad:
return explanation.cpu().detach().numpy()
return explanation.cpu().numpy()
return explanation
def intgrad_explainer(
model, inputs, targets, abs=False, normalise=False, *args, **kwargs
) -> np.array:
"""Wrapper aorund captum's Integrated Gradients implementation."""
gc.collect()
torch.cuda.empty_cache()
# Set model in evaluate mode.
model.to(kwargs.get("device", None))
model.eval()
if not isinstance(inputs, torch.Tensor):
inputs = (
torch.Tensor(inputs)
.reshape(
-1,
kwargs.get("nr_channels", 3),
kwargs.get("img_size", 224),
kwargs.get("img_size", 224),
)
.to(kwargs.get("device", None))
)
if not isinstance(targets, torch.Tensor):
targets = torch.as_tensor(targets).long().to(kwargs.get("device", None))
assert (
len(np.shape(inputs)) == 4
), "Inputs should be shaped (nr_samples, nr_channels, img_size, img_size) e.g., (1, 3, 224, 224)."
explanation = (
IntegratedGradients(model)
.attribute(
inputs=inputs,
target=targets,
baselines=torch.zeros_like(inputs),
n_steps=10,
method="riemann_trapezoid",
)
.sum(axis=1)
.reshape(-1, kwargs.get("img_size", 224), kwargs.get("img_size", 224))
.cpu()
.data
)
gc.collect()
torch.cuda.empty_cache()
if normalise:
explanation = quantus.normalise_func.normalise_by_negative(explanation)
if isinstance(explanation, torch.Tensor):
if explanation.requires_grad:
return explanation.cpu().detach().numpy()
return explanation.cpu().numpy()
return explanation
def gradshap_explainer(
model, inputs, targets, abs=False, normalise=False, *args, **kwargs
) -> np.array:
"""Wrapper aorund captum's GradShap implementation."""
gc.collect()
torch.cuda.empty_cache()
# Set model in evaluate mode.
model.to(kwargs.get("device", None))
model.eval()
if not isinstance(inputs, torch.Tensor):
inputs = (
torch.Tensor(inputs)
.reshape(
-1,
kwargs.get("nr_channels", 3),
kwargs.get("img_size", 224),
kwargs.get("img_size", 224),
)
.to(kwargs.get("device", None))
)
if not isinstance(targets, torch.Tensor):
targets = torch.as_tensor(targets).long().to(kwargs.get("device", None))
assert (
len(np.shape(inputs)) == 4
), "Inputs should be shaped (nr_samples, nr_channels, img_size, img_size) e.g., (1, 3, 224, 224)."
baselines = torch.zeros_like(inputs).to(kwargs.get("device", None))
explanation = (
GradientShap(model)
.attribute(inputs=inputs, target=targets, baselines=baselines)
.sum(axis=1)
.reshape(-1, kwargs.get("img_size", 224), kwargs.get("img_size", 224))
.cpu()
.data
)
gc.collect()
torch.cuda.empty_cache()
if normalise:
explanation = quantus.normalise_func.normalise_by_negative(explanation)
if isinstance(explanation, torch.Tensor):
if explanation.requires_grad:
return explanation.cpu().detach().numpy()
return explanation.cpu().numpy()
return explanation
def fusiongrad_explainer(
model, inputs, targets, abs=False, normalise=False, *args, **kwargs
) -> np.array:
"""Wrapper aorund captum's FusionGrad implementation."""
std = kwargs.get("std", 0.5)
mean = kwargs.get("mean", 1.0)
n = kwargs.get("n", 10)
m = kwargs.get("m", 10)
sg_std = kwargs.get("sg_std", 0.5)
sg_mean = kwargs.get("sg_mean", 0.0)
posterior_mean = kwargs.get("posterior_mean", None)
noise_type = kwargs.get("noise_type", "multiplicative")
clip = kwargs.get("clip", False)
def _sample(model, posterior_mean, std, distribution=None, noise_type="multiplicative"):
"""Implmentation to sample a model."""
# Load model params.
model.load_state_dict(posterior_mean)
# If std is not zero, loop over each layer and add Gaussian noise.
if not std == 0.0:
with torch.no_grad():
for layer in model.parameters():
if noise_type == "additive":
layer.add_(distribution.sample(layer.size()).to(layer.device))
elif noise_type == "multiplicative":
layer.mul_(distribution.sample(layer.size()).to(layer.device))
else:
print(
"Set NoiseGrad attribute 'noise_type' to either 'additive' or 'multiplicative' (str)."
)
return model
# Creates a normal (also called Gaussian) distribution.
distribution = torch.distributions.normal.Normal(
loc=torch.as_tensor(mean, dtype=torch.float),
scale=torch.as_tensor(std, dtype=torch.float),
)
# Set model in evaluate mode.
model.to(kwargs.get("device", None))
model.eval()
if not isinstance(inputs, torch.Tensor):
inputs = (
torch.Tensor(inputs)
.reshape(
-1,
kwargs.get("nr_channels", 3),
kwargs.get("img_size", 224),
kwargs.get("img_size", 224),
)
.to(kwargs.get("device", None))
)
if not isinstance(targets, torch.Tensor):
targets = torch.as_tensor(targets).long().to(kwargs.get("device", None))
assert (
len(np.shape(inputs)) == 4
), "Inputs should be shaped (nr_samples, nr_channels, img_size, img_size) e.g., (1, 3, 224, 224)."
if inputs.shape[0] > 1:
explanation = torch.zeros(
(
n,
m,
inputs.shape[0],
kwargs.get("img_size", 224),
kwargs.get("img_size", 224),
)
)
else:
explanation = torch.zeros(
(n, m, kwargs.get("img_size", 224), kwargs.get("img_size", 224))
)
for i in range(n):
model = _sample(
model=model,
posterior_mean=posterior_mean,
std=std,
distribution=distribution,
noise_type=noise_type,
)
for j in range(m):
inputs_noisy = inputs + torch.randn_like(inputs) * sg_std + sg_mean
if clip:
inputs_noisy = torch.clip(inputs_noisy, min=0.0, max=1.0)
explanation[i][j] = (
Saliency(model)
.attribute(inputs_noisy, targets, abs=abs)
.sum(axis=1)
.reshape(-1, kwargs.get("img_size", 224), kwargs.get("img_size", 224))
.cpu()
.data
)
explanation = explanation.mean(axis=(0, 1))
gc.collect()
torch.cuda.empty_cache()
if normalise:
explanation = quantus.normalise_func.normalise_by_negative(explanation)
if isinstance(explanation, torch.Tensor):
if explanation.requires_grad:
return explanation.cpu().detach().numpy()
return explanation.cpu().numpy()
return explanation
# Produce explanations and empty cache to to survive memory-wise.
# Saliency.
gc.collect()
torch.cuda.empty_cache()
a_batch_saliency = saliency_explainer(model=model.cpu(),
inputs=x_batch,
targets=y_batch,
**{"device": device},
)
# GradShap.
gc.collect()
torch.cuda.empty_cache()
a_batch_gradshap = gradshap_explainer(model=model.cpu(),
inputs=x_batch,
targets=y_batch,
**{"device": device},
)
# Integrated Gradients.
gc.collect()
torch.cuda.empty_cache()
a_batch_intgrad = intgrad_explainer(model=model.cpu(),
inputs=x_batch,
targets=y_batch,
**{"device": device},
)
# FusionGrad
gc.collect()
torch.cuda.empty_cache()
posterior_mean = copy.deepcopy(torchvision.models.mobilenet_v3_small(pretrained=True).to(device).state_dict())
a_batch_fusiongrad = fusiongrad_explainer(model=torchvision.models.mobilenet_v3_small(pretrained=True).to(device),
inputs=x_batch,
targets=y_batch,
**{"posterior_mean": posterior_mean, "mean": 1.0, "std": 0.5,
"sg_mean": 0.0, "sg_std": 0.5, "n": 25, "m": 25,
"noise_type": "multiplicative", "device": device})
# Save explanations to file.
explanations = {
"Saliency": a_batch_saliency,
"GradientShap": a_batch_gradshap,
"IntegratedGradients": a_batch_intgrad,
"FusionGrad": a_batch_fusiongrad
}
/usr/local/lib/python3.7/dist-packages/captum/_utils/gradient.py:59: UserWarning: Input Tensor 0 did not already require gradients, required_grads has been set automatically. "required_grads has been set automatically." % index
Visulise attributions given model and pairs of input-output.
index = 10 #random.randint(0, len(x_batch)-1)
fig, axes = plt.subplots(nrows=1, ncols=1+len(explanations), figsize=(15, 8))
axes[0].imshow(np.moveaxis(quantus.normalise_func.denormalise(x_batch[index].cpu().numpy(), mean=np.array([0.485, 0.456, 0.406]), std=np.array([0.229, 0.224, 0.225])), 0, -1), vmin=0.0, vmax=1.0)
axes[0].title.set_text(f"ImageNet class {y_batch[index].item()}")
axes[0].axis("off");
for i, (k, v) in enumerate(explanations.items()):
axes[i+1].imshow(quantus.normalise_func.normalise_by_negative(explanations[k][index].reshape(224, 224)), cmap="seismic", vmin=-1.0, vmax=1.0)
axes[i+1].title.set_text(f"{k}")
axes[i+1].axis("off");
In the following sections, we analyse the set of explanations under different perspectives:
# Plotting configs.
colours_order = ["#008080", "#FFA500", "#124E78", "#d62728"]
methods_order = ["Saliency (SA)", "Integrated\nGradients (IG)", "GradientShap (GS)", "FusionGrad (FG)"]
plt.rcParams['ytick.left'] = False
plt.rcParams['ytick.labelleft'] = False
plt.rcParams['xtick.bottom'] = False
plt.rcParams['xtick.labelbottom'] = False
include_titles = True
# Plot explanations!
index = 1
ncols = 1 + len(explanations)
fig, axes = plt.subplots(nrows=1, ncols=ncols, figsize=(15, int(ncols)*3))
for i in range(ncols):
if i == 0:
axes[0].imshow(np.moveaxis(quantus.normalise_func.denormalise(x_batch[index], mean=np.array([0.485, 0.456, 0.406]), std=np.array([0.229, 0.224, 0.225])), 0, -1), vmin=0.0, vmax=1.0)
if include_titles:
axes[0].set_title(f"ImageNet class {y_batch[index].item()}", fontsize=14)
axes[0].axis("off")
else:
xai = methods_order[i-1].split("(")[0].replace(" ", "").replace("\n", "")
axes[i].imshow(quantus.normalise_func.normalise_by_negative(explanations[xai][index].reshape(224, 224)), cmap="seismic", vmin=-1.0, vmax=1.0)
if include_titles:
axes[i].set_title(f"{methods_order[i-1]}", fontsize=14)
# Frame configs.
axes[i].xaxis.set_visible([])
axes[i].yaxis.set_visible([])
axes[i].spines["top"].set_color(colours_order[i-1])
axes[i].spines["bottom"].set_color(colours_order[i-1])
axes[i].spines["left"].set_color(colours_order[i-1])
axes[i].spines["right"].set_color(colours_order[i-1])
axes[i].spines["top"].set_linewidth(5)
axes[i].spines["bottom"].set_linewidth(5)
axes[i].spines["left"].set_linewidth(5)
axes[i].spines["right"].set_linewidth(5)
plt.show();
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
From this visualisation, it is hard a decipher which explanation method may be preferred or most helpful in the task of explaining the Imagenet class.
Second, we use Quantus to be able to quantiatively assess the different explanation methods on various evaluation criteria.
# Plotting specifics.
from matplotlib.patches import Circle, RegularPolygon
from matplotlib.path import Path
from matplotlib.projections.polar import PolarAxes
from matplotlib.projections import register_projection
from matplotlib.spines import Spine
from matplotlib.transforms import Affine2D
# Source code: https://matplotlib.org/stable/gallery/specialty_plots/radar_chart.html.
def radar_factory(num_vars, frame='circle'):
"""Create a radar chart with `num_vars` axes.
This function creates a RadarAxes projection and registers it.
Parameters
----------
num_vars : int
Number of variables for radar chart.
frame : {'circle' | 'polygon'}
Shape of frame surrounding axes.
"""
# calculate evenly-spaced axis angles
theta = np.linspace(0, 2*np.pi, num_vars, endpoint=False)
class RadarAxes(PolarAxes):
name = 'radar'
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# rotate plot such that the first axis is at the top
self.set_theta_zero_location('N')
def fill(self, *args, closed=True, **kwargs):
"""Override fill so that line is closed by default."""
return super().fill(closed=closed, *args, **kwargs)
def plot(self, *args, **kwargs):
"""Override plot so that line is closed by default."""
lines = super().plot(*args, **kwargs)
for line in lines:
self._close_line(line)
def _close_line(self, line):
x, y = line.get_data()
# FIXME: markers at x[0], y[0] get doubled-up
if x[0] != x[-1]:
x = np.concatenate((x, [x[0]]))
y = np.concatenate((y, [y[0]]))
line.set_data(x, y)
def set_varlabels(self, labels, angles=None):
self.set_thetagrids(angles=np.degrees(theta), labels=labels)
def _gen_axes_patch(self):
# The Axes patch must be centered at (0.5, 0.5) and of radius 0.5
# in axes coordinates.
if frame == 'circle':
return Circle((0.5, 0.5), 0.5)
elif frame == 'polygon':
return RegularPolygon((0.5, 0.5), num_vars,
radius=.5, edgecolor="k")
else:
raise ValueError("unknown value for 'frame': %s" % frame)
def draw(self, renderer):
""" Draw. If frame is polygon, make gridlines polygon-shaped."""
if frame == 'polygon':
gridlines = self.yaxis.get_gridlines()
for gl in gridlines:
gl.get_path()._interpolation_steps = num_vars
super().draw(renderer)
def _gen_axes_spines(self):
if frame == 'circle':
return super()._gen_axes_spines()
elif frame == 'polygon':
# spine_type must be 'left'/'right'/'top'/'bottom'/'circle'.
spine = Spine(axes=self,
spine_type='circle',
path=Path.unit_regular_polygon(num_vars))
# unit_regular_polygon gives a polygon of radius 1 centered at
# (0, 0) but we want a polygon of radius 0.5 centered at (0.5,
# 0.5) in axes coordinates.
spine.set_transform(Affine2D().scale(.5).translate(.5, .5)
+ self.transAxes)
return {'polar': spine}
else:
raise ValueError("unknown value for 'frame': %s" % frame)
register_projection(RadarAxes)
return theta
# Define XAI methods and metrics.
xai_methods = list(explanations.keys())
metrics = {
"Robustness": quantus.AvgSensitivity(
nr_samples=10,
lower_bound=0.2,
norm_numerator=quantus.norm_func.fro_norm,
norm_denominator=quantus.norm_func.fro_norm,
perturb_func=quantus.perturb_func.uniform_noise,
similarity_func=quantus.similarity_func.difference,
abs=False,
normalise=False,
aggregate_func=np.mean,
return_aggregate=True,
disable_warnings=True,
),
"Faithfulness": quantus.FaithfulnessCorrelation(
nr_runs=10,
subset_size=224,
perturb_baseline="black",
perturb_func=quantus.perturb_func.baseline_replacement_by_indices,
similarity_func=quantus.similarity_func.correlation_pearson,
abs=False,
normalise=False,
aggregate_func=np.mean,
return_aggregate=True,
disable_warnings=True,
),
"Localisation": quantus.RelevanceRankAccuracy(
abs=False,
normalise=False,
aggregate_func=np.mean,
return_aggregate=True,
disable_warnings=True,
),
"Complexity": quantus.Sparseness(
abs=True,
normalise=False,
aggregate_func=np.mean,
return_aggregate=True,
disable_warnings=True,
),
"Randomisation": quantus.RandomLogit(
num_classes=1000,
similarity_func=quantus.similarity_func.ssim,
abs=True,
normalise=False,
aggregate_func=np.mean,
return_aggregate=True,
disable_warnings=True,
),
}
# Retrieve stored 'dummy' data.
results = {
"Saliency": {
"Robustness": [0.023706467548275694],
"Faithfulness": [0.06749252841918861],
"Localisation": [0.5122173871263156],
"Complexity": [0.5503504513474646],
"Randomisation": [0.8064057449830752],
},
"GradientShap": {
"Robustness": [0.034456219962414575],
"Faithfulness": [0.04583139237937677],
"Localisation": [0.5046252238901434],
"Complexity": [0.6088842825604118],
"Randomisation": [0.7366918283019923],
},
"IntegratedGradients": {
"Robustness": [0.02690529538428082],
"Faithfulness": [-0.08233498797221532],
"Localisation": [0.5071891576864163],
"Complexity": [0.6107274870736773],
"Randomisation": [0.7915174357470123],
},
"FusionGrad": {
"Robustness": [5728.689449534697],
"Faithfulness": [0.0004617348733465858],
"Localisation": [0.493903066639846],
"Complexity": [0.5320443076081544],
"Randomisation": [0.006262477424033301],
},
}
# Or, run quantification analysis!
results = {method : {} for method in xai_methods}
for method in xai_methods:
for metric, metric_func in metrics.items():
print(f"Evaluating {metric} of {method} method.")
gc.collect()
torch.cuda.empty_cache()
# Get scores and append results.
scores = metric_func(
model=torchvision.models.mobilenet_v3_small(weights=True).to(device),
x_batch=x_batch,
y_batch=y_batch,
a_batch=None,
s_batch=s_batch,
device=device,
explain_func=explainer_wrapper,
explain_func_kwargs={
"method": method,
"posterior_mean": copy.deepcopy(
torchvision.models.mobilenet_v3_small(weights=True)
.to(device)
.state_dict()
),
"mean": 1.0,
"std": 0.5,
"sg_mean": 0.0,
"sg_std": 0.5,
"n": 25,
"m": 25,
"noise_type": "multiplicative",
"device": device,
},
)
results[method][metric] = scores
# Empty cache.
gc.collect()
torch.cuda.empty_cache()
Evaluating Robustness of Saliency method.
/usr/local/lib/python3.7/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and will be removed in 0.15. The current behavior is equivalent to passing `weights=MobileNet_V3_Small_Weights.IMAGENET1K_V1`. You can also use `weights=MobileNet_V3_Small_Weights.DEFAULT` to get the most up-to-date weights. /usr/local/lib/python3.7/dist-packages/captum/_utils/gradient.py:59: UserWarning: Input Tensor 0 did not already require gradients, required_grads has been set automatically.
Evaluating Faithfulness of Saliency method. Evaluating Localisation of Saliency method. Evaluating Complexity of Saliency method. Evaluating Randomisation of Saliency method. Evaluating Robustness of GradientShap method. Evaluating Faithfulness of GradientShap method. Evaluating Localisation of GradientShap method. Evaluating Complexity of GradientShap method. Evaluating Randomisation of GradientShap method. Evaluating Robustness of IntegratedGradients method. Evaluating Faithfulness of IntegratedGradients method. Evaluating Localisation of IntegratedGradients method. Evaluating Complexity of IntegratedGradients method. Evaluating Randomisation of IntegratedGradients method. Evaluating Robustness of FusionGrad method. Evaluating Faithfulness of FusionGrad method. Evaluating Localisation of FusionGrad method. Evaluating Complexity of FusionGrad method. Evaluating Randomisation of FusionGrad method.
# Postprocessing of scores: to get how the different explanation methods rank across criteria.
results_agg = {}
for method in xai_methods:
results_agg[method] = {}
for metric, metric_func in metrics.items():
results_agg[method][metric] = np.mean(results[method][metric])
df = pd.DataFrame.from_dict(results_agg)
df = df.T.abs()
# Take inverse ranking for Robustness, since lower is better.
df_normalised = df.loc[:, df.columns != 'Robustness'].apply(lambda x: x / x.max())
df_normalised["Robustness"] = df["Robustness"].min()/df["Robustness"].values
df_normalised_rank = df_normalised.rank()
df_normalised_rank
Faithfulness | Localisation | Complexity | Randomisation | Robustness | |
---|---|---|---|---|---|
Saliency | 3.0 | 4.0 | 2.0 | 4.0 | 4.0 |
GradientShap | 2.0 | 2.0 | 3.0 | 2.0 | 2.0 |
IntegratedGradients | 4.0 | 3.0 | 4.0 | 3.0 | 3.0 |
FusionGrad | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 |
# Plotting configs.
sns.set(font_scale=1.8)
plt.style.use('seaborn-white')
plt.rcParams['ytick.labelleft'] = True
plt.rcParams['xtick.labelbottom'] = True
include_titles = True
include_legend = True
import seaborn as sns
sns.set()
# Make spyder graph!
data = [df_normalised_rank.columns.values, (df_normalised_rank.to_numpy())]
theta = radar_factory(len(data[0]), frame='polygon')
spoke_labels = data.pop(0)
fig, ax = plt.subplots(figsize=(11, 11), subplot_kw=dict(projection='radar'))
fig.subplots_adjust(top=0.85, bottom=0.05)
for i, (d, method) in enumerate(zip(data[0], xai_methods)):
line = ax.plot(theta, d, label=method, color=colours_order[i], linewidth=5.0)
ax.fill(theta, d, alpha=0.15)
# Set lables.
if include_titles:
ax.set_varlabels(labels=['Faithfulness', 'Localisation', '\nComplexity', '\nRandomisation', 'Robustness'])
else:
ax.set_varlabels(labels=[])
ax.set_rgrids(np.arange(0, df_normalised_rank.values.max() + 0.5), labels=[])
# Set a title.
ax.set_title("Quantus: Summary of Explainer Quantification", position=(0.5, 1.1), ha='center', fontsize=15)
# Put a legend to the right of the current axis.
if include_legend:
ax.legend(loc='upper left', bbox_to_anchor=(1, 0.5))
plt.tight_layout()
Third, we will investigate how much different parameters influence the evaluation outcome, i.e., how different explanations methods rank.
We use Faithfulness Correlation by Bhatt et al., 2020 for this example.
# Let's list the default parameters of the metric.
quantus.FaithfulnessCorrelation().get_params
Warnings and information: (1) The Faithfulness Correlation metric is likely to be sensitive to the choice of baseline value 'perturb_baseline', size of subset |S| 'subset_size' and the number of runs (for each input and explanation pair) 'nr_runs'. (2) If attributions are normalised or their absolute values are taken it may destroy or skew information in the explanation and as a result, affect the overall evaluation outcome. (3) Make sure to validate the choices for hyperparameters of the metric (by calling .get_params of the metric instance). (4) For further information, see original publication: Bhatt, Umang, Adrian Weller, and José MF Moura. 'Evaluating and aggregating feature-based model explanations.' arXiv preprint arXiv:2005.00631 (2020). (5) To disable these warnings set 'disable_warnings' = True when initialising the metric.
{'abs': False, 'normalise': True, 'return_aggregate': True, 'aggregate_func': <function numpy.mean(a, axis=None, dtype=None, out=None, keepdims=<no value>)>, 'normalise_func': <function quantus.helpers.normalise_func.normalise_by_negative(a: numpy.ndarray, normalise_axes: Union[Sequence[int], NoneType] = None, **kwargs) -> numpy.ndarray>, 'normalise_func_kwargs': {}, 'a_axes': None, 'perturb_func': <function quantus.helpers.perturb_func.baseline_replacement_by_indices(arr: <built-in function array>, indices: Union[int, Sequence[int], Tuple[<built-in function array>]], indexed_axes: Sequence[int], perturb_baseline: Union[float, int, str, <built-in function array>], **kwargs) -> <built-in function array>>, 'perturb_func_kwargs': {'perturb_baseline': 'black'}, 'similarity_func': <function quantus.helpers.similarity_func.correlation_pearson(a: <built-in function array>, b: <built-in function array>, **kwargs) -> float>, 'nr_runs': 100, 'subset_size': 224}
# Define some parameter settings to evaluate.
baseline_strategies = ["mean", "uniform"]
subset_sizes = np.array([2, 52, 102])
sim_funcs = {"pearson": quantus.similarity_func.correlation_pearson, "spearman": quantus.similarity_func.correlation_spearman}
result = {
"Faithfulness score": [],
"Method": [],
"Similarity function": [],
"Baseline strategy": [],
"Subset size": [],
}
# Score explanations!
for b in baseline_strategies:
for s in subset_sizes:
for method, attr in explanations.items():
for sim, sim_func in sim_funcs.items():
metric = quantus.FaithfulnessCorrelation(abs=True,
normalise=True,
return_aggregate=True,
disable_warnings=True,
aggregate_func=np.mean,
normalise_func=quantus.normalise_func.normalise_by_negative,
nr_runs=10,
perturb_baseline=b,
perturb_func=quantus.perturb_func.baseline_replacement_by_indices,
similarity_func=sim_func,
subset_size=s)
score = metric(model=model.cuda(), x_batch=x_batch.cpu().numpy(), y_batch=y_batch.cpu().numpy(), a_batch=attr, device=device)
result["Method"].append(method)
result["Baseline strategy"].append(b.capitalize())
result["Subset size"].append(s)
result["Faithfulness score"].append(score[0])
result["Similarity function"].append(sim)
df = pd.DataFrame(result)
df.head()
Faithfulness score | Method | Similarity function | Baseline strategy | Subset size | |
---|---|---|---|---|---|
0 | 0.122536 | Saliency | pearson | Mean | 2 |
1 | 0.042424 | Saliency | spearman | Mean | 2 |
2 | -0.059269 | GradientShap | pearson | Mean | 2 |
3 | 0.005348 | GradientShap | spearman | Mean | 2 |
4 | -0.010177 | IntegratedGradients | pearson | Mean | 2 |
# Group by the ranking.
df["Rank"] = df.groupby(['Baseline strategy', 'Subset size', 'Similarity function'])["Faithfulness score"].rank()
# Smaller adjustments.
df = df.loc[:, ~df.columns.str.contains('^Unnamed')]
df.columns = map(lambda x: str(x).capitalize(), df.columns)
df.head()
Faithfulness score | Method | Similarity function | Baseline strategy | Subset size | Rank | |
---|---|---|---|---|---|---|
0 | 0.122536 | Saliency | pearson | Mean | 2 | 4.0 |
1 | 0.042424 | Saliency | spearman | Mean | 2 | 4.0 |
2 | -0.059269 | GradientShap | pearson | Mean | 2 | 1.0 |
3 | 0.005348 | GradientShap | spearman | Mean | 2 | 2.0 |
4 | -0.010177 | IntegratedGradients | pearson | Mean | 2 | 2.0 |
# Group by rank.
df_view = df.groupby(["Method"])["Rank"].value_counts(normalize=True).mul(100).reset_index(name='Percentage').round(2)
df_view = df_view.append({'Method': 'Method A', 'Rank': 1.0, 'Percentage': 100}, ignore_index=True)
df_view = df_view.append({'Method': 'Method B', 'Rank': 2.0, 'Percentage': 100}, ignore_index=True)
df_view = df_view.append({'Method': 'Method C', 'Rank': 3.0, 'Percentage': 100}, ignore_index=True)
df_view = df_view.append({'Method': 'Method D', 'Rank': 4.0, 'Percentage': 100}, ignore_index=True)
# Reorder the methods for plotting purporses.
df_view_ordered = pd.DataFrame(columns=["Method", "Rank", "Percentage"])
df_view_ordered = df_view_ordered.append({'Method': 'Method A', 'Rank': 1.0, 'Percentage': 100}, ignore_index=True)
df_view_ordered = df_view_ordered.append({'Method': 'Method B', 'Rank': 2.0, 'Percentage': 100}, ignore_index=True)
df_view_ordered = df_view_ordered.append({'Method': 'Method C', 'Rank': 3.0, 'Percentage': 100}, ignore_index=True)
df_view_ordered = df_view_ordered.append({'Method': 'Method D', 'Rank': 4.0, 'Percentage': 100}, ignore_index=True)
df_view_ordered = df_view_ordered.append([df_view.loc[df_view["Method"] == 'Saliency']], ignore_index=True)
df_view_ordered = df_view_ordered.append([df_view.loc[df_view["Method"] == 'GradientShap']], ignore_index=True)
df_view_ordered = df_view_ordered.append([df_view.loc[df_view["Method"] == 'IntegratedGradients']], ignore_index=True)
df_view_ordered = df_view_ordered.append([df_view.loc[df_view["Method"] == 'FusionGrad']], ignore_index=True)
df_view_ordered
Method | Rank | Percentage | |
---|---|---|---|
0 | Method A | 1.0 | 100 |
1 | Method B | 2.0 | 100 |
2 | Method C | 3.0 | 100 |
3 | Method D | 4.0 | 100 |
4 | Saliency | 4.0 | 41.67 |
5 | Saliency | 1.0 | 33.33 |
6 | Saliency | 2.0 | 16.67 |
7 | Saliency | 3.0 | 8.33 |
8 | GradientShap | 3.0 | 41.67 |
9 | GradientShap | 2.0 | 33.33 |
10 | GradientShap | 1.0 | 16.67 |
11 | GradientShap | 4.0 | 8.33 |
12 | IntegratedGradients | 2.0 | 33.33 |
13 | IntegratedGradients | 1.0 | 25.0 |
14 | IntegratedGradients | 4.0 | 25.0 |
15 | IntegratedGradients | 3.0 | 16.67 |
16 | FusionGrad | 3.0 | 33.33 |
17 | FusionGrad | 1.0 | 25.0 |
18 | FusionGrad | 4.0 | 25.0 |
19 | FusionGrad | 2.0 | 16.67 |
# Plot results!
fig, ax = plt.subplots(figsize=(6.5,5))
ax = sns.histplot(x='Method', hue='Rank', weights='Percentage', multiple='stack', data=df_view_ordered, shrink=0.6, palette="colorblind", legend=False)
ax.spines["right"].set_visible(False)
ax.spines['top'].set_visible(False)
ax.tick_params(axis='both', which='major', labelsize=16)
ax.set_ylabel('Frequency of rank', fontsize=15)
ax.set_xlabel('')
ax.set_xticklabels(["A", "B", "C", "D", "SAL", "GS", "IG", "FG"])
plt.legend(loc='upper center', bbox_to_anchor=(0.5, 1.1), ncol=4, fancybox=True, shadow=False, labels=['1st', "2nd", "3rd", "4th"])
plt.axvline(x=3.5, ymax=0.95, color='black', linestyle='-')
plt.tight_layout();
/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:9: UserWarning: FixedFormatter should only be used together with FixedLocator
Contrary to intution where ranking is consistent over different metric parameterisations, the ranking significantly differ in the different experimental settings.