#!/usr/bin/env python # coding: utf-8 # # Suppose 2 R0 # In[ ]: from pyro.infer.mcmc import MCMC, NUTS import matplotlib.pyplot as plt import numpy as np import torch import pyro import pyro.infer import pyro.optim import pyro.distributions as dist from torch.distributions import constraints import seaborn # In[197]: # Defining parameters as stochastic function # deaths in hospitals / total deaths def frac_dh(): return torch.tensor(3470. / 7594.) # fraction of hospitalized def hh(): return torch.tensor(0.05) # inverse recovery time def gamma(): return torch.tensor(1. / 12.4) # inverse incubation time def epsilon(): return torch.tensor(1. / 5.2) # fatality rate in icu def dea(): return torch.tensor(.5) # population size def n0(): return torch.tensor(11000000.) # population en MR/MRS + personnel soignant def n0_MRS(): return torch.tensor(400000.) # e0 = i0 * factor def e0_factor(): return torch.tensor(37.) # e0_MRS = i0_MRS * factor def e0_MRS_factor(): return torch.tensor(20.) # size of the window for fitting Re's def window(): return torch.tensor(6.) def i0(): #return pyro.sample("i0", dist.Poisson(10)) return torch.tensor(3.) def r0_model(r0_mean): return r0 def drea(): return dea()/ 5. def rrea(): return (1 - dea())/20. def hospi(): return torch.tensor(0.0) def gg(): return torch.tensor(.75) # In[198]: def SEIR(r0): # Initial conditions n = [n0()] # Population totale i = [i0()] # Symptomatiques e = [i[0] * e0_factor()] # Asymptomatiques h = [torch.tensor(.0)] # Lits occupés hopital l = [torch.tensor(.0)] # Lits occupés USI r = [torch.tensor(.0)] # Immunisés m = [torch.tensor(.0)] # Morts (totaux) s = [n[-1] - e[-1] - i[-1] - r[-1]] # Sains # Simulate forward n_days = len(r0) hospi = 0. for day in range(n_days): lam = gamma() * r0[day] if day == 14: hospi = hh() / 7 ds = -lam * (i[-1] / 2 + e[-1]) * s[-1] / n[-1] de = lam * (i[-1] / 2 + e[-1]) * s[-1] / n[-1] - epsilon() * e[-1] di = epsilon() * e[-1] - gamma() * i[-1] - hospi * i[-1] dh = hospi * i[-1] - gg() * h[-1] / 7 - (1 - gg()) * h[-1] / (4. + 2 * torch.tanh((l[-1]-500.)/300.)) dl = (1 - gg()) * h[-1] / (4 + 2 * torch.tanh((l[-1]-500)/300)) - drea() * l[-1] - rrea() * l[-1] dr = gamma() * i[-1] + rrea() * l[-1] + gg() * h[-1] / 7 dm = drea() * l[-1] s.append(s[-1] + ds) e.append(e[-1] + de) i.append(i[-1] + di) h.append(h[-1] + dh) l.append(l[-1] + dl) if l[-1] > 1895: dm = dm + (l[-1] - 1895) l[-1] = torch.tensor(1895.) r.append(r[-1] + dr) m.append(m[-1] + dm) n.append(s[-1] + e[-1] + i[-1] + h[-1] + l[-1] + r[-1]) return s, e, i, h, l, m, r def SEIR_MRS(r0_mrs, n_futures=0, window=6): # Smoothen and extend R0s # Initial conditions alpha = 0.15 / 10 lam = gamma() * 4.3 n = [n0_MRS()] i = [torch.tensor(1)] e = [i[-1] * e0_MRS_factor()] r = [torch.tensor(0.0)] s = [n[-1] - e[-1] - i[-1] - r[-1]] m = [torch.tensor(0.0)] # Simulate forward n_days = len(r0_mrs) for day in range(n_days): lam = gamma() * r0_mrs[day] ds = -lam * (i[-1] / 2 + e[-1]) * s[-1] / n[-1] de = lam * (i[-1] / 2 + e[-1]) * s[-1] / n[-1] - epsilon() * e[-1] di = epsilon() * e[-1] - (gamma() + alpha) * i[-1] dr = gamma() * i[-1] dm = alpha * i[-1] s.append(s[-1] + ds) e.append(e[-1] + de) i.append(i[-1] + di) r.append(r[-1] + dr) m.append(m[-1] + dm) n.append(s[-1] + e[-1] + i[-1] + r[-1]) return s, e, i, m, r # In[ ]: from covid19be import load_data data_df = load_data() n_days = len(data_df) # (source: [wikipedia covid-19 Belgique](https://fr.wikipedia.org/wiki/Pand%C3%A9mie_de_Covid-19_en_Belgique)) # - 11 Mars arrêt des visites en mrs. (en Wallonie uniquement) (13) # - 16 Mars plus de cours. (18) # - 18 Mars confinement. (20) # - 20 avril magasin de bricolage et pépinieriste ouvrent à nouveau. (22) # # Nos données débutent au 28 février # In[200]: date_r0_switch = 20 date_r0_switch_mrs = 13 nb_dirty_data = 4 n_useful_days = n_days - nb_dirty_data # In[205]: mrs = True noiser = lambda mu: dist.ZeroInflatedPoisson(torch.tensor(.0001), mu + 1) # In[216]: def SEIR_full_model_bis(ret_data=False): n_futures=0 # Simulate r0_1 = pyro.sample("r0_1", dist.Uniform(torch.tensor(0.), torch.tensor(6.))) r0_2 = pyro.sample("r0_2", dist.Uniform(torch.tensor(0.), torch.tensor(3.))) # R0 fluctuates around its means that are given by r0_1 and r0_2 r0 = torch.cat((dist.Uniform(torch.ones(date_r0_switch - 1) * r0_1 - .2, torch.ones(date_r0_switch - 1) * r0_1 + .2)(), dist.Uniform(torch.ones(n_useful_days - date_r0_switch + 1) * r0_2 - .2, torch.ones(n_useful_days - date_r0_switch + 1) * r0_2 + .2)())) s_T, e_T, i_T, h_T, l_T, m_T, r_T = SEIR(r0) #print(r0) if mrs: r0_mrs_1 = pyro.sample("r0_mrs_1", dist.Uniform(.0, 6.)) r0_mrs_2 = pyro.sample("r0_mrs_2", dist.Uniform(.0, 3.)) r0_mrs = torch.cat((dist.Uniform(torch.ones(date_r0_switch_mrs - 1) * r0_mrs_1 - .2, torch.ones(date_r0_switch_mrs - 1) * r0_mrs_1 + .2)(), dist.Uniform(torch.ones(n_useful_days - date_r0_switch_mrs + 1) * r0_mrs_2 - .2, torch.ones(n_useful_days - date_r0_switch_mrs + 1) * r0_mrs_2 + .2)())) _, _, _, m_mrs_T, _ = SEIR_MRS(r0_mrs) for idx in range(n_useful_days): if idx > 16: if h_T[idx] < 0: print(h_T[idx]) pyro.sample("h_%d" % idx, noiser(h_T[idx])) pyro.sample("l_%d" % idx, noiser(l_T[idx])) pyro.sample("m_%d" % idx, noiser(m_T[idx])) if mrs: pyro.sample("m_mrs_%d" % idx, noiser(m_mrs_T[idx])) if mrs and ret_data: return s_T, e_T, i_T, h_T, l_T, m_T, m_mrs_T, r_T, r0, r0_mrs if ret_data: return s_T, e_T, i_T, h_T, l_T, m_T, r_T, r0 # In[241]: data = {} hospi = data_df['n_hospitalized'] l = data_df['n_icu'] m = data_df['n_deaths'] for idx in range(n_useful_days): if idx > 16: data["h_%d" % idx] = torch.tensor(hospi[idx] - l[idx], dtype=torch.float) data["l_%d" % idx] = torch.tensor(l[idx], dtype=torch.float) data["m_%d" % idx] = m[idx] * frac_dh() if mrs: data["m_mrs_%d" % idx] = m[idx]*(1-frac_dh()) # In[218]: def r0_guide(ret_data=False): a_r0_1 = pyro.param("a_r0_1", torch.tensor(2.), constraint=constraints.positive) a_r0_2 = pyro.param("a_r0_2", torch.tensor(.25), constraint=constraints.positive) b_r0_1 = pyro.param("b_r0_1", torch.tensor(4.), constraint=constraints.positive) b_r0_2 = pyro.param("b_r0_2", torch.tensor(1.25), constraint=constraints.positive) r0_1 = pyro.sample("r0_1", dist.Uniform(a_r0_1, a_r0_1 + b_r0_1)) r0_2 = pyro.sample("r0_2", dist.Uniform(a_r0_2, a_r0_2 + b_r0_2)) if mrs: a_r0_mrs_1 = pyro.param("a_r0_mrs_1", torch.tensor(2.), constraint=constraints.positive) a_r0_mrs_2 = pyro.param("a_r0_mrs_2", torch.tensor(.25), constraint=constraints.positive) b_r0_mrs_1 = pyro.param("b_r0_mrs_1", torch.tensor(4.), constraint=constraints.positive) b_r0_mrs_2 = pyro.param("b_r0_mrs_2", torch.tensor(1.25), constraint=constraints.positive) r0_mrs_1 = pyro.sample("r0_mrs_1", dist.Uniform(a_r0_mrs_1, a_r0_mrs_1 + b_r0_mrs_1)) r0_mrs_2 = pyro.sample("r0_mrs_2", dist.Uniform(a_r0_mrs_2, a_r0_mrs_2 + b_r0_mrs_2)) if ret_data: # R0 fluctuates around its means that are given by r0_1 and r0_2 r0 = torch.cat((dist.Uniform(torch.ones(date_r0_switch - 1) * r0_1 - .2, torch.ones(date_r0_switch - 1) * r0_1 + .2)(), dist.Uniform(torch.ones(n_useful_days - date_r0_switch + 1) * r0_2 - .2, torch.ones(n_useful_days - date_r0_switch + 1) * r0_2 + .2)())) s_T, e_T, i_T, h_T, l_T, m_T, r_T = SEIR(r0) #print(r0) if mrs: r0_mrs = torch.cat((dist.Uniform(torch.ones(date_r0_switch_mrs - 1) * r0_mrs_1 - .2, torch.ones(date_r0_switch_mrs - 1) * r0_mrs_1 + .2)(), dist.Uniform(torch.ones(n_useful_days - date_r0_switch_mrs + 1) * r0_mrs_2 - .2, torch.ones(n_useful_days - date_r0_switch_mrs + 1) * r0_mrs_2 + .2)())) _, _, _, m_mrs_T, _ = SEIR_MRS(r0_mrs) for idx in range(n_useful_days): if idx > 16: pyro.sample("h_%d" % idx, noiser(h_T[idx])) pyro.sample("l_%d" % idx, noiser(l_T[idx])) pyro.sample("m_%d" % idx, noiser(m_T[idx])) if mrs: pyro.sample("m_mrs_%d" % idx, noiser(m_mrs_T[idx])) if mrs and ret_data: return s_T, e_T, i_T, h_T, l_T, m_T, m_mrs_T, r_T, r0, r0_mrs if ret_data: return s_T, e_T, i_T, h_T, l_T, m_T, r_T, r0 # In[219]: pyro.clear_param_store() # setup the optimizer adam_params = {"lr": 0.001, "betas": (0.90, 0.999)} optimizer = pyro.optim.Adam(adam_params) svi = pyro.infer.SVI(model=pyro.condition(SEIR_full_model_bis, data=data), guide=r0_guide, optim=optimizer, loss=pyro.infer.JitTrace_ELBO()) # In[226]: get_ipython().run_cell_magic('time', '', 'losses, a,b = [], [], []\nnum_steps = 20000\nfor t in range(num_steps):\n #print(svi.step(guess_R0, guess_R0_RMS))\n #print(\'r0 = \',pyro.param("a_r0"), pyro.param("b_r0"))\n losses.append(svi.step())\n if t % 100 == 0:\n print(\'r0 before lock-down \',pyro.param("a_r0_1"), pyro.param("a_r0_1") + pyro.param("b_r0_1"))\n print(\'r0 after lock-down \', pyro.param("a_r0_2"), pyro.param("a_r0_2") + pyro.param("b_r0_2"))\n print(\'r0 mrs before lock-down \',pyro.param("a_r0_mrs_1"), pyro.param("a_r0_mrs_1") + pyro.param("b_r0_mrs_1"))\n print(\'r0 mrs after lock-down \', pyro.param("a_r0_mrs_2"), pyro.param("a_r0_mrs_2") + pyro.param("b_r0_mrs_2"))\n print(losses[-1])\n') # In[230]: n_points = 100 for i in range(n_points): if mrs: s_T, e_T, i_T, h_T, l_T, m_T, m_mrs_T, r_T, r0_T, r0_mrs_T = svi.guide(True) else: s_T, e_T, i_T, h_T, l_T, m_T, r_T, r0_T = svi.guide(True) hospi_T = np.array(list(map(lambda x: (x[0] + x[1]).detach().numpy(), zip(h_T, l_T)))).reshape(-1, 1) l_T = np.array(list(map(lambda x: x.detach().numpy(), l_T))).reshape(-1, 1) m_T = np.array(list(map(lambda x: x.detach().numpy(), m_T))).reshape(-1, 1) m_mrs_T = np.array(list(map(lambda x: x.detach().numpy(), m_mrs_T))).reshape(-1, 1) r0_T = np.array(list(map(lambda x: x.detach().numpy(), r0_T))).reshape(-1, 1) r0_mrs_T = np.array(list(map(lambda x: x.detach().numpy(), r0_mrs_T))).reshape(-1, 1) if i == 0: h = hospi_T l = l_T m = m_T m_mrs = m_mrs_T r0, r0_mrs = r0_T, r0_mrs_T else: h = np.append(h, hospi_T, axis=1) l = np.append(l, l_T, axis=1) m = np.append(m, m_T, axis=1) m_mrs = np.append(m_mrs, m_mrs_T, axis=1) r0_mrs = np.append(r0_mrs, r0_mrs_T, axis=1) r0 = np.append(r0, r0_T, axis=1) # In[239]: def get_percentiles(arr, color=None,ax=None, label=None): arr_50 = np.percentile(arr, 50, axis=1) arr_5 = np.percentile(arr, 5, axis=1) arr_95 = np.percentile(arr, 95, axis=1) if color is not None: ax.plot(data_df['date'][:len(arr)], arr_50, c=color, label=label) ax.fill_between(data_df['date'][:len(arr)], arr_5, arr_95, color=color, alpha=0.2) return arr_50, arr_5, arr_95 seaborn.set() fig, ax = plt.subplots(2, 1, figsize=(20, 10), sharex=True, gridspec_kw={"height_ratios": (4, 1)}) ax[0].plot(data_df['date'], data_df['n_hospitalized'], 'x', color='b', label='Hospitalized') ax[0].plot(data_df['date'], data_df['n_icu'], 'x', color='g', label='ICUs') ax[0].plot(data_df['date'], data_df['n_deaths'] * frac_dh().item(), 'x', color='r', label='Deaths in hospital') ax[0].plot(data_df['date'], data_df['n_deaths']*(1-frac_dh().item()), 'x', color='orange', label='Deaths in MRS') get_percentiles(h, 'b', ax[0]) get_percentiles(l, 'g', ax[0]) get_percentiles(m, 'r', ax[0]) get_percentiles(m_mrs, 'orange', ax[0]) get_percentiles(r0_mrs, 'yellow', ax[1], "R0 in mrs") get_percentiles(r0, 'brown', ax[1], 'R0') ax[0].axvline(data_df['date'][date_r0_switch], 0, 8500, label='Lockdown', c='black', alpha=.8, linestyle ='--') ax[0].axvline(data_df['date'][date_r0_switch_mrs], 0, 8500, label='MRS forbidden', c='black', alpha=.8, linestyle=':') ax[1].axvline(data_df['date'][date_r0_switch], 0, 8500, label='Lockdown', c='black', alpha=.8, linestyle ='--') ax[1].axvline(data_df['date'][date_r0_switch_mrs], 0, 8500, label='MRS forbidden', c='black', alpha=.8, linestyle=':') ax[0].legend() ax[1].legend() plt.savefig('pyro_SEIR.png') plt.show()