We use bert-base-uncased as the model and SST-2 as the dataset for example. More models can be found in PaddleNLP Model Zoo.
Obviously, PaddleNLP is needed to run this notebook, which is easy to install:
pip install setuptools_scm
pip install --upgrade paddlenlp
import paddle
import paddlenlp
from paddlenlp.transformers import BertForSequenceClassification, BertTokenizer
MODEL_NAME = "bert-base-uncased"
model = BertForSequenceClassification.from_pretrained(MODEL_NAME, num_classes=2)
tokenizer = BertTokenizer.from_pretrained(MODEL_NAME)
from paddlenlp.datasets import load_dataset
train_ds, dev_ds, test_ds = load_dataset(
"glue", name='sst-2', splits=["train", "dev", "test"]
)
[2021-11-04 16:50:48,431] [ INFO] - Already cached /root/.paddlenlp/models/bert-base-uncased/bert-base-uncased.pdparams W1104 16:50:48.433992 22865 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.2, Runtime API Version: 10.2 W1104 16:50:48.439213 22865 device_context.cc:465] device: 0, cuDNN Version: 7.6. [2021-11-04 16:50:58,691] [ INFO] - Already cached /root/.paddlenlp/models/bert-base-uncased/bert-base-uncased-vocab.txt INFO:paddle.utils.download:unique_endpoints {'10.255.126.17:35174'}
# training the model and save to save_dir
# only needs to run once.
# total steps ~2100 (1 epoch)
from assets.utils import training_model
training_model(model, tokenizer, train_ds, dev_ds, save_dir=f'assets/sst-2-{MODEL_NAME}')
# global step 2100, epoch: 1, batch: 2100, loss: 0.22977, acc: 0.91710
# eval loss: 0.20062, accu: 0.91972
# Load the trained model.
state_dict = paddle.load(f'assets/sst-2-{MODEL_NAME}/model_state.pdparams')
model.set_dict(state_dict)
from assets.utils import predict
reviews = [
"it 's a charming and often affecting journey . ",
'the movie achieves as great an impact by keeping these thoughts hidden as ... ( quills ) did by showing them . ',
'this one is definitely one to skip , even for horror movie fanatics . ',
'in its best moments , resembles a bad high school production of grease , without benefit of song . '
]
data = [ {"text": r} for r in reviews]
label_map = {0: 'negative', 1: 'positive'}
batch_size = 32
results = predict(
model, data, tokenizer, label_map, batch_size=batch_size)
for idx, text in enumerate(data):
print('Data: {} \t Lable: {}'.format(text, results[idx]))
Data: {'text': "it 's a charming and often affecting journey . "} Lable: positive Data: {'text': 'the movie achieves as great an impact by keeping these thoughts hidden as ... ( quills ) did by showing them . '} Lable: positive Data: {'text': 'this one is definitely one to skip , even for horror movie fanatics . '} Lable: positive Data: {'text': 'in its best moments , resembles a bad high school production of grease , without benefit of song . '} Lable: negative
import interpretdl as it
import numpy as np
from assets.utils import convert_example, aggregate_subwords_and_importances
from paddlenlp.data import Stack, Tuple, Pad
from interpretdl.data_processor.visualizer import VisualizationTextRecord, visualize_text
def preprocess_fn(data):
examples = []
if not isinstance(data, list):
data = [data]
for text in data:
input_ids, segment_ids = convert_example(
text,
tokenizer,
max_seq_length=128,
is_test=True
)
examples.append((input_ids, segment_ids))
batchify_fn = lambda samples, fn=Tuple(
Pad(axis=0, pad_val=tokenizer.pad_token_id), # input id
Pad(axis=0, pad_val=tokenizer.pad_token_id), # segment id
): fn(samples)
input_ids, segment_ids = batchify_fn(examples)
return paddle.to_tensor(input_ids, stop_gradient=False), paddle.to_tensor(segment_ids, stop_gradient=False)
ig = it.IntGradNLPInterpreter(model, device='gpu:0')
pred_labels, pred_probs, avg_gradients = ig.interpret(
preprocess_fn(data),
steps=50,
return_pred=True)
true_labels = [1, 1, 0, 0] * 5
recs = []
for i in range(avg_gradients.shape[0]):
subwords = " ".join(tokenizer._tokenize(data[i]['text'])).split(' ')
subword_importances = avg_gradients[i]
words, word_importances = aggregate_subwords_and_importances(subwords, subword_importances)
word_importances = np.array(word_importances) / np.linalg.norm(
word_importances)
pred_label = pred_labels[i]
pred_prob = pred_probs[i, pred_label]
true_label = true_labels[i]
interp_class = pred_label
if interp_class == 0:
word_importances = -word_importances
recs.append(
VisualizationTextRecord(words, word_importances, true_label,
pred_label, pred_prob, interp_class)
)
visualize_text(recs)
# The visualization is not available at github
True Label | Predicted Label (Prob) | Target Label | Word Importance |
---|---|---|---|
it ' s a charming and often affecting journey . | |||
the movie achieves as great an impact by keeping these thoughts hidden as . . . ( quills ) did by showing them . | |||
this one is definitely one to skip , even for horror movie fanatics . | |||
in its best moments , resembles a bad high school production of grease , without benefit of song . | |||
true_labels = [1, 1, 0, 0] * 5
recs = []
lime = it.LIMENLPInterpreter(model, device='gpu:0')
for i, review in enumerate(data):
pred_class, pred_prob, lime_weights = lime.interpret(
review,
preprocess_fn,
num_samples=1000,
batch_size=32,
unk_id=tokenizer.convert_tokens_to_ids('[UNK]'),
pad_id=tokenizer.convert_tokens_to_ids('[PAD]'),
return_pred=True)
# subwords
subwords = " ".join(tokenizer._tokenize(review['text'])).split(' ')
interp_class = list(lime_weights.keys())[0]
weights = lime_weights[interp_class][1 : -1]
subword_importances = [t[1] for t in lime_weights[interp_class][1 : -1]]
words, word_importances = aggregate_subwords_and_importances(subwords, subword_importances)
word_importances = np.array(word_importances) / np.linalg.norm(
word_importances)
true_label = true_labels[i]
if interp_class == 0:
word_importances = -word_importances
rec = VisualizationTextRecord(
words,
word_importances,
true_label,
pred_class[0],
pred_prob[0],
interp_class
)
recs.append(rec)
visualize_text(recs)
# The visualization is not available at github
True Label | Predicted Label (Prob) | Target Label | Word Importance |
---|---|---|---|
it ' s a charming and often affecting journey . | |||
the movie achieves as great an impact by keeping these thoughts hidden as . . . ( quills ) did by showing them . | |||
this one is definitely one to skip , even for horror movie fanatics . | |||
in its best moments , resembles a bad high school production of grease , without benefit of song . | |||
ig = it.GradShapNLPInterpreter(model, device='gpu:0')
pred_labels, pred_probs, avg_gradients = ig.interpret(
preprocess_fn(data),
n_samples=10,
noise_amount=0.1,
return_pred=True)
true_labels = [1, 1, 0, 0] * 5
recs = []
for i in range(avg_gradients.shape[0]):
subwords = " ".join(tokenizer._tokenize(data[i]['text'])).split(' ')
subword_importances = avg_gradients[i]
words, word_importances = aggregate_subwords_and_importances(subwords, subword_importances)
word_importances = np.array(word_importances) / np.linalg.norm(
word_importances)
pred_label = pred_labels[i]
pred_prob = pred_probs[i, pred_label]
true_label = true_labels[i]
interp_class = pred_label
if interp_class == 0:
word_importances = -word_importances
recs.append(
VisualizationTextRecord(words, word_importances, true_label,
pred_label, pred_prob, interp_class)
)
visualize_text(recs)
# The visualization is not available at github
True Label | Predicted Label (Prob) | Target Label | Word Importance |
---|---|---|---|
it ' s a charming and often affecting journey . | |||
the movie achieves as great an impact by keeping these thoughts hidden as . . . ( quills ) did by showing them . | |||
this one is definitely one to skip , even for horror movie fanatics . | |||
in its best moments , resembles a bad high school production of grease , without benefit of song . | |||