from IPython.display import HTML
from IPython.display import display
display(HTML("<style>.container { width:70% !important; }</style>"))
%matplotlib inline
import numpy as np, scipy, scipy.stats as stats, pandas as pd, matplotlib.pyplot as plt, seaborn as sns
import statsmodels, statsmodels.api as sm
import sympy, sympy.stats
import pymc3 as pm
import daft
import xarray, numba, arviz as az
pd.set_option('display.max_columns', 500)
pd.set_option('display.width', 1000)
# pd.set_option('display.float_format', lambda x: '%.2f' % x)
np.set_printoptions(edgeitems=10)
np.set_printoptions(linewidth=1000)
np.set_printoptions(suppress=True)
np.core.arrayprint._line_width = 180
SEED = 42
np.random.seed(SEED)
sns.set()
import warnings
warnings.filterwarnings("ignore")
This blog post is part of the Series: Monte Carlo Methods.
You can find this blog post on weisser-zwerg.dev or on github as either html or via nbviewer.
I will not explain the details of why Markov chain Monte Carlo (MCMC) works, because other people have done already a very good job at it. Have a look below at the Resources section. I personally found the following two books very helpful, because they explain the mechanics of MCMC in a discrete set-up, which at least for me helps in building an intuition.
But all the other mentioned references are very helpful, too.
A Markov chain starts with the initial distribution of starting states $\pi_0(X)$. If you start with a concrete single state then this will be a dirac-delta $\pi_0(X) = \delta(X=x_0)$. In addition you have a so called transition operator $\mathcal{T}(x\to x')$. This transition operator is a simple conditional probability just in reverse notation $\mathcal{T}(x\to x')\equiv p(x'\,|\,x)$. In addition this transitio operator is kalled a kernel $\mathcal{K}(x\to x')\equiv \mathcal{T}(x\to x')\equiv p(x'\,|\,x)$. If you then apply the transition operator / kernel several times in sequence and marginalize over all earlier state variables then you get a sequence of distributions
$ \begin{array}{rcl} \pi_1(x')&=&\int\pi_0(x)\cdot p(x'\,|\,x)dx=\int\pi_0(x)\cdot \mathcal{T}(x\to x')dx=\int\pi_0(x)\cdot \mathcal{K}(x\to x')dx=\int p(x, x')dx \end{array} $
$\pi_0(X)\to \pi_1(X)\to .... \to \pi_N(X)$.
Just as a recap: the following vocabulary is typically used in conjunction with the conditions that have to hold to make MCMC work:
But in the end the goal is to construct a Markov chain that has the target distribution as its stationary (aka invariant) / equilibrium distribution and will converge to this equilibrium distribution in the long run no matter from where you start.
Stationarity is defined by:
$ \begin{array}{rcl} \pi_S(X=x')&=&\displaystyle\sum_{x\in Val(X)}\pi_S(X=x)\mathcal{T}(x\to x')\\ &=&\displaystyle\int\pi_S(x)\mathcal{T}(x\to x')dx \end{array} $
The first line is for the discrete case and the second one is for the continuous case. The subscript $S$ in $\pi_S$ is for stationary.
So in our construction process of a transition operator for a target distribution you start by restricting yourself to transition operators that satisfy detailed balance w.r.t. a particular target distribution, because then that distribution will be the stationary / invariant distribution.
The following equation is the detailed-balance equation:
$ \begin{array}{rcl} \displaystyle\pi_S(X=x)\mathcal{T}(x\to x')&=&\displaystyle\pi_S(X=x')\mathcal{T}(x'\to x)\\ \end{array} $
A Markov chain that respects detailed-balance is said to be reversible. Reversibility implies that $\pi_S$ is a stationary distribution of $\mathcal{T}$
In addition we use homogeneity. A Markov chain is homogeneous if the transition operators are the same (constant) for every transition we make $\mathcal{T}_1=\mathcal{T}_2=...=\mathcal{T}_N=\mathcal{T}$. It can be shown that a homogeneous Markov chain that possesses a stationary distribution (guaranteed via detailed-balance) will converge to the single equilibrium distribution from any starting point, subject only to weak restrictions on the stationary distribution and the transition probabilities (Neal 1993)
In the discrete case regularity (see Probabilistic Graphical Models: Principles and Techniques) plus detailed-balance guarantee convergence to its stationary distribution.
My goal for this blog post is to look deeper into how to combine elements of algorithms in the sense of fine-grained composable abstractions (FCA), which does not get a lot of attention in the other resources.
A good starting point for that is to look at the detailed balance equation.
# Flip coin 9 times and get 6 heads
samples = np.array([0,0,0,1,1,1,1,1,1])
def fn(a, b):
return lambda p: stats.bernoulli(p).pmf(samples).prod() * stats.beta(a,b).pdf(p)
# convert from omega, kappa parametrization of the beta distribution to the alpha, beta parametrization
def ok2ab(omega, kappa):
return omega*(kappa-2)+1,(1-omega)*(kappa-2)+1
@numba.jit(nopython=True)
def bernoulli(p, samples):
r = np.zeros_like(samples, dtype=numba.float64)
for i in range(len(r)):
if samples[i] < 0.5: # == 0
r[i] = np.log(1-p)
else: # == 1
r[i] = np.log(p)
return np.sum(r)
bernoulli(0.001, samples)
# verify that our implementation delivers the same results as the scipy implementation
stats.bernoulli(0.001).logpmf(samples).sum()
@numba.jit(nopython=True)
def logpdf(p):
return bernoulli(p,samples)
@numba.jit(nopython=True)
def mcmc(q_rvs, unif_rvs, trace, logpdf):
p = 0.5
for i in range(N):
rv = q_rvs[i]
unifrand = unif_rvs[i]
p_new = p + rv
log_hastings_ratio = logpdf(p_new) - logpdf(p) # is only correct, because rv is from a symmetric distribution otherwise the "Hastings q(y,x)/q(x,y) is missing"
if log_hastings_ratio >= 0.0 or unifrand < np.exp(log_hastings_ratio):
p = p_new
trace[i] = p
return trace
N = 10000
q_rvs = stats.norm(0,0.3).rvs(size=N, random_state=np.random.RandomState(42))
unif_rvs = stats.uniform.rvs(size=N, random_state=np.random.RandomState(42))
trace = np.zeros(N)
mcmc(q_rvs, unif_rvs, trace, logpdf)
trace
datadict = {'p': trace}
dataset = az.convert_to_inference_data(datadict)
dataset
az.summary(dataset)
az.plot_trace(dataset)
def kdeplot(lds, parameter_name=None, x_min = None, x_max = None, ax=None, kernel='gau'):
if parameter_name is None and isinstance(lds, pd.Series):
parameter_name = lds.name
kde = sm.nonparametric.KDEUnivariate(lds)
kde.fit(kernel=kernel, fft=False, gridsize=2**10)
if x_min is None:
x_min = kde.support.min()
else:
x_min = min(kde.support.min(), x_min)
if x_max is None:
x_max = kde.support.max()
else:
x_max = max(kde.support.max(), x_max)
x = np.linspace(x_min, x_max,100)
y = kde.evaluate(x)
if ax is None:
plt.figure(figsize=(6, 3), dpi=80, facecolor='w', edgecolor='k')
ax = plt.subplot(1, 1, 1)
ax.plot(x, y, lw=2)
ax.set_xlabel(parameter_name)
ax.set_ylabel('Density')
plt.figure(figsize=(15, 8), dpi=80, facecolor='w', edgecolor='k')
ax = plt.subplot(1, 1, 1)
kdeplot(trace, x_min=0.0, x_max=1.0, ax=ax)
x = np.linspace(0.0,1.0,100)
y = stats.beta(1+6,1+3).pdf(x)
ax.plot(x,y)
with pm.Model() as model:
p = pm.Beta('p', 1.0, 1.0)
yl = pm.Bernoulli('yl', p, observed=samples)
prior = pm.sample_prior_predictive()
posterior = pm.sample(return_inferencedata=True)
posterior_pred = pm.sample_posterior_predictive(posterior)
pm.summary(posterior)
ldf = pd.DataFrame(datadict)
ldf['w'] = 1.0/len(ldf)
ldf = ldf.sort_values('p').set_index('p')
ldf['c'] = ldf['w'].cumsum()
ldf1 = ldf
ldf1
ldf = posterior['posterior']['p'].loc[dict(chain=0)].to_dataframe()
ldf = ldf.drop('chain', axis=1)
ldf['w'] = 1.0/len(ldf)
ldf = ldf.sort_values('p').set_index('p')
ldf['c'] = ldf['w'].cumsum()
ldf2 = ldf
ldf2
plt.figure(figsize=(15, 8), dpi=80, facecolor='w', edgecolor='k')
ax = plt.subplot(1, 1, 1)
x = np.linspace(0.0,1.0,100)
y = stats.beta(1+6,1+3).cdf(x)
ax.plot(x,y, label='exact')
ldf1.loc[0.0:1.0, 'c'].plot(ax=ax, label='self-made MCMC')
ldf2.loc[0.0:1.0, 'c'].plot(ax=ax, label='PyMC3')
ax.legend()
# pip install numpyro[cpu]
import numpyro, numpyro.distributions, numpyro.infer
import jax
numpyro.set_platform('cpu')
numpyro.set_host_device_count(4)
def coin_flip_example(y=None):
p = numpyro.sample('p', numpyro.distributions.Beta(1,1))
numpyro.sample('obs', numpyro.distributions.Bernoulli(p), obs=y)
nuts_kernel = numpyro.infer.NUTS(coin_flip_example)
sample_kwargs = dict(
sampler=nuts_kernel,
num_warmup=1000,
num_samples=1000,
num_chains=4,
chain_method="parallel"
)
mcmc = numpyro.infer.MCMC(**sample_kwargs)
rng_key = jax.random.PRNGKey(0)
mcmc.run(rng_key, y=samples, extra_fields=('potential_energy',))
mcmc.print_summary()
pyro_trace = az.from_numpyro(mcmc)
pyro_trace
az.summary(pyro_trace)
az.plot_trace(pyro_trace)
Text books:
Tutorial:
Blog posts:
Papers: