This tutorial demonstrates how to use the library for robustness evaluation
explanation of text classification models.
For this purpose, we use a pre-trained Distilbert
model from Huggingface and GLUE/SST2
dataset here.
This is not a working example yet, and is meant only for demonstration purposes so far. For this demo, we use a (yet) unreleased version of Quantus.
Author: Artem Sereda
from __future__ import annotations
# Use an unreleased version of Quantus.
!pip install 'quantus @ git+https://github.com/aaarrti/Quantus.git@nlp-domain' --no-deps
!pip install transformers datasets nlpaug tf_explain tensorflow_probability
import numpy as np
import pandas as pd
from datasets import load_dataset
import tensorflow as tf
from functools import partial
import logging
from typing import NamedTuple, List, Any
from transformers import AutoTokenizer, TFDistilBertForSequenceClassification, TFPreTrainedModel, PreTrainedTokenizerFast
import quantus.nlp as qn
import matplotlib.pyplot as plt
import tensorflow_probability as tfp
# Suppress debug logs.
logging.getLogger('absl').setLevel(logging.WARNING)
tf.config.list_physical_devices()
[PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'), PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
MODEL_NAME = "distilbert-base-uncased-finetuned-sst-2-english"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = TFDistilBertForSequenceClassification.from_pretrained(MODEL_NAME)
Metal device set to: AMD Radeon Pro 560
All model checkpoint layers were used when initializing TFDistilBertForSequenceClassification. All the layers of TFDistilBertForSequenceClassification were initialized from the model checkpoint at distilbert-base-uncased-finetuned-sst-2-english. If your task is similar to the task the model of the checkpoint was trained on, you can already use TFDistilBertForSequenceClassification for predictions without further training.
BATCH_SIZE = 8
dataset = load_dataset("sst2")['test']
x_batch = dataset['sentence'][:BATCH_SIZE]
WARNING:datasets.builder:Found cached dataset sst2 (/Users/artemsereda/.cache/huggingface/datasets/sst2/default/2.0.0/9896208a8d85db057ac50c72282bcb8fe755accc671a57dd8059d4e130961ed5)
0%| | 0/3 [00:00<?, ?it/s]
Run an example inference, and demonstrate models predictions.
CLASS_NAMES = ['negative', 'positive']
def decode_labels(y_batch: np.ndarray, class_names: List[str]) -> List[str]:
"""A helper function to map integer labels to human-readable class names."""
return [class_names[i] for i in y_batch]
# Run tokenizer.
tokens = tokenizer(x_batch, padding='longest', return_tensors='tf')
logits = model(**tokens).logits
y_batch = tf.argmax(tf.nn.softmax(logits), axis=1).numpy()
# Show the x, y data.
pd.DataFrame([x_batch, decode_labels(y_batch, CLASS_NAMES)]).T
0 | 1 | |
---|---|---|
0 | uneasy mishmash of styles and genres . | negative |
1 | this film 's relationship to actual tension is... | negative |
2 | by the end of no such thing the audience , lik... | positive |
3 | director rob marshall went out gunning to make... | positive |
4 | lathan and diggs have considerable personal ch... | positive |
5 | a well-made and often lovely depiction of the ... | positive |
6 | none of this violates the letter of behan 's b... | negative |
7 | although it bangs a very cliched drum at times... | positive |
There are not many XAI libraries for NLP out there, so here we fully relly on our own implementations of explanation methods. This section write functions to visualise our explanations.
def plot_textual_heatmap(explanations: List[qn.TokenSalience]):
"""
Plots attributions over a batch of text sequence explanations.
References:
- https://stackoverflow.com/questions/74046734/plot-text-saliency-map-in-jupyter-notebook
Parameters
----------
explanations: List of Named tuples (tokens, salience) containing batch of explanations.
Returns
-------
plot: matplotplib.pyplot object, which will be automatically rendered by jupyter.
"""
h_len = len(explanations)
v_len = len(explanations[0].tokens)
tokens = np.asarray([i.tokens for i in explanations]).reshape(-1)
colors = np.asarray([i.salience for i in explanations]).reshape(-1)
fig, axes = plt.subplots(h_len, v_len, figsize=(v_len, h_len*0.5), gridspec_kw=dict(left=0., right=1.))
for i, ax in enumerate(axes.ravel()):
rect = plt.Rectangle((0, 0), 1, 1, color=(1., 1 - colors[i], 1- colors[i]))
ax.add_patch(rect)
ax.text(0.5, 0.5, tokens[i], ha='center', va='center')
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.axis('off')
ax = fig.add_axes([0, 0.05, 1 , 0.9], fc=[0, 0, 0, 0])
for axis in ['left', 'right']:
ax.spines[axis].set_visible(False)
ax.set_xticks([])
ax.set_yticks([])
return plt
Write out functions to generate explanations using baseline methods: Gradient Norm and Integrated Gradients
@tf.function(jit_compile=True)
def normalize(x: tf.Tensor) -> tf.Tensor:
"""
Normalize attribution values to comply with RGB standards.
- Take absolute values.
- Scale attribution scores, so that maximum value is 1.
Parameters
----------
x: 1D tensor containing attribution scores.
Returns
-------
x: 1D tensor containing normalized attribution scores.
"""
abs = tf.abs(x)
max = tf.reduce_max(abs)
return abs / max
def explain_gradient_norm(
model: TFPreTrainedModel,
token_ids: tf.Tensor,
attention_mask: tf.Tensor,
target: int,
tokenizer: PreTrainedTokenizerFast
) -> qn.TokenSalience:
"""
Computes token attribution score using the Gradient Norm method for a single point.
Parameters
----------
model:
Huggingface model, which is subject to explanation.
token_ids:
1D Array of token ids.
attention_mask:
1D array of attention mask.
target:
Predicted label.
tokenizer:
Huggingface tokenizer used to convert input_ids back to plain text tokens.
Returns
-------
a: quantus.nlp.TokenSalience
Named tuple (tokens, salience), with tokens and their respective attribution scores.
"""
# Convert tokens to embeddings.
embeddings = model.distilbert.get_input_embeddings()(input_ids=token_ids)
with tf.GradientTape() as tape:
tape.watch(embeddings)
logits = model(None,
inputs_embeds=embeddings,
attention_mask=attention_mask
).logits
logits_for_label = tf.gather(logits, axis=1, indices=target)
# Compute gradients of logits with respect to embeddings.
grads = tape.gradient(logits_for_label, embeddings)
# Compute L2 norm of gradients.
grad_norm = tf.linalg.norm(grads, axis=-1)
with tf.device('cpu'):
scores = normalize(grad_norm[0]).numpy()
return qn.TokenSalience(tokenizer.convert_ids_to_tokens(token_ids), scores)
def explain_gradient_norm_batch(
model: TFPreTrainedModel,
inputs: List[str],
targets: np.ndarray,
tokenizer: PreTrainedTokenizerFast
) -> List[qn.TokenSalience]:
"""
Computes token attribution score using the Gradient Norm method for batch.
Parameters
----------
model:
Huggingface model, which is subject to explanation.
inputs:
List of plain text inputs.
targets:
1D array of predicted labels.
tokenizer:
Huggingface tokenizer used to convert input_ids back to plain text tokens.
Returns
-------
a_batch: List of quantus.nlp.TokenSalience.
List of named tuples (tokens, salience), with tokens and their respective attribution scores.
"""
"""A wrapper around explain_gradient_norm which allows calling it on batch"""
tokens = tokenizer(inputs, return_tensors='tf', padding='longest')
batch_size = len(targets)
return [
explain_gradient_norm(model, tokens['input_ids'][i], tokens['attention_mask'][i], targets[i], tokenizer)
for i in range(batch_size)
]
@tf.function(jit_compile=True)
def get_interpolated_inputs(
baseline: tf.Tensor,
target: tf.Tensor,
num_steps: int
) -> tf.Tensor:
"""
Gets num_step linearly interpolated inputs from baseline to target.
Reference: https://github.com/PAIR-code/lit/blob/main/lit_nlp/components/gradient_maps.py#L238
Returns
-------
interpolated_inputs: <float32>[num_steps, num_tokens, emb_size]
"""
baseline = tf.cast(baseline, dtype=tf.float64)
target = tf.cast(target, dtype=tf.float64)
delta = target - baseline # <float32>[num_tokens, emb_size]
# Creates scale values array of shape [num_steps, num_tokens, emb_dim],
# where the values in scales[i] are the ith step from np.linspace. <float32>[num_steps, 1, 1]
scales = tf.linspace(0, 1, num_steps + 1)[:, tf.newaxis, tf.newaxis]
shape = (num_steps + 1,) + delta.shape
# <float32>[num_steps, num_tokens, emb_size]
deltas = scales * tf.broadcast_to(delta, shape)
interpolated_inputs = baseline + deltas
return interpolated_inputs
def explain_int_grad(
model: TFPreTrainedModel,
token_ids: tf.Tensor,
attention_mask: tf.Tensor,
target: int,
tokenizer: PreTrainedTokenizerFast,
num_steps: int
) -> qn.TokenSalience:
"""
Computes token attribution score using the Integrated Gradients method for a single point.
Parameters
----------
model:
Huggingface model, which is subject to explanation.
token_ids:
1D Array of token ids.
attention_mask:
1D array of attention mask.
target:
Predicted label.
tokenizer:
Huggingface tokenizer used to convert input_ids back to plain text tokens.
Returns
-------
a: quantus.nlp.TokenSalience
Named tuple (tokens, salience), with tokens and their respective attribution scores.
"""
# Convert tokens to embeddings.
embeddings = model.distilbert.get_input_embeddings()(input_ids=token_ids)[0]
baseline = tf.zeros_like(embeddings)
# Generate interpolation from 0 to embeddings.
with tf.device('cpu'):
interpolated_embeddings = get_interpolated_inputs(baseline, embeddings, num_steps)
interpolated_embeddings = tf.cast(interpolated_embeddings, tf.float32)
interpolated_attention_mask = tf.stack([attention_mask for i in range(num_steps + 1)])
with tf.GradientTape() as tape:
tape.watch(interpolated_embeddings)
logits = model(None,
inputs_embeds=interpolated_embeddings,
attention_mask=interpolated_attention_mask,
).logits
logits_for_label = tf.gather(logits, axis=1, indices=target)
# Compute gradients of logits with respect to interpolations.
grads = tape.gradient(logits_for_label, interpolated_embeddings)
# Integrate gradients.
int_grad = tfp.math.trapz(tfp.math.trapz(grads, axis=0))
with tf.device('cpu'):
scores = normalize(int_grad).numpy()
return qn.TokenSalience(tokenizer.convert_ids_to_tokens(token_ids), scores)
def explain_int_grad_batch(
model: TFPreTrainedModel,
inputs: List[str],
targets: np.ndarray,
tokenizer: PreTrainedTokenizerFast,
num_steps: int = 10
) -> List[qn.TokenSalience]:
"""
Computes token attribution score using the Integrated Gradients method for batch.
Parameters
----------
model:
Huggingface model, which is subject to explanation.
inputs:
List of plain text inputs.
targets:
1D array of predicted labels.
tokenizer:
Huggingface tokenizer used to convert input_ids back to plain text tokens.
num_steps: int.
Number of interpolations steps, default=10.
Returns
-------
a_batch: List of quantus.nlp.TokenSalience.
List of named tuples (tokens, salience), with tokens and their respective attribution scores.
"""
tokens = tokenizer(inputs, return_tensors='tf', padding='longest')
batch_size = len(targets)
return [
explain_int_grad(model, tokens['input_ids'][i], tokens['attention_mask'][i], targets[i], tokenizer, num_steps)
for i in range(batch_size)
]
# Create functions which match the signature required by Quantus.
explain_gradient_norm_func = partial(explain_gradient_norm_batch, tokenizer=tokenizer)
explain_int_grad_func = partial(explain_int_grad_batch, tokenizer=tokenizer)
# Visualise GradNorm.
a_batch_grad_norm = explain_gradient_norm_func(model, x_batch[2:5], y_batch[2:5])
plot_textual_heatmap(a_batch_grad_norm)
/var/folders/vv/f22t8y7d1l96ynv9mzgy0j5w0000gn/T/ipykernel_28733/2299850630.py:33: MatplotlibDeprecationWarning: Adding an axes using the same arguments as a previous axes currently reuses the earlier instance. In a future version, a new instance will always be created and returned. Meanwhile, this warning can be suppressed, and the future behavior ensured, by passing a unique label to each axes instance. ax = fig.add_axes([0, 0.05, 1 , 0.9], fc=[0, 0, 0, 0])
<module 'matplotlib.pyplot' from '/Users/artemsereda/anaconda3/envs/quantus/lib/python3.9/site-packages/matplotlib/pyplot.py'>
# Visualise Integrated Gradients explanations.
a_batch_int_grad = explain_int_grad_func(model, x_batch[2:5], y_batch[2:5])
plot_textual_heatmap(a_batch_int_grad)
/var/folders/vv/f22t8y7d1l96ynv9mzgy0j5w0000gn/T/ipykernel_28733/2299850630.py:33: MatplotlibDeprecationWarning: Adding an axes using the same arguments as a previous axes currently reuses the earlier instance. In a future version, a new instance will always be created and returned. Meanwhile, this warning can be suppressed, and the future behavior ensured, by passing a unique label to each axes instance. ax = fig.add_axes([0, 0.05, 1 , 0.9], fc=[0, 0, 0, 0])
<module 'matplotlib.pyplot' from '/Users/artemsereda/anaconda3/envs/quantus/lib/python3.9/site-packages/matplotlib/pyplot.py'>
For this example, we compute Sensitivity metric
# This is only a workaround to account for hardcoded attribute access in lib.
class ModelTuple(NamedTuple):
model: Any
tokenizer: Any
# This is also only a workaround to account for hardcoded attribute access in lib.
model_stub = ModelTuple(model, tokenizer)
model_stub.model.bert = model.distilbert
model_stub.model.bert.embeddings.word_embeddings = model.distilbert.embeddings.weight
Average Sensitivity captures the average change in explanations under slight perturbation
# Instantiate metric.
avg_sensitivity = qn.AvgSensitivity()
# Evaluate avg sensitivity for Gradient Norm.
avg_sensitivity_grad_norm = avg_sensitivity(
model=model_stub,
x_batch=x_batch,
y_batch=y_batch,
perturb_func=qn.change_spelling,
explain_func=explain_gradient_norm_func,
).mean()
# Evaluate avg sensitivity for Integrated Gradients.
avg_sensitivity_int_grad = avg_sensitivity(
model=model_stub,
x_batch=x_batch,
y_batch=y_batch,
perturb_func=qn.change_spelling,
explain_func=explain_int_grad_func
).mean()
Collecting perturbations: 0%| | 0/10 [00:00<?, ?it/s]
Collecting explanations: 0%| | 0/9 [00:00<?, ?it/s]
Collecting perturbations: 0%| | 0/10 [00:00<?, ?it/s]
Collecting explanations: 0%| | 0/9 [00:00<?, ?it/s]
# Instantiate metric.
max_sensitivity = qn.MaxSensitivity()
# Evaluate max sensitivity metric for Gradient Norm.
max_sensitivity_grad_norm = max_sensitivity(
model=model_stub,
x_batch=x_batch,
y_batch=y_batch,
perturb_func=qn.change_spelling,
explain_func=explain_gradient_norm_func,
).mean()
# Evaluate max sensitivity metric for Integrated Gradients.
max_sensitivity_int_grad = max_sensitivity(
model=model_stub,
x_batch=x_batch,
y_batch=y_batch,
perturb_func=qn.change_spelling,
explain_func=explain_int_grad_func
).mean()
Collecting perturbations: 0%| | 0/10 [00:00<?, ?it/s]
Collecting explanations: 0%| | 0/9 [00:00<?, ?it/s]
Collecting perturbations: 0%| | 0/10 [00:00<?, ?it/s]
Collecting explanations: 0%| | 0/10 [00:00<?, ?it/s]
Maximum Sensitivity captures the maximal change in explanations under slight perturbation
Display results in tabular form
# Reformat the results.
all_results = np.asarray([
[
avg_sensitivity_grad_norm,
avg_sensitivity_int_grad
],
[
max_sensitivity_grad_norm,
max_sensitivity_int_grad
]
])
# Print out the evaluation outcome!
pd.DataFrame(
all_results,
columns=['Gradient Norm', 'Integrated Gradients'],
index=['Average Sensitivity', 'Max Sensitivity']
)
Gradient Norm | Integrated Gradients | |
---|---|---|
Average Sensitivity | 0.126966 | 0.236630 |
Max Sensitivity | 0.259777 | 0.190411 |