import warnings
import arviz as az
import numpy as np
import pandas as pd
import pymc3 as pm
import theano.tensor as tt
import xarray as xr
%matplotlib inline
warnings.simplefilter(action="ignore", category=FutureWarning)
az.rcParams["stats.ic_pointwise"] = True
There are many situations where one model can be used for several prediction tasks at the same time. Hierarchical models or models with multiple observations are examples of such cases. With two observations for example, the same model can be used to predict only the first observation, only the second or both observations at the same time.
Before estimating the predictive accuracy, there are two important questions to answer: what is the predictive task we are interested in and, whether or not the exchangeability criteria is met. This section will show several alternative ways to define the predictive task using the same model.
We are going to analyze data from the 2018-2019 season of Spain's highest women's football league. We will start by loading the already cleaned up data. It is a dataframe summarizing all the matches of the season. Each row represents a match. You can see the head of the dataframe below.
df = pd.read_csv("18-19_df.csv")
df.head()
home_team | away_team | home_goals | away_goals | |
---|---|---|---|---|
0 | Atlético de Madrid | Athletic Club | 3 | 0 |
1 | Barcelona | Athletic Club | 2 | 1 |
2 | R.C.D. Espanyol | Athletic Club | 1 | 2 |
3 | Fundación Albacete | Athletic Club | 0 | 1 |
4 | Granadilla | Athletic Club | 3 | 1 |
The model used is taken from this blog post which was added as an example notebook to PyMC docs. This notebook will only describe the model quite concisely and will not discuss the model implementation in order to focus on information criteria calculation. To read more about the models please refer to the two posts and references therein.
We are trying to model a league in which all teams play against each other twice. We indicate the number of goals scored by the home and the away team in the $g$-th game of the season ($n$ matches) as $y_{g,h}$ and $y_{g,a}$ respectively. The model assumes the goals scored by a team follow a Poisson distribution:
$$ y_{g,j} | \theta_{g,j} \sim \text{Poiss}(\theta_{g,j}) $$where $j = {h, a}$ representing either home or away team. We will therefore start with a model containing two observation vectors: $\mathbf{y_h} = (y_{1,h}, y_{2,h}, \dots, y_{n,h})$ and $\mathbf{y_a} = (y_{1,a}, \dots, y_{n,a})$. In order to take into account each team's scoring and defensive power and also the advantage of playing home, we will use different formulas for $\theta_{g,h}$ and for $\theta_{g,a}$:
$$ \begin{align} \theta_{g,h} &= \alpha + home + atts_{home\_team} + defs_{away\_team}\\ \theta_{g,a} &= \alpha + atts_{away\_team} + defs_{home\_team} \end{align} $$The expected number of goals score by the home team $\theta_{g,h}$ depends on an intercept, $\alpha$, $home$ to quantify the home advantage, on the attacking power of the home team and on the defensive power of the away team. Similarly, the expected number of goals score by the away team $\theta_{g,a}$ also depends on the intercept but not on the home advantage, and now, consequently, we use the attacking power of the away team and the defensive power of the home team.
Summing up and including the priors, our base model is the following one:
$$ \begin{align} \alpha &\sim \text{Normal}(0,5) \\ home &\sim \text{Normal}(0,5) \\ sd_{att} &\sim \text{HalfStudentT}(3,2.5) \\ sd_{def} &\sim \text{HalfStudentT}(3,2.5) \\ atts_* &\sim \text{Normal}(0,sd_{att}) \\ defs_* &\sim \text{Normal}(0,sd_{def}) \\ \mathbf{y}_h &\sim \text{Poiss}(\theta_h) \\ \mathbf{y}_a &\sim \text{Poiss}(\theta_a) \end{align} $$where $\theta_j$ has been defined above, $atts = atts_* - \text{mean}(atts_*)$ and $defs$ is defined like $atts$.
df = pd.read_csv("18-19_df.csv")
home_team_idxs, team_names = pd.factorize(df.home_team, sort=True)
away_team_idxs, _ = pd.factorize(df.away_team, sort=True)
num_teams = len(team_names)
df
home_team | away_team | home_goals | away_goals | |
---|---|---|---|---|
0 | Atlético de Madrid | Athletic Club | 3 | 0 |
1 | Barcelona | Athletic Club | 2 | 1 |
2 | R.C.D. Espanyol | Athletic Club | 1 | 2 |
3 | Fundación Albacete | Athletic Club | 0 | 1 |
4 | Granadilla | Athletic Club | 3 | 1 |
... | ... | ... | ... | ... |
235 | Rayo Vallecano | Valencia | 1 | 1 |
236 | Real Betis | Valencia | 4 | 0 |
237 | Real Sociedad | Valencia | 6 | 0 |
238 | Sevilla F.C. | Valencia | 2 | 2 |
239 | Sporting Huelva | Valencia | 0 | 2 |
240 rows × 4 columns
coords = {"team": team_names, "match": np.arange(len(df))}
with pm.Model(coords=coords) as m_base:
# constant data
home_team = pm.Data("home_team", home_team_idxs, dims="match")
away_team = pm.Data("away_team", away_team_idxs, dims="match")
# global model parameters
home = pm.Normal('home', mu=0, sigma=5)
sd_att = pm.HalfStudentT('sd_att', nu=3, sigma=2.5)
sd_def = pm.HalfStudentT('sd_def', nu=3, sigma=2.5)
intercept = pm.Normal('intercept', mu=0, sigma=5)
# team-specific model parameters
atts_star = pm.Normal("atts_star", mu=0, sigma=sd_att, dims="team")
defs_star = pm.Normal("defs_star", mu=0, sigma=sd_def, dims="team")
atts = atts_star - tt.mean(atts_star)
defs = defs_star - tt.mean(defs_star)
home_theta = tt.exp(intercept + home + atts[home_team] + defs[away_team])
away_theta = tt.exp(intercept + atts[away_team] + defs[home_team])
# likelihood of observed data
home_goals = pm.Poisson('home_goals', mu=home_theta, observed=df.home_goals, dims="match")
away_goals = pm.Poisson('away_goals', mu=away_theta, observed=df.away_goals, dims="match")
with m_base:
idata = pm.sample(draws=2000, random_seed=1375, return_inferencedata=True)
# define helpers to make code less verbose
log_lik = idata.log_likelihood
const = idata.constant_data
Auto-assigning NUTS sampler... Initializing NUTS using jitter+adapt_diag... Multiprocess sampling (4 chains in 4 jobs) NUTS: [defs_star, atts_star, intercept, sd_def, sd_att, home]
Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 7 seconds.
idata
<xarray.Dataset> Dimensions: (chain: 4, draw: 2000, team: 16) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 ... 1993 1994 1995 1996 1997 1998 1999 * team (team) object 'Athletic Club' 'Atlético de Madrid' ... 'Valencia' Data variables: home (chain, draw) float64 0.2115 0.2047 0.2854 ... 0.3334 0.1549 intercept (chain, draw) float64 0.1675 0.128 0.1678 ... 0.1153 0.1571 atts_star (chain, draw, team) float64 0.1237 0.6615 ... -0.6013 -0.2507 defs_star (chain, draw, team) float64 -0.1521 -0.8256 ... 0.05384 0.2952 sd_att (chain, draw) float64 0.4699 0.3658 0.5048 ... 0.4023 0.3256 sd_def (chain, draw) float64 0.4336 0.5163 0.2876 ... 0.426 0.3412 Attributes: created_at: 2020-07-13T15:58:56.530478 arviz_version: 0.9.0 inference_library: pymc3 inference_library_version: 3.9.2 sampling_time: 7.236523389816284 tuning_steps: 1000
array([0, 1, 2, 3])
array([ 0, 1, 2, ..., 1997, 1998, 1999])
array(['Athletic Club', 'Atlético de Madrid', 'Barcelona', 'Fundación Albacete', 'Granadilla', 'Levante U.D.', 'Logroño', 'Madrid CFF', 'Málaga', 'R.C.D. Espanyol', 'Rayo Vallecano', 'Real Betis', 'Real Sociedad', 'Sevilla F.C.', 'Sporting Huelva', 'Valencia'], dtype=object)
array([[0.21154169, 0.20469343, 0.28539529, ..., 0.13309639, 0.36091624, 0.2226952 ], [0.2401905 , 0.31491612, 0.12143601, ..., 0.37324973, 0.18206524, 0.33472613], [0.26922698, 0.37413167, 0.23244674, ..., 0.39793253, 0.0694607 , 0.2624238 ], [0.36177412, 0.3635785 , 0.17345875, ..., 0.26492811, 0.33343765, 0.15488555]])
array([[0.16752991, 0.12796635, 0.16778481, ..., 0.16048541, 0.12174387, 0.17878365], [0.19930337, 0.06438167, 0.27754051, ..., 0.03270089, 0.19830445, 0.06719662], [0.03649964, 0.08970781, 0.15828158, ..., 0.08183066, 0.28504525, 0.15737091], [0.08066439, 0.11977842, 0.13329816, ..., 0.15412488, 0.11528356, 0.1571089 ]])
array([[[ 1.23730891e-01, 6.61516839e-01, 7.20673939e-01, ..., -1.83265916e-01, -7.95688039e-01, 4.00409211e-01], [ 2.89846577e-01, 1.03576540e+00, 9.00638929e-01, ..., 7.93574644e-02, -3.51908861e-01, -1.13558909e-01], [-6.51852244e-02, 5.31637247e-01, 6.49570449e-01, ..., -2.65844392e-01, -6.50606547e-01, 1.03034975e-01], ..., [-1.92041375e-02, 7.60511634e-01, 7.82077142e-01, ..., -7.97981571e-02, -3.08659870e-01, 8.51813412e-03], [ 1.32716352e-01, 7.25875918e-01, 9.65116783e-01, ..., 1.99532574e-01, -7.70869602e-01, -9.06739727e-02], [ 2.72640947e-01, 8.76134692e-01, 7.69095654e-01, ..., -1.65183705e-01, -1.93100600e-01, 1.84089543e-01]], [[-1.61339743e-01, 4.85999541e-01, 4.30169253e-01, ..., -3.51844583e-01, -7.41527375e-01, 4.80847343e-02], [ 1.02911079e-02, 6.10677263e-01, 7.04800223e-01, ..., -1.95593403e-01, -1.00427756e+00, -1.07598594e-01], [-1.24899013e-02, 6.88089188e-01, 5.32500497e-01, ..., -1.21666971e-01, -5.78690847e-01, -8.77415780e-02], ... [ 1.01847884e-01, 7.76066254e-01, 6.06512804e-01, ..., -2.62112754e-01, -3.27979084e-01, -1.28199201e-01], [ 9.19130839e-02, 7.41599264e-01, 8.25727551e-01, ..., -3.23810025e-02, -6.87077796e-01, 7.53073507e-02], [-2.30045191e-05, 6.81278734e-01, 5.18717443e-01, ..., -9.86307838e-02, -4.42477050e-01, -5.18837616e-02]], [[ 8.63649971e-02, 6.67181654e-01, 7.25996559e-01, ..., -1.89825701e-01, -6.36641109e-01, -4.10268800e-02], [ 5.00692142e-02, 8.23121501e-01, 7.64515576e-01, ..., -2.10154107e-02, -5.04946472e-01, -2.88915808e-02], [ 4.23243331e-02, 6.40436350e-01, 6.88386025e-01, ..., -8.68213346e-02, -6.56440567e-01, -5.50004220e-02], ..., [-9.53597671e-03, 7.42293612e-01, 8.61426432e-01, ..., 6.83621121e-02, -5.45442836e-01, -1.41179891e-01], [ 2.14215781e-01, 7.26237046e-01, 5.12522293e-01, ..., -2.67871024e-01, -4.70902483e-01, 2.14733641e-01], [ 1.95419937e-01, 6.62273359e-01, 7.28184187e-01, ..., -1.89903520e-01, -6.01285459e-01, -2.50705194e-01]]])
array([[[-1.52119126e-01, -8.25632737e-01, -1.03250952e+00, ..., 2.63611617e-01, -9.64470545e-04, 3.84225956e-01], [-4.45734209e-01, -7.52062827e-01, -6.76197778e-01, ..., 1.34480705e-01, -5.14607263e-02, 1.59560835e-01], [-1.36944094e-01, -6.22311717e-01, -7.82538723e-01, ..., 3.00763259e-01, 1.02028927e-01, 3.34004953e-02], ..., [-2.82688428e-01, -7.83988995e-01, -8.28443603e-01, ..., 3.48302296e-01, 2.62254233e-03, 2.99066395e-01], [-2.52532017e-01, -2.07606051e-01, -1.83696893e-01, ..., 5.42324357e-01, 5.04742585e-01, 3.56930097e-01], [ 2.75250815e-01, -6.19970166e-01, -9.61926424e-01, ..., 4.73705396e-01, 6.46063238e-02, 4.55862352e-01]], [[-2.81016352e-01, -6.79032085e-01, -5.62255142e-01, ..., 3.71818816e-01, 1.49532061e-01, 3.92341441e-02], [ 3.25317994e-02, -9.64354395e-01, -6.84068710e-01, ..., 2.44475343e-01, 3.65026935e-01, 4.42257675e-01], [-3.89451677e-01, -4.28956557e-01, -8.27423349e-01, ..., 3.90601817e-01, -1.01565112e-01, -6.16778477e-02], ... [ 1.54296370e-01, -4.68397123e-01, -7.02999462e-01, ..., 1.67743063e-01, 3.58637450e-01, 3.10411658e-01], [-5.40535760e-01, -5.01121525e-01, -6.12979155e-01, ..., 4.56706068e-01, -7.85691937e-02, 1.95474409e-01], [ 1.58227354e-01, -5.15711825e-01, -8.44632418e-01, ..., 2.69123646e-01, 3.74803322e-01, 3.05957739e-01]], [[-2.78857305e-01, -6.69698214e-01, -9.21512353e-01, ..., 3.18052748e-01, -4.28602992e-02, 9.22762599e-02], [-2.48659740e-01, -5.40529243e-01, -7.37042207e-01, ..., 1.37451143e-01, -1.29265669e-01, 4.28439278e-02], [-2.40504863e-01, -5.36714850e-01, -4.96534951e-01, ..., 2.88809583e-01, 1.79064848e-01, 1.75393871e-01], ..., [ 7.72954282e-02, -6.70113539e-01, -7.65622686e-01, ..., 3.58412323e-01, 4.10889936e-01, 4.07982565e-01], [-3.29783402e-01, -5.03062131e-01, -7.36703508e-01, ..., 6.08095828e-01, 3.14349482e-01, 1.39344380e-01], [-1.73845835e-01, -5.24420519e-01, -4.83205147e-01, ..., 1.47473392e-01, 5.38354859e-02, 2.95161982e-01]]])
array([[0.46987804, 0.36582889, 0.50482259, ..., 0.34534629, 0.44535971, 0.42316146], [0.34079095, 0.50200427, 0.39811673, ..., 0.52735283, 0.37571635, 0.45359797], [0.36752224, 0.45634018, 0.49692049, ..., 0.42478124, 0.33690609, 0.38447942], [0.33583563, 0.51393131, 0.5050457 , ..., 0.45966472, 0.40232205, 0.32561029]])
array([[0.43363668, 0.51630333, 0.28759708, ..., 0.52097708, 0.43117119, 0.74265084], [0.33376653, 0.45894763, 0.40311287, ..., 0.52781656, 0.32655906, 0.36787168], [0.46600693, 0.48490717, 0.69567974, ..., 0.29605216, 0.4055735 , 0.34535644], [0.67347968, 0.31652436, 0.31301461, ..., 0.30621 , 0.4259553 , 0.34123488]])
<xarray.Dataset> Dimensions: (chain: 4, draw: 2000, match: 240) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 ... 1994 1995 1996 1997 1998 1999 * match (match) int64 0 1 2 3 4 5 6 7 ... 233 234 235 236 237 238 239 Data variables: home_goals (chain, draw, match) float64 -1.567 -1.369 ... -1.363 -1.036 away_goals (chain, draw, match) float64 -0.577 -1.226 ... -1.613 -1.703 Attributes: created_at: 2020-07-13T15:58:57.815138 arviz_version: 0.9.0 inference_library: pymc3 inference_library_version: 3.9.2
array([0, 1, 2, 3])
array([ 0, 1, 2, ..., 1997, 1998, 1999])
array([ 0, 1, 2, ..., 237, 238, 239])
array([[[-1.56688619, -1.36877571, -1.0013594 , ..., -3.09416077, -1.32271809, -0.95319762], [-1.52183702, -1.32620374, -1.02076034, ..., -4.25965606, -1.31309022, -1.19943457], [-1.55188075, -1.42401678, -1.00013969, ..., -4.01230556, -1.4617987 , -0.89327928], ..., [-1.72375817, -1.3069583 , -1.00285995, ..., -4.47567616, -1.37354922, -1.21460358], [-1.69713253, -1.37758184, -1.02689892, ..., -3.27688929, -1.3168906 , -0.836347 ], [-1.5342279 , -1.54814476, -1.04696981, ..., -4.46917617, -1.38827393, -1.4414448 ]], [[-1.68481407, -1.30760312, -1.02908866, ..., -4.73795205, -1.51154851, -0.83042627], [-1.49674278, -1.63485159, -1.00025676, ..., -3.32198426, -1.30790173, -0.92005301], [-1.60871382, -1.30823911, -1.02200035, ..., -5.41202064, -1.42454925, -0.88008585], ... [-1.63833989, -1.64331946, -1.08865316, ..., -3.31522658, -1.33856045, -1.55857085], [-1.91478059, -1.31467605, -1.10803075, ..., -4.74308013, -1.34873283, -0.84078414], [-1.52011936, -1.45989688, -1.05066688, ..., -3.88322885, -1.31697759, -1.28013156]], [[-1.5100111 , -1.45856447, -1.00311047, ..., -4.02830779, -1.33703635, -1.06974258], [-1.5012605 , -1.49638371, -1.00368184, ..., -3.70401306, -1.31439167, -1.12875436], [-1.61588907, -1.33271829, -1.00515553, ..., -4.94112464, -1.34529143, -0.92390037], ..., [-1.49988622, -1.57015252, -1.00057067, ..., -2.86349649, -1.30696077, -1.09386185], [-1.63333103, -1.32231616, -1.00220718, ..., -3.95506289, -1.47368088, -1.05323562], [-1.59480002, -1.35106871, -1.03366471, ..., -4.49185999, -1.3626634 , -1.03649301]]])
array([[[-0.57699915, -1.22596072, -1.50109448, ..., -1.40034724, -1.3225853 , -1.32601997], [-0.74623747, -1.02190001, -1.58963965, ..., -0.9680934 , -1.52205952, -1.68867205], [-0.62591407, -1.16201691, -1.47819469, ..., -1.19302452, -1.3116939 , -1.37319997], ..., [-0.48081231, -1.23663902, -1.86958474, ..., -0.97089429, -1.37145663, -1.61455223], [-0.81825757, -1.01472613, -1.52464547, ..., -0.83634761, -1.42658592, -1.45065127], [-0.62679219, -1.25435667, -1.45178137, ..., -1.16095545, -1.32973008, -1.5729366 ]], [[-0.5688803 , -1.08665506, -1.56586131, ..., -1.13651416, -1.30686209, -1.3515922 ], [-0.45371615, -1.11049471, -1.47727219, ..., -0.78270565, -1.44257944, -1.37453232], [-0.95099204, -1.08716353, -1.41071469, ..., -1.19720404, -1.30685365, -1.51308043], ... [-0.73849196, -1.12180983, -1.62164639, ..., -0.94222363, -1.59568492, -1.44703482], [-0.85172969, -1.03393616, -1.3524212 , ..., -1.15125907, -1.31475356, -1.48043596], [-0.67389804, -1.20860158, -1.70225501, ..., -1.22307205, -1.41910943, -1.3640874 ]], [[-0.71653768, -1.14216724, -1.66621486, ..., -0.93897443, -1.33297972, -1.54162065], [-0.76275279, -1.09400484, -1.41316365, ..., -1.02966206, -1.42521486, -1.63359502], [-0.76647468, -1.02367235, -1.68167198, ..., -0.96028779, -1.35629097, -1.41076607], ..., [-0.48801802, -1.25647658, -1.49371682, ..., -0.77040515, -1.53078295, -1.49029897], [-0.78752597, -1.09594183, -1.44181614, ..., -1.18242954, -1.3411574 , -1.31938525], [-0.86769088, -1.00490467, -1.55695429, ..., -0.8139259 , -1.61291383, -1.70298382]]])
<xarray.Dataset> Dimensions: (chain: 4, draw: 2000) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 ... 1994 1995 1996 1997 1998 1999 Data variables: step_size (chain, draw) float64 0.6336 0.6336 ... 0.4508 0.4508 depth (chain, draw) int64 3 3 3 3 3 3 3 3 3 ... 3 3 3 3 3 3 3 3 tree_size (chain, draw) float64 7.0 7.0 7.0 7.0 ... 7.0 7.0 7.0 7.0 energy (chain, draw) float64 743.1 743.8 737.4 ... 747.7 738.9 diverging (chain, draw) bool False False False ... False False False max_energy_error (chain, draw) float64 1.28 -0.9798 0.8729 ... -1.79 0.7297 lp (chain, draw) float64 -731.8 -723.1 ... -723.3 -722.2 mean_tree_accept (chain, draw) float64 0.492 0.9858 ... 0.9949 0.8221 energy_error (chain, draw) float64 1.28 -0.9798 ... -1.447 0.1291 step_size_bar (chain, draw) float64 0.5741 0.5741 ... 0.5413 0.5413 Attributes: created_at: 2020-07-13T15:58:56.537859 arviz_version: 0.9.0 inference_library: pymc3 inference_library_version: 3.9.2 sampling_time: 7.236523389816284 tuning_steps: 1000
array([0, 1, 2, 3])
array([ 0, 1, 2, ..., 1997, 1998, 1999])
array([[0.63358241, 0.63358241, 0.63358241, ..., 0.63358241, 0.63358241, 0.63358241], [0.86662909, 0.86662909, 0.86662909, ..., 0.86662909, 0.86662909, 0.86662909], [0.63115372, 0.63115372, 0.63115372, ..., 0.63115372, 0.63115372, 0.63115372], [0.45081558, 0.45081558, 0.45081558, ..., 0.45081558, 0.45081558, 0.45081558]])
array([[3, 3, 3, ..., 3, 3, 3], [3, 3, 3, ..., 3, 3, 3], [3, 4, 3, ..., 3, 3, 3], [3, 3, 3, ..., 3, 3, 3]])
array([[ 7., 7., 7., ..., 7., 7., 7.], [ 7., 7., 7., ..., 7., 7., 7.], [ 7., 15., 7., ..., 7., 7., 7.], [ 7., 7., 7., ..., 7., 7., 7.]])
array([[743.12426913, 743.77450313, 737.35933789, ..., 749.25264614, 749.1399153 , 745.28094151], [734.19673216, 743.15241056, 744.59720533, ..., 750.09657282, 748.155667 , 745.2099136 ], [747.14142929, 745.92810964, 743.50278398, ..., 744.75728395, 744.93887589, 740.15863884], [750.97957434, 743.78059825, 745.8244164 , ..., 754.00172823, 747.73426474, 738.94057143]])
array([[False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False]])
array([[ 1.28014004, -0.97982717, 0.87291351, ..., -0.8907806 , -0.81006402, -0.77364594], [ 1.17183507, 1.17707845, 0.5732771 , ..., 1.1868332 , -0.91636379, 0.98886441], [ 0.97704684, -0.78247891, -0.3707744 , ..., -0.48113788, -0.14936893, -0.73230631], [ 0.98512064, -0.91617271, 0.74799643, ..., -1.29509836, -1.7902586 , 0.72973368]])
array([[-731.82742629, -723.1454253 , -718.09870167, ..., -730.18858585, -730.56431692, -728.9241251 ], [-721.58719521, -727.98519057, -726.89491816, ..., -731.1485986 , -724.38695097, -729.54189481], [-727.62148459, -729.03079646, -726.92892101, ..., -726.44507566, -727.84061619, -722.91636662], [-726.03173614, -727.32633284, -728.10722816, ..., -733.19121559, -723.30355995, -722.19471244]])
array([[0.49201638, 0.98576419, 0.92471489, ..., 0.96579702, 0.89949312, 1. ], [0.61672151, 0.5358579 , 0.89069057, ..., 0.47122146, 0.93385901, 0.75920332], [0.85751288, 0.97041216, 0.9962857 , ..., 0.94796464, 0.95583529, 0.97333745], [0.76590893, 0.90820811, 0.93093164, ..., 0.99834316, 0.99487677, 0.82214533]])
array([[ 1.28014004, -0.97982717, -0.42242177, ..., -0.19829589, 0.03924549, -0.42135793], [ 0.43906038, 0.47727926, 0.04654607, ..., 0.66546256, -0.62695449, 0.6200425 ], [-0.19643991, -0.06091777, -0.15445098, ..., 0.01216838, 0.08805948, -0.30807209], [-0.43321591, 0.01753415, 0.07784562, ..., 0.02176147, -1.44722546, 0.12913911]])
array([[0.5740627 , 0.5740627 , 0.5740627 , ..., 0.5740627 , 0.5740627 , 0.5740627 ], [0.53933499, 0.53933499, 0.53933499, ..., 0.53933499, 0.53933499, 0.53933499], [0.53265979, 0.53265979, 0.53265979, ..., 0.53265979, 0.53265979, 0.53265979], [0.54132966, 0.54132966, 0.54132966, ..., 0.54132966, 0.54132966, 0.54132966]])
<xarray.Dataset> Dimensions: (match: 240) Coordinates: * match (match) int64 0 1 2 3 4 5 6 7 ... 233 234 235 236 237 238 239 Data variables: home_goals (match) int32 3 2 1 0 3 2 1 2 0 1 0 2 ... 0 3 0 0 1 1 1 4 6 2 0 away_goals (match) int32 0 1 2 1 1 0 3 0 2 1 0 2 ... 2 0 0 0 1 4 1 0 0 2 2 Attributes: created_at: 2020-07-13T15:58:57.817213 arviz_version: 0.9.0 inference_library: pymc3 inference_library_version: 3.9.2
array([ 0, 1, 2, ..., 237, 238, 239])
array([3, 2, 1, 0, 3, 2, 1, 2, 0, 1, 0, 2, 1, 0, 0, 2, 2, 0, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 2, 0, 1, 0, 1, 3, 0, 3, 2, 4, 2, 1, 2, 1, 2, 0, 0, 1, 3, 2, 6, 3, 2, 1, 3, 4, 1, 4, 2, 3, 1, 2, 0, 2, 2, 2, 3, 1, 1, 3, 0, 0, 1, 0, 1, 1, 0, 1, 1, 2, 2, 0, 2, 0, 0, 2, 2, 0, 0, 0, 1, 1, 1, 1, 4, 6, 2, 1, 2, 1, 4, 4, 1, 1, 3, 3, 3, 0, 2, 3, 6, 7, 1, 3, 2, 0, 2, 2, 0, 3, 3, 3, 0, 5, 0, 4, 6, 1, 1, 1, 7, 1, 1, 1, 1, 3, 5, 2, 0, 1, 3, 9, 1, 2, 2, 1, 3, 0, 4, 3, 2, 2, 0, 1, 2, 3, 3, 1, 2, 1, 2, 1, 1, 0, 0, 1, 2, 1, 1, 2, 5, 4, 1, 0, 2, 1, 2, 0, 1, 2, 2, 0, 0, 0, 6, 3, 6, 2, 1, 2, 1, 1, 0, 3, 0, 1, 2, 2, 4, 4, 3, 2, 2, 3, 3, 1, 2, 3, 1, 1, 1, 0, 2, 6, 2, 3, 3, 1, 0, 3, 0, 0, 1, 1, 1, 4, 6, 2, 0], dtype=int32)
array([0, 1, 2, 1, 1, 0, 3, 0, 2, 1, 0, 2, 2, 0, 0, 4, 1, 1, 4, 2, 4, 3, 3, 4, 3, 2, 3, 3, 3, 4, 1, 2, 3, 6, 0, 1, 4, 4, 4, 4, 3, 5, 2, 3, 0, 2, 1, 0, 2, 1, 0, 0, 1, 1, 5, 0, 0, 1, 0, 0, 0, 1, 1, 4, 0, 1, 1, 1, 2, 0, 0, 1, 4, 0, 2, 0, 0, 0, 1, 3, 2, 1, 1, 3, 2, 2, 1, 2, 1, 2, 0, 0, 0, 1, 1, 6, 4, 3, 1, 1, 0, 2, 0, 1, 3, 3, 0, 0, 0, 2, 0, 0, 0, 1, 1, 2, 0, 4, 2, 1, 0, 1, 0, 0, 1, 3, 1, 1, 1, 1, 1, 0, 0, 2, 3, 0, 1, 0, 1, 1, 0, 0, 2, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 3, 0, 0, 4, 2, 1, 0, 0, 1, 2, 2, 1, 0, 0, 3, 2, 1, 0, 3, 3, 0, 0, 3, 0, 4, 1, 0, 1, 1, 1, 3, 1, 2, 2, 3, 1, 1, 2, 0, 4, 0, 0, 2, 0, 2, 1, 0, 2, 1, 1, 1, 1, 0, 1, 1, 0, 1, 3, 1, 2, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 2, 0, 0, 1, 2, 0, 0, 0, 1, 4, 1, 0, 0, 2, 2], dtype=int32)
<xarray.Dataset> Dimensions: (match: 240) Coordinates: * match (match) int64 0 1 2 3 4 5 6 7 ... 232 233 234 235 236 237 238 239 Data variables: home_team (match) int32 1 2 9 3 4 5 6 7 8 10 ... 4 5 6 7 8 10 11 12 13 14 away_team (match) int32 0 0 0 0 0 0 0 0 0 0 ... 15 15 15 15 15 15 15 15 15 Attributes: created_at: 2020-07-13T15:58:57.818392 arviz_version: 0.9.0 inference_library: pymc3 inference_library_version: 3.9.2
array([ 0, 1, 2, ..., 237, 238, 239])
array([ 1, 2, 9, 3, 4, 5, 6, 7, 8, 10, 11, 12, 13, 14, 15, 0, 2, 9, 3, 4, 5, 6, 7, 8, 10, 11, 12, 13, 14, 15, 0, 1, 9, 3, 4, 5, 6, 7, 8, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 11, 12, 13, 14, 15, 0, 1, 2, 9, 4, 5, 6, 7, 8, 10, 11, 12, 13, 14, 15, 0, 1, 2, 9, 3, 5, 6, 7, 8, 10, 11, 12, 13, 14, 15, 0, 1, 2, 9, 3, 4, 6, 7, 8, 10, 11, 12, 13, 14, 15, 0, 1, 2, 9, 3, 4, 5, 7, 8, 10, 11, 12, 13, 14, 15, 0, 1, 2, 9, 3, 4, 5, 6, 8, 10, 11, 12, 13, 14, 15, 0, 1, 2, 9, 3, 4, 5, 6, 7, 10, 11, 12, 13, 14, 15, 0, 1, 2, 9, 3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15, 0, 1, 2, 9, 3, 4, 5, 6, 7, 8, 10, 12, 13, 14, 15, 0, 1, 2, 9, 3, 4, 5, 6, 7, 8, 10, 11, 13, 14, 15, 0, 1, 2, 9, 3, 4, 5, 6, 7, 8, 10, 11, 12, 14, 15, 0, 1, 2, 9, 3, 4, 5, 6, 7, 8, 10, 11, 12, 13, 15, 0, 1, 2, 9, 3, 4, 5, 6, 7, 8, 10, 11, 12, 13, 14], dtype=int32)
array([ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15], dtype=int32)
Due to the presence of the two likelihoods in our model, we cannot call az.loo
or az.waic
straight away because the predictive task to evaluate is ambiguous. The calculation of information criteria requires pointwise likelihood values, $p(y_i|\theta)$ with $y_i$ indicating observation $i$-th and $\theta$ representing all the parameters in the model. We need to define $y_i$, what does one observation represent in our model.
As we were introducing above, this model alone can tackle several predictive tasks. These predictive tasks can be identified by the definition of one observation which at the same time defines how are pointwise likelihood values to be calculated. Here are some examples:
There are even more examples of predictive tasks where this particular model can be of use. However, it is important to keep in mind that this model predicts the number of goals scored. Its results can be used to estimate probabilities of victory and other derived quantities, but calculating the likelihood of these derived quantities may not be straighforward. And as you can see above, there isn't one unique predictive task: it all depends on the specific question you're interested in. As often in statistics, the answer to these questions lies outside the model, you must tell the model what to do, not the other way around.
Even though we know that the predictive task is ambiguous, we will start trying to calculate az.loo
with idata_base
and then work on the examples above and a couple more to show how would this kind of tasks be performed with ArviZ. But before that, let's see what ArviZ says when you naively ask it for the LOO of a multi-likelihood model:
az.loo(idata)
--------------------------------------------------------------------------- TypeError Traceback (most recent call last) <ipython-input-8-aaad86725ae1> in <module> ----> 1 az.loo(idata) ~/venvs/arviz-dev/lib/python3.6/site-packages/arviz/stats/stats.py in loo(data, pointwise, var_name, reff, scale) 643 """ 644 inference_data = convert_to_inference_data(data) --> 645 log_likelihood = _get_log_likelihood(inference_data, var_name=var_name) 646 pointwise = rcParams["stats.ic_pointwise"] if pointwise is None else pointwise 647 ~/venvs/arviz-dev/lib/python3.6/site-packages/arviz/stats/stats_utils.py in get_log_likelihood(idata, var_name) 413 if len(var_names) > 1: 414 raise TypeError( --> 415 "Found several log likelihood arrays {}, var_name cannot be None".format(var_names) 416 ) 417 return idata.log_likelihood[var_names[0]] TypeError: Found several log likelihood arrays ['home_goals', 'away_goals'], var_name cannot be None
Note: I guess we can change the error message after the notebook is finished
As expected, ArviZ has no way of knowing what predictive task we have in mind so it raises an error.
In this particular case, we are interested in predicting the goals scored by the away team. We will still use the goals scored by the home team, but won't take them into account when assessing the predictive accuracy. Below there is an illustration of how would cross validation be performed to assess the predictive accuracy in this particular case:
This can also be seen from a mathematical point of view. We can write the pointwise log likelihood in the following way so it defines the predictive task at hand:
$$ p(y_i|\theta) = p(y_{i,h}|\theta_{i,h}) = \text{Poiss}(y_{i,h}; \theta_{i,h}) $$with $i$ being the match indicator ($g$) in this case. These are precisely the values stored in the home_goals
of the log_likelihood
group of idata_base
.
We can tell ArviZ to use these values using the argument var_name
.
az.loo(idata, var_name="home_goals")
Computed from 8000 by 240 log-likelihood matrix Estimate SE elpd_loo -372.15 11.52 p_loo 14.80 - ------ Pareto k diagnostic values: Count Pct. (-Inf, 0.5] (good) 240 100.0% (0.5, 0.7] (ok) 0 0.0% (0.7, 1] (bad) 0 0.0% (1, Inf) (very bad) 0 0.0% The scale is now log by default. Use 'scale' argument or 'stats.ic_scale' rcParam if you rely on a specific value. A higher log-score (or a lower deviance) indicates a model with better predictive accuracy.
Another option is being interested in the outcome of the matches. In our current model, the outcome of a match is not who wins or the aggregate of scored goals by both teams, the outcome is the goals scored by the home team and by the away team, both quantities at the same time. Below there is an illustration on how would cross validation be used to assess the predictive accuracy in this situation:
The one observation in this situation is therefore a vector with two components: $y_i = (y_{i,h}, y_{i,a})$. Like above, we also have $n$ observations. The pointwise likelihood is therefore a product:
$$ p(y_i|\theta) = p(y_{i,h}|\theta_{i,h})p(y_{i,a}|\theta_{i,a}) = \text{Poiss}(y_{i,h}; \theta_{i,h})\text{Poiss}(y_{i,a}; \theta_{i,a}) $$with $i$ still being equal to the match indicator $g$. Therefore, we have $n$ observations like in the previous example, but each observation has two components.
We can calculate the product as a sum of logarithms and store the result in a new variable inside the log_likelihood
group.
log_lik["matches"] = log_lik.home_goals + log_lik.away_goals
az.loo(idata, var_name="matches")
Computed from 8000 by 240 log-likelihood matrix Estimate SE elpd_loo -716.45 15.81 p_loo 27.37 - ------ Pareto k diagnostic values: Count Pct. (-Inf, 0.5] (good) 240 100.0% (0.5, 0.7] (ok) 0 0.0% (0.7, 1] (bad) 0 0.0% (1, Inf) (very bad) 0 0.0% The scale is now log by default. Use 'scale' argument or 'stats.ic_scale' rcParam if you rely on a specific value. A higher log-score (or a lower deviance) indicates a model with better predictive accuracy.
Another example described above is being interested in the scored goals per match and per team. In this situation, our observations are a scalar once again.
The expression of the likelihood is basically the same as the one in the first example (both cases are scalars), but the difference is in the index, but that does not make it less significant:
$$ p(y_i|\theta) = p(y_{i}|\theta_{i}) = \text{Poiss}(y_{i}; \theta_{i}) $$with $i$ not being equal to the match indicator $g$ anymore. Now, we will consider $i$ as an index iterating over the values in
$$\big\{(1,h), (2,h), \dots, (n-1,h), (n,h), (1,a), (2,a) \dots (n-1,a), (n,a)\big\}$$Therefore, unlike in previous cases, we have $2n$ observations.
We can obtain the pointwise log likelihood corresponding to this case by concatenating the pointwise log likelihoods of home_goals
and away_goals
. Then, like in the previous case, store the result in a new variable inside the log_likelihood
group.
log_lik["goals"] = xr.concat((log_lik.home_goals, log_lik.away_goals), "match").rename({"match": "goal"})
az.loo(idata, var_name="goals")
Computed from 8000 by 480 log-likelihood matrix Estimate SE elpd_loo -716.46 17.38 p_loo 27.41 - ------ Pareto k diagnostic values: Count Pct. (-Inf, 0.5] (good) 480 100.0% (0.5, 0.7] (ok) 0 0.0% (0.7, 1] (bad) 0 0.0% (1, Inf) (very bad) 0 0.0% The scale is now log by default. Use 'scale' argument or 'stats.ic_scale' rcParam if you rely on a specific value. A higher log-score (or a lower deviance) indicates a model with better predictive accuracy.
The last example covered here is estimating the predictive accuracy at group level. This can be useful to assess the accuracy of predicting the whole season of a new team. In addition, this can also be used to evaluate the hierarchical part of the model.
Although theoretically possible, importance sampling tends to fail at the group level due to all the observations being too informative. See this post for more details.
In this situation, we could describe the cross validation as excluding a team. When we exclude a team, we will exclude all the matches played by the team, not only the goals scored by the team but the whole match. Here is the illustration:
In the first column, we are excluding "Levante U.D." which in the rows shown only appears once. In the second one, we are excluding "Athletic Club" which appears two times. This goes on following the order of appearance in the away team column.
groupby_sum_home = log_lik.groupby(const.home_team).sum().rename({"home_team": "team"})
groupby_sum_away = log_lik.groupby(const.away_team).sum().rename({"away_team": "team"})
log_lik["teams_match"] = (
groupby_sum_home.home_goals + groupby_sum_home.away_goals +
groupby_sum_away.home_goals + groupby_sum_away.away_goals
)
az.loo(idata, var_name="teams_match")
/home/oriol/venvs/arviz-dev/lib/python3.6/site-packages/arviz/stats/stats.py:683: UserWarning: Estimated shape parameter of Pareto distribution is greater than 0.7 for one or more samples. You should consider using a more robust model, this is because importance sampling is less likely to work well if the marginal posterior and LOO posterior are very different. This is more likely to happen with a non-robust model and highly influential observations. "Estimated shape parameter of Pareto distribution is greater than 0.7 for "
Computed from 8000 by 16 log-likelihood matrix Estimate SE elpd_loo -1436.08 17.96 p_loo 50.30 - There has been a warning during the calculation. Please check the results. ------ Pareto k diagnostic values: Count Pct. (-Inf, 0.5] (good) 0 0.0% (0.5, 0.7] (ok) 0 0.0% (0.7, 1] (bad) 13 81.2% (1, Inf) (very bad) 3 18.8% The scale is now log by default. Use 'scale' argument or 'stats.ic_scale' rcParam if you rely on a specific value. A higher log-score (or a lower deviance) indicates a model with better predictive accuracy.
# this does something different, not sure this approach would make any sense though
home_goals_team = log_lik.home_goals.groupby(const.home_team).sum().rename({"home_team": "team"})
away_goals_team = log_lik.away_goals.groupby(const.away_team).sum().rename({"away_team": "team"})
log_lik["teams"] = home_goals_team + away_goals_team
az.loo(idata, var_name="teams")
/home/oriol/venvs/arviz-dev/lib/python3.6/site-packages/arviz/stats/stats.py:683: UserWarning: Estimated shape parameter of Pareto distribution is greater than 0.7 for one or more samples. You should consider using a more robust model, this is because importance sampling is less likely to work well if the marginal posterior and LOO posterior are very different. This is more likely to happen with a non-robust model and highly influential observations. "Estimated shape parameter of Pareto distribution is greater than 0.7 for "
Computed from 8000 by 16 log-likelihood matrix Estimate SE elpd_loo -718.41 27.93 p_loo 25.65 - There has been a warning during the calculation. Please check the results. ------ Pareto k diagnostic values: Count Pct. (-Inf, 0.5] (good) 0 0.0% (0.5, 0.7] (ok) 3 18.8% (0.7, 1] (bad) 11 68.8% (1, Inf) (very bad) 2 12.5% The scale is now log by default. Use 'scale' argument or 'stats.ic_scale' rcParam if you rely on a specific value. A higher log-score (or a lower deviance) indicates a model with better predictive accuracy.
%load_ext watermark
%watermark -n -u -v -iv -w
xarray 0.16.0 pandas 1.0.5 arviz 0.9.0 pymc3 3.9.2 numpy 1.19.0 last updated: Mon Jul 13 2020 CPython 3.6.9 IPython 7.16.1 watermark 2.0.2