#!/usr/bin/env python # coding: utf-8 # # `auton-survival` Cross Validation Survival Regression # # `auton-survival` offers a simple to use API to train Survival Regression Models that performs cross validation model selection by minimizing integrated brier score. In this notebook we demonstrate the use of `auton-survival` to train survival models on the *SUPPORT* dataset in cross validation fashion. # In[ ]: import sys sys.path.append('../') from auton_survival import datasets outcomes, features = datasets.load_support() # In[ ]: from auton_survival.preprocessing import Preprocessor cat_feats = ['sex', 'dzgroup', 'dzclass', 'income', 'race', 'ca'] num_feats = ['age', 'num.co', 'meanbp', 'wblc', 'hrt', 'resp', 'temp', 'pafi', 'alb', 'bili', 'crea', 'sod', 'ph', 'glucose', 'bun', 'urine', 'adlp', 'adls'] # Data should be processed in a fold-independent manner when performing cross-validation. # For simplicity in this demo, we process the dataset in a non-independent manner. preprocessor = Preprocessor(cat_feat_strat='ignore', num_feat_strat= 'mean') x = preprocessor.fit_transform(features, cat_feats=cat_feats, num_feats=num_feats, one_hot=True, fill_value=-1) # In[ ]: import numpy as np times = np.quantile(outcomes.time[outcomes.event==1], [0.25, 0.5, 0.75]).tolist() # In[ ]: from auton_survival.experiments import SurvivalRegressionCV param_grid = {'k' : [3], 'distribution' : ['Weibull'], 'learning_rate' : [1e-4, 1e-3], 'layers' : [[100]]} experiment = SurvivalRegressionCV(model='dsm', num_folds=3, hyperparam_grid=param_grid, random_seed=0) model = experiment.fit(x, outcomes, times, metric='brs') # In[ ]: print(experiment.folds) model # In[ ]: out_risk = model.predict_risk(x, times) out_survival = model.predict_survival(x, times) # In[ ]: from auton_survival.metrics import survival_regression_metric for fold in set(experiment.folds): print(survival_regression_metric('brs', outcomes[experiment.folds==fold], out_survival[experiment.folds==fold], times=times)) # In[ ]: from auton_survival.metrics import survival_regression_metric for fold in set(experiment.folds): print(survival_regression_metric('ctd', outcomes[experiment.folds==fold], out_survival[experiment.folds==fold], times=times)) # In[ ]: for fold in set(experiment.folds): for time in times: print(time) # In[ ]: