In this notebook we will train the Cox-CC method. We will use the METABRIC data sets as an example
A more detailed introduction to the
pycox package can be found in this notebook about the
The main benefit Cox-CC (and the other Cox methods) has over Logistic-Hazard is that it is a continuous-time method, meaning we do not need to discretize the time scale.
import numpy as np import matplotlib.pyplot as plt from sklearn.preprocessing import StandardScaler from sklearn_pandas import DataFrameMapper import torch import torchtuples as tt from pycox.datasets import metabric from pycox.models import CoxCC from pycox.evaluation import EvalSurv
## Uncomment to install `sklearn-pandas` # ! pip install sklearn-pandas
np.random.seed(1234) _ = torch.manual_seed(123)
df_train = metabric.read_df() df_test = df_train.sample(frac=0.2) df_train = df_train.drop(df_test.index) df_val = df_train.sample(frac=0.2) df_train = df_train.drop(df_val.index)
We have 9 covariates, in addition to the durations and event indicators.
We will standardize the 5 numerical covariates, and leave the binary variables as is. As variables needs to be of type
'float32', as this is required by pytorch.
cols_standardize = ['x0', 'x1', 'x2', 'x3', 'x8'] cols_leave = ['x4', 'x5', 'x6', 'x7'] standardize = [([col], StandardScaler()) for col in cols_standardize] leave = [(col, None) for col in cols_leave] x_mapper = DataFrameMapper(standardize + leave)
x_train = x_mapper.fit_transform(df_train).astype('float32') x_val = x_mapper.transform(df_val).astype('float32') x_test = x_mapper.transform(df_test).astype('float32')
We need no label transforms
get_target = lambda df: (df['duration'].values, df['event'].values) y_train = get_target(df_train) y_val = get_target(df_val) durations_test, events_test = get_target(df_test) val = tt.tuplefy(x_val, y_val)
((305, 9), ((305,), (305,)))
TupleTree (the results of
tt.tuplefy) we can easily repeat the validation dataset multiple times. This will be useful for reduce the variance of the validation loss, as the validation loss of
CoxCC is not deterministic.
((610, 9), ((610,), (610,)))
We create a simple MLP with two hidden layers, ReLU activations, batch norm and dropout.
Here, we just use the
torchtuples.practical.MLPVanilla net to do this.
Note that we set
out_features to 1, and that we have not
in_features = x_train.shape num_nodes = [32, 32] out_features = 1 batch_norm = True dropout = 0.1 output_bias = False net = tt.practical.MLPVanilla(in_features, num_nodes, out_features, batch_norm, dropout, output_bias=output_bias)
To train the model we need to define an optimizer. You can choose any
torch.optim optimizer, but here we instead use one from
tt.optim as it has some added functionality.
We use the
Adam optimizer, but instead of choosing a learning rate, we will use the scheme proposed by Smith 2017 to find a suitable learning rate with
model.lr_finder. See this post for an explanation.
model = CoxCC(net, tt.optim.Adam)
batch_size = 256 lrfinder = model.lr_finder(x_train, y_train, batch_size, tolerance=2) _ = lrfinder.plot()
Often, this learning rate is a little high, so we instead set it manually to 0.01
We include the
EarlyStopping callback to stop training when the validation loss stops improving. After training, this callback will also load the best performing model in terms of validation loss.
epochs = 512 callbacks = [tt.callbacks.EarlyStopping()] verbose = True
%%time log = model.fit(x_train, y_train, batch_size, epochs, callbacks, verbose, val_data=val.repeat(10).cat())
0: [0s / 0s], train_loss: 0.7162, val_loss: 0.6546 1: [0s / 0s], train_loss: 0.6317, val_loss: 0.6570 2: [0s / 0s], train_loss: 0.6778, val_loss: 0.6638 3: [0s / 0s], train_loss: 0.6556, val_loss: 0.6440 4: [0s / 0s], train_loss: 0.6288, val_loss: 0.6493 5: [0s / 0s], train_loss: 0.6078, val_loss: 0.6377 6: [0s / 0s], train_loss: 0.6308, val_loss: 0.6464 7: [0s / 0s], train_loss: 0.6238, val_loss: 0.6464 8: [0s / 0s], train_loss: 0.6239, val_loss: 0.6481 9: [0s / 0s], train_loss: 0.5940, val_loss: 0.6544 10: [0s / 0s], train_loss: 0.6091, val_loss: 0.6639 11: [0s / 0s], train_loss: 0.6085, val_loss: 0.6455 12: [0s / 0s], train_loss: 0.5987, val_loss: 0.6498 13: [0s / 0s], train_loss: 0.5791, val_loss: 0.6623 14: [0s / 0s], train_loss: 0.5931, val_loss: 0.6507 15: [0s / 0s], train_loss: 0.5994, val_loss: 0.6626 CPU times: user 1.68 s, sys: 69.2 ms, total: 1.75 s Wall time: 813 ms
_ = log.plot()
We can get the partial log-likelihood
For evaluation we first need to obtain survival estimates for the test set.
This can be done with
model.predict_surv which returns an array of survival estimates, or with
model.predict_surv_df which returns the survival estimates as a dataframe.
CoxCC is semi-parametric, we first need to get the non-parametric baseline hazard estimates with
Note that for large datasets the
sample argument can be used to estimate the baseline hazard on a subset.
_ = model.compute_baseline_hazards()
surv = model.predict_surv_df(x_test)
surv.iloc[:, :5].plot() plt.ylabel('S(t | x)') _ = plt.xlabel('Time')
We can use the
EvalSurv class for evaluation the concordance, brier score and binomial log-likelihood. Setting
censor_surv='km' means that we estimate the censoring distribution by Kaplan-Meier on the test set.
ev = EvalSurv(surv, durations_test, events_test, censor_surv='km')
time_grid = np.linspace(durations_test.min(), durations_test.max(), 100) _ = ev.brier_score(time_grid).plot()