We need to install interpret and gamchanger packages.
# Install `interpretml` and `gamchanger` packages.
# !pip install --upgrade interpret gamchanger
We will train a simple EBM model to predict if an indivisual's income is above 50K using the census dataset.
import pandas as pd
import numpy as np
import gamchanger as gc
from json import load
from sklearn.model_selection import train_test_split
from interpret.glassbox import ExplainableBoostingClassifier
df = pd.read_csv(
"https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data",
header=None)
df.columns = [
"Age", "WorkClass", "fnlwgt", "Education", "EducationNum",
"MaritalStatus", "Occupation", "Relationship", "Race", "Gender",
"CapitalGain", "CapitalLoss", "HoursPerWeek", "NativeCountry", "Income"
]
train_cols = df.columns[0:-1]
label = df.columns[-1]
X = df[train_cols]
y = df[label].apply(lambda x: 0 if x == " <=50K" else 1) #Turning response into 0 and 1
seed = 1
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.20, random_state=seed)
ebm = ExplainableBoostingClassifier(random_state=seed, n_jobs=-1)
ebm.fit(X_train, y_train)
ExplainableBoostingClassifier(feature_names=['Age', 'WorkClass', 'fnlwgt', 'Education', 'EducationNum', 'MaritalStatus', 'Occupation', 'Relationship', 'Race', 'Gender', 'CapitalGain', 'CapitalLoss', 'HoursPerWeek', 'NativeCountry', 'Relationship x HoursPerWeek', 'Age x Relationship', 'EducationNum x Occupation', 'MaritalStatus x HoursPerWeek', 'Occupation x Relationship', 'Occ... feature_types=['continuous', 'categorical', 'continuous', 'categorical', 'continuous', 'categorical', 'categorical', 'categorical', 'categorical', 'categorical', 'continuous', 'continuous', 'continuous', 'categorical', 'interaction', 'interaction', 'interaction', 'interaction', 'interaction', 'interaction', 'interaction', 'interaction', 'interaction', 'interaction'], n_jobs=-1, random_state=1)
Then we can start to investigate, validate, and edit the trained EBM model using GAM Changer.
GAM Changer expects:
GAM Changer uses the sample data to compute model performance, feature correlation, etc. We suggest always providing sample data, unless the data is sensitive and you plan to share GAM Changer visualizations/outputs with external collaborators.
We recommend to generate sample data using the validation set (if you have one), or a large subset of the training set. GAM Changer can support up to 10k+ data points (upper bound is the browser's memory limit), and the it can provide realtime feedback when the sample size is about less than 3k.
# Randomly sample 2000 points from the training set for GAM Changer
rand_indexes = np.random.choice(range(len(y_train)), 2000)
X_sample = X_train.to_numpy()[rand_indexes]
y_sample = y_train.to_numpy()[rand_indexes]
gc.visualize(ebm, X_sample, y_sample)
Suppose you have maded some edits on your EBM model, and save a .gamchanger
file by clicking the save button on the botttom right. Let's load the new model in Python!
Unfortunately, we cannot automatically transfer the new model back to Python (restriction of computational notebooks for security concerns). Instead, we will use Python to load the downloaded .gamchanger
file.
For cloud-based notebooks (e.g., Colab), you would need to upload the .gamchanger
file to the working directory so that you can load it from your notebook.
For example, here we load a previous editing history edit-6-14-2021.gamchanger
to the root directory on the left panel, then use the code in the below cell.
# Read from file
# Use below code to load from `~/Downloads`
# gc_dict = load(open(os.path.join(os.path.expanduser('~'), 'Downloads/edit-6-14-2022.gamchanger'), 'r'))
# Use below code to load from `./`
gc_dict = load(open('./edit-6-14-2022.gamchanger', 'r'))
# gc.get_edited_model will return a copy of your original EBM where edits are applied
new_ebm = gc.get_edited_model(ebm, gc_dict)
new_ebm
ExplainableBoostingClassifier(feature_names=['Age', 'WorkClass', 'fnlwgt', 'Education', 'EducationNum', 'MaritalStatus', 'Occupation', 'Relationship', 'Race', 'Gender', 'CapitalGain', 'CapitalLoss', 'HoursPerWeek', 'NativeCountry', 'Relationship x HoursPerWeek', 'Age x Relationship', 'EducationNum x Occupation', 'MaritalStatus x HoursPerWeek', 'Occupation x Relationship', 'Occ... feature_types=['continuous', 'categorical', 'continuous', 'categorical', 'continuous', 'categorical', 'categorical', 'categorical', 'categorical', 'categorical', 'continuous', 'continuous', 'continuous', 'categorical', 'interaction', 'interaction', 'interaction', 'interaction', 'interaction', 'interaction', 'interaction', 'interaction', 'interaction', 'interaction'], n_jobs=-1, random_state=1)