#!/usr/bin/env python # coding: utf-8 # # Getting Started! # [![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/understandable-machine-intelligence-lab/Quantus/main?labpath=tutorials%2FTutorial_Getting_Started.ipynb) # # # This notebook shows how to get started with Quantus, using a very simple example. For this purpose, we use a LeNet model and MNIST dataset. # # - Make sure to have GPUs enabled to speed up computation. # In[15]: from IPython.display import clear_output get_ipython().system('pip install torch torchvision captum quantus seaborn') clear_output() # In[16]: import pathlib import numpy as np import pandas as pd import quantus import torch import torchvision from captum.attr import * import matplotlib.pyplot as plt import seaborn as sns sns.set() # Enable GPU. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") clear_output() # ## 1) Preliminaries # ### 1.1 Load datasets # # We will then load a batch of input, output pairs that we generate explanations for, then to evaluate. # In[17]: # Load datasets and make loaders. test_samples = 24 transformer = torchvision.transforms.Compose([torchvision.transforms.ToTensor()]) train_set = torchvision.datasets.MNIST(root='./sample_data', train=True, transform=transformer, download=True) test_set = torchvision.datasets.MNIST(root='./sample_data', train=False, transform=transformer, download=True) train_loader = torch.utils.data.DataLoader(train_set, batch_size=128, shuffle=True, pin_memory=True) # num_workers=4, test_loader = torch.utils.data.DataLoader(test_set, batch_size=200, pin_memory=True) # Load a batch of inputs and outputs to use for evaluation. x_batch, y_batch = iter(test_loader).next() x_batch, y_batch = x_batch.to(device), y_batch.to(device) # In[18]: # Plot some inputs! nr_images = 5 fig, axes = plt.subplots(nrows=1, ncols=nr_images, figsize=(nr_images*3, int(nr_images*2/3))) for i in range(nr_images): axes[i].imshow((np.reshape(x_batch[i].cpu().numpy(), (28, 28)) * 255).astype(np.uint8), vmin=0.0, vmax=1.0, cmap="gray") axes[i].title.set_text(f"MNIST class - {y_batch[i].item()}") axes[i].axis("off") plt.show() # ### 1.2 Train a LeNet model # # (or any other model of choice). # Network architecture from: https://github.com/ChawDoe/LeNet5-MNIST-PyTorch. # In[19]: class LeNet(torch.nn.Module): """Network architecture from: https://github.com/ChawDoe/LeNet5-MNIST-PyTorch.""" def __init__(self): super().__init__() self.conv_1 = torch.nn.Conv2d(1, 6, 5) self.pool_1 = torch.nn.MaxPool2d(2, 2) self.relu_1 = torch.nn.ReLU() self.conv_2 = torch.nn.Conv2d(6, 16, 5) self.pool_2 = torch.nn.MaxPool2d(2, 2) self.relu_2 = torch.nn.ReLU() self.fc_1 = torch.nn.Linear(256, 120) self.relu_3 = torch.nn.ReLU() self.fc_2 = torch.nn.Linear(120, 84) self.relu_4 = torch.nn.ReLU() self.fc_3 = torch.nn.Linear(84, 10) def forward(self, x): x = self.pool_1(self.relu_1(self.conv_1(x))) x = self.pool_2(self.relu_2(self.conv_2(x))) x = x.view(x.shape[0], -1) x = self.relu_3(self.fc_1(x)) x = self.relu_4(self.fc_2(x)) x = self.fc_3(x) return x # Load model architecture. model = LeNet() print(f"\n Model architecture: {model.eval()}\n") # In[20]: def train_model(model, train_data: torchvision.datasets, test_data: torchvision.datasets, device: torch.device, epochs: int = 20, criterion: torch.nn = torch.nn.CrossEntropyLoss(), optimizer: torch.optim = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9), evaluate: bool = False): """Train torch model.""" model.train() for epoch in range(epochs): for images, labels in train_data: images, labels = images.to(device), labels.to(device) optimizer.zero_grad() logits = model(images) loss = criterion(logits, labels) loss.backward() optimizer.step() # Evaluate model! if evaluate: predictions, labels = evaluate_model(model, test_data, device) test_acc = np.mean(np.argmax(predictions.cpu().numpy(), axis=1) == labels.cpu().numpy()) print(f"Epoch {epoch+1}/{epochs} - test accuracy: {(100 * test_acc):.2f}% and CE loss {loss.item():.2f}") return model def evaluate_model(model, data, device): """Evaluate torch model.""" model.eval() logits = torch.Tensor().to(device) targets = torch.LongTensor().to(device) with torch.no_grad(): for images, labels in data: images, labels = images.to(device), labels.to(device) logits = torch.cat([logits, model(images)]) targets = torch.cat([targets, labels]) return torch.nn.functional.softmax(logits, dim=1), targets # In[21]: # Train and evaluate model. model = train_model(model=model.to(device), train_data=train_loader, test_data=test_loader, device=device, epochs=1, criterion=torch.nn.CrossEntropyLoss().to(device), optimizer=torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9), evaluate=True) # Model to GPU and eval mode. model.to(device) model.eval() # Check test set performance. predictions, labels = evaluate_model(model, test_loader, device) test_acc = np.mean(np.argmax(predictions.cpu().numpy(), axis=1) == labels.cpu().numpy()) print(f"Model test accuracy: {(100 * test_acc):.2f}%") # ### 1.3 Generate explanations # # There exist multiple ways to generate explanations for neural network models e.g., using `captum` or `innvestigate` libraries. In this example, we rely on the `quantus.explain` functionality (a simple wrapper around `captum`) however use whatever approach or library you'd like to create your explanations. # In[22]: # Generate normalised Saliency and Integrated Gradients attributions of the first batch of the test set. a_batch_saliency = quantus.normalise_func.normalise_by_negative(Saliency(model).attribute(inputs=x_batch, target=y_batch, abs=True).sum(axis=1).cpu().numpy()) a_batch_intgrad = quantus.normalise_func.normalise_by_negative(IntegratedGradients(model).attribute(inputs=x_batch, target=y_batch, baselines=torch.zeros_like(x_batch)).sum(axis=1).cpu().numpy()) # Save x_batch and y_batch as numpy arrays that will be used to call metric instances. x_batch, y_batch = x_batch.cpu().numpy(), y_batch.cpu().numpy() # Quick assert. assert [isinstance(obj, np.ndarray) for obj in [x_batch, y_batch, a_batch_saliency, a_batch_intgrad]] # Visualise attributions given model and pairs of input-output. # In[23]: # Plot explanations! nr_images = 3 fig, axes = plt.subplots(nrows=nr_images, ncols=3, figsize=(nr_images*2.5, int(nr_images*3))) for i in range(nr_images): axes[i, 0].imshow((np.reshape(x_batch[i], (28, 28)) * 255).astype(np.uint8), vmin=0.0, vmax=1.0, cmap="gray") axes[i, 0].title.set_text(f"MNIST digit {y_batch[i].item()}") axes[i, 0].axis("off") axes[i, 1].imshow(a_batch_saliency[i], cmap="seismic") axes[i, 1].title.set_text(f"Saliency") axes[i, 1].axis("off") a = axes[i, 2].imshow(a_batch_intgrad[i], cmap="seismic") axes[i, 2].title.set_text(f"Integrated Gradients") axes[i, 2].axis("off") plt.tight_layout() plt.show() # ## 2) Quantative evaluation using Quantus # # We can evaluate our explanations on a variety of quantuative criteria but as a motivating example we test the Max-Sensitivity (Yeh at el., 2019) of the explanations. This metric tests how the explanations maximally change while subject to slight perturbations. # In[24]: # Define metric for evaluation. metric_init = quantus.MaxSensitivity(nr_samples=10, lower_bound=0.1, norm_numerator=quantus.norm_func.fro_norm, norm_denominator=quantus.norm_func.fro_norm, perturb_func=quantus.perturb_func.uniform_noise, similarity_func=quantus.similarity_func.difference, disable_warnings=True, normalise=True, abs=True) # In[25]: # Return Max-Sensitivity scores for Saliency. scores_saliency = metric_init(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_intgrad, device=device, explain_func=quantus.explain, explain_func_kwargs={"method": "Saliency", "softmax": False}) # In[26]: # Return Max-Sensitivity scores for Integrated Gradients. scores_intgrad = metric_init(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_intgrad, device=device, explain_func=quantus.explain, explain_func_kwargs={"method": "IntegratedGradients"}) # In[27]: print(f"max-Sensitivity scores by Yeh et al., 2019\n" \ f"\n • Saliency = {np.mean(scores_saliency):.2f} ({np.std(scores_saliency):.2f})." \ f"\n • Integrated Gradients = {np.mean(scores_intgrad):.2f} ({np.std(scores_intgrad):.2f})." ) # In[28]: # Use the quantus.evaluate functionality of Quantus to do a more comprehensive quantification. metrics = {"max-Sensitivity": metric_init} xai_methods = {"Saliency": a_batch_saliency, "IntegratedGradients": a_batch_intgrad} results = quantus.evaluate(metrics=metrics, xai_methods=xai_methods, model=model.cpu(), x_batch=x_batch, y_batch=y_batch, agg_func=np.mean, explain_func_kwargs={"method": "IntegratedGradients", "device": device}, call_kwargs={"0": {}}) df = pd.DataFrame(results) df