This notebook demonstrates the use of the responsibleai
API to assess a text classification huggingface transformers model trained on the DBPedia dataset. It walks through the API calls necessary to create a widget with model analysis insights, then guides a visual analysis of the model.
The following section examines the code necessary to create datasets and a model. It then generates insights using the responsibleai
API that can be visually analyzed.
The following section can be skipped. It loads a dataset and trains a model for illustrative purposes.
First we import all necessary dependencies
import datasets
import pandas as pd
import zipfile
from transformers import (AutoModelForSequenceClassification, AutoTokenizer,
pipeline)
from raiutils.common.retries import retry_function
try:
from urllib import urlretrieve
except ImportError:
from urllib.request import urlretrieve
Next we load the DBPedia dataset from huggingface datasets. Note we use only 6 examples and 8 additional error instances here since it can take some time to compute explanations, especially on CPU. You can increase the NUM_TEST_SAMPLES to 100 or more to get a more interesting dashboard.
# Bump up the number of examples to 100 or greater to view more
# information, but it may take longer to compute
NUM_TEST_SAMPLES = 6
def load_dataset(split):
dataset = datasets.load_dataset("DeveloperOats/DBPedia_Classes", split=split)
return pd.DataFrame({"text": dataset["text"], "l1": dataset["l1"]})
pd_valid_data = load_dataset("test")
def rename_label_column(dataset):
dataset["label"] = dataset["l1"]
dataset = dataset.drop(columns="l1")
return dataset
pd_valid_data = rename_label_column(pd_valid_data)
START_INDEX = 0
test_data = pd_valid_data[:NUM_TEST_SAMPLES]
Add some known error instances to make the data more interesting
error_indices = [101, 319, 391, 414, 644, 894, 1078, 1209]
test_data = test_data.append(pd_valid_data.iloc[error_indices]).reset_index(drop=True)
Fetch a pre-trained huggingface model on the DBPedia dataset
DBPEDIA_MODEL_NAME = "dbpedia_model"
NUM_LABELS = 9
class FetchModel(object):
def __init__(self):
pass
def fetch(self):
zipfilename = DBPEDIA_MODEL_NAME + '.zip'
url = ('https://publictestdatasets.blob.core.windows.net/models/' +
DBPEDIA_MODEL_NAME + '.zip')
urlretrieve(url, zipfilename)
with zipfile.ZipFile(zipfilename, 'r') as unzip:
unzip.extractall(DBPEDIA_MODEL_NAME)
def retrieve_dbpedia_model():
fetcher = FetchModel()
action_name = "Model download"
err_msg = "Failed to download model"
max_retries = 4
retry_delay = 60
retry_function(fetcher.fetch, action_name, err_msg,
max_retries=max_retries,
retry_delay=retry_delay)
model = AutoModelForSequenceClassification.from_pretrained(
DBPEDIA_MODEL_NAME, num_labels=NUM_LABELS)
return model
model = retrieve_dbpedia_model()
Load the model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
device = -1
if device >= 0:
model = model.cuda()
# build a pipeline object to do predictions
pred = pipeline(
"text-classification",
model=model,
tokenizer=tokenizer,
device=device,
return_all_scores=True
)
Define the encoded classes, which needs to come from the original trained model, and display the number of errors on the dataset
from ml_wrappers import wrap_model
wrapped_model = wrap_model(pred, test_data, 'text_classification')
encoded_classes = ['Agent', 'Device', 'Event', 'Place', 'Species',
'SportsSeason', 'TopicalConcept', 'UnitOfWork',
'Work']
labels = [encoded_classes.index(y) for y in test_data['label'].tolist()]
print("number of errors on test dataset: " + str(sum(wrapped_model.predict(test_data['text'].tolist()) != labels)))
from responsibleai_text import RAITextInsights, ModelTask
from raiwidgets import ResponsibleAIDashboard
To use Responsible AI Dashboard, initialize a RAITextInsights object upon which different components can be loaded.
RAITextInsights accepts the model, the test dataset, the classes and the task type as its arguments.
rai_insights = RAITextInsights(pred, test_data,
"label",
classes=encoded_classes,
task_type=ModelTask.TEXT_CLASSIFICATION)
Add the components of the toolbox for model assessment.
rai_insights.explainer.add()
rai_insights.error_analysis.add()
Once all the desired components have been loaded, compute insights on the test set.
rai_insights.compute()
Finally, visualize and explore the model insights. Use the resulting widget or follow the link to view this in a new tab.
ResponsibleAIDashboard(rai_insights)