__author__ = "Christopher Potts"
__version__ = "CS224u, Stanford, Spring 2022"
This notebook is an experimental extension of the CS224u course code. It focuses on the Integrated Gradients method for feature attribution, with comparisons to the "inputs $\times$ gradients" method. To run the notebook, first install the Captum library:
!pip install captum
Requirement already satisfied: captum in /Applications/anaconda3/envs/nlu/lib/python3.9/site-packages (0.5.0) Requirement already satisfied: matplotlib in /Applications/anaconda3/envs/nlu/lib/python3.9/site-packages (from captum) (3.4.3) Requirement already satisfied: torch>=1.6 in /Applications/anaconda3/envs/nlu/lib/python3.9/site-packages (from captum) (1.10.0) Requirement already satisfied: numpy in /Applications/anaconda3/envs/nlu/lib/python3.9/site-packages (from captum) (1.20.3) Requirement already satisfied: typing-extensions in /Applications/anaconda3/envs/nlu/lib/python3.9/site-packages (from torch>=1.6->captum) (3.10.0.2) Requirement already satisfied: pillow>=6.2.0 in /Applications/anaconda3/envs/nlu/lib/python3.9/site-packages (from matplotlib->captum) (8.4.0) Requirement already satisfied: cycler>=0.10 in /Applications/anaconda3/envs/nlu/lib/python3.9/site-packages (from matplotlib->captum) (0.10.0) Requirement already satisfied: python-dateutil>=2.7 in /Applications/anaconda3/envs/nlu/lib/python3.9/site-packages (from matplotlib->captum) (2.8.2) Requirement already satisfied: pyparsing>=2.2.1 in /Applications/anaconda3/envs/nlu/lib/python3.9/site-packages (from matplotlib->captum) (3.0.4) Requirement already satisfied: kiwisolver>=1.0.1 in /Applications/anaconda3/envs/nlu/lib/python3.9/site-packages (from matplotlib->captum) (1.3.1) Requirement already satisfied: six in /Applications/anaconda3/envs/nlu/lib/python3.9/site-packages (from cycler>=0.10->matplotlib->captum) (1.16.0)
This is not currently a required installation (but it will be in future years).
For both implementations, the forward
method of model
is used. X
is an (m x n) tensor of attributions. Use targets=None
for models with scalar outputs, else supply a LongTensor giving a label for each example.
import torch
def grad_x_input(model, X, targets=None):
"""Implementation using PyTorch directly."""
X.requires_grad = True
y = model(X)
y = y if targets is None else y[list(range(len(y))), targets]
(grads, ) = torch.autograd.grad(y.unbind(), X)
return grads * X
from captum.attr import InputXGradient
def captum_grad_x_input(model, X, target):
"""Captum-based implementation."""
X.requires_grad = True
amod = InputXGradient(model)
return amod.attribute(X, target=target)
import numpy as np
import torch
import torch.nn as nn
from captum.attr import IntegratedGradients
from captum.attr import InputXGradient
class SelectivityAssessor(nn.Module):
"""Model used by Sundararajan et al, section 2.1 to show that
input * gradients violates their selectivity axiom.
"""
def __init__(self):
super().__init__()
self.relu = nn.ReLU()
def forward(self, X):
return 1.0 - self.relu(1.0 - X)
sel_mod = SelectivityAssessor()
Simple inputs with just one feature:
X_sel = torch.FloatTensor([[0.0], [2.0]])
The outputs for our two examples differ:
sel_mod(X_sel)
tensor([[0.], [1.]])
However, InputXGradient
assigns the same importance to the feature across the two examples, violating selectivity:
captum_grad_x_input(sel_mod, X_sel, target=None)
tensor([[0.], [-0.]], grad_fn=<MulBackward0>)
Integrated gradients addresses the problem by averaging gradients across all interpolated representations between the baseline and the actual input:
ig_sel = IntegratedGradients(sel_mod)
sel_baseline = torch.FloatTensor([[0.0]])
ig_sel.attribute(X_sel, sel_baseline)
tensor([[0.], [1.]], dtype=torch.float64, grad_fn=<MulBackward0>)
A toy implementation to help bring out what is happening:
def ig_reference_implementation(model, x, base, m=50):
vals = []
for k in range(m):
# Interpolated representation:
xx = (base + (k/m)) * (x - base)
# Gradient for the interpolated example:
xx.requires_grad = True
y = model(xx)
(grads, ) = torch.autograd.grad(y.unbind(), xx)
vals.append(grads)
return (1 / m) * torch.cat(vals).sum(axis=0) * (x - base)
ig_reference_implementation(sel_mod, torch.FloatTensor([[2.0]]), sel_baseline)
tensor([[1.]])
from captum.attr import IntegratedGradients
from sklearn.datasets import make_classification
from sklearn.feature_selection import mutual_info_classif
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
import torch
from torch_shallow_neural_classifier import TorchShallowNeuralClassifier
X_cls, y_cls = make_classification(
n_samples=5000,
n_classes=3,
n_features=5,
n_informative=3,
n_redundant=0,
random_state=42)
The classification problem has two uninformative features:
mutual_info_classif(X_cls, y_cls)
array([0.20138107, 0.02833358, 0.11584416, 0. , 0. ])
X_cls_train, X_cls_test, y_cls_train, y_cls_test = train_test_split(X_cls, y_cls)
classifier = TorchShallowNeuralClassifier()
_ = classifier.fit(X_cls_train, y_cls_train)
Stopping after epoch 449. Training loss did not improve more than tol=1e-05. Final error is 1.3419027030467987.
cls_preds = classifier.predict(X_cls_test)
accuracy_score(y_cls_test, cls_preds)
0.8568
classifier_ig = IntegratedGradients(classifier.model)
classifier_baseline = torch.zeros(1, X_cls_train.shape[1])
Integrated gradients with respect to the actual labels:
classifier_attrs = classifier_ig.attribute(
torch.FloatTensor(X_cls_test),
classifier_baseline,
target=torch.LongTensor(y_cls_test))
Average attribution is low for the two uninformative features:
classifier_attrs.mean(axis=0)
tensor([ 0.6544, 0.6739, 0.7057, -0.0173, -0.0059], dtype=torch.float64)
from collections import Counter
from captum.attr import IntegratedGradients
from nltk.corpus import stopwords
from operator import itemgetter
import os
from sklearn.metrics import classification_report
import torch
from torch_shallow_neural_classifier import TorchShallowNeuralClassifier
import sst
SST_HOME = os.path.join("data", "sentiment")
Bag-of-word featurization with stopword removal to make this a little easier to study:
stopwords = set(stopwords.words('english'))
def phi(text):
return Counter([w for w in text.lower().split() if w not in stopwords])
def fit_mlp(X, y):
mod = TorchShallowNeuralClassifier(early_stopping=True)
mod.fit(X, y)
return mod
experiment = sst.experiment(
sst.train_reader(SST_HOME),
phi,
fit_mlp,
sst.dev_reader(SST_HOME))
Stopping after epoch 24. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 1.3742991983890533
precision recall f1-score support negative 0.629 0.696 0.661 428 neutral 0.295 0.100 0.150 229 positive 0.625 0.773 0.691 444 accuracy 0.603 1101 macro avg 0.516 0.523 0.500 1101 weighted avg 0.558 0.603 0.567 1101
Trained model:
sst_classifier = experiment['model']
Captum needs to have labels as indices rather than strings:
sst_classifier.classes_
['negative', 'neutral', 'positive']
y_sst_test = [sst_classifier.classes_.index(label)
for label in experiment['assess_datasets'][0]['y']]
sst_preds = [sst_classifier.classes_.index(label)
for label in experiment['predictions'][0]]
Our featurized test set:
X_sst_test = experiment['assess_datasets'][0]['X']
Feature names to help with analyses:
fnames = experiment['train_dataset']['vectorizer'].get_feature_names()
Integrated gradients:
sst_ig = IntegratedGradients(sst_classifier.model)
All-0s baseline:
sst_baseline = torch.zeros(1, experiment['train_dataset']['X'].shape[1])
Attributions with respect to the model's predictions:
sst_attrs = sst_ig.attribute(
torch.FloatTensor(X_sst_test),
sst_baseline,
target=torch.LongTensor(sst_preds))
Helper functions for error analysis:
def error_analysis(gold=1, predicted=2):
err_ind = [i for i, (g, p) in enumerate(zip(y_sst_test, sst_preds))
if g == gold and p == predicted]
attr_lookup = create_attr_lookup(sst_attrs[err_ind])
return attr_lookup, err_ind
def create_attr_lookup(attrs):
mu = attrs.mean(axis=0).detach().numpy()
return sorted(zip(fnames, mu), key=itemgetter(1), reverse=True)
sst_attrs_lookup, sst_err_ind = error_analysis(gold=1, predicted=2)
sst_attrs_lookup[: 5]
[(',', 0.04512846003808179), ('.', 0.03875377384651548), ('film', 0.036562292947638124), ('fun', 0.02995531556022619), ('best', 0.015621606617978723)]
Error analysis for a specific example:
ex_ind = sst_err_ind[0]
experiment['assess_datasets'][0]['raw_examples'][ex_ind]
'No one goes unindicted here , which is probably for the best .'
ex_attr_lookup = create_attr_lookup(sst_attrs[ex_ind:ex_ind+1])
[(f, a) for f, a in ex_attr_lookup if a != 0]
[('best', 0.43364394640626314), (',', 0.04500691178712216), ('.', 0.03940604247146967), ('probably', 0.03321118433841792), ('one', 0.008722432294266332), ('goes', -0.03914730368530946)]
import torch
import torch.nn.functional as F
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from captum.attr import LayerIntegratedGradients
from captum.attr import visualization as viz
hf_weights_name = 'cardiffnlp/twitter-roberta-base-sentiment'
hf_tokenizer = AutoTokenizer.from_pretrained(hf_weights_name)
hf_model = AutoModelForSequenceClassification.from_pretrained(hf_weights_name)
def hf_predict_one_proba(text):
input_ids = hf_tokenizer.encode(
text, add_special_tokens=True, return_tensors='pt')
hf_model.eval()
with torch.no_grad():
logits = hf_model(input_ids)[0]
preds = F.softmax(logits, dim=1)
hf_model.train()
return preds.squeeze(0)
def hf_ig_encodings(text):
pad_id = hf_tokenizer.pad_token_id
cls_id = hf_tokenizer.cls_token_id
sep_id = hf_tokenizer.sep_token_id
input_ids = hf_tokenizer.encode(text, add_special_tokens=False)
base_ids = [pad_id] * len(input_ids)
input_ids = [cls_id] + input_ids + [sep_id]
base_ids = [cls_id] + base_ids + [sep_id]
return torch.LongTensor([input_ids]), torch.LongTensor([base_ids])
def hf_ig_analyses(text2class):
data = []
for text, true_class in text2class.items():
score_vis = hf_ig_analysis_one(text, true_class)
data.append(score_vis)
viz.visualize_text(data)
def hf_ig_analysis_one(text, true_class):
# Option to look at different layers:
# layer = model.roberta.encoder.layer[0]
# layer = model.roberta.embeddings.word_embeddings
layer = hf_model.roberta.embeddings
def ig_forward(inputs):
return hf_model(inputs).logits
ig = LayerIntegratedGradients(ig_forward, layer)
input_ids, base_ids = hf_ig_encodings(text)
attrs, delta = ig.attribute(
input_ids,
base_ids,
target=true_class,
return_convergence_delta=True)
# Summarize and z-score normalize the attributions
# for each representation in `layer`:
scores = attrs.sum(dim=-1).squeeze(0)
scores = (scores - scores.mean()) / scores.norm()
# Intuitive tokens to help with analysis:
raw_input = hf_tokenizer.convert_ids_to_tokens(input_ids.tolist()[0])
# RoBERTa-specific clean-up:
raw_input = [x.strip("Ġ") for x in raw_input]
# Predictions for comparisons:
pred_probs = hf_predict_one_proba(text)
pred_class = pred_probs.argmax()
score_vis = viz.VisualizationDataRecord(
word_attributions=scores,
pred_prob=pred_probs.max(),
pred_class=pred_class,
true_class=true_class,
attr_class=None,
attr_score=attrs.sum(),
raw_input_ids=raw_input,
convergence_score=delta)
return score_vis
score_vis = hf_ig_analyses({
"They said it would be great, and they were right.": 2,
"They said it would be great, and they were wrong.": 0,
"They were right to say it would be great.": 2,
"They were wrong to say it would be great.": 0,
"They said it would be stellar, and they were correct.": 2,
"They said it would be stellar, and they were incorrect.": 0})
True Label | Predicted Label | Attribution Label | Attribution Score | Word Importance |
---|---|---|---|---|
#s They said it would be great , and they were right . #/s | ||||
#s They said it would be great , and they were wrong . #/s | ||||
#s They were right to say it would be great . #/s | ||||
#s They were wrong to say it would be great . #/s | ||||
#s They said it would be stellar , and they were correct . #/s | ||||
#s They said it would be stellar , and they were incorrect . #/s | ||||