# default_exp mli #hide %load_ext autoreload %autoreload 2 from nbdev.showdoc import * %matplotlib inline import lightgbm as lgb lgb.__version__ import plotly plotly.__version__ import shap shap.initjs() shap.__version__ # export import numpy as np def logit(p): """ Compute log-odds of p """ assert 0 < p < 1 return log(p / 1 - p) def expit(a): """ The reverse of logit """ p = 1 / (1 + np.exp(-a)) assert 0 < p < 1 return p # export from fastcore.basics import store_attr import shap class Shap: """ Fit a model for X, y, then explain it using SHAP plots """ def __init__(self, X, y, model, n_samples=1000): store_attr() model.fit(X, y) self.samples = samples = X.iloc[:n_samples] self.explainer = shap.Explainer(model, samples) self.shap_values = self.explainer(samples) def force_plot(self, n_plots=1): """ Display a SHAP force plot """ plot = shap.force_plot( self.explainer.expected_value, self.shap_values.values[:n_plots], self.samples.iloc[:n_plots], link="logit", ) return plot def waterfall_plot(self, id: int): """ Display SHAP waterfall plot """ plot = shap.plots.waterfall(self.shap_values[id]) return plot def bar_plot(self, id: int, figsize=None): """ Display shap values in a bar plot """ import pandas as pd df = pd.DataFrame( dict( shap=self.shap_values[id].values, shap_abs=abs(self.shap_values[id].values), ), index=[ f"{v} = {n}" for n, v in zip(self.X.columns, self.shap_values[id].data) ], ) plot = df.sort_values("shap_abs", ascending=False).plot.bar(y="shap", figsize=figsize) return plot X, y = shap.datasets.adult() print(y) X from lightgbm.sklearn import LGBMClassifier model = LGBMClassifier() sh = Shap(X, y, model) show_doc(Shap.waterfall_plot) sh.waterfall_plot(id=0) import plotly.io as pio print(pio.renderers.default) # export from fastcore.foundation import patch @patch def wf_plot(self: Shap, id: int, title="", n_feature=0, n_pos=0, n_neg=0, ): """ Display shap values in a waterfall plot. Params: - n_feature: Number of features with larger SHAP values to show in the plot. - n_pos: Number of features with positive SHAP values to show in the plot. - n_neg: Number of features with negative SHAP values to show in the plot. """ import pandas as pd import plotly.graph_objects as go df = pd.DataFrame( dict( data=self.shap_values[id].data, feature_val=[ f"{v} = {n}" for n, v in zip(self.X.columns, self.shap_values[id].data) ], shap=self.shap_values[id].values, shap_abs=abs(self.shap_values[id].values), shap_positive=(self.shap_values[id].values > 0), ), index=self.X.columns, ) df = df[df.shap != 0] df = df.sort_values("shap_abs", ascending=False) if 0 < n_feature < df.shape[0]: others = df.iloc[n_feature:].shap.sum() df = df.iloc[:n_feature].append( dict( feature_val="Others", shap=others, shap_abs=abs(others), shap_positive=(others > 0), ), ignore_index=True, ) else: df_pos = df[df.shap_positive].sort_values("shap_abs", ascending=False) if n_pos > 0: if n_pos < df_pos.shape[0] - 1: p_others = df_pos.iloc[n_pos:].shap.sum() df_pos = df_pos.iloc[:n_pos].append( dict( feature_val="Others+", shap=p_others, shap_abs=abs(p_others), shap_positive=(p_others > 0), ), ignore_index=True, ) df_neg = df[~df.shap_positive].sort_values("shap_abs", ascending=False) if n_neg > 0: if n_neg < df_neg.shape[0] - 1: n_others = df_neg.iloc[n_neg:].shap.sum() df_neg = df_neg.iloc[:n_neg].append( dict( feature_val="Others-", shap=n_others, shap_abs=abs(n_others), shap_positive=(n_others > 0), ), ignore_index=True, ) df = df_pos.append(df_neg) # df = df.sort_values("shap_abs", ascending=False) df = df.sort_values("shap", ascending=False) fig = go.Figure( go.Waterfall( # name="インパクト", orientation="v", # measure = ["relative", "total"], # x = [ df.data, df.feature ], # x = [ df.feature, df.data ], # x = df.feature, x=df.feature_val, textposition="outside", y=df.shap, connector={"line": {"color": "green", "width": 1}}, decreasing={"marker": {"color": "indianred"}}, increasing={"marker": {"color": "deepskyblue"}}, ) ) fig.update_layout( title=title, # showlegend=True, ) fig.update_yaxes( # visible=False, showticklabels=False, # nticks=5, # tick0=0.25, dtick=0.5, ) return fig show_doc(Shap.wf_plot) sh.wf_plot(3) sh.wf_plot(3, n_pos=3, n_neg=3, title="Features' impact") sh.wf_plot(3, n_feature=8, n_pos=3, n_neg=3) show_doc(Shap.bar_plot) sh.bar_plot(id=0, figsize=(5,3)) show_doc(Shap.force_plot) sh.force_plot() sh.force_plot(1000) # hide from nbdev.export import notebook2script notebook2script()