Cox-CC

In this notebook we will train the Cox-CC method. We will use the METABRIC data sets as an example

A more detailed introduction to the pycox package can be found in this notebook about the LogisticHazard method.

The main benefit Cox-CC (and the other Cox methods) has over Logistic-Hazard is that it is a continuous-time method, meaning we do not need to discretize the time scale.

In [1]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
from sklearn_pandas import DataFrameMapper

import torch
import torchtuples as tt

from pycox.datasets import metabric
from pycox.models import CoxCC
from pycox.evaluation import EvalSurv
In [2]:
## Uncomment to install `sklearn-pandas`
# ! pip install sklearn-pandas
In [3]:
np.random.seed(1234)
_ = torch.manual_seed(123)

Dataset

We load the METABRIC data set and split in train, test and validation.

In [4]:
df_train = metabric.read_df()
df_test = df_train.sample(frac=0.2)
df_train = df_train.drop(df_test.index)
df_val = df_train.sample(frac=0.2)
df_train = df_train.drop(df_val.index)
In [5]:
df_train.head()
Out[5]:
x0 x1 x2 x3 x4 x5 x6 x7 x8 duration event
0 5.603834 7.811392 10.797988 5.967607 1.0 1.0 0.0 1.0 56.840000 99.333336 0
1 5.284882 9.581043 10.204620 5.664970 1.0 0.0 0.0 1.0 85.940002 95.733330 1
3 6.654017 5.341846 8.646379 5.655888 0.0 0.0 0.0 0.0 66.910004 239.300003 0
4 5.456747 5.339741 10.555724 6.008429 1.0 0.0 0.0 1.0 67.849998 56.933334 1
5 5.425826 6.331182 10.455145 5.749053 1.0 1.0 0.0 1.0 70.519997 123.533333 0

Feature transforms

We have 9 covariates, in addition to the durations and event indicators.

We will standardize the 5 numerical covariates, and leave the binary variables as is. As variables needs to be of type 'float32', as this is required by pytorch.

In [6]:
cols_standardize = ['x0', 'x1', 'x2', 'x3', 'x8']
cols_leave = ['x4', 'x5', 'x6', 'x7']

standardize = [([col], StandardScaler()) for col in cols_standardize]
leave = [(col, None) for col in cols_leave]

x_mapper = DataFrameMapper(standardize + leave)
In [7]:
x_train = x_mapper.fit_transform(df_train).astype('float32')
x_val = x_mapper.transform(df_val).astype('float32')
x_test = x_mapper.transform(df_test).astype('float32')

We need no label transforms

In [8]:
get_target = lambda df: (df['duration'].values, df['event'].values)
y_train = get_target(df_train)
y_val = get_target(df_val)
durations_test, events_test = get_target(df_test)
val = tt.tuplefy(x_val, y_val)
In [9]:
val.shapes()
Out[9]:
((305, 9), ((305,), (305,)))

With TupleTree (the results of tt.tuplefy) we can easily repeat the validation dataset multiple times. This will be useful for reduce the variance of the validation loss, as the validation loss of CoxCC is not deterministic.

In [10]:
val.repeat(2).cat().shapes()
Out[10]:
((610, 9), ((610,), (610,)))

Neural net

We create a simple MLP with two hidden layers, ReLU activations, batch norm and dropout. Here, we just use the torchtuples.practical.MLPVanilla net to do this.

Note that we set out_features to 1, and that we have not output_bias.

In [11]:
in_features = x_train.shape[1]
num_nodes = [32, 32]
out_features = 1
batch_norm = True
dropout = 0.1
output_bias = False

net = tt.practical.MLPVanilla(in_features, num_nodes, out_features, batch_norm,
                              dropout, output_bias=output_bias)

Training the model

To train the model we need to define an optimizer. You can choose any torch.optim optimizer, but here we instead use one from tt.optim as it has some added functionality. We use the Adam optimizer, but instead of choosing a learning rate, we will use the scheme proposed by Smith 2017 to find a suitable learning rate with model.lr_finder. See this post for an explanation.

In [12]:
model = CoxCC(net, tt.optim.Adam)
In [13]:
batch_size = 256
lrfinder = model.lr_finder(x_train, y_train, batch_size, tolerance=2)
_ = lrfinder.plot()
In [14]:
lrfinder.get_best_lr()
Out[14]:
0.1873817422860396

Often, this learning rate is a little high, so we instead set it manually to 0.01

In [15]:
model.optimizer.set_lr(0.01)

We include the EarlyStopping callback to stop training when the validation loss stops improving. After training, this callback will also load the best performing model in terms of validation loss.

In [16]:
epochs = 512
callbacks = [tt.callbacks.EarlyStopping()]
verbose = True
In [17]:
%%time
log = model.fit(x_train, y_train, batch_size, epochs, callbacks, verbose,
                val_data=val.repeat(10).cat())
0:	[0s / 0s],		train_loss: 0.7162,	val_loss: 0.6546
1:	[0s / 0s],		train_loss: 0.6317,	val_loss: 0.6570
2:	[0s / 0s],		train_loss: 0.6778,	val_loss: 0.6638
3:	[0s / 0s],		train_loss: 0.6556,	val_loss: 0.6440
4:	[0s / 0s],		train_loss: 0.6288,	val_loss: 0.6493
5:	[0s / 0s],		train_loss: 0.6078,	val_loss: 0.6377
6:	[0s / 0s],		train_loss: 0.6308,	val_loss: 0.6464
7:	[0s / 0s],		train_loss: 0.6238,	val_loss: 0.6464
8:	[0s / 0s],		train_loss: 0.6239,	val_loss: 0.6481
9:	[0s / 0s],		train_loss: 0.5940,	val_loss: 0.6544
10:	[0s / 0s],		train_loss: 0.6091,	val_loss: 0.6639
11:	[0s / 0s],		train_loss: 0.6085,	val_loss: 0.6455
12:	[0s / 0s],		train_loss: 0.5987,	val_loss: 0.6498
13:	[0s / 0s],		train_loss: 0.5791,	val_loss: 0.6623
14:	[0s / 0s],		train_loss: 0.5931,	val_loss: 0.6507
15:	[0s / 0s],		train_loss: 0.5994,	val_loss: 0.6626
CPU times: user 1.68 s, sys: 69.2 ms, total: 1.75 s
Wall time: 813 ms
In [18]:
_ = log.plot()

We can get the partial log-likelihood

In [19]:
model.partial_log_likelihood(*val).mean()
Out[19]:
-4.979775905609131

Prediction

For evaluation we first need to obtain survival estimates for the test set. This can be done with model.predict_surv which returns an array of survival estimates, or with model.predict_surv_df which returns the survival estimates as a dataframe.

However, as CoxCC is semi-parametric, we first need to get the non-parametric baseline hazard estimates with compute_baseline_hazards.

Note that for large datasets the sample argument can be used to estimate the baseline hazard on a subset.

In [20]:
_ = model.compute_baseline_hazards()
In [21]:
surv = model.predict_surv_df(x_test)
In [22]:
surv.iloc[:, :5].plot()
plt.ylabel('S(t | x)')
_ = plt.xlabel('Time')

Evaluation

We can use the EvalSurv class for evaluation the concordance, brier score and binomial log-likelihood. Setting censor_surv='km' means that we estimate the censoring distribution by Kaplan-Meier on the test set.

In [23]:
ev = EvalSurv(surv, durations_test, events_test, censor_surv='km')
In [24]:
ev.concordance_td()
Out[24]:
0.6558738769282929
In [25]:
time_grid = np.linspace(durations_test.min(), durations_test.max(), 100)
_ = ev.brier_score(time_grid).plot()
In [26]:
ev.integrated_brier_score(time_grid)
Out[26]:
0.1667853479188019
In [27]:
ev.integrated_nbll(time_grid)
Out[27]:
0.4971520413959927
In [ ]: