#!/usr/bin/env python # coding: utf-8 # In[1]: get_ipython().run_line_magic('load_ext', 'autoreload') get_ipython().run_line_magic('autoreload', '2') get_ipython().run_line_magic('matplotlib', 'inline') import random random.seed(1100038344) import survivalstan import numpy as np import pandas as pd from stancache import stancache from matplotlib import pyplot as plt # ## Simulate survival data # In order to demonstrate the use of this model, we will first simulate some survival data using `survivalstan.sim.sim_data_exp_correlated`. As the name implies, this function simulates data assuming a constant hazard throughout the follow-up time period, which is consistent with the Exponential survival function. # # This function includes two simulated covariates by default (`age` and `sex`). We also simulate a situation where hazard is a function of the simulated value for `sex`. # # We also center the `age` variable since this will make it easier to interpret estimates of the baseline hazard. # # In[2]: d = stancache.cached( survivalstan.sim.sim_data_exp_correlated, N=100, censor_time=20, rate_form='1 + sex', rate_coefs=[-3, 0.5], ) d['age_centered'] = d['age'] - d['age'].mean() # *Aside: In order to make this a more reproducible example, this code is using a file-caching function `stancache.cached` to wrap a function call to `survivalstan.sim.sim_data_exp_correlated`. * # ## Explore simulated data # Here is what these data look like - this is `per-subject` or `time-to-event` form: # In[3]: d.head() # *It's not that obvious from the field names, but in this example "subjects" are indexed by the field `index`.* # We can plot these data using `lifelines`, or the rudimentary plotting functions provided by `survivalstan`. # In[4]: survivalstan.utils.plot_observed_survival(df=d[d['sex']=='female'], event_col='event', time_col='t', label='female') survivalstan.utils.plot_observed_survival(df=d[d['sex']=='male'], event_col='event', time_col='t', label='male') plt.legend() # ## model1: original spec # In[5]: model_code = ''' functions { // Defines the log survival vector log_S (vector t, real shape, vector rate) { vector[num_elements(t)] log_S; for (i in 1:num_elements(t)) { log_S[i] = gamma_lccdf(t[i]|shape,rate[i]); } return log_S; } // Defines the log hazard vector log_h (vector t, real shape, vector rate) { vector[num_elements(t)] log_h; vector[num_elements(t)] ls; ls = log_S(t,shape,rate); for (i in 1:num_elements(t)) { log_h[i] = gamma_lpdf(t[i]|shape,rate[i]) - ls[i]; } return log_h; } // Defines the sampling distribution real surv_gamma_lpdf (vector t, vector d, real shape, vector rate) { vector[num_elements(t)] log_lik; real prob; log_lik = d .* log_h(t,shape,rate) + log_S(t,shape,rate); prob = sum(log_lik); return prob; } } data { int N; // number of observations vector[N] y; // observed times vector[N] event; // censoring indicator (1=observed, 0=censored) int M; // number of covariates matrix[N, M] x; // matrix of covariates (with n rows and H columns) } parameters { vector[M] beta; // Coefficients in the linear predictor (including intercept) real alpha; // shape parameter } transformed parameters { vector[N] linpred; vector[N] mu; linpred = x*beta; for (i in 1:N) { mu[i] = exp(linpred[i]); } } model { alpha ~ gamma(0.01,0.01); beta ~ normal(0,5); y ~ surv_gamma(event, alpha, mu); } ''' # Now, we are ready to fit our model using `survivalstan.fit_stan_survival_model`. # # We pass a few parameters to the fit function, many of which are required. See ?survivalstan.fit_stan_survival_model for details. # # Similar to what we did above, we are asking `survivalstan` to cache this model fit object. See [stancache](http://github.com/jburos/stancache) for more details on how this works. Also, if you didn't want to use the cache, you could omit the parameter `FIT_FUN` and `survivalstan` would use the standard pystan functionality. # # In[6]: testfit = survivalstan.fit_stan_survival_model( model_cohort = 'model 1', model_code = model_code, df = d, time_col = 't', event_col = 'event', formula = '~ age_centered + sex', iter = 5000, chains = 4, seed = 9001, FIT_FUN = stancache.cached_stan_fit, drop_intercept = False, ) # In[7]: # 0:01:33.270480 elapsed # In[8]: survivalstan.utils.print_stan_summary([testfit], pars=['lp__', 'alpha', 'beta']) # ## model2: alternate version of surv_gamma_lpdf # In[9]: model_code2 = ''' functions { // Defines the log survival real surv_gamma_lpdf (vector t, vector d, real shape, vector rate) { vector[num_elements(t)] log_lik; real prob; for (i in 1:num_elements(t)) { log_lik[i] = d[i] * (gamma_lpdf(t[i]|shape,rate[i]) - gamma_lccdf(t[i]|shape,rate[i])) + gamma_lccdf(t[i]|shape,rate[i]); } prob = sum(log_lik); return prob; } } data { int N; // number of observations vector[N] y; // observed times vector[N] event; // censoring indicator (1=observed, 0=censored) int M; // number of covariates matrix[N, M] x; // matrix of covariates (with n rows and H columns) } parameters { vector[M] beta; // Coefficients in the linear predictor (including intercept) real alpha; // shape parameter } transformed parameters { vector[N] mu; { vector[N] linpred; linpred = x*beta; mu = exp(linpred); } } model { alpha ~ gamma(0.01,0.01); beta ~ normal(0,5); y ~ surv_gamma(event, alpha, mu); } ''' # In[10]: testfit2 = survivalstan.fit_stan_survival_model( model_cohort = 'model 2', model_code = model_code2, df = d, time_col = 't', event_col = 'event', formula = '~ age_centered + sex', iter = 5000, chains = 4, seed = 9001, FIT_FUN = stancache.cached_stan_fit, drop_intercept = False, ) # In[11]: # 0:01:20.742172 elapsed # In[12]: survivalstan.utils.print_stan_summary([testfit2], pars=['lp__', 'alpha', 'beta']) # ## model3: use `log_mix` inside surv_gamma_lpdf # In[13]: model_code3 = ''' functions { // Defines the log survival real surv_gamma_lpdf (vector t, vector d, real shape, vector rate) { vector[num_elements(t)] log_lik; real prob; for (i in 1:num_elements(t)) { log_lik[i] = log_mix(d[i], gamma_lpdf(t[i]|shape,rate[i]), gamma_lccdf(t[i]|shape,rate[i])); } prob = sum(log_lik); return prob; } } data { int N; // number of observations vector[N] y; // observed times vector[N] event; // censoring indicator (1=observed, 0=censored) int M; // number of covariates matrix[N, M] x; // matrix of covariates (with n rows and H columns) } parameters { vector[M] beta; // Coefficients in the linear predictor (including intercept) real alpha; // shape parameter } transformed parameters { vector[N] linpred; vector[N] mu; linpred = x*beta; mu = exp(linpred); } model { alpha ~ gamma(0.01,0.01); beta ~ normal(0,5); y ~ surv_gamma(event, alpha, mu); } ''' # In[14]: testfit3 = survivalstan.fit_stan_survival_model( model_cohort = 'model 3', model_code = model_code3, df = d, time_col = 't', event_col = 'event', formula = '~ age_centered + sex', iter = 5000, chains = 4, seed = 9001, FIT_FUN = stancache.cached_stan_fit, drop_intercept = False, ) # In[15]: # 0:00:42.036498 elapsed # In[16]: survivalstan.utils.print_stan_summary([testfit3], pars=['lp__', 'alpha', 'beta']) # ## model4: vectorize surv_gamma_lpdf # In[17]: model_code4 = ''' functions { int count_value(vector a, real val) { int s; s = 0; for (i in 1:num_elements(a)) if (a[i] == val) s = s + 1; return s; } // Defines the log survival real surv_gamma_lpdf (vector t, vector d, real shape, vector rate, int num_cens, int num_obs) { vector[2] log_lik; int idx_obs[num_obs]; int idx_cens[num_cens]; real prob; int i_cens; int i_obs; i_cens = 1; i_obs = 1; for (i in 1:num_elements(t)) { if (d[i] == 1) { idx_obs[i_obs] = i; i_obs = i_obs+1; } else { idx_cens[i_cens] = i; i_cens = i_cens+1; } } print(idx_obs); log_lik[1] = gamma_lpdf(t[idx_obs] | shape, rate[idx_obs]); log_lik[2] = gamma_lccdf(t[idx_cens] | shape, rate[idx_cens]); prob = sum(log_lik); return prob; } } data { int N; // number of observations vector[N] y; // observed times vector[N] event; // censoring indicator (1=observed, 0=censored) int M; // number of covariates matrix[N, M] x; // matrix of covariates (with n rows and H columns) } transformed data { int num_cens; int num_obs; num_obs = count_value(event, 1); num_cens = N - num_obs; } parameters { vector[M] beta; // Coefficients in the linear predictor (including intercept) real alpha; // shape parameter } transformed parameters { vector[N] linpred; vector[N] mu; linpred = x*beta; mu = exp(linpred); } model { alpha ~ gamma(0.01,0.01); beta ~ normal(0,5); y ~ surv_gamma(event, alpha, mu, num_cens, num_obs); } ''' # In[18]: testfit4 = survivalstan.fit_stan_survival_model( model_cohort = 'model 4', model_code = model_code4, df = d, time_col = 't', event_col = 'event', formula = '~ age_centered + sex', iter = 5000, chains = 4, seed = 9001, FIT_FUN = stancache.cached_stan_fit, drop_intercept = False, ) # In[19]: # 0:00:16.703755 elapsed # In[20]: survivalstan.utils.print_stan_summary([testfit4], pars=['lp__', 'alpha', 'beta']) # ## compare coefficient estimates for each model spec # In[21]: survivalstan.utils.plot_coefs([testfit, testfit2, testfit3, testfit4]) # In[ ]: # In[ ]: