Author: *Mononito Goswami* <mgoswami@cs.cmu.edu>
Cox Mixture with Heterogenous Effects (CMHE) is a flexible approach to recover counterfactual phenotypes of individuals that demonstrate heterogneous effects to an intervention in terms of censored Time-to-Event outcomes. CMHE is not restricted by the strong Cox Proportional Hazards assumption or any parametric assumption on the time to event distributions. CMHE achieves this by describing each individual as belonging to two different latent groups, $Z$ that mediate the base survival rate and $\phi$ the effect of the treatment. CMHE can also be employed to model individual level counterfactuals or for standard factual survival regression.
Figure B (Right): CMHE in Plate Notation. $\mathbf{x}$ confounds treatment assignment $A$ and outcome $T$ (Model parameters and censoring distribution have been abstracted out).
For full details on Cox Mixtures with Heterogenous Effects, please refer to our preprint:
import pandas as pd
import torch
from tqdm import tqdm
import sys
sys.path.append('../')
from auton_survival.datasets import load_dataset
from cmhe_demo_utils import *
# Load the synthetic dataset
outcomes, features, interventions = load_dataset(dataset='SYNTHETIC')
# Let's take a look at take the dataset
features.head(5)
plot_synthetic_data(outcomes, features, interventions)
# Hyper-parameters
random_seed = 0
test_size = 0.25
# Split the synthetic data into training and testing data
import numpy as np
np.random.seed(random_seed)
n = features.shape[0]
test_idx = np.zeros(n).astype('bool')
test_idx[np.random.randint(n, size=int(n*test_size))] = True
features_tr = features.iloc[~test_idx]
outcomes_tr = outcomes.iloc[~test_idx]
interventions_tr = interventions[~test_idx]
print(f'Number of training data points: {len(features_tr)}')
features_te = features.iloc[test_idx]
outcomes_te = outcomes.iloc[test_idx]
interventions_te = interventions[test_idx]
print(f'Number of test data points: {len(features_te)}')
x_tr = features_tr.values.astype('float32')
t_tr = outcomes_tr['time'].values.astype('float32')
e_tr = outcomes_tr['event'].values.astype('float32')
a_tr = interventions_tr.values.astype('float32')
x_te = features_te.values.astype('float32')
t_te = outcomes_te['time'].values.astype('float32')
e_te = outcomes_te['event'].values.astype('float32')
a_te = interventions_te.values.astype('float32')
print('Training Data Statistics:')
print(f'Shape of covariates: {x_tr.shape} | times: {t_tr.shape} | events: {e_tr.shape} | interventions: {a_tr.shape}')
def find_max_treatment_effect_phenotype(g, zeta_probs, factual_outcomes):
"""
Find the group with the maximum treatement effect phenotype
"""
mean_differential_survival = np.zeros(zeta_probs.shape[1]) # Area under treatment phenotype group
outcomes_train, interventions_train = factual_outcomes
# Assign each individual to their treatment phenotype group
for gr in range(g): # For each treatment phenotype group
# Probability of belonging the the g^th treatment phenotype
zeta_probs_g = zeta_probs[:, gr]
# Consider only those individuals who are in the top 75 percentiles in this phenotype
z_mask = zeta_probs_g>np.quantile(zeta_probs_g, 0.75)
mean_differential_survival[gr] = find_mean_differential_survival(
outcomes_train.loc[z_mask], interventions_train.loc[z_mask])
return np.nanargmax(mean_differential_survival)
# Hyper-parameters to train model
k = 1 # number of underlying base survival phenotypes
g = 2 # number of underlying treatment effect phenotypes.
layers = [50, 50] # number of neurons in each hidden layer.
random_seed = 10
iters = 100 # number of training epochs
learning_rate = 0.01
batch_size = 256
vsize = 0.15 # size of the validation split
patience = 3
optimizer = "Adam"
from auton_survival.models.cmhe import DeepCoxMixturesHeterogenousEffects
torch.manual_seed(random_seed)
np.random.seed(random_seed)
# Instantiate the CMHE model
model = DeepCoxMixturesHeterogenousEffects(random_seed=random_seed, k=k, g=g, layers=layers)
model = model.fit(x_tr, t_tr, e_tr, a_tr, vsize=vsize, val_data=None, iters=iters,
learning_rate=learning_rate, batch_size=batch_size,
optimizer=optimizer, patience=patience)
print(f'Treatment Effect for the {g} groups: {model.torch_model[0].omega.detach()}')
zeta_probs_train = model.predict_latent_phi(x_tr)
zeta_train = np.argmax(zeta_probs_train, axis=1)
print(f'Distribution of individuals in each treatement phenotype in the training data: \
{np.unique(zeta_train, return_counts=True)[1]}')
max_treat_idx_CMHE = find_max_treatment_effect_phenotype(
g=2, zeta_probs=zeta_probs_train, factual_outcomes=(outcomes_tr, interventions_tr))
print(f'\nGroup {max_treat_idx_CMHE} has the maximum restricted mean survival time on the training data!')
# Now for each individual in the test data, let's find the probability that
# they belong to the max treatment effect group
zeta_probs_test_CMHE = model.predict_latent_phi(x_te)
zeta_test = np.argmax(zeta_probs_test_CMHE, axis=1)
print(f'Distribution of individuals in each treatement phenotype in the test data: \
{np.unique(zeta_test, return_counts=True)[1]}')
# Now let us evaluate our performance
plot_phenotypes_roc(outcomes_te, zeta_probs_test_CMHE[:, max_treat_idx_CMHE])
We compare the ability of CMHE against dimensionality reduction followed by clustering for counterfactual phenotyping. Specifically, we first perform dimensionality reduction of the input confounders, $\mathbf{x}$, followed by clustering. Due to a small number of confounders in the synthetic data, in the following experiment, we directly perform clustering using a Gaussian Mixture Model (GMM) with 2 components and diagonal covariance matrices.
from phenotyping import ClusteringPhenotyper
from sklearn.metrics import auc
clustering_method = 'gmm'
dim_red_method = None # We would not perform dimensionality reduction for the synthetic dataset
n_components = None
n_clusters = 2 # Number of underlying treatment effect phenotypes
# Running the phenotyper
phenotyper = ClusteringPhenotyper(clustering_method=clustering_method,
dim_red_method=dim_red_method,
n_components=n_components,
n_clusters=n_clusters,
random_seed=36)
zeta_probs_train = phenotyper.fit(features_tr.values).predict_proba(features_tr.values)
zeta_train = phenotyper.fit_predict(features_tr.values)
print(f'Distribution of individuals in each treatement phenotype in the training data: \
{np.unique(zeta_train, return_counts=True)[1]}')
max_treat_idx_CP = find_max_treatment_effect_phenotype(
g=2, zeta_probs=zeta_probs_train, factual_outcomes=(outcomes_tr, interventions_tr))
print(f'\nGroup {max_treat_idx_CP} has the maximum restricted mean survival time on the training data!')
# Now for each individual in the test data, let's find the probability that
# they belong to the max treatment effect group
# Use the phenotyper trained on training data to phenotype on testing data
zeta_probs_test_CP = phenotyper.predict_proba(x_te)
zeta_test_CP = np.argmax(zeta_probs_test_CP, axis=1)
print(f'Distribution of individuals in each treatement phenotype in the test data: \
{np.unique(zeta_test_CP, return_counts=True)[1]}')
# Now let us evaluate our performance
plot_phenotypes_roc(outcomes_te, zeta_probs_test_CP[:, max_treat_idx_CP])
For completeness, we further evaluate the performance of CMHE in estimating factual risk over multiple time horizons using the standard survival analysis metrics, including:
We compute the censoring adjusted estimates of the Time Dependent Concordance Index (Antolini et al., 2005; Gerds et al., 2013) and the Integrated Brier Score (i.e. Brier Score integrated over 1, 3 and 5 years. $\text{IBS} = \mathop{\sum}_t \frac{t}{t_\text{max}} \cdot \text{BS}(t)$) (Gerds and Schumacher, 2006; Graf et al., 1999) to assess both discriminative performance and model calibration at multiple time horizons.
We find that CMHE had similar or better discriminative performance than a simple Cox Model with MLP hazard functions. CMHE was also better calibrated as evidenced by overall lower Integrated Brier Score, suggesting utility for factual risk estimation.
horizons = [1, 3, 5]
# Now let us predict survival using CMHE
predictions_test_CMHE = model.predict_survival(x_te, a_te, t=horizons)
CI1, CI3, CI5, IBS = factual_evaluate((x_tr, t_tr, e_tr, a_tr), (x_te, t_te, e_te, a_te),
horizons, predictions_test_CMHE)
print(f'Concordance Index (1 Year): {np.around(CI1, 4)} (3 Year) {np.around(CI3, 4)}: (5 Year): {np.around(CI5, 4)}')
print(f'Integrated Brier Score: {np.around(IBS, 4)}')
from auton_survival.estimators import SurvivalModel
# Now let us train a Deep Cox-proportional Hazard model with two linear layers and tanh activations
dcph_model = SurvivalModel('dcph', random_seed=0, bs=100, learning_rate=1e-3, layers=[50, 50])
interventions_tr.name, interventions_te.name = 'treat', 'treat'
features_tr_dcph = pd.concat([features_tr, interventions_tr.astype('float64')], axis=1)
features_te_dcph = pd.concat([features_te, interventions_te.astype('float64')], axis=1)
outcomes_tr_dcph = pd.DataFrame(outcomes_tr, columns=['event', 'time']).astype('float64')
# Train the DCPH model
dcph_model = dcph_model.fit(features_tr_dcph, outcomes_tr_dcph)
# Find suvival scores in the test data
predictions_test_DCPH = dcph_model.predict_survival(features_te_dcph, horizons)
CI1, CI3, CI5, IBS = factual_evaluate((x_tr, t_tr, e_tr, a_tr), (x_te, t_te, e_te, a_te),
horizons, predictions_test_DCPH)
print(f'Concordance Index (1 Year): {np.around(CI1, 4)} (3 Year) {np.around(CI3, 4)}: (5 Year): {np.around(CI5, 4)}')
print(f'Integrated Brier Score: {np.around(IBS, 4)}')
features