# export
import numpy as np

def logit(p):
    """
    Compute log-odds of p
    """
    assert 0 < p < 1
    return log(p / 1 - p)

def expit(a):
    """
    The reverse of logit
    """
    p = 1 / (1 + np.exp(-a))
    assert 0 < p < 1
    return p

# export
from fastcore.basics import store_attr
import shap

class Shap:
    """
    Fit a model for X, y, then explain it using SHAP plots
    """

    def __init__(self, X, y, model, n_samples=1000):
        store_attr()
        model.fit(X, y)
        self.samples = samples = X.iloc[:n_samples]
        self.explainer = shap.Explainer(model, samples)
        self.shap_values = self.explainer(samples)

    def force_plot(self, n_plots=1):
        """
        Display a SHAP force plot
        """
        plot = shap.force_plot(
            self.explainer.expected_value,
            self.shap_values.values[:n_plots],
            self.samples.iloc[:n_plots],
            link="logit",
        )
        return plot

    def waterfall_plot(self, id: int):
        """
        Display SHAP waterfall plot
        """
        plot = shap.plots.waterfall(self.shap_values[id])
        return plot

    def bar_plot(self, id: int, figsize=None):
        """
        Display shap values in a bar plot
        """
        import pandas as pd

        df = pd.DataFrame(
            dict(
                shap=self.shap_values[id].values,
                shap_abs=abs(self.shap_values[id].values),
            ),
            index=[
                f"{v} = {n}"
                for n, v in zip(self.X.columns, self.shap_values[id].data)
            ],
        )
        plot = df.sort_values("shap_abs", ascending=False).plot.bar(y="shap", figsize=figsize)
        return plot

X, y = shap.datasets.adult()
print(y)
X

from lightgbm.sklearn import LGBMClassifier

model = LGBMClassifier()
sh = Shap(X, y, model)

show_doc(Shap.waterfall_plot)

sh.waterfall_plot(id=0)

import plotly.io as pio

print(pio.renderers.default)

# export
from fastcore.foundation import patch

@patch
def wf_plot(self: Shap, id: int, title="", n_feature=0, n_pos=0, n_neg=0,
):
    """
    Display shap values in a waterfall plot.
    Params:
    - n_feature: Number of features with larger SHAP values to show in the plot.
    - n_pos: Number of features with positive SHAP values to show in the plot.
    - n_neg: Number of features with negative SHAP values to show in the plot.
    """
    import pandas as pd
    import plotly.graph_objects as go

    df = pd.DataFrame(
        dict(
            data=self.shap_values[id].data,
            feature_val=[
                f"{v} = {n}" for n, v in zip(self.X.columns, self.shap_values[id].data)
            ],
            shap=self.shap_values[id].values,
            shap_abs=abs(self.shap_values[id].values),
            shap_positive=(self.shap_values[id].values > 0),
        ),
        index=self.X.columns,
    )
    df = df[df.shap != 0]
    df = df.sort_values("shap_abs", ascending=False)
    if 0 < n_feature < df.shape[0]:
        others = df.iloc[n_feature:].shap.sum()
        df = df.iloc[:n_feature].append(
            dict(
                feature_val="Others",
                shap=others,
                shap_abs=abs(others),
                shap_positive=(others > 0),
            ),
            ignore_index=True,
        )
    else:
        df_pos = df[df.shap_positive].sort_values("shap_abs", ascending=False)
        if n_pos > 0:
            if n_pos < df_pos.shape[0] - 1:
                p_others = df_pos.iloc[n_pos:].shap.sum()
                df_pos = df_pos.iloc[:n_pos].append(
                    dict(
                        feature_val="Others+",
                        shap=p_others,
                        shap_abs=abs(p_others),
                        shap_positive=(p_others > 0),
                    ),
                    ignore_index=True,
                )
        df_neg = df[~df.shap_positive].sort_values("shap_abs", ascending=False)
        if n_neg > 0:
            if n_neg < df_neg.shape[0] - 1:
                n_others = df_neg.iloc[n_neg:].shap.sum()
                df_neg = df_neg.iloc[:n_neg].append(
                    dict(
                        feature_val="Others-",
                        shap=n_others,
                        shap_abs=abs(n_others),
                        shap_positive=(n_others > 0),
                    ),
                    ignore_index=True,
                )
        df = df_pos.append(df_neg)
    # df = df.sort_values("shap_abs", ascending=False)
    df = df.sort_values("shap", ascending=False)
    fig = go.Figure(
        go.Waterfall(
            # name="インパクト", orientation="v",
            # measure = ["relative", "total"],
            # x = [ df.data, df.feature ],
            # x = [ df.feature, df.data ],
            # x = df.feature,
            x=df.feature_val,
            textposition="outside",
            y=df.shap,
            connector={"line": {"color": "green", "width": 1}},
            decreasing={"marker": {"color": "indianred"}},
            increasing={"marker": {"color": "deepskyblue"}},
        )
    )
    fig.update_layout(
        title=title,
        # showlegend=True,
    )
    fig.update_yaxes(
        # visible=False, showticklabels=False,
        # nticks=5,
        # tick0=0.25, dtick=0.5,
    )
    return fig

show_doc(Shap.wf_plot)

sh.wf_plot(3)

sh.wf_plot(3, n_pos=3, n_neg=3, title="Features' impact")

sh.wf_plot(3, n_feature=8, n_pos=3, n_neg=3)

show_doc(Shap.bar_plot)

sh.bar_plot(id=0, figsize=(5,3))

show_doc(Shap.force_plot)

sh.force_plot()

sh.force_plot(1000)