This notebook demonstrates how to use the Responsible AI Widget's Error Analysis dashboard to understand a model trained on the multiclass wine dataset. The goal of this sample notebook is to classify types of wine with scikit-learn and explore model errors and explanations:
# %pip install --upgrade interpret-community
# %pip install --upgrade raiwidgets
from sklearn.datasets import load_wine
from sklearn import svm
# Imports for SHAP MimicExplainer with LightGBM surrogate model
from interpret.ext.blackbox import MimicExplainer
from interpret.ext.glassbox import LGBMExplainableModel
wine = load_wine()
X = wine['data']
y = wine['target']
classes = wine['target_names']
feature_names = wine['feature_names']
# Split data into train and test
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=0)
from sklearn.linear_model import LogisticRegression
clf = svm.SVC(gamma=0.001, C=100., probability=True)
model = clf.fit(X_train, y_train)
# Notice the model makes a fair number of errors
print("number of errors on test dataset: " + str(sum(model.predict(X_test) != y_test)))
from raiwidgets import ErrorAnalysisDashboard
predictions = model.predict(X_test)
ErrorAnalysisDashboard(dataset=X_test, true_y=y_test, features=feature_names, pred_y=predictions)
from interpret_community.common.constants import ModelTask
# Train the LightGBM surrogate model using MimicExplaner
model_task = ModelTask.Classification
explainer = MimicExplainer(model, X_train, LGBMExplainableModel,
augment_data=True, max_num_of_augmentations=10,
features=feature_names, classes=classes, model_task=model_task)
Explain overall model predictions (global explanation)
# Passing in test dataset for evaluation examples - note it must be a representative sample of the original data
# X_train can be passed as well, but with more examples explanations will take longer although they may be more accurate
global_explanation = explainer.explain_global(X_test)
# Print out a dictionary that holds the sorted feature importance names and values
print('global importance rank: {}'.format(global_explanation.get_feature_importance_dict()))
from raiwidgets import ErrorAnalysisDashboard
ErrorAnalysisDashboard(global_explanation, model, dataset=X_test, true_y=y_test)