Rulefit algorithm combines tree ensembles and linear models to take advantage of both methods: a tree ensemble accuracy and a linear model interpretability. The general algorithm fits a tree ensebmle to the data, builds a rule ensemble by traversing each tree, evaluates the rules on the data to build a rule feature set and fits a sparse linear model (LASSO) to the rule feature set joined with the original feature set.
For more information, refer to: http://statweb.stanford.edu/~jhf/ftp/RuleFit.pdf by Jerome H. Friedman and Bogden E. Popescu.
We will train a rulefit model to predict the rules defining whether or not someone will survive:
import h2o
from h2o.estimators import H2ORuleFitEstimator, H2ORandomForestEstimator
# init h2o cluster
h2o.init()
Checking whether there is an H2O instance running at http://localhost:54321 . connected.
H2O_cluster_uptime: | 4 mins 19 secs |
H2O_cluster_timezone: | Europe/Prague |
H2O_data_parsing_timezone: | UTC |
H2O_cluster_version: | 3.34.0.99999 |
H2O_cluster_version_age: | 17 minutes |
H2O_cluster_name: | zuzanaolajcova |
H2O_cluster_total_nodes: | 1 |
H2O_cluster_free_memory: | 3.546 Gb |
H2O_cluster_total_cores: | 12 |
H2O_cluster_allowed_cores: | 12 |
H2O_cluster_status: | locked, healthy |
H2O_connection_url: | http://localhost:54321 |
H2O_connection_proxy: | {"http": null, "https": null} |
H2O_internal_security: | False |
H2O_API_Extensions: | Algos, AutoML, Core V3, TargetEncoder, Core V4 |
Python_version: | 3.8.1 final |
df = h2o.import_file("https://s3.amazonaws.com/h2o-public-test-data/smalldata/gbm_test/titanic.csv",
col_types={'pclass': "enum", 'survived': "enum"})
x = ["age", "sibsp", "parch", "sex", "pclass"]
# Split the dataset into train and test
train, test = df.split_frame(ratios=[.8], seed=1234)
Parse progress: |████████████████████████████████████████████████████████████████| (done) 100%
Using the algorithm
parameter, a user can set whether algorithm will use DRF or GBM to fit a tree enseble.
Using the min_rule_length
and max_rule_length
parameters, a user can set interval of tree enseble depths to be fitted. The bigger this interval is, the more tree ensembles will be fitted (1 per each depth) and the bigger the rule feature set will be.
Using the max_num_rules
parameter, the maximum number of rules to return can be set.
Using the model_type
parameter, the type of base learners in the enseble can be set.
Using the rule_generation_ntrees
parameter, the number of trees for tree enseble can be set.
rfit = H2ORuleFitEstimator(algorithm="drf",
min_rule_length=1,
max_rule_length=10,
max_num_rules=100,
model_type="rules_and_linear",
rule_generation_ntrees=50,
seed=1234)
rfit.train(training_frame=train, x=x, y="survived")
rulefit Model Build progress: |██████████████████████████████████████████████████| (done) 100% Model Details ============= H2ORuleFitEstimator : RuleFit Model Key: RuleFit_model_python_1636562504000_1 Rulefit Model Summary:
family | link | regularization | number_of_predictors_total | number_of_active_predictors | number_of_iterations | rule_ensemble_size | number_of_trees | number_of_internal_trees | min_depth | max_depth | mean_depth | min_leaves | max_leaves | mean_leaves | ||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | binomial | logit | Lasso (lambda = 0.01292 ) | 20784 | 8 | 3 | 20776.0 | 500.0 | 500.0 | 0.0 | 10.0 | 5.5 | 0.0 | 135.0 | 41.552 |
ModelMetricsBinomialGLM: rulefit ** Reported on train data. ** MSE: 0.14668202166384883 RMSE: 0.3829908897922362 LogLoss: 0.4616331658988569 Null degrees of freedom: 1053 Residual degrees of freedom: 1045 Null deviance: 1405.0919048764067 Residual deviance: 973.1227137147903 AIC: 991.1227137147903 AUC: 0.8361042692939246 AUCPR: 0.7904193564939762 Gini: 0.6722085385878491 Confusion Matrix (Act/Pred) for max f1 @ threshold = 0.44132286664639514:
0 | 1 | Error | Rate | ||
---|---|---|---|---|---|
0 | 0 | 526.0 | 122.0 | 0.1883 | (122.0/648.0) |
1 | 1 | 106.0 | 300.0 | 0.2611 | (106.0/406.0) |
2 | Total | 632.0 | 422.0 | 0.2163 | (228.0/1054.0) |
Maximum Metrics: Maximum metrics at their respective thresholds
metric | threshold | value | idx | |
---|---|---|---|---|
0 | max f1 | 0.441323 | 0.724638 | 3.0 |
1 | max f2 | 0.160033 | 0.783832 | 7.0 |
2 | max f0point5 | 0.809013 | 0.774478 | 1.0 |
3 | max accuracy | 0.523805 | 0.790323 | 2.0 |
4 | max precision | 0.809013 | 0.919048 | 1.0 |
5 | max recall | 0.156308 | 1.000000 | 8.0 |
6 | max specificity | 0.855041 | 0.973765 | 0.0 |
7 | max absolute_mcc | 0.523805 | 0.550968 | 2.0 |
8 | max min_per_class_accuracy | 0.441323 | 0.738916 | 3.0 |
9 | max mean_per_class_accuracy | 0.441323 | 0.775322 | 3.0 |
10 | max tns | 0.855041 | 631.000000 | 0.0 |
11 | max fns | 0.855041 | 217.000000 | 0.0 |
12 | max fps | 0.156308 | 648.000000 | 8.0 |
13 | max tps | 0.156308 | 406.000000 | 8.0 |
14 | max tnr | 0.855041 | 0.973765 | 0.0 |
15 | max fnr | 0.855041 | 0.534483 | 0.0 |
16 | max fpr | 0.156308 | 1.000000 | 8.0 |
17 | max tpr | 0.156308 | 1.000000 | 8.0 |
Gains/Lift Table: Avg response rate: 38.52 %, avg score: 38.52 %
group | cumulative_data_fraction | lower_threshold | lift | cumulative_lift | response_rate | score | cumulative_response_rate | cumulative_score | capture_rate | cumulative_capture_rate | gain | cumulative_gain | kolmogorov_smirnov | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 1 | 0.195446 | 0.855041 | 2.381821 | 2.381821 | 0.917476 | 0.855041 | 0.917476 | 0.855041 | 0.465517 | 0.465517 | 138.182123 | 138.182123 | 0.439283 |
1 | 2 | 0.348197 | 0.523805 | 1.402839 | 1.952350 | 0.540373 | 0.530891 | 0.752044 | 0.712839 | 0.214286 | 0.679803 | 40.283940 | 95.234963 | 0.539371 |
2 | 3 | 0.400380 | 0.414525 | 1.132826 | 1.845540 | 0.436364 | 0.441323 | 0.710900 | 0.677452 | 0.059113 | 0.738916 | 13.282579 | 84.553965 | 0.550645 |
3 | 4 | 0.528463 | 0.307335 | 0.788433 | 1.589329 | 0.303704 | 0.307335 | 0.612208 | 0.587746 | 0.100985 | 0.839901 | -21.156723 | 58.932883 | 0.506568 |
4 | 5 | 1.000000 | 0.156308 | 0.339525 | 1.000000 | 0.130785 | 0.158208 | 0.385199 | 0.385203 | 0.160099 | 1.000000 | -66.047517 | 0.000000 | 0.000000 |
The output for the Rulefit model includes: - model parameters - rule importences in tabular form - training and validation metrics of the underlying linear model
from IPython.display import display
display(rfit.rule_importance())
Rule Importance:
variable | coefficient | rule | ||
---|---|---|---|---|
0 | M2T21N13 | 1.298409e+00 | (sex in {female}) & (sibsp < 3.5 or sibsp is NA) & (pclass in {1, ... | |
1 | M2T23N21 | -8.453746e-01 | (sex in {male} or sex is NA) & (pclass in {2, 3} or pclass is NA) ... | |
2 | M1T0N7 | 3.809983e-01 | (pclass in {1, 2}) & (sex in {female}) | |
3 | M1T28N10 | -3.448192e-01 | (sex in {male} or sex is NA) & (age >= 13.496771812438965 or age i... | |
4 | M1T23N7 | 3.310857e-01 | (sex in {female}) & (sibsp < 2.5 or sibsp is NA) | |
5 | M1T37N10 | -2.319945e-01 | (sex in {male} or sex is NA) & (age >= 14.977890968322754 or age i... | |
6 | M4T3N45 | -2.797404e-02 | (sex in {male} or sex is NA) & (pclass in {2, 3} or pclass is NA) ... | |
7 | M1T1N7 | 2.887806e-14 | (pclass in {1, 2}) & (sex in {female}) |
There are several rules that can be recapped as:
Note: The rules are additive. That means that if a passenger is described by multiple rules, their probability is added together from those rules.