auton-survival
Cross Validation Survival Regression¶auton-survival
offers a simple to use API to train Survival Regression Models that performs cross validation model selection by minimizing integrated brier score. In this notebook we demonstrate the use of auton-survival
to train survival models on the SUPPORT dataset in cross validation fashion.
import sys
sys.path.append('../')
from auton_survival import datasets
outcomes, features = datasets.load_support()
from auton_survival.preprocessing import Preprocessor
cat_feats = ['sex', 'dzgroup', 'dzclass', 'income', 'race', 'ca']
num_feats = ['age', 'num.co', 'meanbp', 'wblc', 'hrt', 'resp',
'temp', 'pafi', 'alb', 'bili', 'crea', 'sod', 'ph',
'glucose', 'bun', 'urine', 'adlp', 'adls']
features = Preprocessor().fit_transform(features, cat_feats=cat_feats, num_feats=num_feats)
import numpy as np
horizons = [0.25, 0.5, 0.75]
times = np.quantile(outcomes.time[outcomes.event==1], horizons).tolist()
from auton_survival.experiments import SurvivalRegressionCV
param_grid = {'k' : [3],
'distribution' : ['Weibull'],
'learning_rate' : [ 1e-4, 1e-3],
'layers' : [[], [100]]}
experiment = SurvivalRegressionCV(model='dsm', cv_folds=5, hyperparam_grid=param_grid, random_seed=0)
model = experiment.fit(features, outcomes)
0%| | 0/4 [00:00<?, ?it/s] 18%|█▊ | 1779/10000 [00:03<00:14, 554.25it/s] 100%|██████████| 10/10 [00:01<00:00, 7.14it/s] 14%|█▍ | 1418/10000 [00:02<00:17, 482.72it/s] 100%|██████████| 10/10 [00:01<00:00, 6.95it/s] 18%|█▊ | 1793/10000 [00:03<00:13, 588.09it/s] 100%|██████████| 10/10 [00:01<00:00, 6.14it/s] 18%|█▊ | 1770/10000 [00:03<00:16, 502.67it/s] 100%|██████████| 10/10 [00:01<00:00, 6.72it/s] 18%|█▊ | 1814/10000 [00:03<00:14, 556.46it/s] 100%|██████████| 10/10 [00:01<00:00, 7.86it/s] 100%|██████████| 5/5 [00:24<00:00, 4.85s/it] 25%|██▌ | 1/4 [00:24<01:14, 25.00s/it] 18%|█▊ | 1779/10000 [00:04<00:19, 428.03it/s] 100%|██████████| 10/10 [00:01<00:00, 5.74it/s] 14%|█▍ | 1418/10000 [00:02<00:13, 618.01it/s] 100%|██████████| 10/10 [00:01<00:00, 7.09it/s] 18%|█▊ | 1793/10000 [00:03<00:14, 570.05it/s] 100%|██████████| 10/10 [00:01<00:00, 6.29it/s] 18%|█▊ | 1770/10000 [00:02<00:13, 596.33it/s] 100%|██████████| 10/10 [00:01<00:00, 7.31it/s] 18%|█▊ | 1814/10000 [00:03<00:15, 518.75it/s] 100%|██████████| 10/10 [00:01<00:00, 6.74it/s] 100%|██████████| 5/5 [00:24<00:00, 4.86s/it] 50%|█████ | 2/4 [00:49<00:49, 25.00s/it] 18%|█▊ | 1779/10000 [00:02<00:13, 619.50it/s] 100%|██████████| 10/10 [00:01<00:00, 6.79it/s] 14%|█▍ | 1418/10000 [00:02<00:15, 540.40it/s] 100%|██████████| 10/10 [00:01<00:00, 5.69it/s] 18%|█▊ | 1793/10000 [00:03<00:14, 552.62it/s] 100%|██████████| 10/10 [00:01<00:00, 6.47it/s] 18%|█▊ | 1770/10000 [00:03<00:15, 520.40it/s] 90%|█████████ | 9/10 [00:01<00:00, 4.68it/s] 18%|█▊ | 1814/10000 [00:02<00:13, 619.36it/s] 100%|██████████| 10/10 [00:01<00:00, 6.73it/s] 100%|██████████| 5/5 [00:23<00:00, 4.74s/it] 75%|███████▌ | 3/4 [01:14<00:24, 24.74s/it] 18%|█▊ | 1779/10000 [00:03<00:15, 528.24it/s] 80%|████████ | 8/10 [00:01<00:00, 5.58it/s] 14%|█▍ | 1418/10000 [00:02<00:14, 586.05it/s] 90%|█████████ | 9/10 [00:01<00:00, 5.60it/s] 18%|█▊ | 1793/10000 [00:02<00:13, 604.69it/s] 100%|██████████| 10/10 [00:01<00:00, 5.98it/s] 18%|█▊ | 1770/10000 [00:02<00:13, 592.69it/s] 90%|█████████ | 9/10 [00:01<00:00, 5.96it/s] 18%|█▊ | 1814/10000 [00:02<00:13, 624.91it/s] 100%|██████████| 10/10 [00:01<00:00, 6.77it/s] 100%|██████████| 5/5 [00:22<00:00, 4.53s/it] 100%|██████████| 4/4 [01:37<00:00, 24.45s/it] 18%|█▊ | 1815/10000 [00:03<00:14, 575.52it/s] 100%|██████████| 10/10 [00:01<00:00, 5.19it/s]
experiment.folds
array([2, 0, 0, ..., 4, 4, 2])
out_risk = model.predict_risk(features, times)
out_survival = model.predict_survival(features, times)
from sksurv.metrics import concordance_index_ipcw, brier_score, cumulative_dynamic_auc
from auton_survival.metrics import survival_regression_metric
for fold in set(experiment.folds):
print(survival_regression_metric('brs', out_survival, outcomes, times, experiment.folds, fold))
[0.11593564 0.18645067 0.20140866] [0.12495799 0.18906541 0.20994536] [0.1228231 0.18860809 0.20365317] [0.13017307 0.19169085 0.20218919] [0.12173914 0.18685545 0.20330223]
from auton_survival.metrics import survival_regression_metric
for fold in set(experiment.folds):
for i, time in enumerate(times):
print(survival_regression_metric('ctd', out_survival[:, i], outcomes, time, experiment.folds, fold), end=' ' )
print()
0.7797837441187897 0.7403985926375387 0.7026575605080277 0.7753023996424281 0.720478809092244 0.6852956869933615 0.7794202898550725 0.7353951125213819 0.6962889282771891 0.7676518744513308 0.7298422836907454 0.6952441350326791 0.7993855420033212 0.7372447791362597 0.6976821314815531
for fold in set(experiment.folds):
for time in times:
print(time)
14.0 58.0 252.0 14.0 58.0 252.0 14.0 58.0 252.0 14.0 58.0 252.0 14.0 58.0 252.0