This notebook demonstrates how to use the Responsible AI Widget's Error Analysis dashboard to understand a model trained on the regression california housing dataset. The goal of this sample notebook is to predict housing prices with scikit-learn and explore model errors and explanations:
# %pip install --upgrade interpret-community
# %pip install --upgrade raiwidgets
from sklearn.datasets import fetch_california_housing
from sklearn import svm
# Imports for SHAP MimicExplainer with LightGBM surrogate model
from interpret.ext.blackbox import MimicExplainer
from interpret.ext.glassbox import LGBMExplainableModel
housing = fetch_california_housing()
X = housing['data']
y = housing['target']
feature_names = housing['feature_names']
# Split data into train and test
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=0)
clf = svm.SVR(gamma=0.001, C=100.)
model = clf.fit(X_train, y_train)
# Notice the model makes a fair amount of error
print("average abs error on test dataset: " + str(sum(abs(model.predict(X_test) - y_test))/y_test.shape[0]))
from raiwidgets import ErrorAnalysisDashboard
predictions = model.predict(X_test)
ErrorAnalysisDashboard(dataset=X_test, true_y=y_test, features=feature_names, pred_y=predictions, model_task='regression')
from interpret_community.common.constants import ModelTask
# Train the LightGBM surrogate model using MimicExplaner
model_task = ModelTask.Regression
explainer = MimicExplainer(model, X_train, LGBMExplainableModel,
augment_data=True, max_num_of_augmentations=10,
features=feature_names, model_task=model_task)
Explain overall model predictions (global explanation)
# Passing in test dataset for evaluation examples - note it must be a representative sample of the original data
# X_train can be passed as well, but with more examples explanations will take longer although they may be more accurate
global_explanation = explainer.explain_global(X_test)
# Print out a dictionary that holds the sorted feature importance names and values
print('global importance rank: {}'.format(global_explanation.get_feature_importance_dict()))
from raiwidgets import ErrorAnalysisDashboard
ErrorAnalysisDashboard(global_explanation, model, dataset=X_test, true_y=y_test, model_task='regression')