# Imports
%matplotlib inline
%config InlineBackend.figure_format = 'svg'
import sys
import warnings
import numpy as np
import jax
import jax.numpy as jnp
from jax.experimental import optimizers
import numpyro
from numpyro.infer import SVI, Trace_ELBO, TraceMeanField_ELBO, Predictive
import numpyro.distributions as dist
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style('darkgrid')
numpyro.set_platform('cpu')
numpyro.set_host_device_count(8)
rng_key = jax.random.PRNGKey(42)
ground_truth_params = {
"slope" : 2.32,
"intercept": 4.11,
"noise_std": 0.5
}
# Define the data
rng_key, rng_key_ = jax.random.split(rng_key)
np.random.seed(rng_key_)
# Generate random data
n = 51 # Number of samples
# Underlying linear relation
slope_true = ground_truth_params["slope"]
intercept_true = ground_truth_params["intercept"]
fn = lambda x_: x_ * slope_true + intercept_true
# Noise
err = ground_truth_params["noise_std"] * np.random.randn(n) # Noise
# Features and output
x_data = np.random.uniform(-1, 1, n) # Independent variable x
y_data = fn(x_data) + err # Dependent variable
# Show data
plt.figure(figsize=(7, 4), dpi=100)
plt.scatter(x_data, y_data, label='data: $(x,y)$', color='tab:blue')
plt.plot(
[-1, 1], [fn(-1), fn(1)], color='black', linestyle='-',
label=f'$y = {intercept_true:.2f} + {slope_true:.2f} x$')
plt.xlim((-1, 1))
plt.xlabel('$x$')
plt.ylabel('$y$')
plt.title('Noisy data samples from linear line')
plt.legend()
plt.show()
#
def model(x, y):
slope = numpyro.sample('slope', dist.Normal(0., 10.))
intercept = numpyro.sample('intercept', dist.Normal(0., 10.))
noise_std = numpyro.sample('noise_std', dist.Exponential(1.))
with numpyro.plate('obs', x.shape[0]):
y_loc = numpyro.deterministic('y_loc', intercept + slope * x)
numpyro.sample('y', dist.Normal(y_loc, noise_std), obs=y)
def guide(x, y):
slope_loc = numpyro.param("slope_loc", 0.)
slope_scale = numpyro.param("slope_scale", 0.01, constraint=dist.constraints.positive)
slope = numpyro.sample('slope', dist.Normal(slope_loc, slope_scale))
intercept_loc = numpyro.param("intercept_loc", 0.)
intercept_scale = numpyro.param("intercept_scale", 0.01, constraint=dist.constraints.positive)
intercept = numpyro.sample('intercept', dist.Normal(intercept_loc, intercept_scale))
noise_std_rate = numpyro.param("noise_std_rate", 1., constraint=dist.constraints.positive)
noise_std = numpyro.sample('noise_std', dist.Exponential(noise_std_rate))
# Learning rate schedule
def cosine_annealing(lr_min, lr_max, num_steps, i):
return lr_min + 0.5 * (lr_max - lr_min) * (1 + jnp.cos(jnp.pi * i / num_steps))
num_steps = 5000
lr_max = 2e-3
lr_min = 1e-4
iterations = jnp.arange(num_steps)
lr_steps = cosine_annealing(lr_min, lr_max, num_steps, iterations)
def lr_schedule(idx):
return lr_steps[idx]
%%time
# Use clipped Optimizer to deal with unstable gradients
# http://num.pyro.ai/en/stable/optimizers.html#clippedadam
optimizer = numpyro.optim.ClippedAdam(step_size=lr_schedule, clip_norm=1.0)
# setup the inference algorithm
svi = SVI(
model=model,
guide=guide,
optim=optimizer,
loss=TraceMeanField_ELBO()
)
# Run
svi_result = svi.run(
jax.random.PRNGKey(0),
num_steps=5000,
x=x_data,
y=y_data
)
100%|█████████████████████████████████████████████████████████████████████████████████| 5000/5000 [00:01<00:00, 3510.10it/s, init loss: 8736.8174, avg. loss [4751-5000]: 2151657.2500]
CPU times: user 2.14 s, sys: 80.6 ms, total: 2.22 s Wall time: 2.18 s
fig, ax = plt.subplots(1, 1, figsize=(7, 3))
ax.plot(svi_result.losses)
ax.set_title("losses")
ax.set_yscale("symlog")
plt.show()
svi_predictive = Predictive(
guide,
params=svi_result.params,
num_samples=2000
)
posterior_samples = svi_predictive(
jax.random.PRNGKey(0),
x=x_data,
y=y_data
)
fig, axs = plt.subplots(1, len(posterior_samples), figsize=(12, 4))
for ax, (param_name, param_samples) in zip(axs, posterior_samples.items()):
d = sns.histplot(param_samples, kde=True, stat='probability', ax=ax)
ax.set_xlabel(param_name)
ax.set_title(f"Samples from {param_name!r}")
ax.axvline(np.mean(param_samples), color="black", label="mean")
ax.axvline( ground_truth_params[param_name], color="red", label="true")
fig.legend(*ax.get_legend_handles_labels(), bbox_to_anchor=(0., 0.7, 1.13, -.0))
plt.show()
for param_name, param_samples in posterior_samples.items():
param_gt = ground_truth_params[param_name]
param_mean = np.mean(param_samples)
param_std = np.std(param_samples)
param_median = np.median(param_samples)
param_quantile_low, param_quantile_high = np.quantile(param_samples, (.025, .975))
print(f"{param_name}: true={param_gt:+.2f} \tmedian={param_median:+.2f} \t95%-interval:{param_quantile_low:+.2f} - {param_quantile_high:+.2f} \t"
f"mean:{param_mean:+.2f}±{param_std:.2f}")
intercept: true=+4.11 median=+4.18 95%-interval:+3.93 - +4.43 mean:+4.18±0.13 noise_std: true=+0.50 median=+0.80 95%-interval:+0.02 - +3.97 mean:+1.12±1.12 slope: true=+2.32 median=+2.21 95%-interval:+1.73 - +2.71 mean:+2.21±0.25
mean_slope = np.mean(posterior_samples["slope"])
mean_intercept = np.mean(posterior_samples["intercept"])
y_mean_pred = jnp.array([-1., 1]) * mean_slope + mean_intercept
# Show mean fit vs data
plt.figure(figsize=(7, 4), dpi=100)
plt.scatter(x_data, y_data, label='data: $(x,y)$', color='tab:blue')
plt.plot(
[-1, 1], [fn(-1), fn(1)], color='black', linestyle='-',
label='true'
)
plt.plot(
[-1, 1], y_mean_pred, color='red', linestyle='-',
label=f'$pred = {mean_intercept:.2f} + {mean_slope:.2f} x$'
)
# plt.scatter(x_data, y_mean_pred, label='predicted$', color='tab:red')
plt.xlim((-1, 1))
plt.xlabel('$x$')
plt.ylabel('$y$')
plt.title('Mean fit vs ground-truth data')
plt.legend()
plt.show()
#