!pip install catboost==1.2.8
!pip install shap==0.47.2
!pip install igraph==0.11.8
!pip install shapflex==0.0.2
!pip install causal-learn==0.1.4.1
!pip install pandas==1.5.3
!pip install numpy==1.24.4
Requirement already satisfied: catboost in /usr/local/lib/python3.11/dist-packages (1.2.8) Requirement already satisfied: graphviz in /usr/local/lib/python3.11/dist-packages (from catboost) (0.20.3) Requirement already satisfied: matplotlib in /usr/local/lib/python3.11/dist-packages (from catboost) (3.10.0) Requirement already satisfied: numpy<3.0,>=1.16.0 in /usr/local/lib/python3.11/dist-packages (from catboost) (1.24.4) Requirement already satisfied: pandas>=0.24 in /usr/local/lib/python3.11/dist-packages (from catboost) (1.5.3) Requirement already satisfied: scipy in /usr/local/lib/python3.11/dist-packages (from catboost) (1.15.3) Requirement already satisfied: plotly in /usr/local/lib/python3.11/dist-packages (from catboost) (5.24.1) Requirement already satisfied: six in /usr/local/lib/python3.11/dist-packages (from catboost) (1.17.0) Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.11/dist-packages (from pandas>=0.24->catboost) (2.9.0.post0) Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.11/dist-packages (from pandas>=0.24->catboost) (2025.2) Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib->catboost) (1.3.2) Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.11/dist-packages (from matplotlib->catboost) (0.12.1) Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.11/dist-packages (from matplotlib->catboost) (4.58.0) Requirement already satisfied: kiwisolver>=1.3.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib->catboost) (1.4.8) Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.11/dist-packages (from matplotlib->catboost) (24.2) Requirement already satisfied: pillow>=8 in /usr/local/lib/python3.11/dist-packages (from matplotlib->catboost) (11.2.1) Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib->catboost) (3.2.3) Requirement already satisfied: tenacity>=6.2.0 in /usr/local/lib/python3.11/dist-packages (from plotly->catboost) (9.1.2) Requirement already satisfied: shap in /usr/local/lib/python3.11/dist-packages (0.47.2) Requirement already satisfied: numpy in /usr/local/lib/python3.11/dist-packages (from shap) (1.24.4) Requirement already satisfied: scipy in /usr/local/lib/python3.11/dist-packages (from shap) (1.15.3) Requirement already satisfied: scikit-learn in /usr/local/lib/python3.11/dist-packages (from shap) (1.6.1) Requirement already satisfied: pandas in /usr/local/lib/python3.11/dist-packages (from shap) (1.5.3) Requirement already satisfied: tqdm>=4.27.0 in /usr/local/lib/python3.11/dist-packages (from shap) (4.67.1) Requirement already satisfied: packaging>20.9 in /usr/local/lib/python3.11/dist-packages (from shap) (24.2) Requirement already satisfied: slicer==0.0.8 in /usr/local/lib/python3.11/dist-packages (from shap) (0.0.8) Requirement already satisfied: numba>=0.54 in /usr/local/lib/python3.11/dist-packages (from shap) (0.60.0) Requirement already satisfied: cloudpickle in /usr/local/lib/python3.11/dist-packages (from shap) (3.1.1) Requirement already satisfied: typing-extensions in /usr/local/lib/python3.11/dist-packages (from shap) (4.13.2) Requirement already satisfied: llvmlite<0.44,>=0.43.0dev0 in /usr/local/lib/python3.11/dist-packages (from numba>=0.54->shap) (0.43.0) Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.11/dist-packages (from pandas->shap) (2.9.0.post0) Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.11/dist-packages (from pandas->shap) (2025.2) Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn->shap) (1.5.0) Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn->shap) (3.6.0) Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.11/dist-packages (from python-dateutil>=2.8.1->pandas->shap) (1.17.0) Requirement already satisfied: igraph in /usr/local/lib/python3.11/dist-packages (0.11.8) Requirement already satisfied: texttable>=1.6.2 in /usr/local/lib/python3.11/dist-packages (from igraph) (1.7.0) Requirement already satisfied: shapflex in /usr/local/lib/python3.11/dist-packages (0.0.2) Requirement already satisfied: causal-learn in /usr/local/lib/python3.11/dist-packages (0.1.4.1) Requirement already satisfied: numpy in /usr/local/lib/python3.11/dist-packages (from causal-learn) (1.24.4) Requirement already satisfied: scipy in /usr/local/lib/python3.11/dist-packages (from causal-learn) (1.15.3) Requirement already satisfied: scikit-learn in /usr/local/lib/python3.11/dist-packages (from causal-learn) (1.6.1) Requirement already satisfied: graphviz in /usr/local/lib/python3.11/dist-packages (from causal-learn) (0.20.3) Requirement already satisfied: statsmodels in /usr/local/lib/python3.11/dist-packages (from causal-learn) (0.14.4) Requirement already satisfied: pandas in /usr/local/lib/python3.11/dist-packages (from causal-learn) (1.5.3) Requirement already satisfied: matplotlib in /usr/local/lib/python3.11/dist-packages (from causal-learn) (3.10.0) Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from causal-learn) (3.4.2) Requirement already satisfied: pydot in /usr/local/lib/python3.11/dist-packages (from causal-learn) (3.0.4) Requirement already satisfied: tqdm in /usr/local/lib/python3.11/dist-packages (from causal-learn) (4.67.1) Requirement already satisfied: momentchi2 in /usr/local/lib/python3.11/dist-packages (from causal-learn) (0.1.8) Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib->causal-learn) (1.3.2) Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.11/dist-packages (from matplotlib->causal-learn) (0.12.1) Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.11/dist-packages (from matplotlib->causal-learn) (4.58.0) Requirement already satisfied: kiwisolver>=1.3.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib->causal-learn) (1.4.8) Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.11/dist-packages (from matplotlib->causal-learn) (24.2) Requirement already satisfied: pillow>=8 in /usr/local/lib/python3.11/dist-packages (from matplotlib->causal-learn) (11.2.1) Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib->causal-learn) (3.2.3) Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.11/dist-packages (from matplotlib->causal-learn) (2.9.0.post0) Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.11/dist-packages (from pandas->causal-learn) (2025.2) Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn->causal-learn) (1.5.0) Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn->causal-learn) (3.6.0) Requirement already satisfied: patsy>=0.5.6 in /usr/local/lib/python3.11/dist-packages (from statsmodels->causal-learn) (1.0.1) Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.11/dist-packages (from python-dateutil>=2.7->matplotlib->causal-learn) (1.17.0) Requirement already satisfied: pandas==1.5.3 in /usr/local/lib/python3.11/dist-packages (1.5.3) Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.11/dist-packages (from pandas==1.5.3) (2.9.0.post0) Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.11/dist-packages (from pandas==1.5.3) (2025.2) Requirement already satisfied: numpy>=1.21.0 in /usr/local/lib/python3.11/dist-packages (from pandas==1.5.3) (1.24.4) Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.11/dist-packages (from python-dateutil>=2.8.1->pandas==1.5.3) (1.17.0) Requirement already satisfied: numpy==1.24.4 in /usr/local/lib/python3.11/dist-packages (1.24.4)
import pickle
import matplotlib.pyplot as plt
df = pickle.load( open( "df_causal_discovery.p", "rb") )
import pandas as pd
import numpy as np
from catboost import CatBoostClassifier
data_to_explain = df.copy()
outcome_name = 'greaterThan50k'
outcome_col = pd.Series(data_to_explain.columns)[data_to_explain.columns==outcome_name].index[0]
X, y = data_to_explain.drop(outcome_name, axis=1), data_to_explain[outcome_name].values
model = CatBoostClassifier(iterations=10000)
model.fit(X, y, verbose=100)
def predict_function(model, data_to_explain):
return pd.DataFrame(model.predict_proba(data_to_explain)[:, [0]])
Learning rate set to 0.005266 0: learn: 0.6892780 total: 9ms remaining: 1m 30s 100: learn: 0.4915075 total: 739ms remaining: 1m 12s 200: learn: 0.4419287 total: 1.47s remaining: 1m 11s 300: learn: 0.4269281 total: 2.18s remaining: 1m 10s 400: learn: 0.4211700 total: 2.97s remaining: 1m 11s 500: learn: 0.4183080 total: 3.71s remaining: 1m 10s 600: learn: 0.4165378 total: 4.44s remaining: 1m 9s 700: learn: 0.4153002 total: 5.15s remaining: 1m 8s 800: learn: 0.4144293 total: 5.87s remaining: 1m 7s 900: learn: 0.4137281 total: 6.6s remaining: 1m 6s 1000: learn: 0.4131087 total: 7.3s remaining: 1m 5s 1100: learn: 0.4125886 total: 9.61s remaining: 1m 17s 1200: learn: 0.4121477 total: 12.5s remaining: 1m 31s 1300: learn: 0.4117578 total: 14.5s remaining: 1m 37s 1400: learn: 0.4114092 total: 15.3s remaining: 1m 33s 1500: learn: 0.4110725 total: 16s remaining: 1m 30s 1600: learn: 0.4107218 total: 16.7s remaining: 1m 27s 1700: learn: 0.4103984 total: 17.4s remaining: 1m 24s 1800: learn: 0.4100679 total: 18.1s remaining: 1m 22s 1900: learn: 0.4097523 total: 18.8s remaining: 1m 20s 2000: learn: 0.4094076 total: 19.6s remaining: 1m 18s 2100: learn: 0.4090354 total: 20.3s remaining: 1m 16s 2200: learn: 0.4086606 total: 21.8s remaining: 1m 17s 2300: learn: 0.4083136 total: 23.6s remaining: 1m 18s 2400: learn: 0.4079803 total: 24.3s remaining: 1m 16s 2500: learn: 0.4076614 total: 25s remaining: 1m 15s 2600: learn: 0.4073819 total: 26.3s remaining: 1m 14s 2700: learn: 0.4070933 total: 27s remaining: 1m 12s 2800: learn: 0.4068058 total: 27.7s remaining: 1m 11s 2900: learn: 0.4065017 total: 28.4s remaining: 1m 9s 3000: learn: 0.4062223 total: 29.1s remaining: 1m 7s 3100: learn: 0.4059145 total: 29.9s remaining: 1m 6s 3200: learn: 0.4056258 total: 30.6s remaining: 1m 4s 3300: learn: 0.4053371 total: 31.3s remaining: 1m 3s 3400: learn: 0.4050275 total: 32s remaining: 1m 2s 3500: learn: 0.4047438 total: 32.7s remaining: 1m 3600: learn: 0.4044669 total: 33.5s remaining: 59.6s 3700: learn: 0.4042065 total: 35.3s remaining: 1m 3800: learn: 0.4039553 total: 36.6s remaining: 59.8s 3900: learn: 0.4037135 total: 37.4s remaining: 58.5s 4000: learn: 0.4034771 total: 38.8s remaining: 58.2s 4100: learn: 0.4032248 total: 40.9s remaining: 58.9s 4200: learn: 0.4029768 total: 41.6s remaining: 57.5s 4300: learn: 0.4027417 total: 42.4s remaining: 56.1s 4400: learn: 0.4024949 total: 43.1s remaining: 54.8s 4500: learn: 0.4022685 total: 43.8s remaining: 53.5s 4600: learn: 0.4020343 total: 44.5s remaining: 52.3s 4700: learn: 0.4018102 total: 45.3s remaining: 51s 4800: learn: 0.4015990 total: 46s remaining: 49.8s 4900: learn: 0.4013678 total: 47.7s remaining: 49.6s 5000: learn: 0.4011609 total: 49.3s remaining: 49.3s 5100: learn: 0.4009639 total: 50s remaining: 48s 5200: learn: 0.4007454 total: 50.7s remaining: 46.8s 5300: learn: 0.4005560 total: 51.5s remaining: 45.6s 5400: learn: 0.4003578 total: 52.2s remaining: 44.5s 5500: learn: 0.4001564 total: 53s remaining: 43.3s 5600: learn: 0.3999603 total: 53.7s remaining: 42.2s 5700: learn: 0.3997656 total: 54.5s remaining: 41.1s 5800: learn: 0.3995780 total: 55.2s remaining: 39.9s 5900: learn: 0.3994030 total: 55.9s remaining: 38.8s 6000: learn: 0.3992027 total: 56.7s remaining: 37.8s 6100: learn: 0.3990184 total: 57.5s remaining: 36.7s 6200: learn: 0.3988182 total: 58.2s remaining: 35.6s 6300: learn: 0.3986289 total: 59s remaining: 34.6s 6400: learn: 0.3984530 total: 1m 1s remaining: 34.7s 6500: learn: 0.3982771 total: 1m 2s remaining: 33.6s 6600: learn: 0.3981018 total: 1m 3s remaining: 32.6s 6700: learn: 0.3979345 total: 1m 3s remaining: 31.5s 6800: learn: 0.3977810 total: 1m 4s remaining: 30.4s 6900: learn: 0.3976136 total: 1m 5s remaining: 29.4s 7000: learn: 0.3974574 total: 1m 6s remaining: 28.3s 7100: learn: 0.3972888 total: 1m 6s remaining: 27.3s 7200: learn: 0.3971263 total: 1m 7s remaining: 26.3s 7300: learn: 0.3969797 total: 1m 8s remaining: 25.3s 7400: learn: 0.3968290 total: 1m 9s remaining: 24.3s 7500: learn: 0.3966850 total: 1m 9s remaining: 23.3s 7600: learn: 0.3965330 total: 1m 10s remaining: 22.3s 7700: learn: 0.3963747 total: 1m 11s remaining: 21.3s 7800: learn: 0.3962246 total: 1m 12s remaining: 20.6s 7900: learn: 0.3960729 total: 1m 14s remaining: 19.8s 8000: learn: 0.3959197 total: 1m 15s remaining: 18.8s 8100: learn: 0.3957832 total: 1m 16s remaining: 17.8s 8200: learn: 0.3956439 total: 1m 16s remaining: 16.8s 8300: learn: 0.3955016 total: 1m 17s remaining: 15.9s 8400: learn: 0.3953539 total: 1m 18s remaining: 14.9s 8500: learn: 0.3952175 total: 1m 18s remaining: 13.9s 8600: learn: 0.3950890 total: 1m 19s remaining: 13s 8700: learn: 0.3949550 total: 1m 20s remaining: 12s 8800: learn: 0.3948148 total: 1m 21s remaining: 11.1s 8900: learn: 0.3946763 total: 1m 21s remaining: 10.1s 9000: learn: 0.3945449 total: 1m 22s remaining: 9.18s 9100: learn: 0.3944083 total: 1m 23s remaining: 8.24s 9200: learn: 0.3942704 total: 1m 24s remaining: 7.31s 9300: learn: 0.3941349 total: 1m 25s remaining: 6.44s 9400: learn: 0.3939961 total: 1m 27s remaining: 5.57s 9500: learn: 0.3938612 total: 1m 28s remaining: 4.63s 9600: learn: 0.3937271 total: 1m 28s remaining: 3.69s 9700: learn: 0.3936023 total: 1m 29s remaining: 2.76s 9800: learn: 0.3934668 total: 1m 30s remaining: 1.83s 9900: learn: 0.3933373 total: 1m 31s remaining: 910ms 9999: learn: 0.3932004 total: 1m 31s remaining: 0us
from shapflex.shapflex import shapFlex_plus
explain, reference = data_to_explain.iloc[:300, :data_to_explain.shape[1]-1], data_to_explain.iloc[:, :data_to_explain.shape[1]-1]
exmpl_of_test = shapFlex_plus(explain, model, predict_function, target_features=pd.Series([
"age", "inRelationship",
"hours-per-week", "hasGraduateDegree",
"isFemale", "isWhite"])
)
result = exmpl_of_test.forward()
Assembling a causal data to a format with which shap's beeswarm is able to deal
from shap._explanation import Explanation
values = pd.DataFrame(result['shap_effect'].values.reshape(-1, 6), columns = result['feature_name'].unique()).values
base_values = np.array([result['shap_effect intercept'][0] for i in range(explain.shape[0])])
data = explain.values
shap_values_shapflex = Explanation(values, base_values=base_values, data=data, feature_names=result.loc[:5, 'feature_name'].values)
import shap
shap.plots.beeswarm(shap_values=shap_values_shapflex )
step 1: Use FCI without prior knowledge
from causallearn.search.ConstraintBased.FCI import fci
G = fci(df.values)
0%| | 0/7 [00:00<?, ?it/s]
X1 --> X4 X1 --> X7 X2 --> X4 X7 --> X4 X7 --> X5
step 2: turn off some nodes
from causallearn.utils.PCUtils.BackgroundKnowledge import BackgroundKnowledge
nodes = G[0].get_nodes()
bc = BackgroundKnowledge() \
.add_forbidden_by_node(nodes[2], nodes[0]) \
.add_forbidden_by_node(nodes[1], nodes[2]) \
.add_forbidden_by_node(nodes[4], nodes[5]) \
.add_forbidden_by_node(nodes[5], nodes[4]) \
.add_forbidden_by_node(nodes[5], nodes[0]) \
.add_forbidden_by_node(nodes[0], nodes[5]) \
.add_forbidden_by_node(nodes[3], nodes[4]) \
.add_forbidden_by_node(nodes[6], nodes[4]) \
.add_forbidden_by_node(nodes[3], nodes[5]) \
.add_forbidden_by_node(nodes[2], nodes[4]) \
.add_forbidden_by_node(nodes[1], nodes[4]) \
.add_required_by_node(nodes[1], nodes[5])
G2 = fci(df.values, background_knowledge=bc)
0%| | 0/7 [00:00<?, ?it/s]
Starting BK Orientation. Orienting edge (Knowledge): X1 --> X3 Orienting edge (Knowledge): X3 --> X2 Orienting edge (Knowledge): X2 --> X6 Orienting edge (Knowledge): X5 --> X3 Orienting edge (Knowledge): X5 --> X4 Orienting edge (Knowledge): X6 --> X4 Orienting edge (Knowledge): X5 --> X7 Finishing BK Orientation. Starting BK Orientation. Orienting edge (Knowledge): X1 --> X3 Orienting edge (Knowledge): X3 --> X2 Orienting edge (Knowledge): X2 --> X6 Orienting edge (Knowledge): X5 --> X3 Orienting edge (Knowledge): X5 --> X4 Orienting edge (Knowledge): X6 --> X4 Orienting edge (Knowledge): X5 --> X7 Finishing BK Orientation. X3 --> X2 X2 --> X4 X2 --> X6 X2 --> X7 X3 --> X7 X6 --> X4 X4 --> X7
nodes = []
for edge in G2[1]:
nodes.append([edge.get_node1().get_name(), edge.get_node2().get_name()])
names = {x:y for x, y in zip([
'X' + str(i) for i in range(1, 8)
], df.columns.values
)}
causal = pd.DataFrame()
causal['cause'] = pd.DataFrame(nodes)[0].apply(lambda x: names[x])
causal['effect'] = pd.DataFrame(nodes)[1].apply(lambda x: names[x])
We don't want to have an objective node in our graph. Also, though we turned off an edge 'hours-per-week' $\rightarrow$ 'isFemale' it somehow appeared in a graph. So, we have to delete it.
causal_without_objective = causal.where(causal!='greaterThan50k').dropna(axis=0)
causal_without_objective = causal_without_objective.drop(causal_without_objective.index[6]).reset_index(drop=True)
# drop(6).
import networkx as nx
fig = plt.figure(figsize=(15, 13))
nx.draw_networkx(nx.from_pandas_edgelist(causal_without_objective, source='cause', target='effect', create_using=nx.classes.digraph.DiGraph), font_size=18, font_color='r', arrowsize=30)
print(causal_without_objective.head())
cause effect 0 age hasGraduateDegree 1 age inRelationship 2 hasGraduateDegree hours-per-week 3 hours-per-week inRelationship 4 hours-per-week isFemale
exmpl_of_test = shapFlex_plus(explain, model, predict_function, target_features=pd.Series(
['age', 'hours-per-week', 'hasGraduateDegree', 'inRelationship',
'isWhite', 'isFemale']), causal=causal_without_objective.columns, causal_weights=[1. for i in range(causal_without_objective.shape[0])])
result = exmpl_of_test.forward()
values = pd.DataFrame(result['shap_effect'].values.reshape(-1, 6), columns = result['feature_name'].unique()).values
base_values = np.array([result['shap_effect intercept'][0] for i in range(explain.shape[0])])
data = explain.values
shap_values_shapflex_2 = Explanation(values, base_values=base_values, data=data, feature_names=result.loc[:5, 'feature_name'].values)
shap.plots.beeswarm(shap_values=shap_values_shapflex_2 )