#!/usr/bin/env python # coding: utf-8 # # Causal Inference # # *(Better displayed in [nbviewer](https://nbviewer.jupyter.org/) as red warnings in font tag may not be displayed on github)* # # This notebook runs a set of illustrative examples of causal inference taken from [Ferenc Huszar blogpost](https://www.inference.vc/causal-inference-2-illustrating-interventions-in-a-toy-example/). For the explanation and interepretation of the results shown below, please refer to the excellent blogpost. # ## Setup # Importing libraries. # In[1]: import pymc3 as pm import numpy as np import matplotlib.pyplot as plt import seaborn as sns from scipy import stats # Setting parameters for running inference algorithms. # In[2]: n_samples=1000 # Defining a couple of support functions for visualization. # In[3]: def jointplot(x,y,color,title): g = sns.jointplot(x, y, color=color) g.annotate(stats.pearsonr) g.set_axis_labels(xlabel='x', ylabel='y') g.fig.suptitle(title) def kdeplot(y,color,title): g = sns.kdeplot(y, color=color, label=title) # ## Observational models # We define the three basic programs/models as presented in [here](https://www.inference.vc/causal-inference-2-illustrating-interventions-in-a-toy-example/) (see the blogpost for a clear representation of these models). # In[4]: def model1(): with pm.Model(): x = pm.Normal('x', mu=0, sd=1) y = pm.Deterministic('y',x + 1 + np.sqrt(3)*pm.Normal('n0', mu=0, sd=1)) trace = pm.sample(n_samples) return trace['x'],trace['y'] def model2(): with pm.Model(): y = pm.Deterministic('y', 1 + 2*pm.Normal('n0', mu=0, sd=1)) x = pm.Deterministic('x', (y-1)/4 + np.sqrt(3)*pm.Normal('n1', mu=0, sd=1)/2) trace = pm.sample(n_samples) return trace['x'],trace['y'] def model3(): with pm.Model(): z = pm.Normal('z', mu=0, sd=1) y = pm.Deterministic('y',z + 1 + np.sqrt(3)*pm.Normal('n0', mu=0, sd=1)) x = pm.Deterministic('x',z) trace = pm.sample(n_samples) return trace['x'],trace['y'] # We use PyMC3 to sample and plot the joint distribution $P(X,Y)$ for the three models. # In[5]: x,y = model1() jointplot(x, y, color='blue', title='Observational P(X,Y) for model1') x,y = model2() jointplot(x, y, color='green', title='Observational P(X,Y) for model2') x,y = model3() jointplot(x, y, color='red', title='Observational P(X,Y) for model3') # As expected the three models show the same **observational** joint distribution. # ## Observational model under conditioning # We now analyze the behaviour of these three models under the **observation** $X=3$. We use again PyMC3 to redefine the models with the conditioning $X=3$. (Notice that the models has been slightly reformulated in order to get rid of the *Deterministic* object which can not be conditoned in PyMC3). # In[6]: def model1_observe_X_3(): with pm.Model(): x = pm.Normal('x', mu=0, sd=1,observed=3) y = pm.Deterministic('y',x + 1 + np.sqrt(3)*pm.Normal('n0', mu=0, sd=1)) trace = pm.sample(n_samples) return trace['y'] def model2_observe_X_3(): with pm.Model(): y = pm.Deterministic('y', 1 + 2*pm.Normal('n0', mu=0, sd=1)) x = pm.Normal('x', mu=(y-1)/4.0, sd=3/4.0, observed=3) trace = pm.sample(n_samples) return trace['y'] def model3_observe_X_3(): with pm.Model(): z = pm.Normal('z', mu=0, sd=1,observed=3) y = pm.Deterministic('y',z + 1 + np.sqrt(3)*pm.Normal('n0', mu=0, sd=1)) x = pm.Deterministic('x',z) trace = pm.sample(n_samples) return trace['y'] # We use again PyMC3 to sample from these models and estimate $P(Y \vert X=3)$ for the three models. # In[7]: y1 = model1_observe_X_3() #jointplot(3*np.ones(y1.shape[0]), y1, color='blue', title='Observational P(X,Y | X=3) for model1') y2 = model2_observe_X_3() #jointplot(3*np.ones(y2.shape[0]), y2, color='green', title='Observational P(X,Y | X=3) for model2') y3 = model3_observe_X_3() #jointplot(3*np.ones(y3.shape[0]), y3, color='red', title='Observational P(X,Y | X=3) for model3') plt.figure() plt.title('Observational P(Y | X=3)') kdeplot(y1, color='blue', title='model1') kdeplot(y2, color='green', title='model2') kdeplot(y3, color='red', title='model3') plt.legend() # The **observational** conditional distributions $P(Y \vert X=3)$ of the three models are (approximately) the same. (Need to check *model2* which is slighlty shifted...) # ## Interventional model under do-action # We now analyze the behaviour of these three models under the **intervention** $X=3$. We redefine the models and we force $X=3$ (formally, this amount to redefine the model under *mutilation* with the *structural equation* for $X$ fixed to 3). # In[8]: def model1_do_X_3(): with pm.Model(): x = 3 y = pm.Deterministic('y',x + 1 + np.sqrt(3)*pm.Normal('n0', mu=0, sd=1)) trace = pm.sample(n_samples) return trace['y'] def model2_do_X_3(): with pm.Model(): y = pm.Deterministic('y', 1 + 2*pm.Normal('n0', mu=0, sd=1)) x = 3 trace = pm.sample(n_samples) return trace['y'] def model3_do_X_3(): with pm.Model(): z = pm.Normal('z', mu=0, sd=1) y = pm.Deterministic('y',z + 1 + np.sqrt(3)*pm.Normal('n0', mu=0, sd=1)) x = 3 trace = pm.sample(n_samples) return trace['y'] # We now sample from the interventional models and estimate $P(Y \vert do(X=3))$. # In[9]: y1 = model1_do_X_3() #jointplot(3*np.ones(y1.shape[0]), y1, color='blue', title='Interventional P(X,Y | do(X=3)) for model1') y2 = model2_do_X_3() #jointplot(3*np.ones(y2.shape[0]), y2, color='green', title='Interventional P(X,Y | do(X=3)) for model2') y3 = model3_do_X_3() #jointplot(3*np.ones(y3.shape[0]), y3, color='red', title='Interventional P(X,Y | do(X=3)) for model3') plt.figure() plt.title('Interventional P(Y | do(X=3))') kdeplot(y1, color='blue', title='model1') kdeplot(y2, color='green', title='model2') kdeplot(y3, color='red', title='model3') plt.legend() # The **interventional** distributions $P(Y \vert do(X=3))$ are not the same anymore. # ## Interventional model evaluated from the observational model via do-calculus # We now analyze evaluate the three models under the **intervention** $X=3$ relying on the **observational** model transformed via *do-calculus*. We redefine the models as the models derived by *do-calculus* and then we observe their behaviour under *conditioning* $X=3$. # In[10]: def model1_docalculus_X_3(): with pm.Model(): x = pm.Normal('x', mu=0, sd=1,observed=3) y = pm.Deterministic('y',x + 1 + np.sqrt(3)*pm.Normal('n0', mu=0, sd=1)) trace = pm.sample(n_samples) return trace['y'] def model2_docalculus_X_3(): with pm.Model(): y = pm.Deterministic('y', 1 + 2*pm.Normal('n0', mu=0, sd=1)) trace = pm.sample(n_samples) return trace['y'] def model3_docalculus_X_3(): with pm.Model(): z = pm.Normal('z', mu=0, sd=1) y = pm.Deterministic('y',z + 1 + np.sqrt(3)*pm.Normal('n0', mu=0, sd=1)) trace = pm.sample(n_samples) return trace['y'] # We now sample from the models computed via *do-calculus* and estimate $P(Y \vert X=3)$. # In[11]: y1 = model1_docalculus_X_3() #jointplot(3*np.ones(y1.shape[0]), y1, color='blue', title='Interventional P(X,Y | do(X=3)) for model1') y2 = model2_docalculus_X_3() #jointplot(3*np.ones(y2.shape[0]), y2, color='green', title='Interventional P(X,Y | do(X=3)) for model2') y3 = model3_docalculus_X_3() #jointplot(3*np.ones(y3.shape[0]), y3, color='red', title='Interventional P(X,Y | do(X=3)) for model3') plt.figure() plt.title('Interventional P(Y | do(X=3)) evaluated from observational data via do-calculus') kdeplot(y1, color='blue', title='model1') kdeplot(y2, color='green', title='model2') kdeplot(y3, color='red', title='model3') plt.legend() # The estimated distributions match the estimation of $P(Y \vert do(X=3))$. In other words, we can estimate the **interventional** distribution $P(Y \vert do(X=3))$ as the conditional **observational** distribution $P(Y \vert X=3)$ on the model derived via *do-calculus*. # ## Individual counterfactual # We now examine individual **counterfactuals**. To do this, we sample from the a model, and then we perform the **intervention** $X=3$ *while keeping everything else unmodified* (formally, this amount to sampling from the model, keeping the value of the *exogenous nodes* of the *SEM model*, performing the intervention, and then computing the counterfactual of interest). # In[12]: def model1_counterfactual_X_3(): with pm.Model(): x = pm.Normal('x', mu=0, sd=1) y = pm.Deterministic('y',x + 1 + np.sqrt(3)*pm.Normal('n0', mu=0, sd=1)) trace = pm.sample(1,chains=1) factual_x = trace['x'] factual_y = trace['y'] factual_n0 = trace['n0'] counterfactual_x = 3 counterfactual_y = counterfactual_x + 1 + factual_n0 return factual_x, factual_y, counterfactual_x, counterfactual_y def model2_counterfactual_X_3(): with pm.Model(): y = pm.Deterministic('y', 1 + 2*pm.Normal('n0', mu=0, sd=1)) x = pm.Deterministic('x', (y-1)/4 + np.sqrt(3)*pm.Normal('n1', mu=0, sd=1)/2) trace = pm.sample(1,chains=1) factual_x = trace['x'] factual_y = trace['y'] factual_n0 = trace['n0'] factual_n1 = trace['n1'] counterfactual_y = factual_y counterfactual_x = 3 return factual_x, factual_y, counterfactual_x, counterfactual_y def model3_counterfactual_X_3(): with pm.Model(): z = pm.Normal('z', mu=0, sd=1) y = pm.Deterministic('y',z + 1 + np.sqrt(3)*pm.Normal('n0', mu=0, sd=1)) x = pm.Deterministic('x',z) trace = pm.sample(1,chains=1) factual_z = trace['x'] factual_x = trace['x'] factual_y = trace['y'] factual_n0 = trace['n0'] counterfactual_y = factual_z + 1 + np.sqrt(3)*factual_n0 counterfactual_x = 3 return factual_x, factual_y, counterfactual_x, counterfactual_y # We now run some simple simulations to evaluate how the value of the variables change. # In[13]: factual_x, factual_y, counterfactual_x, counterfactual_y = model1_counterfactual_X_3() print('MODEL1 - Factual (x,y): {0}, {1}'.format(factual_x, factual_y)) print('MODEL1 - Counterfactual (x,y): {0}, {1}'.format(counterfactual_x, counterfactual_y)) factual_x, factual_y, counterfactual_x, counterfactual_y = model1_counterfactual_X_3() print('MODEL1 - Factual (x,y): {0}, {1}'.format(factual_x, factual_y)) print('MODEL1 - Counterfactual (x,y): {0}, {1}'.format(counterfactual_x, counterfactual_y)) factual_x, factual_y, counterfactual_x, counterfactual_y = model2_counterfactual_X_3() print('MODEL2 - Factual (x,y): {0}, {1}'.format(factual_x, factual_y)) print('MODEL2 - Counterfactual (x,y): {0}, {1}'.format(counterfactual_x, counterfactual_y)) factual_x, factual_y, counterfactual_x, counterfactual_y = model2_counterfactual_X_3() print('MODEL2 - Factual (x,y): {0}, {1}'.format(factual_x, factual_y)) print('MODEL2 - Counterfactual (x,y): {0}, {1}'.format(counterfactual_x, counterfactual_y)) factual_x, factual_y, counterfactual_x, counterfactual_y = model3_counterfactual_X_3() print('MODEL3 - Factual (x,y): {0}, {1}'.format(factual_x, factual_y)) print('MODEL3 - Counterfactual (x,y): {0}, {1}'.format(counterfactual_x, counterfactual_y)) factual_x, factual_y, counterfactual_x, counterfactual_y = model3_counterfactual_X_3() print('MODEL3 - Factual (x,y): {0}, {1}'.format(factual_x, factual_y)) print('MODEL3 - Counterfactual (x,y): {0}, {1}'.format(counterfactual_x, counterfactual_y)) # Notice how the **counterfactual** value of $Y$ changes in *model1* wrt to its **factual** value; this means that, in *model1*, the value of $Y$ would change if we were to perform the intervention $X=3$ while keeping everything else the same. # # Differently the **counterfactual** value of $Y$ does NOT change in *model2* and *model3* wrt to its **factual** value; this means that, in *model2* and *model3*, the value of $Y$ would NOT change if we were to perform the intervention $X=3$ while keeping everything else the same. # # These results make sense when we consider that under the intervention $do(X=3)$, the variables $X$ and $Y$ become independent in *model2* and *model3* (formally, this can be seen by performing the *mutilation* of the *SEM* graph). It is not surprising then, that if we intervene on $X$, the value of $Y$ simply remains the same (because of its independence from $X$). # ## Counterfactuals # We consider now the distribution of of **counterfactuals** $P(Y^* \vert X^*=3)$ (using the notation in the [blogpost](https://www.inference.vc/causal-inference-3-counterfactuals/)). We use the same approach as before. # In[14]: def model1_counterfactual_X_3(): with pm.Model(): x = pm.Normal('x', mu=0, sd=1) y = pm.Deterministic('y',x + 1 + np.sqrt(3)*pm.Normal('n0', mu=0, sd=1)) trace = pm.sample(n_samples*2,chains=1) factual_x = trace['x'] factual_y = trace['y'] factual_n0 = trace['n0'] counterfactual_x = 3*np.ones(n_samples*2) counterfactual_y = counterfactual_x + 1 + factual_n0 return factual_x, factual_y, counterfactual_x, counterfactual_y def model2_counterfactual_X_3(): with pm.Model(): y = pm.Deterministic('y', 1 + 2*pm.Normal('n0', mu=0, sd=1)) x = pm.Deterministic('x', (y-1)/4 + np.sqrt(3)*pm.Normal('n1', mu=0, sd=1)/2) trace = pm.sample(n_samples*2,chains=1) factual_x = trace['x'] factual_y = trace['y'] factual_n0 = trace['n0'] factual_n1 = trace['n1'] counterfactual_y = factual_y counterfactual_x = 3*np.ones(n_samples*2) return factual_x, factual_y, counterfactual_x, counterfactual_y def model3_counterfactual_X_3(): with pm.Model(): z = pm.Normal('z', mu=0, sd=1) y = pm.Deterministic('y',z + 1 + np.sqrt(3)*pm.Normal('n0', mu=0, sd=1)) x = pm.Deterministic('x',z) trace = pm.sample(n_samples*2,chains=1,verbose=-1) factual_z = trace['x'] factual_x = trace['x'] factual_y = trace['y'] factual_n0 = trace['n0'] counterfactual_y = factual_z + 1 + np.sqrt(3)*factual_n0 counterfactual_x = 3*np.ones(n_samples*2) return factual_x, factual_y, counterfactual_x, counterfactual_y # We sample and evaluate the *observational* distribution $P(Y\vert X)$ and the related *counterfactual* distribution $P(Y^*\vert X^*=3)$ under the intervention $do(X=3)$. # In[15]: factual_x1, factual_y1, counterfactual_x1, counterfactual_y1 = model1_counterfactual_X_3() factual_x2, factual_y2, counterfactual_x2, counterfactual_y2 = model2_counterfactual_X_3() factual_x3, factual_y3, counterfactual_x3, counterfactual_y3 = model3_counterfactual_X_3() plt.figure() plt.title('Observational P(Y|X)') kdeplot(factual_y1, color='blue', title='model1') kdeplot(factual_y2, color='green', title='model2') kdeplot(factual_y3, color='red', title='model3') plt.legend() plt.figure() plt.title('Counterfactual P(Y* | X*=3)') kdeplot(counterfactual_y1, color='blue', title='model1') kdeplot(counterfactual_y2, color='green', title='model2') kdeplot(counterfactual_y3, color='red', title='model3') plt.legend() # Consistently with the result before, the marginal distribution of $Y$ does not change in the **observational** and in the **counterfactual** model for *model2* and *model3*, due to the fact that under intervention $X$ and $Y$ are independent. # # Instead, for *model1*, we actually register a different between the **observational** and in the **counterfactual** model, due to the fact that intervening on $X$ does affect the outcome of $Y$.