Recently during an interview to promote the upcoming PyMCon Web Series (for which I am mentoring one of the first round of presenters), my friend Ravin Kumar kindly mentioned that a 2015 post of mine, Bayesian Survival Analysis in Python with PyMC3, drove a lot of his early interest in PyMC. Ravin's comment caused me to revisit the post and realize how out of date it is. In the six and a half years since it was published, there have been advances along many fronts:
For all of these reasons, I thought it would be fun for me and perhap helpful to others to revisit this post and rework the example with updated techniques.
First we make the necessary Python imports and do some light configuration.
%matplotlib inline
from multipledispatch.dispatcher import AmbiguityWarning
from warnings import filterwarnings
import arviz as az
import lifelines as ll
from matplotlib import pyplot as plt
from matplotlib.ticker import NullLocator
import numpy as np
import nutpie
import pandas as pd
import pymc as pm
from pytensor import tensor as pt
import seaborn as sns
from seaborn import objects as so
from statsmodels.datasets import get_rdataset
filterwarnings("ignore", category=AmbiguityWarning)
filterwarnings("ignore", module="pymc", category=FutureWarning)
plt.rc("figure", figsize=(8, 6))
sns.set(color_codes=True)
We begin by loading the mastectomy
dataset from the HSAUR
R package.
df = get_rdataset("mastectomy", package="HSAUR", cache=True).data
df["metastized"] = df["metastized"] == "yes"
df.head()
time | event | metastized | |
---|---|---|---|
0 | 23 | True | False |
1 | 47 | True | False |
2 | 69 | True | False |
3 | 70 | False | False |
4 | 100 | False | False |
df.tail()
time | event | metastized | |
---|---|---|---|
39 | 162 | False | True |
40 | 188 | False | True |
41 | 212 | False | True |
42 | 217 | False | True |
43 | 225 | False | True |
From the HSAUR
documentation, this dataset represents
[s]urvival times in months after mastectomy of women with breast cancer. The cancers are classified as having metastized or not based on a histochemical marker.
In this data frame,
time
is the number of months since mastectomy,event
indicates whether the woman died at the corresponding time
(if True
) or the observation was censored (if False
). In the context of survival analysis, censoring means that the woman survived past the corresponding time, but that her death was not observed. Censoring (and its counterpart truncation) represents a fundamental challenge in survival analysis, andmetastized
indicates whether the woman's cancer had metastized.The following plot shows the time-to-event for each patient, along with whether or not the cancer had metastized and whether or not the patient's death was observed.
MONTHS_LABEL = "Months since mastectomy"
sorted_df = df.sort_values("time")
(so.Plot(sorted_df, x=0, y=np.arange(sorted_df.shape[0]),
color="event", linestyle="metastized")
.add(so.Dash(), so.Shift(x=sorted_df["time"] / 2), width="time")
.scale(y=so.Continuous().tick(locator=NullLocator()),
color=so.Nominal(), linestyle=so.Nominal())
.limit(x=(0, None))
.label(x=MONTHS_LABEL, y=None,
color=str.capitalize, linestyle=str.capitalize))
We see that metastization is highly correlated with both short lifetime post-mastectomy and actual death. The following contingency table confirms this observation.
df.pivot_table(values="time", index="metastized", columns="event",
aggfunc=np.size, margins=True)
event | False | True | All |
---|---|---|---|
metastized | |||
False | 7 | 5 | 12 |
True | 11 | 21 | 32 |
All | 18 | 26 | 44 |
When studying time-to-event data, especially in the presence of censoring (and/or truncation), survival analysis is the appropriate modeling framework. As indicated by the name, survival analysis focuses on estimating the survival function. If $T$ is the time-to-event in question, the survival function is
$$S(t) = \mathbb{P}(T \geq t).$$This focus on the survival function is important because for censored observations we only know that the time-to-event exceeds the recorded time. The naive approach of treating cesnored observations as if the event occured at the time of censoring risks systematically underestimating the true average survival time, as we will illustrate later in this post.
Instead of working directly with the survival function, it is convenient to phrase our models in terms of the cumulative hazard function.
$$\Lambda(t) = -\log S(t).$$From this definition we see that $\Lambda(t) \geq 0$, $\Lambda(0) = 0$, and $\Lambda$ is nondecreasing. Since we are working on a discrete timescale (we only know how many months the patient surived for, not exactly when they died or were censored), it is further convenient to decompose this cumulative hazard function into a sum of per-period hazards,
$$\Lambda(t) = \sum_{s \leq t} \lambda(s).$$In the continuous-time case, this sum is replaced with an appropriate integral.
With the cumulative hazard function decomposed as above, we can to introduce the Cox proportional hazards model. Given predictors $\mathbf{x}$, this model treats $\lambda(t)$ as a log linear model,
$$\log \lambda(t\ |\ \mathbf{x}) = \alpha(t) + \beta \cdot \mathbf{x}.$$This model is often expressed as
$$\lambda(t\ |\ \mathbf{x}) = \lambda_0(t) \cdot \exp(\beta \cdot \mathbf{x}),$$where $\lambda_0(t) = \exp(\alpha(t)).$ This model has "proportional hazards" because, if $\mathbb{y}$ is the predictors for another patient,
$$\frac{\lambda(t\ |\ \mathbf{x})}{\lambda(t\ |\ \mathbf{y})} = \frac{\lambda_0(t) \cdot \exp(\beta \cdot \mathbf{x})}{\lambda_0(t) \cdot \exp(\beta \cdot \mathbf{y})} = \exp(\beta \cdot (\mathbf{x} - \mathbf{y}))$$is independent of $t$.
With the mathematical form of our model specified, we can begin to implement it in Python using PyMC. First we extract the relevant columns from the data frame. In our case,
$$x_i = \begin{cases} 0 & \text{if the }i\text{-th patient's cancer has not metastized} \\ 1 & \text{if the }i\text{-th patient's cancer has metastized} \end{cases}.$$t = df["time"].values
event = df["event"].values
x = df["metastized"].values
We place a hierarchical normal prior on $\alpha_t$ and a normal prior on $\beta$.
# the scale necessary to make a halfnormal distribution have unit variance
HALFNORMAL_SCALE = 1 / np.sqrt(1 - 2 / np.pi)
def noncentered_normal(name, *, dims, μ=None):
if μ is None:
μ = pm.Normal(f"μ_{name}", 0, 2.5)
Δ = pm.Normal(f"Δ_{name}", 0, 1, dims=dims)
σ = pm.HalfNormal(f"σ_{name}", 2.5 * HALFNORMAL_SCALE)
return pm.Deterministic(name, μ + Δ * σ, dims=dims)
coords = {"metastized": np.array([False, True]), "time": np.arange(t.max() + 1)}
with pm.Model(coords=coords) as model:
α = noncentered_normal("α", dims="time")
β = pm.Normal("β", 0, 2.5)
We then define $\lambda(t) = \exp(\alpha(t) + \beta \cdot x).$
with model:
λ = pt.exp(α[np.newaxis] + β * x[:, np.newaxis])
We could proceed by building
$$S(t\ |\ x) = \exp\left(-\sum_{s \leq t} \lambda(s\ |\ x)\right),$$and using the fact that
$$\mathbb{P}(\text{died at }t\ |\ x) = S(t + 1\ |\ x) - S(t\ |\ x)$$to tie the hazard function we have specified to our observations. A simpler approach uses a classic result (reference see §7.4.3) from the early 1980s, which is that this model is equivalent to the following Poisson regression model.
Let
$$d_{i, t} = \begin{cases} 1 & \text{if the }i\text{-th patient died in period }t \\ 0 & \text{otherwise} \end{cases}$$indicate if the patient in question died in the $t$-th period and
$$e_{i, t} = \begin{cases} 1 & \text{if the }i\text{-th patient was alive at the beginning of period }t \\ 0 & \text{otherwise} \end{cases}$$indicate if the patient was alive (exposed) at the beginning of the $t$-th period. The Cox model is equivalent to a Poisson model for $d_{i, t}$ with mean $e_{i, t} \cdot \lambda(t\ |\ x_i).$
We now construct the arrays measuring exposure and indicating death.
exposed = np.full((df.shape[0], t.max() + 1), True, dtype=np.bool_)
np.put_along_axis(exposed, t[:, np.newaxis], False, axis=1)
exposed = np.minimum.accumulate(exposed, axis=1)
event_ = np.full_like(exposed, False, dtype=np.bool_)
np.put_along_axis(event_, t[:, np.newaxis] - 1, event[:, np.newaxis], axis=1)
We now add the observed Poisson-distributed event indicator to the model.
with model:
pm.Poisson("event", exposed * λ, observed=event_)
Before we sample from the model's posterior distribution, we define a series of quantities that will allow us to easily visualize the posterior predictive survival functions for patients whose cancer had or had not metastized. We do not use PyMC's built-in posterior predictive sampling method here due to the way we constructed the observed variable using the equivalent Poisson model. Adding these auxiliary quantities is more straightforward in this case.
with model:
λ_pred = pt.exp(α[np.newaxis] + β * np.array([[0, 1]]).T)
Λ_pred = λ_pred.cumsum(axis=1)
sf_pred = pm.Deterministic("sf_pred", pt.exp(-Λ_pred), dims=("metastized", "time"))
Now we are ready to use nutpie to sample from the model's posterior distribution.
SEED = 123456789
trace = nutpie.sample(
nutpie.compile_pymc_model(model),
seed=SEED
)
With only a few divergences, we check the Gelman-Rubin $\hat{R}$ statistics to see if there were any obvious issues with sampling.
az.rhat(trace).max()
<xarray.Dataset> Dimensions: () Data variables: μ_α float64 1.003 Δ_α float64 1.005 σ_α_log__ float64 1.001 β float64 1.002 σ_α float64 1.001 α float64 1.003 sf_pred float64 1.003
All of the $\hat{R}$ statistics are below 1.01, so we see no obvious issues with sampling.
We now use seaborn to visualize the posterior predictive survival functions for patients in either state of metastization.
ALPHA = 0.05
so_ci = so.Perc([100 * ALPHA / 2, 100 * (1 - ALPHA / 2)])
PP_SF_LABEL = "Posterior predictive\nsurvival function"
(so.Plot(trace.posterior["sf_pred"].to_dataframe(),
x="time", y="sf_pred", color="metastized")
.add(so.Line(), so.Agg())
.add(so.Band(), so_ci)
.scale(color=so.Nominal())
.limit(x=(0, t.max()), y=(0, 1))
.label(x=MONTHS_LABEL, y=PP_SF_LABEL, color=str.capitalize))
These posterior predictive survival functions look plausible; to confirm their correctness we compare them to predictions from Cox models fit using the lifetimes package.
ll_model = ll.CoxPHFitter().fit(df, duration_col="time", event_col="event")
pred_df = pd.merge(
trace.posterior["sf_pred"].to_dataframe(),
ll_model.predict_survival_function(np.array([[False, True]]).T, times=np.arange(t.max() + 1))
.rename(columns=bool)
.rename_axis("time", axis=0)
.rename_axis("metastized", axis=1)
.stack()
.rename("ll_pred"),
left_index=True, right_index=True
)
The following plot adds the predictions from lifetimes to the posterior predictions from our model.
(so.Plot(pred_df,
x="time", y="sf_pred", color="metastized")
.add(so.Line(), so.Agg())
.add(so.Band(), so_ci)
.add(so.Line(linestyle="--"), so.Agg(func=lambda x: x[0]),
y="ll_pred")
.scale(color=so.Nominal())
.limit(x=(0, t.max()), y=(0, 1))
.label(x=MONTHS_LABEL, y=PP_SF_LABEL, color=str.capitalize))