from __future__ import print_function, division import nsfg import survival import thinkstats2 import thinkplot import pandas import numpy from lifelines import KaplanMeierFitter from collections import defaultdict import matplotlib.pyplot as pyplot %matplotlib inline preg = nsfg.ReadFemPreg() complete = preg.query('outcome in [1, 3, 4]').prglngth cdf = thinkstats2.Cdf(complete, label='cdf') sf = survival.SurvivalFunction(cdf, label='survival') thinkplot.Plot(sf) thinkplot.Config(xlabel='duration (weeks)', ylabel='survival function') #thinkplot.Save(root='survival_talk1', formats=['png']) hf = sf.MakeHazard(label='hazard') thinkplot.Plot(hf) thinkplot.Config(xlabel='duration (weeks)', ylabel='hazard function', ylim=[0, 0.75], loc='upper left') #thinkplot.Save(root='survival_talk2', formats=['png']) rem_life = sf.RemainingLifetime() thinkplot.Plot(rem_life) thinkplot.Config(xlabel='weeks', ylabel='mean remaining weeks', legend=False) #thinkplot.Save(root='survival_talk3', formats=['png']) resp = survival.ReadFemResp2002() len(resp) complete = resp[resp.evrmarry == 1].agemarry ongoing = resp[resp.evrmarry == 0].age nan = complete[numpy.isnan(complete)] len(nan) hf = survival.EstimateHazardFunction(complete, ongoing) sf = hf.MakeSurvival() thinkplot.Plot(hf) thinkplot.Config(xlabel='age (years)', ylabel='hazard function', legend=False) #thinkplot.Save(root='survival_talk4', formats=['png']) thinkplot.Plot(sf) thinkplot.Config(xlabel='age (years)', ylabel='prob unmarried', ylim=[0, 1], legend=False) #thinkplot.Save(root='survival_talk5', formats=['png']) ss = sf.ss end_ss = ss[-1] prob_marry44 = (ss - end_ss) / ss thinkplot.Plot(sf.ts, prob_marry44) thinkplot.Config(xlabel='age (years)', ylabel='prob marry before 44', ylim=[0, 1], legend=False) #thinkplot.Save(root='survival_talk6', formats=['png']) func = lambda pmf: pmf.Percentile(50) rem_life = sf.RemainingLifetime(filler=numpy.inf, func=func) thinkplot.Plot(rem_life) thinkplot.Config(ylim=[0, 15], xlim=[11, 31], xlabel='age (years)', ylabel='median remaining years') #thinkplot.Save(root='survival_talk7', formats=['png']) resp['event_times'] = resp.age resp['event_times'][resp.evrmarry == 1] = resp.agemarry len(resp) cleaned = resp.dropna(subset=['event_times']) len(cleaned) kmf = KaplanMeierFitter() kmf.fit(cleaned.event_times, cleaned.evrmarry) thinkplot.Plot(sf) thinkplot.Config(xlim=[0, 45], legend=False) pyplot.grid() kmf.survival_function_.plot() complete = [1, 2, 3] ongoing = [2.5, 3.5] hf = survival.EstimateHazardFunction(complete, ongoing) hf.series sf = hf.MakeSurvival() sf.ts, sf.ss T = pandas.Series(complete + ongoing) E = [1, 1, 1, 0, 0] kmf = KaplanMeierFitter() kmf.fit(T, E) kmf.survival_function_ resp5 = survival.ReadFemResp1995() resp6 = survival.ReadFemResp2002() resp7 = survival.ReadFemResp2010() resp8 = survival.ReadFemResp2013() def EstimateSurvival(resp): """Estimates the survival curve. resp: DataFrame of respondents returns: pair of HazardFunction, SurvivalFunction """ complete = resp[resp.evrmarry == 1].agemarry ongoing = resp[resp.evrmarry == 0].age hf = survival.EstimateHazardFunction(complete, ongoing) sf = hf.MakeSurvival() return hf, sf def ResampleSurvivalByDecade(resps, iters=101, predict_flag=False, omit=[]): """Makes survival curves for resampled data. resps: list of DataFrames iters: number of resamples to plot predict_flag: whether to also plot predictions returns: map from group name to list of survival functions """ sf_map = defaultdict(list) # iters is the number of resampling runs to make for i in range(iters): # we have to resample the data from each cycles separately samples = [thinkstats2.ResampleRowsWeighted(resp) for resp in resps] # then join the cycles into one big sample sample = pandas.concat(samples, ignore_index=True) for decade in omit: sample = sample[sample.decade != decade] # group by decade grouped = sample.groupby('decade') # and estimate (hf, sf) for each group hf_map = grouped.apply(lambda group: EstimateSurvival(group)) if predict_flag: MakePredictionsByDecade(hf_map) # extract the sf from each pair and acculumulate the results for name, (hf, sf) in hf_map.iteritems(): sf_map[name].append(sf) return sf_map def MakePredictionsByDecade(hf_map, **options): """Extends a set of hazard functions and recomputes survival functions. For each group in hf_map, we extend hf and recompute sf. hf_map: map from group name to (HazardFunction, SurvivalFunction) """ # TODO: this only works if the names and values are in increasing order, # which is true when hf_map is a GroupBy object, but not generally # true for maps. names = hf_map.index.values hfs = [hf for (hf, sf) in hf_map.values] # extend each hazard function using data from the previous cohort, # and update the survival function for i, hf in enumerate(hfs): if i > 0: hf.Extend(hfs[i-1]) sf = hf.MakeSurvival() hf_map[names[i]] = hf, sf def MakeSurvivalCI(sf_seq, percents): # find the union of all ts where the sfs are evaluated ts = set() for sf in sf_seq: ts |= set(sf.ts) ts = list(ts) ts.sort() # evaluate each sf at all times ss_seq = [sf.Probs(ts) for sf in sf_seq] # return the requested percentiles from each column rows = thinkstats2.PercentileRows(ss_seq, percents) return ts, rows resps = [resp5, resp6, resp7, resp8] sf_map = ResampleSurvivalByDecade(resps) resps = [resp5, resp6, resp7, resp8] sf_map_pred = ResampleSurvivalByDecade(resps, predict_flag=True) def PlotSurvivalFunctionByDecade(sf_map, predict_flag=False): thinkplot.PrePlot(len(sf_map)) for name, sf_seq in sorted(sf_map.iteritems(), reverse=True): ts, rows = MakeSurvivalCI(sf_seq, [10, 50, 90]) thinkplot.FillBetween(ts, rows[0], rows[2], color='gray') if predict_flag: thinkplot.Plot(ts, rows[1], color='gray') else: thinkplot.Plot(ts, rows[1], label='%d0s'%name) thinkplot.Config(xlabel='age(years)', ylabel='prob unmarried', xlim=[15, 45], ylim=[0, 1], legend=True, loc='upper right') PlotSurvivalFunctionByDecade(sf_map) #thinkplot.Save(root='survival_talk8', formats=['png']) PlotSurvivalFunctionByDecade(sf_map_pred, predict_flag=True) PlotSurvivalFunctionByDecade(sf_map) #thinkplot.Save(root='survival_talk9', formats=['png'])