!git clone https://github.com/pyro-ppl/pyro.git %cd /content/pyro !pip install .[extras] import os import logging import urllib.request from collections import OrderedDict import pandas as pd import numpy as np import matplotlib.pyplot as plt import seaborn as sns import torch import pyro import pyro.distributions as dist from pyro.contrib.epidemiology import CompartmentalModel, binomial_dist, infection_dist from pyro.ops.tensor_utils import convolve %matplotlib inline pyro.enable_validation(True) torch.set_default_dtype(torch.double) class CovidModel(CompartmentalModel): def __init__(self, population, new_cases, new_recovered, new_deaths): ''' population (int) – Total population = S + E + I + R. ''' assert len(new_cases) == len(new_recovered) == len(new_deaths) compartments = ("S", "E", "I", "D") # R is implicit. duration = len(new_cases) super().__init__(compartments, duration, population) self.new_cases = new_cases self.new_deaths = new_deaths self.new_recovered = new_recovered def global_model(self): tau_i = pyro.sample("rec_time", dist.Normal(15.0, 3.0)) tau_e = pyro.sample("incub_time", dist.Normal(5.0, 1.0)) # R0 = pyro.sample("R0", dist.LogNormal(0., 1.)) R0 = pyro.sample("R0", dist.Normal(2.5, 0.5)) rho = pyro.sample("rho", dist.Beta(10, 10)) # About 50% response rate. mort_rate = pyro.sample("mort_rate", dist.Beta(2, 50)) # About 2% mortality rate. rec_rate = pyro.sample("rec_rate",dist.Beta(10, 10)) # About 50% recovery rate. return R0, tau_e, tau_i, rho, mort_rate, rec_rate def initialize(self, params): # Start with a single infection. return {"S": self.population - 1, "E": 0, "I": 1, "D": 0} def transition(self, params, state, t): R0, tau_e, tau_i, rho, mort_rate, rec_rate = params # Sample flows between compartments. S2E = pyro.sample("S2E_{}".format(t), infection_dist(individual_rate=R0 / tau_i, num_susceptible=state["S"], num_infectious=state["I"], population=self.population)) E2I = pyro.sample("E2I_{}".format(t), binomial_dist(state["E"], 1 / tau_e )) I2R = pyro.sample("I2R_{}".format(t), binomial_dist(state["I"], 1 / tau_i)) I2D = pyro.sample("I2D_{}".format(t), binomial_dist(state["I"], mort_rate / tau_i)) # Update compartments with flows. state["S"] = state["S"] - S2E state["E"] = state["E"] + S2E - E2I state["I"] = state["I"] + E2I - I2R - I2D state["D"] = state["D"] + I2D # Condition on observations. t_is_observed = isinstance(t, slice) or t < self.duration pyro.sample("new_cases_{}".format(t), binomial_dist(S2E, rho), obs=self.new_cases[t] if t_is_observed else None) pyro.sample("new_deaths_{}".format(t), binomial_dist(I2D, 1), obs=self.new_deaths[t] if t_is_observed else None) pyro.sample("new_recovered_{}".format(t), binomial_dist(I2R, rho), obs=self.new_recovered[t] if t_is_observed else None) def compute_flows(self, prev, curr, t): S2E = prev["S"] - curr["S"] # S can only go to E. I2D = curr["D"] - prev["D"] # D can only have come from I. # We deduce the remaining flows by conservation of mass: # curr - prev = inflows - outflows E2I = prev["E"] - curr["E"] + S2E I2R = prev["I"] - curr["I"] + E2I - I2D return { "S2E_{}".format(t): S2E, "E2I_{}".format(t): E2I, "I2D_{}".format(t): I2D, "I2R_{}".format(t): I2R, } # function to make the time series of confirmed and daily confirmed cases for a specific country def create_country (country, start_date, end_date, state = False) : url = 'https://raw.githubusercontent.com/assemzh/ProbProg-COVID-19/master/full_grouped.csv' data = pd.read_csv(url) data.Date = pd.to_datetime(data.Date) if state : df = data.loc[data["Province/State"] == country, ["Province/State", "Date", "Confirmed", "Deaths", "Recovered", "Active", "New cases", "New deaths", "New recovered"]] else : df = data.loc[data["Country/Region"] == country, ["Country/Region", "Date", "Confirmed", "Deaths", "Recovered", "Active", "New cases", "New deaths", "New recovered"]] df.columns = ["country", "date", "confirmed", "deaths", "recovered", "active", "new_cases", "new_deaths", "new_recovered"] # group by country and date df = df.groupby(['country','date'])['confirmed', 'deaths', 'recovered',"active", "new_cases", "new_deaths", "new_recovered"].sum().reset_index() # convert date string to datetime df.date = pd.to_datetime(df.date) df = df.sort_values(by = "date") df = df[df.date >= start_date] df = df[df.date <= end_date] active = df['active'].tolist() recovered = df['recovered'].tolist() deaths = df['deaths'].tolist() new_cases = df['new_cases'].tolist() new_recovered = df['new_recovered'].tolist() new_deaths = df['new_deaths'].tolist() active = torch.tensor(list(map(float, active))).view(len(active),1) recovered = torch.tensor(list(map(float, recovered))).view(len(recovered),1) deaths = torch.tensor(list(map(float, deaths))).view(len(deaths),1) new_cases = torch.tensor(list(map(float, new_cases))).view(len(new_cases),1) new_recovered = torch.tensor(list(map(float, new_recovered))).view(len(new_recovered),1) new_deaths = torch.tensor(list(map(float, new_deaths))).view(len(new_deaths),1) return_data = { 'active':active, 'recovered':recovered, 'deaths':deaths, 'new_cases':new_cases, 'new_recovered': new_recovered, 'new_deaths':new_deaths } return return_data # Parameters country = "Japan" start_date = "2020-02-01" end_date = "2020-04-01" population = 126500000 data = create_country(country, start_date, end_date) %%time model = CovidModel(population, data["new_cases"], data["new_recovered"], data["new_deaths"] ) mcmc = model.fit_mcmc(num_samples=500, warmup_steps = 200) mcmc.summary()