import sys
# CausalML is not supported for Pyhon 2.7
if sys.version_info < (3, 0):
exit(0)
from causalml.inference.tree import UpliftRandomForestClassifier
from causalml.inference.tree import uplift_tree_plot
from causalml.metrics import plot_gain, plot_qini, plot_lift
from causalml.metrics import auuc_score
import h2o
from h2o.estimators.uplift_random_forest import H2OUpliftRandomForestEstimator
import pandas as pd
from IPython.display import Image
control_name = "control"
treatment_column = "treatment"
response_column = "outcome"
feature_cols = ["feature_"+str(x) for x in range(1,13)]
train_df = pd.read_csv("../../smalldata/uplift/upliftml_train.csv")
test_df = pd.read_csv("../../smalldata/uplift/upliftml_test.csv")
train_df[treatment_column].replace({1: "treatment", 0: "control"}, inplace=True)
test_df[treatment_column].replace({1: "treatment", 0: "control"}, inplace=True)
train_df.shape, test_df.shape
((4989, 17), (5011, 17))
ntree = 40
max_depth = 10
metric_cml="EU"
metric_h2o="Euclidean"
# Train CausalML uplift tree
causalml_uplift_model = UpliftRandomForestClassifier(
n_estimators=ntree,
max_depth=max_depth,
evaluationFunction=metric_cml,
control_name=control_name,
min_samples_leaf=10,
min_samples_treatment=0,
normalization=False,
random_state=42)
causalml_uplift_model.fit(train_df[feature_cols].values,
treatment=train_df[treatment_column].values,
y=train_df[response_column].values)
#test_df = train_df
causalml_preds = causalml_uplift_model.predict(test_df.values)
h2o.init(strict_version_check=False)
Checking whether there is an H2O instance running at http://localhost:54321 . connected. Warning: Version mismatch. H2O is version 3.37.0.99999, but the h2o-python package is version 0.0.local. This is a developer build, please contact your developer.
H2O_cluster_uptime: | 6 mins 57 secs |
H2O_cluster_timezone: | Europe/Berlin |
H2O_data_parsing_timezone: | UTC |
H2O_cluster_version: | 3.37.0.99999 |
H2O_cluster_version_age: | 6 days |
H2O_cluster_name: | mori |
H2O_cluster_total_nodes: | 1 |
H2O_cluster_free_memory: | 4.877 Gb |
H2O_cluster_total_cores: | 8 |
H2O_cluster_allowed_cores: | 8 |
H2O_cluster_status: | locked, healthy |
H2O_connection_url: | http://localhost:54321 |
H2O_connection_proxy: | {"http": null, "https": null} |
H2O_internal_security: | False |
Python_version: | 3.7.3 candidate |
train_h2o = h2o.H2OFrame(train_df)
train_h2o[treatment_column] = train_h2o[treatment_column].asfactor()
train_h2o[response_column] = train_h2o[response_column].asfactor()
h2o_uplift_model = H2OUpliftRandomForestEstimator(
ntrees=ntree,
max_depth=max_depth-1,
treatment_column=treatment_column,
uplift_metric=metric_h2o,
min_rows=10,
auuc_nbins=100,
seed=42,
sample_rate=0.50,
auuc_type="gain",
score_each_iteration=False)
h2o_uplift_model.train(y=response_column, x=feature_cols, training_frame=train_h2o)
h2o_uplift_model
Parse progress: |████████████████████████████████████████████████████████████████| (done) 100% upliftdrf Model Build progress: |████████████████████████████████████████████████| (done) 100% Model Details ============= H2OUpliftRandomForestEstimator : Uplift Distributed Random Forest Model Key: UpliftDRF_model_python_1649756637843_62 Model Summary:
number_of_trees | number_of_internal_trees | model_size_in_bytes | min_depth | max_depth | mean_depth | min_leaves | max_leaves | mean_leaves | ||
---|---|---|---|---|---|---|---|---|---|---|
0 | 40.0 | 80.0 | 24305.0 | 9.0 | 9.0 | 9.0 | 11.0 | 29.0 | 19.625 |
ModelMetricsBinomialUplift: upliftdrf ** Reported on train data. ** AUUC: 438.91621787270583 AUUC normalized: 1.0243607039739482 AUUC table (number of bins: 100): All types of AUUC value
uplift_type | qini | lift | gain | |
---|---|---|---|---|
0 | AUUC value | 220.031802 | 0.219001 | 438.916218 |
1 | AUUC normalized | 1.021918 | 0.219001 | 1.024361 |
2 | AUUC random value | 108.711176 | 0.043363 | 216.338275 |
Qini value: 111.3206265315211 AECU values table: All types of AECU value
uplift_type | qini | lift | gain | |
---|---|---|---|---|
0 | AECU value | 111.320627 | 0.175638 | 222.577943 |
Scoring History:
timestamp | duration | number_of_trees | training_auuc_nbins | training_auuc | training_auuc_normalized | training_qini_value | ||
---|---|---|---|---|---|---|---|---|
0 | 2022-04-12 11:50:56 | 0.044 sec | 0.0 | 0 | NaN | NaN | NaN | |
1 | 2022-04-12 11:50:56 | 0.071 sec | 1.0 | 15 | 366.755916 | 0.855950 | 43.835470 | |
2 | 2022-04-12 11:50:56 | 0.086 sec | 2.0 | 40 | 395.012929 | 0.921897 | 80.325497 | |
3 | 2022-04-12 11:50:56 | 0.101 sec | 3.0 | 59 | 408.367429 | 0.953065 | 93.433475 | |
4 | 2022-04-12 11:50:56 | 0.127 sec | 4.0 | 78 | 389.270971 | 0.908497 | 84.215633 | |
5 | 2022-04-12 11:50:56 | 0.146 sec | 5.0 | 94 | 400.061840 | 0.933681 | 90.757261 | |
6 | 2022-04-12 11:50:56 | 0.177 sec | 6.0 | 99 | 418.919837 | 0.977692 | 101.321117 | |
7 | 2022-04-12 11:50:56 | 0.216 sec | 7.0 | 100 | 421.546681 | 0.983823 | 102.566550 | |
8 | 2022-04-12 11:50:56 | 0.267 sec | 8.0 | 100 | 423.997945 | 0.989544 | 103.645982 | |
9 | 2022-04-12 11:50:56 | 0.306 sec | 9.0 | 100 | 427.692317 | 0.998166 | 105.195268 | |
10 | 2022-04-12 11:50:56 | 0.337 sec | 10.0 | 100 | 426.656112 | 0.995748 | 104.029365 | |
11 | 2022-04-12 11:50:56 | 0.362 sec | 11.0 | 100 | 426.574255 | 0.995557 | 104.311611 | |
12 | 2022-04-12 11:50:56 | 0.382 sec | 12.0 | 100 | 435.256716 | 1.015820 | 108.863536 | |
13 | 2022-04-12 11:50:56 | 0.417 sec | 13.0 | 100 | 436.454717 | 1.018616 | 109.233144 | |
14 | 2022-04-12 11:50:56 | 0.440 sec | 14.0 | 100 | 437.222194 | 1.020407 | 109.137147 | |
15 | 2022-04-12 11:50:56 | 0.459 sec | 15.0 | 100 | 435.702131 | 1.016860 | 108.977371 | |
16 | 2022-04-12 11:50:56 | 0.479 sec | 16.0 | 100 | 435.412446 | 1.016183 | 109.004228 | |
17 | 2022-04-12 11:50:56 | 0.501 sec | 17.0 | 100 | 439.075091 | 1.024731 | 111.027491 | |
18 | 2022-04-12 11:50:56 | 0.530 sec | 18.0 | 100 | 439.383070 | 1.025450 | 111.291618 | |
19 | 2022-04-12 11:50:56 | 0.554 sec | 19.0 | 100 | 440.996662 | 1.029216 | 111.675966 |
See the whole table with table.as_data_frame()
sh = h2o_uplift_model.scoring_history()
import matplotlib.pyplot as plt
plt.plot(sh['training_auuc'])
plt.xlabel('Number of trees')
plt.ylabel('AUUC')
plt.title('Scoring history')
plt.show()
test_h2o = h2o.H2OFrame(test_df)
test_h2o[treatment_column] = test_h2o[treatment_column].asfactor()
test_h2o[response_column] = test_h2o[response_column].asfactor()
preds_h2o = h2o_uplift_model.predict(test_h2o)
perf_h2o = h2o_uplift_model.model_performance(test_h2o)
auuc_h2o = perf_h2o.auuc()
print("H2O training metrics AUUC Gain: "+str(auuc_h2o))
Parse progress: |████████████████████████████████████████████████████████████████| (done) 100% upliftdrf prediction progress: |█████████████████████████████████████████████████| (done) 100% H2O training metrics AUUC Gain: 450.7117786374067
preds_comp = preds_h2o["uplift_predict"]
preds_comp.names = ["h2o"]
preds_comp["causal"] = h2o.H2OFrame(causalml_preds)
preds_comp["diff"] = abs(preds_comp["h2o"] - preds_comp["causal"])
preds_comp[treatment_column] = h2o.H2OFrame(test_df[treatment_column].values)
preds_comp[response_column] = h2o.H2OFrame(test_df[response_column].values)
preds_comp.summary()
min_diff = preds_comp["diff"].min()
max_diff = preds_comp["diff"].max()
mean_diff = preds_comp["diff"].mean(return_frame=False)[0]
print("min: %f max: %f mean: %f" % (min_diff, max_diff, mean_diff))
results = preds_comp.as_data_frame()
results = results[["h2o", "causal", response_column, treatment_column]]
mapping = {'control': 0, 'treatment': 1}
results = results.replace({treatment_column: mapping})
Parse progress: |████████████████████████████████████████████████████████████████| (done) 100% Parse progress: |████████████████████████████████████████████████████████████████| (done) 100% Parse progress: |████████████████████████████████████████████████████████████████| (done) 100%
h2o | causal | diff | treatment | outcome | |
---|---|---|---|---|---|
type | real | real | real | enum | int |
mins | -0.6264740526676178 | -0.390655875 | 0.004659703913163948 | 0.0 | |
mean | 0.08199791085882721 | -0.29422592594120467 | 0.3900188090139431 | 0.7579325483935342 | |
maxs | 0.4822891443967819 | -0.03872682500000013 | 0.7961030723154424 | 1.0 | |
sigma | 0.13776629982556876 | 0.0500373229456983 | 0.09240175231909001 | 0.4283776621130897 | |
zeros | 0 | 0 | 0 | 1213 | |
missing | 0 | 708 | 708 | 0 | 0 |
0 | -0.09308872148394576 | -0.3457214499999999 | 0.2526327285160541 | treatment | 0.0 |
1 | 0.056703162193298295 | -0.34273535 | 0.3994385121932983 | treatment | 1.0 |
2 | 0.07776566892862324 | -0.34762824999999997 | 0.4253939189286232 | control | 0.0 |
3 | 0.11391159445047383 | -0.2860978000000002 | 0.400009394450474 | control | 1.0 |
4 | 0.0736839056015014 | -0.30289307499999996 | 0.37657698060150135 | treatment | 0.0 |
5 | 0.04312786720693107 | -0.18695925 | 0.23008711720693106 | control | 1.0 |
6 | 0.09887885749340064 | -0.309560175 | 0.40843903249340063 | control | 1.0 |
7 | 0.11090063825249663 | nan | nan | control | 1.0 |
8 | 0.11205950677394871 | -0.32656569999999974 | 0.43862520677394845 | control | 0.0 |
9 | -0.043879247084260053 | -0.281969575 | 0.23809032791573992 | treatment | 1.0 |
min: nan max: nan mean: 0.390019
auuc = auuc_score(results, outcome_col=response_column, treatment_col=treatment_column, normalize=False)
print("H2O AUUC:")
print(auuc["h2o"])
print("CauslML AUUC:")
print(auuc["causal"])
print("Random AUUC:")
print(auuc["Random"])
H2O AUUC: 449.1073405919903 CauslML AUUC: 258.7092568544478 Random AUUC: 234.98047847365325
auuc = auuc_score(results, outcome_col=response_column, treatment_col=treatment_column, normalize=True)
print("H2O AUUC:")
print(auuc["h2o"])
print("CauslML AUUC:")
print(auuc["causal"])
print("Random AUUC:")
print(auuc["Random"])
H2O AUUC: 0.9232925077480227 CauslML AUUC: 0.5318646500498337 Random AUUC: 0.4830820955983254
plot_qini(results, outcome_col=response_column, treatment_col=treatment_column)
plot_qini(results, outcome_col=response_column, treatment_col=treatment_column, normalize=True)
perf_h2o.plot_uplift(metric="qini")
n, uplift = perf_h2o.plot_uplift(metric="qini", plot=False)
print(len(n))
print(n)
print(uplift)
98 [51, 101, 151, 201, 251, 301, 351, 401, 451, 502, 552, 602, 652, 702, 752, 802, 852, 902, 952, 1003, 1053, 1103, 1153, 1203, 1253, 1303, 1353, 1403, 1453, 1504, 1554, 1604, 1654, 1704, 1754, 1804, 1854, 1904, 1954, 2006, 2055, 2105, 2155, 2205, 2255, 2305, 2355, 2405, 2455, 2506, 2556, 2606, 2656, 2706, 2756, 2806, 2856, 2906, 2956, 3007, 3057, 3107, 3157, 3207, 3257, 3307, 3357, 3409, 3457, 3508, 3558, 3608, 3658, 3708, 3758, 3808, 3858, 3908, 3958, 4009, 4060, 4109, 4159, 4209, 4259, 4309, 4359, 4409, 4459, 4510, 4560, 4610, 4660, 4710, 4760, 4810, 4860, 5011] [16.307692307692307, 26.789473684210527, 35.45977011494253, 48.77272727272727, 56.68840579710145, 70.86956521739131, 72.77720207253886, 84.94883720930233, 91.28688524590164, 104.62172284644194, 109.89795918367348, 115.4447949526814, 120.83529411764707, 120.48517520215634, 125.9090909090909, 130.49289099526067, 129.16071428571428, 133.04621848739495, 135.82258064516128, 144.30710172744722, 150.06557377049182, 152.3602811950791, 158.6434634974533, 165.3186274509804, 167.42857142857144, 170.54642313546424, 169.84187408491948, 176.21985815602835, 178.05212620027436, 181.61396574440056, 180.70812182741116, 185.99506172839506, 189.28229665071763, 188.89339513325615, 195.12429378531078, 195.45993413830956, 197.171974522293, 198.4107883817428, 199.02623612512616, 205.7362204724409, 210.20977011494256, 216.0298507462686, 223.58416742493182, 228.96000000000004, 236.17636837532575, 242.54344122657574, 245.35225375626044, 248.24468085106378, 251.16000000000008, 253.57411764705887, 257.3533487297922, 259.72452830188683, 258.67256637168146, 260.51198257080614, 263.0085714285715, 265.81320224719104, 267.90041493775936, 264.3532203389831, 262.8164893617021, 261.75651041666674, 264.48560460652595, 265.80543272267846, 268.1232876712329, 268.92944785276086, 270.6269649334945, 270.71888028588455, 277.3763945977687, 277.72090330052106, 282.36285714285714, 285.97631133671734, 291.1740576496675, 293.5759562841531, 302.7684835402051, 305.27854855923147, 307.79033105622693, 309.35966735966736, 314.9835728952771, 317.02027369488087, 317.66084788029934, 316.14018691588785, 311.0737148399612, 310.05454545454563, 312.4376181474481, 312.8387096774193, 312.9843101061374, 305.3433789954338, 307.3313769751692, 304.5532768613464, 304.49624724061823, 303.58424031345226, 300.6408268733851, 297.7319982956967, 291.8212478920743, 294.4339152119701, 288.52873563218395, 279.89426596177304, 266.553319919517, 238.01641266119577]
perf_h2o.plot_uplift(metric="qini", normalize=True)
plot_lift(results, outcome_col=response_column, treatment_col=treatment_column)
perf_h2o.plot_uplift(metric="lift")
n, uplift = perf_h2o.plot_uplift(metric="lift", plot=False)
print(n)
print(uplift)
[51, 101, 151, 201, 251, 301, 351, 401, 451, 502, 552, 602, 652, 702, 752, 802, 852, 902, 952, 1003, 1053, 1103, 1153, 1203, 1253, 1303, 1353, 1403, 1453, 1504, 1554, 1604, 1654, 1704, 1754, 1804, 1854, 1904, 1954, 2006, 2055, 2105, 2155, 2205, 2255, 2305, 2355, 2405, 2455, 2506, 2556, 2606, 2656, 2706, 2756, 2806, 2856, 2906, 2956, 3007, 3057, 3107, 3157, 3207, 3257, 3307, 3357, 3409, 3457, 3508, 3558, 3608, 3658, 3708, 3758, 3808, 3858, 3908, 3958, 4009, 4060, 4109, 4159, 4209, 4259, 4309, 4359, 4409, 4459, 4510, 4560, 4610, 4660, 4710, 4760, 4810, 4860, 5011] [0.6523076923076923, 0.6088516746411483, 0.554058908045977, 0.5359640359640359, 0.5016673079389509, 0.5062111801242236, 0.46061520299075226, 0.45671417854463614, 0.4409994456323751, 0.4451988206231572, 0.4259610821072615, 0.40506945597432065, 0.38729260935143284, 0.3640035504596868, 0.3536772216547498, 0.34340234472437015, 0.3197047383309759, 0.3123150668718191, 0.29785653650254673, 0.2993923272353677, 0.29774915430653137, 0.2853188786424702, 0.2812827366976123, 0.2797269500016588, 0.2717996289424861, 0.2640037509836908, 0.25349533445510364, 0.25246398016622973, 0.24592835110535138, 0.2437771352273832, 0.23591138619766472, 0.23425070746649257, 0.231396450673249, 0.22460570170422833, 0.22453888813039202, 0.21888010541803982, 0.2161973404849704, 0.21107530678908792, 0.20667314239369272, 0.2078143641135768, 0.20792262128085315, 0.20912860672436462, 0.2117274312736097, 0.21199999999999997, 0.21392786990518642, 0.21445043432942157, 0.21205899201059675, 0.2098433481412204, 0.20843153526970948, 0.20599034739809818, 0.2047361565073923, 0.202751388213807, 0.19897889720898565, 0.19602105535801806, 0.19395912347239785, 0.1923395095855217, 0.19000029428209886, 0.18473320778405522, 0.18100309184690222, 0.17794460259460676, 0.17703186385978975, 0.174413013597558, 0.17287123640956337, 0.1705323068184913, 0.1688253056353678, 0.166289238504843, 0.1677003594907912, 0.16511349780054763, 0.1654146790526404, 0.1648278451508458, 0.16600573412181718, 0.16511583593034473, 0.16773877204443488, 0.16645504283491352, 0.1659247067688555, 0.1642036450953649, 0.16491286539019756, 0.1638347667673803, 0.16265276389160221, 0.15998997313557073, 0.15569254996995052, 0.1535683731820433, 0.15293079693952416, 0.1511298114383668, 0.14961009087291466, 0.14409786644428213, 0.14334485866379165, 0.14060631434041848, 0.13878589208779313, 0.13718221433052513, 0.13433459645817025, 0.13156517821285763, 0.1275442517010814, 0.12779249792186198, 0.1241517795319208, 0.11905328199139653, 0.11223297680821764, 0.09707031511468012]
perf_h2o.plot_uplift(metric="lift", normalize=True)
plot_gain(results, outcome_col=response_column, treatment_col=treatment_column)
plot_gain(results, outcome_col=response_column, treatment_col=treatment_column, normalize=True)
perf_h2o.plot_uplift(metric="gain")
n, uplift = perf_h2o.plot_uplift(metric="gain", plot=False)
print(n)
print(uplift)
[51, 101, 151, 201, 251, 301, 351, 401, 451, 502, 552, 602, 652, 702, 752, 802, 852, 902, 952, 1003, 1053, 1103, 1153, 1203, 1253, 1303, 1353, 1403, 1453, 1504, 1554, 1604, 1654, 1704, 1754, 1804, 1854, 1904, 1954, 2006, 2055, 2105, 2155, 2205, 2255, 2305, 2355, 2405, 2455, 2506, 2556, 2606, 2656, 2706, 2756, 2806, 2856, 2906, 2956, 3007, 3057, 3107, 3157, 3207, 3257, 3307, 3357, 3409, 3457, 3508, 3558, 3608, 3658, 3708, 3758, 3808, 3858, 3908, 3958, 4009, 4060, 4109, 4159, 4209, 4259, 4309, 4359, 4409, 4459, 4510, 4560, 4610, 4660, 4710, 4760, 4810, 4860, 5011] [33.26769230769231, 61.49401913875598, 83.66289511494253, 107.72877122877121, 125.91849429267667, 152.3695652173913, 161.67593624975405, 183.1423855963991, 198.89074998020115, 223.48980795282492, 235.13051732320835, 243.85181249654102, 252.51478129713422, 255.53049242270012, 265.96527068437183, 275.40868046894485, 272.3884370579915, 281.70819031838084, 283.55942275042446, 300.29050421707376, 313.52985948477755, 314.7067231426446, 324.31899541234696, 336.51152085199556, 340.56493506493507, 343.9968875317491, 342.97918751775524, 354.2069641732203, 357.3338941560756, 366.64081138198435, 366.60629415117097, 375.73813477625407, 382.7297294135538, 382.72811570400506, 393.84120978070763, 394.85971017414386, 400.82986925913514, 401.8873841264234, 403.83932023727556, 416.87561441183504, 427.2809867321532, 440.2157171547875, 456.27261439462893, 467.4599999999999, 482.40734663619537, 494.3082511293167, 499.3989261849553, 504.6732522796351, 511.6994190871368, 516.2118105796341, 523.3056160328947, 528.3701176851811, 528.4879509870659, 530.4329757987969, 534.5513442899285, 539.7046638969739, 542.6408404696743, 536.8347018204645, 535.0451394994429, 535.0794200019825, 541.1864078193772, 541.9012332476127, 545.7544933449916, 546.8971079669016, 549.8640204543929, 549.9185117355158, 562.9701068105861, 562.8719140020669, 571.8385454849779, 578.2160807891671, 590.6484020054255, 595.7379360366838, 613.5884281385428, 617.2152988318594, 623.545048037359, 625.2874805231495, 636.2338346753822, 640.2662685269222, 643.7796394829616, 641.3998023005031, 632.1117528779992, 631.0124454050159, 636.039184471481, 636.1053763440858, 637.1893770277435, 620.9177065084117, 624.8402389154678, 619.9332399269051, 618.8462928194696, 618.6917866306684, 612.5657598492563, 606.5154715612737, 594.3562129270393, 601.90266521197, 590.9624705719431, 572.6462863786173, 545.4522672879377, 486.41934903966205]
perf_h2o.plot_uplift(metric="gain", normalize=True)