### Training RNN using World's data¶

This is based on COVID-19 growth prediction using multivariate long short term memory by Novanto Yudistira

https://arxiv.org/pdf/2005.04809.pdf

https://github.com/VICS-CORE/lstmcorona/blob/master/lstm.py

• We've aligned all countries' inputs rather than taking an absolute timeline. We start when cumulative number of confirmed cases in the country has crossed 100.
• We've normalised data by dividing by a population factor. That way the network can learn a general understanding of the pattern irrespective of the country's population.
• Rather than using the entire timeline as an input as suggested by NYudistira, we're training a fixed window (e.g. 20 days) so that the model learns to predict the future by looking at present data. The problem with fixed window approach is that some countries have peaked, while others have not. Also few countries start early, and some start late.
• The paper uses a multivariate network with confirmed, recovered and deceased data. However this'd increase computation time and hence we're restricting ourselves to a univariate model with confirmed cases as the only parameter.

#### Other ideas¶

• One idea is to train the current net with only the most populous countries' data, since their behaviour would be similar to India's.
• Adding metrics like humidity, population density, lockdown intensity etc might be beneficial and should have some correlation with the growth in cases. But this'd need more computation power.
• Another idea is to train a neuralnet to predict SIR like buckets.
In [ ]:
import pandas as pd
import numpy as np
import requests as rq
import datetime as dt
import torch
import json

tnn = torch.nn
top = torch.optim
from torch.utils import data as tdt

from matplotlib.ticker import MultipleLocator
from matplotlib.dates import DayLocator, AutoDateLocator, ConciseDateFormatter
%matplotlib inline

In [ ]:
CUDA="cuda:0"
CPU="cpu"
if torch.cuda.is_available():
device = torch.device(CUDA)
cd = torch.cuda.current_device()
print("Num devices:", torch.cuda.device_count())
print("Current device:", cd)
print("Device name:", torch.cuda.get_device_name(cd))
print("Device props:", torch.cuda.get_device_properties(cd))
print(torch.cuda.memory_summary(cd))
else:
device = torch.device(CPU)
print(device)

In [ ]:
# define paths
DATA_DIR = 'data'
MODELS_DIR = 'models'


### Colab only¶

In [ ]:
from google.colab import drive
drive.mount('/content/drive')

In [ ]:
%cd 'drive/My Drive/CS/colab/'

In [ ]:
!cat /proc/cpuinfo

In [ ]:
!cat /proc/meminfo


In [ ]:
!curl https://covid.ourworldindata.org/data/owid-covid-data.csv --output data/owid-covid-data.csv

In [ ]:
!head -n1 data/owid-covid-data.csv

In [ ]:
cols = ['location', 'date', 'total_cases', 'new_cases', 'total_deaths', 'new_deaths', 'population']
dates = ['date']
usecols=cols,
parse_dates=dates)
df.sample()


### LSTM¶

In [ ]:
class YudistirNet(tnn.Module):
def __init__(self, ip_seq_len=1, op_seq_len=1, hidden_size=1, num_layers=1):
super(YudistirNet, self).__init__()

self.ip_seq_len = ip_seq_len
self.op_seq_len = op_seq_len
self.hidden_size = hidden_size
self.num_layers = num_layers

self.lstm = tnn.LSTM(input_size=1, hidden_size=self.hidden_size, num_layers=self.num_layers, batch_first=True)
self.linear = tnn.Linear(self.hidden_size * self.ip_seq_len, self.op_seq_len)
self.sigmoid = tnn.Sigmoid()

def forward(self, ip):
lstm_out, _ = self.lstm(ip)
linear_out = self.linear(lstm_out.reshape(-1, self.hidden_size * self.ip_seq_len))
sigmoid_out = self.sigmoid(linear_out.view(-1, self.op_seq_len))
return sigmoid_out

def predict(self, ip):
preds = self.forward(ip)
return preds


### Checkpoint methods¶

In [ ]:
def save_checkpoint(epoch, model, optimizer, trn_losses, val_losses, min_val_loss, path=""):
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'trn_losses': trn_losses,
'val_losses': val_losses,
'min_val_loss': min_val_loss
}, path or MODELS_DIR + "/latest.pt")
print("Checkpoint saved")

cp = torch.load(path or MODELS_DIR + "/latest.pt", map_location=device)
return cp['epoch'], cp['model_state_dict'], cp['optimizer_state_dict'], cp['trn_losses'], cp['val_losses'], cp.get('min_val_loss', np.Inf)


### Config¶

In [ ]:
# config
IP_SEQ_LEN = 40
OP_SEQ_LEN = 20

BATCH_SIZE = 1
VAL_RATIO = 0.3

HIDDEN_SIZE = 20
NUM_LAYERS = 4
LEARNING_RATE = 0.001
NUM_EPOCHS = 3001

# to continue training on another model, set resume to true
RESUME = False

model = YudistirNet(ip_seq_len=IP_SEQ_LEN, op_seq_len=OP_SEQ_LEN, hidden_size=HIDDEN_SIZE, num_layers=NUM_LAYERS)
model = model.to(device)

loss_fn = tnn.MSELoss()

In [ ]:
sum(p.numel() for p in model.parameters() if p.requires_grad)


### Prepare dataset¶

In [ ]:
def gen_dataset():
ip_trn = []
op_trn = []

countries = df['location'].unique()
pop_countries = ['China', 'United States', 'Indonesia', 'Pakistan', 'Brazil', 'Bangladesh', 'Russia', 'Mexico']

c = 0
for country in countries:
if country in ['World', 'International', 'India']: # Countries to be skipped
continue
country_df = df.loc[df.location == country]
tot_cases_gt_100 = (country_df['total_cases'] >= 100)
country_df = country_df.loc[tot_cases_gt_100]

if len(country_df) >= IP_SEQ_LEN + OP_SEQ_LEN:
c += 1
pop = country_df['population'].iloc[0]
print(c, country, len(country_df), pop)
daily_cases = np.array(country_df['new_cases'].rolling(7, center=True, min_periods=1).mean() * 1000 / pop, dtype=np.float32)

for i in range(len(country_df) - IP_SEQ_LEN - OP_SEQ_LEN + 1):
ip_trn.append(daily_cases[i : i+IP_SEQ_LEN])
op_trn.append(daily_cases[i+IP_SEQ_LEN : i+IP_SEQ_LEN+OP_SEQ_LEN])

ip_trn = torch.from_numpy(np.array(ip_trn, dtype=np.float32))
op_trn = torch.from_numpy(np.array(op_trn, dtype=np.float32))
dataset = tdt.TensorDataset(ip_trn, op_trn)

val_len = int(VAL_RATIO * len(dataset))
trn_len = len(dataset) - val_len
trn_set, val_set = tdt.random_split(dataset, (trn_len, val_len))
return trn_set, val_set

try:
trn_set, val_set = ds['trn'], ds['val']
except FileNotFoundError:
trn_set, val_set = gen_dataset()
torch.save({'trn': trn_set, 'val': val_set}, DATA_DIR + '/ds.pt')
print("Saved dataset to ds.pt")
finally:
print("Training data:", len(trn_set), "Validation data:", len(val_set))

In [ ]:
trn_loader = tdt.DataLoader(trn_set, shuffle=True, batch_size=BATCH_SIZE)


### Train¶

In [ ]:
trn_loss_vals = []
val_loss_vals = []
e = 0
min_val_loss = np.Inf

if RESUME:
e, model_dict, optimizer_dict, trn_loss_vals, val_loss_vals, min_val_loss = load_checkpoint(device=device)
e+=1

# TRAIN
print("BEGIN: [", dt.datetime.now(), "]")
while e < NUM_EPOCHS:
model.train()
trn_losses = []
ip, op = data
ip = ip.to(device)
op = op.to(device)
preds = model(ip.view(-1, IP_SEQ_LEN, 1)) # predict
loss = loss_fn(preds, op.view(-1, OP_SEQ_LEN)) # calc loss
loss.backward() # calc and assign grads
optimizer.step() # update weights
trn_losses.append(loss) # logging
avg_trn_loss = torch.stack(trn_losses).mean().item() * 10000
trn_loss_vals.append(avg_trn_loss)

model.eval()
val_losses = []
ip, op = data
ip = ip.to(device)
op = op.to(device)
preds = model(ip.view(-1, IP_SEQ_LEN, 1))
loss = loss_fn(preds, op.view(-1, OP_SEQ_LEN))
val_losses.append(loss)
avg_val_loss = torch.stack(val_losses).mean().item() * 10000
val_loss_vals.append(avg_val_loss)

if e%10==0:
print("[", dt.datetime.now(), "] epoch:", f"{e:3}", "avg_val_loss:", f"{avg_val_loss: .5f}", "avg_trn_loss:", f"{avg_trn_loss: .5f}")
if e%100==0:
save_checkpoint(e, model, optimizer, trn_loss_vals, val_loss_vals, min_val_loss, MODELS_DIR + "/latest-e" + str(e) + ".pt")
if avg_val_loss <= min_val_loss:
min_val_loss = avg_val_loss
save_checkpoint(e, model, optimizer, trn_loss_vals, val_loss_vals, min_val_loss, MODELS_DIR + "/best-e" + str(e) + ".pt")
e+=1
print("END: [", dt.datetime.now(), "]")


### Load saved model for evaluation¶

In [ ]:
# model_path = MODELS_DIR + "/IP20_OP10_H10_L4_E2001_LR001.pt"
e, md, _, trn_loss_vals, val_loss_vals, _ = load_checkpoint(model_path, device=device)
print(e)
model.eval()


### Plot losses¶

In [ ]:
df_loss = pd.DataFrame({
'trn_loss': trn_loss_vals,
'val_loss': val_loss_vals
})
df_loss['trn_loss'] = df_loss['trn_loss'].rolling(10).mean()
df_loss['val_loss'] = df_loss['val_loss'].rolling(10).mean()
_ = df_loss.plot(
y=['trn_loss', 'val_loss'],
title='Loss per epoch',
subplots=True,
figsize=(5,6),
sharex=False,
logy=True
)


### Evalute fit¶

In [ ]:
c = "India"
pop_fct = df.loc[df.location==c, 'population'].iloc[0] / 1000

all_preds = []
pred_vals = []
out_vals = []

test_data = np.array(df.loc[(df.location==c) & (df.total_cases>=100), 'new_cases'].rolling(7, center=True, min_periods=1).mean() / pop_fct, dtype=np.float32)

for i in range(len(test_data) - IP_SEQ_LEN - OP_SEQ_LEN + 1):
ip = torch.tensor(test_data[i : i+IP_SEQ_LEN])
op = torch.tensor(test_data[i+IP_SEQ_LEN : i+IP_SEQ_LEN+OP_SEQ_LEN])
ip = ip.to(device)
op = op.to(device)

pred = model.predict(ip.view(1, IP_SEQ_LEN, 1))
if i==0: # prepend first input
out_vals.extend(ip.view(IP_SEQ_LEN).cpu().numpy() * pop_fct)
pred_vals.extend([np.NaN] * IP_SEQ_LEN)
all_preds.append(pred.view(OP_SEQ_LEN).cpu().numpy() * pop_fct)
pred_vals.append(pred.view(OP_SEQ_LEN).cpu().numpy()[0] * pop_fct)
out_vals.append(op.view(OP_SEQ_LEN).cpu().numpy()[0] * pop_fct)

# last N-1 values
out_vals.extend(op.view(OP_SEQ_LEN).cpu().numpy()[1:] * pop_fct)
pred_vals.extend(([np.NaN] * OP_SEQ_LEN)[1:]) # pad with NaN

cmp_df = pd.DataFrame({
'actual': out_vals,
'predicted0': pred_vals
})

# set date
start_date = df.loc[(df.location==c) & (df.total_cases>=100)]['date'].iloc[0]
end_date = start_date + dt.timedelta(days=cmp_df.index[-1])
cmp_df['Date'] = pd.Series([start_date + dt.timedelta(days=i) for i in range(len(cmp_df))])

# plot noodles
ax=None
i=IP_SEQ_LEN
mape=[]
for pred in all_preds:
cmp_df['predicted_cases'] = np.NaN
cmp_df.loc[i:i+OP_SEQ_LEN-1, 'predicted_cases'] = pred
ax = cmp_df.plot(x='Date', y='predicted_cases', ax=ax, legend=False)
ape = np.array(100 * ((cmp_df['actual'] - cmp_df['predicted_cases']).abs() / cmp_df['actual']))
#     mape.append(ape.mean())
mape.append(ape[~np.isnan(ape)])
i+=1

total = np.zeros(OP_SEQ_LEN)
for m in mape:
total += m
elwise_mape = total / len(mape)
print("Day wise accuracy:", 100 - elwise_mape)
acc = f"{100 - sum(elwise_mape)/len(elwise_mape):0.2f}%"
# acc = f"{100 - sum(mape)/len(mape):0.2f}%"

# plot primary lines
ax = cmp_df.plot(
x='Date',
y=['actual', 'predicted0'],
figsize=(20,8),
lw=5,
title=c + ' | Daily predictions | ' + acc,
ax=ax
)
mn_l = DayLocator()
ax.xaxis.set_minor_locator(mn_l)
mj_l = AutoDateLocator()
mj_f = ConciseDateFormatter(mj_l, show_offset=False)
ax.xaxis.set_major_formatter(mj_f)


### Test (predict) using OWID data¶

In [ ]:
c = "India"
n_days_prediction = 200

pop_fct = df.loc[df.location==c, 'population'].iloc[0] / 1000
test_data = np.array(df.loc[(df.location==c) & (df.total_cases>=100), 'new_cases'].rolling(7, center=True, min_periods=1).mean() / pop_fct, dtype=np.float32)

in_data = test_data[-IP_SEQ_LEN:]
out_data = np.array([], dtype=np.float32)
for i in range(int(n_days_prediction / OP_SEQ_LEN)):
ip = torch.tensor(
in_data,
dtype=torch.float32
)
ip = ip.to(device)
pred = model.predict(ip.view(1, IP_SEQ_LEN, 1))
in_data = np.append(in_data[-IP_SEQ_LEN+OP_SEQ_LEN:], pred.cpu().numpy())
out_data = np.append(out_data, pred.cpu().numpy())

orig_df = pd.DataFrame({
'actual': test_data * pop_fct
})
fut_df = pd.DataFrame({
'predicted': out_data * pop_fct
})
# print(fut_df['predicted'].astype('int').to_csv(sep='|', index=False))
orig_df = orig_df.append(fut_df, ignore_index=True, sort=False)
orig_df['total'] = (orig_df['actual'].fillna(0) + orig_df['predicted'].fillna(0)).cumsum()

start_date = df.loc[(df.location==c) & (df.total_cases>=100)]['date'].iloc[0]
orig_df['Date'] = pd.Series([start_date + dt.timedelta(days=i) for i in range(len(orig_df))])
ax = orig_df.plot(
x='Date',
y=['actual', 'predicted'],
title=c + ' daily cases',
figsize=(10,6),
grid=True
)
mn_l = DayLocator()
ax.xaxis.set_minor_locator(mn_l)
mj_l = AutoDateLocator()
mj_f = ConciseDateFormatter(mj_l, show_offset=False)
ax.xaxis.set_major_formatter(mj_f)
# orig_df['total'] = orig_df['total'].astype('int')
# orig_df['predicted'] = orig_df['predicted'].fillna(0).astype('int')
# print(orig_df.tail(n_days_prediction))

# arrow
# peakx = 172
# peak = orig_df.iloc[peakx]
# peak_desc = peak['Date'].strftime("%d-%b") + "\n" + str(int(peak['predicted']))
# _ = ax.annotate(
#     peak_desc,
#     xy=(peak['Date'] - dt.timedelta(days=1), peak['predicted']),
#     xytext=(peak['Date'] - dt.timedelta(days=45), peak['predicted'] * .9),
#     arrowprops={},
#     bbox={'facecolor':'white'}
# )

# _ = ax.axvline(x=peak['Date'], linewidth=1, color='r')


### Statewise prediction¶

In [ ]:
r=rq.get('https://api.covid19india.org/v3/min/timeseries.min.json')
ts = r.json()

data = []
for state in ts:
for date in ts[state]:
data.append((state, date, ts[state][date]['total'].get('confirmed', 0)))

states_df = pd.DataFrame(data, columns=['state', 'date', 'total'])
states_df['date'] = pd.to_datetime(states_df['date'])
first_case_date = states_df['date'].min()

In [ ]:
# http://www.populationu.com/india-population
STT_INFO = {
'AN' : {"name": "Andaman & Nicobar Islands", "popn": 450000},
'AP' : {"name": "Andhra Pradesh", "popn": 54000000},
'AR' : {"name": "Arunachal Pradesh", "popn": 30000000},
'AS' : {"name": "Asaam", "popn": 35000000},
'BR' : {"name": "Bihar", "popn": 123000000},
'CH' : {"name": "Chandigarh", "popn": 1200000},
'CT' : {"name": "Chhattisgarh", "popn": 29000000},
'DL' : {"name": "Delhi", "popn": 19500000},
'DN' : {"name": "Dadra & Nagar Haveli and Daman & Diu", "popn": 700000},
'GA' : {"name": "Goa", "popn": 1580000},
'GJ' : {"name": "Gujarat", "popn": 65000000},
'HP' : {"name": "Himachal Pradesh", "popn": 7400000},
'HR' : {"name": "Haryana", "popn": 28000000},
'JH' : {"name": "Jharkhand", "popn": 38000000},
'JK' : {"name": "Jammu & Kashmir", "popn": 13600000},
'KA' : {"name": "Karnataka", "popn": 67000000},
'KL' : {"name": "Kerala", "popn": 36000000},
'LA' : {"name": "Ladakh", "popn": 325000},
'MH' : {"name": "Maharashtra", "popn": 122000000},
'ML' : {"name": "Meghalaya", "popn": 3400000},
'MN' : {"name": "Manipur", "popn": 3000000},
'MZ' : {"name": "Mizoram", "popn": 1200000},
'NL' : {"name": "Nagaland", "popn": 2200000},
'OR' : {"name": "Odisha", "popn": 46000000},
'PB' : {"name": "Punjab", "popn": 30000000},
'PY' : {"name": "Puducherry", "popn": 1500000},
'RJ' : {"name": "Rajasthan", "popn": 80000000},
'TG' : {"name": "Telangana", "popn": 39000000},
'TN' : {"name": "Tamil Nadu", "popn": 77000000},
'TR' : {"name": "Tripura", "popn": 4100000},
'UP' : {"name": "Uttar Pradesh", "popn": 235000000},
'UT' : {"name": "Uttarakhand", "popn": 11000000},
'WB' : {"name": "West Bengal", "popn": 98000000},
#     'SK' : {"name": "Sikkim", "popn": 681000},
#     'UN' : {"name": "Unassigned", "popn": 40000000}, #avg pop
#     'LD' : {"name": "Lakshadweep", "popn": 75000}
}

# uncomment for India
# STT_INFO = {
#     'TT' : {"name": "India", "popn": 1387155000}
# }


#### Dummy state data: fruit country¶

In [ ]:
# dummy data for testing
# SET 1 - 10 states
# STT_INFO = {
#     'A': {"name": "Apple", "popn": 10000000},
#     'B': {"name": "Berry", "popn": 10000000},
#     'C': {"name": "Cherry", "popn": 10000000},
#     'D': {"name": "Dates", "popn": 10000000},
#     'E': {"name": "Elderberry", "popn": 10000000},
#     'F': {"name": "Fig", "popn": 10000000},
#     'G': {"name": "Grape", "popn": 10000000},
#     'H': {"name": "Honeysuckle", "popn": 10000000},
#     'I': {"name": "Icaco", "popn": 10000000},
#     'J': {"name": "Jujube", "popn": 10000000},
# }
# total = 100
# SET 2 - 1 agg state
STT_INFO = {
'Z': {"name": "FruitCountry1000x", "popn": 10000000},
}
total = 1000

r = {
'state': [],
'date': [],
'total': []
}

start_date = dt.datetime(day=1, month=3, year=2020)
end_date = dt.datetime.now()
while start_date <= end_date:
for s in STT_INFO:
r['state'].append(s)
r['date'].append(start_date)
r['total'].append(total)
total *= 1.03
start_date += dt.timedelta(days=1)
states_df = pd.DataFrame(r)
states_df['date'] = pd.to_datetime(states_df['date'])
states_df.tail()


#### Predict¶

In [ ]:
def expand(df):
'''Fill missing dates in an irregular timeline'''
min_date = df['date'].min()
max_date = df['date'].max()
idx = pd.date_range(min_date, max_date)

df.index = pd.DatetimeIndex(df.date)
df = df.drop(columns=['date'])

def prefill(df, min_date):
'''Fill zeros from first_case_date to df.date.min()'''
assert(len(df.state.unique()) == 1)
s = df.state.unique().item()
min_date = min_date
max_date = df['date'].max()
idx = pd.date_range(min_date, max_date)

df.index = pd.DatetimeIndex(df.date)
df = df.drop(columns=['date'])
return df.reindex(idx).reset_index().rename(columns={'index':'date'}).fillna({'state':s, 'total':0})

In [ ]:
prediction_offset = 1 # how many days of data to skip
n_days_prediction = 200 # number of days for prediction
n_days_data = len(expand(states_df.loc[states_df['state']=='TT']))
assert(n_days_prediction%OP_SEQ_LEN == 0)

agg_days = n_days_data - prediction_offset + n_days_prediction # number of days for plotting agg curve i.e. prediction + actual data
states_agg = np.zeros(agg_days)

ax = None
api = {}
for state in STT_INFO:
pop_fct = STT_INFO[state]["popn"] / 1000

state_df = prefill(expand(state_df), first_case_date)
state_df['daily'] = state_df['total'] - state_df['total'].shift(1).fillna(0)
test_data = np.array(state_df['daily'].rolling(7, center=True, min_periods=1).mean() / pop_fct, dtype=np.float32)

in_data = test_data[-IP_SEQ_LEN:]
out_data = np.array([], dtype=np.float32)
for i in range(int(n_days_prediction / OP_SEQ_LEN)):
ip = torch.tensor(
in_data,
dtype=torch.float32
).to(device)
try:
pred = model.predict(ip.view(-1, IP_SEQ_LEN, 1))
except Exception as e:
print(state, e)
in_data = np.append(in_data[-IP_SEQ_LEN+OP_SEQ_LEN:], pred.cpu().numpy())
out_data = np.append(out_data, pred.cpu().numpy())

sn = STT_INFO[state]['name']
orig_df = pd.DataFrame({
'actual': np.array(test_data * pop_fct, dtype=np.int)
})
fut_df = pd.DataFrame({
'predicted': np.array(out_data * pop_fct, dtype=np.int)
})
# print(fut_df.to_csv(sep='|'))
orig_df = orig_df.append(fut_df, ignore_index=True, sort=False)
orig_df[sn] = orig_df['actual'].fillna(0) + orig_df['predicted'].fillna(0)
orig_df['total'] = orig_df[sn].cumsum()
states_agg += np.array(orig_df[sn][-agg_days:].fillna(0))

# generate date col for orig_df from state_df
start_date = state_df['date'].iloc[0]
orig_df['Date'] = pd.to_datetime([(start_date + dt.timedelta(days=i)).strftime("%Y-%m-%d") for i in range(len(orig_df))])
#     if orig_df[sn].max() < 10000: # or orig_df[sn].max() < 5000:
#         continue

# print state, peak date, peak daily cases, cumulative since beginning
peak = orig_df.loc[orig_df[sn].idxmax()]
print(sn, "|", peak['Date'].strftime("%b %d"), "|", int(peak[sn]), "|", int(orig_df['total'].iloc[-1]))

# export data for API
orig_df['deceased_daily'] = orig_df[sn] * 0.028
orig_df['recovered_daily'] = orig_df[sn].shift(14, fill_value=0) - orig_df['deceased_daily'].shift(7, fill_value=0)
orig_df['active_daily'] = orig_df[sn] - orig_df['recovered_daily'] - orig_df['deceased_daily']

api[state] = {}
for idx, row in orig_df[-agg_days:].iterrows():
row_date = row['Date'].strftime("%Y-%m-%d")
api[state][row_date] = {
"delta": {
"confirmed": int(row[sn]),
"deceased": int(row['deceased_daily']),
"recovered": int(row['recovered_daily']),
"active": int(row['active_daily'])
}
}

# plot state chart
ax = orig_df.plot(
x='Date',
y=[sn],
title='Daily Cases',
figsize=(15,10),
grid=True,
ax=ax,
lw=3
)
mn_l = DayLocator()
ax.xaxis.set_minor_locator(mn_l)
mj_l = AutoDateLocator()
mj_f = ConciseDateFormatter(mj_l, show_offset=False)
ax.xaxis.set_major_formatter(mj_f)

# plot aggregate chart
cum_df = pd.DataFrame({
'states_agg': states_agg
})
last_date = orig_df['Date'].iloc[-1].to_pydatetime()
start_date = last_date - dt.timedelta(days=agg_days)
cum_df['Date'] = pd.to_datetime([(start_date + dt.timedelta(days=i)).strftime("%Y-%m-%d") for i in range(len(cum_df))])
ax = cum_df.plot(
x='Date',
y=['states_agg'],
title='Aggregate daily cases',
figsize=(15,10),
grid=True,
lw=3
)
mn_l = DayLocator()
ax.xaxis.set_minor_locator(mn_l)
mj_l = AutoDateLocator()
mj_f = ConciseDateFormatter(mj_l, show_offset=False)
ax.xaxis.set_major_formatter(mj_f)

# plot peak in agg
# peakx = 171
# peak = cum_df.iloc[peakx]
# peak_desc = peak['Date'].strftime("%d-%b") + "\n" + str(int(peak['states_agg']))
# _ = ax.annotate(
#     peak_desc,
#     xy=(peak['Date'] + dt.timedelta(days=1), peak['states_agg']),
#     xytext=(peak['Date'] + dt.timedelta(days=45), peak['states_agg'] * .9),
#     arrowprops={},
#     bbox={'facecolor':'white'}
# )
# _ = ax.axvline(x=peak['Date'], linewidth=1, color='r')


#### Export JSON for API¶

In [ ]:
# aggregate predictions
api['TT'] = {}
for state in api:
if state == 'TT':
continue
for date in api[state]:
api['TT'][date] = api['TT'].get(date, {'delta':{}, 'total':{}})
for k in ['delta']: #'total'
api['TT'][date][k]['confirmed'] = api['TT'][date][k].get('confirmed', 0) + api[state][date][k]['confirmed']
api['TT'][date][k]['deceased'] = api['TT'][date][k].get('deceased', 0) + api[state][date][k]['deceased']
api['TT'][date][k]['recovered'] = api['TT'][date][k].get('recovered', 0) + api[state][date][k]['recovered']
api['TT'][date][k]['active'] = api['TT'][date][k].get('active', 0) + api[state][date][k]['active']

# export
with open("predictions.json", "w") as f:
f.write(json.dumps(api, sort_keys=True))


#### Export data for video player¶

In [ ]:
# aggregate predictions
api['TT'] = {}
for state in api:
if state == 'TT':
continue
for date in api[state]:
api['TT'][date] = api['TT'].get(date, {})
api['TT'][date]['c'] = api['TT'][date].get('c', 0) + api[state][date]['delta']['confirmed']
api['TT'][date]['d'] = api['TT'][date].get('d', 0) + api[state][date]['delta']['deceased']
api['TT'][date]['r'] = api['TT'][date].get('r', 0) + api[state][date]['delta']['recovered']
api['TT'][date]['a'] = api['TT'][date].get('a', 0) + api[state][date]['delta']['active']

# cumulative
# t = {'c':0, 'd':0, 'r':0, 'a':0}
# for date in sorted(api['TT'].keys()):
#     for k in ['c', 'd', 'r', 'a']:
#         api['TT'][date][k] += t[k] # add cum to today
#         t[k] = api['TT'][date][k] # udpate cum

k = (states_df.date.max().to_pydatetime() - dt.timedelta(days=prediction_offset)).strftime("%Y-%m-%d")
try:
with open("vp.json", "r") as f:
except Exception as e:
out = {}

with open("vp.json", "w") as f:
out[k] = {'TT': api['TT']}
f.write(json.dumps(out, sort_keys=True))


#### CSV export video player ouput¶

In [ ]:
df_csv = pd.DataFrame(out[k]['TT'])
df_csv = df_csv.transpose()
df_csv['c'].to_csv('vp_' + k + '.csv')