# Imports
%matplotlib inline
%config InlineBackend.figure_format = 'svg'
import sys
import warnings
import numpy as np
import jax
import jax.numpy as jnp
from jax.experimental import optimizers
import numpyro
from numpyro.infer import SVI, Trace_ELBO, TraceMeanField_ELBO, Predictive
import numpyro.distributions as dist
from numpyro import handlers
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style('darkgrid')
numpyro.set_platform('cpu')
numpyro.set_host_device_count(8)
ground_truth_params = {
"slope" : 2.32,
"intercept": 4.11,
"noise_std": 0.5
}
# Define the data
np.random.seed(42)
# Generate random data
n = 51 # Number of samples
# Linear relation
slope_true = ground_truth_params["slope"]
intercept_true = ground_truth_params["intercept"]
fn = lambda x_: x_ * slope_true + intercept_true
# Noise
err = ground_truth_params["noise_std"] * np.random.randn(n) # Noise
# Features and output
x_data = np.random.uniform(-1, 1, n) # Independent variable x
y_data = fn(x_data) + err # Dependent variable
# Show data
plt.figure(figsize=(7, 4), dpi=100)
plt.scatter(x_data, y_data, label='data: $(x,y)$', color='tab:blue')
plt.plot(
[-1, 1], [fn(-1), fn(1)], color='black', linestyle='-',
label=f'$y = {intercept_true:.2f} + {slope_true:.2f} x$')
plt.xlim((-1, 1))
plt.xlabel('$x$')
plt.ylabel('$y$')
plt.title('Noisy data samples from linear line')
plt.legend()
plt.show()
#
def model(x, y):
slope = numpyro.sample('slope', dist.Normal(0., 10.))
intercept = numpyro.sample('intercept', dist.Normal(0., 10.))
noise_std = numpyro.sample('noise_std', dist.Exponential(1.))
with numpyro.plate('obs', x.shape[0]):
y_loc = numpyro.deterministic('y_loc', intercept + slope * x)
numpyro.sample('y', dist.Normal(y_loc, noise_std), obs=y)
def guide(x, y):
slope_loc = numpyro.param("slope_loc", 0.)
slope_scale = numpyro.param("slope_scale", 0.01, constraint=dist.constraints.positive)
slope = numpyro.sample('slope', dist.Normal(slope_loc, slope_scale))
intercept_loc = numpyro.param("intercept_loc", 0.)
intercept_scale = numpyro.param("intercept_scale", 0.01, constraint=dist.constraints.positive)
intercept = numpyro.sample('intercept', dist.Normal(intercept_loc, intercept_scale))
noise_std_log_loc = numpyro.param("noise_std_log_loc", 0.1)
noise_std_scale = numpyro.param("noise_std_scale", 0.01, constraint=dist.constraints.positive)
noise_std = numpyro.sample('noise_std', dist.LogNormal(noise_std_log_loc, noise_std_scale))
# Learning rate schedule
def cosine_annealing(lr_min, lr_max, num_steps, i):
return lr_min + 0.5 * (lr_max - lr_min) * (1 + jnp.cos(jnp.pi * i / num_steps))
num_steps = 5000
lr_max = 2e-3
lr_min = 1e-4
iterations = jnp.arange(num_steps)
lr_steps = cosine_annealing(lr_min, lr_max, num_steps, iterations)
def lr_schedule(idx):
return lr_steps[idx]
# Use clipped Optimizer to deal with unstable gradients
# http://num.pyro.ai/en/stable/optimizers.html#clippedadam
optimizer = numpyro.optim.ClippedAdam(step_size=lr_schedule, clip_norm=1.0)
# setup the inference algorithm
svi = SVI(
model=model,
guide=guide,
optim=optimizer,
loss=TraceMeanField_ELBO(num_particles=1)
)
# Run
svi_result = svi.run(
jax.random.PRNGKey(0),
num_steps=5000,
x=x_data,
y=y_data
)
100%|███████████████████████████████████████████████████████████████████████████████████████| 5000/5000 [00:01<00:00, 3625.13it/s, init loss: 448.5081, avg. loss [4751-5000]: 45.7353]
fig, ax = plt.subplots(1, 1, figsize=(7, 3))
ax.plot(svi_result.losses)
ax.set_title("losses")
ax.set_yscale("symlog")
plt.show()
svi_predictive = Predictive(
guide,
params=svi_result.params,
num_samples=2000
)
posterior_samples = svi_predictive(
jax.random.PRNGKey(0),
x=x_data,
y=y_data
)
fig, axs = plt.subplots(1, len(posterior_samples), figsize=(12, 4))
for ax, (param_name, param_samples) in zip(axs, posterior_samples.items()):
d = sns.histplot(param_samples, kde=True, stat='probability', ax=ax)
ax.set_xlabel(param_name)
ax.set_title(f"Samples from {param_name!r}")
ax.axvline(np.mean(param_samples), color="black", label="mean")
ax.axvline( ground_truth_params[param_name], color="red", label="true")
fig.legend(*ax.get_legend_handles_labels(), bbox_to_anchor=(0., 0.7, 1.0, -.0))
plt.show()
for param_name, param_samples in posterior_samples.items():
param_gt = ground_truth_params[param_name]
param_mean = np.mean(param_samples)
param_std = np.std(param_samples)
param_median = np.median(param_samples)
param_quantile_low, param_quantile_high = np.quantile(param_samples, (.025, .975))
print(f"{param_name:>13}: true={param_gt:.2f}\t median={param_median:.2f}\t 95%-interval: {param_quantile_low:+.2f} - {param_quantile_high:+.2f}\t "
f"mean:{param_mean:.2f}±{param_std:.2f}")
intercept: true=4.11 median=4.00 95%-interval: +3.82 - +4.19 mean:4.00±0.10 noise_std: true=0.50 median=0.49 95%-interval: +0.37 - +0.63 mean:0.49±0.07 slope: true=2.32 median=2.30 95%-interval: +1.97 - +2.65 mean:2.30±0.17
mean_slope = np.mean(posterior_samples["slope"])
mean_intercept = np.mean(posterior_samples["intercept"])
y_mean_pred = jnp.array([-1., 1]) * mean_slope + mean_intercept
# Show mean fit vs data
plt.figure(figsize=(7, 4), dpi=100)
plt.scatter(x_data, y_data, label='data: $(x,y)$', color='tab:blue')
plt.plot(
[-1, 1], [fn(-1), fn(1)], color='black', linestyle='-',
label='true'
)
plt.plot(
[-1, 1], y_mean_pred, color='red', linestyle='-',
label=f'$pred = {mean_intercept:.2f} + {mean_slope:.2f} x$'
)
plt.xlim((-1, 1))
plt.xlabel('$x$')
plt.ylabel('$y$')
plt.title('Mean fit vs ground-truth data')
plt.legend()
plt.show()
def plot_predictions(x_samples, predictions, name):
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(7, 8))
# Plot prior parameters
y_mu_mean = jnp.mean(predictions['y_loc'], 0)
y_mu_pct = jnp.percentile(predictions['y_loc'], q=np.array([5., 95., 0.5, 99.5]), axis=0)
for i in range(min(10, predictions['y_loc'].shape[0])):
yi = predictions['y_loc'][i]
label=None
if i == 0:
label = 'samples'
ax1.plot(x_samples, yi, color='tab:gray', linestyle='-', alpha=0.5, label=label)
ax1.plot(x_samples, y_mu_mean, color='tab:blue', linestyle='-', label='mean($\mu_y$)', linewidth=2)
ax1.fill_between(x_samples, y_mu_pct[0], y_mu_pct[1], color='tab:blue', alpha=0.2, label='$\mu_y \; 90\%$')
ax1.fill_between(x_samples, y_mu_pct[2], y_mu_pct[3], color='tab:blue', alpha=0.1, label='$\mu_y \; 99\%$')
ax1.set_xlim((-1, 1))
ax1.set_xlabel('$x$')
ax1.set_ylabel('$y$')
ax1.set_title(f'{name} parameter distribution')
ax1.legend(loc='lower right')
# Plot prior predictions
y_mean = jnp.mean(predictions['y'], 0)
y_pct = jnp.percentile(predictions['y'], q=np.array([5., 95., 0.5, 99.5]), axis=0)
# Plot samples
for i in range(min(100, predictions['y'].shape[0])):
yi = predictions['y'][i]
label=None
if i == 0:
label = 'samples'
ax2.plot(x_samples, yi, color='tab:blue', marker='o', alpha=0.03, label=label)
ax2.plot(x_samples, y_mean, 'k-', label='mean($y$)')
ax2.fill_between(x_samples, y_pct[0], y_pct[1], color='k', alpha=0.2, label='$y \; 90\%$')
ax2.fill_between(x_samples, y_pct[2], y_pct[3], color='k', alpha=0.1, label='$y \; 99\%$')
ax2.set_xlim((-1, 1))
ax2.set_xlabel('$x$')
ax2.set_ylabel('$y$')
ax2.set_title(f'{name} predictive distribution')
ax2.legend(loc='lower right')
plt.tight_layout()
plt.show()
# Get posterior predictive samples
# https://forum.pyro.ai/t/svi-version-of-mcmc-get-samples/3069/4
posterior_predictive = Predictive(
model=model,
guide=guide,
params=svi_result.params,
num_samples=1000
)
x_samples = np.linspace(-1, 1, 100)
posterior_predictions = posterior_predictive(jax.random.PRNGKey(42), x=x_samples, y=None)
plot_predictions(x_samples, posterior_predictions, 'Posterior')