#!/usr/bin/env python # coding: utf-8 # # Linear regression with SVI # In[1]: # Imports get_ipython().run_line_magic('matplotlib', 'inline') get_ipython().run_line_magic('config', "InlineBackend.figure_format = 'svg'") # In[2]: import sys import warnings import numpy as np import jax import jax.numpy as jnp from jax.experimental import optimizers import numpyro from numpyro.infer import SVI, Trace_ELBO, TraceMeanField_ELBO, Predictive import numpyro.distributions as dist import matplotlib import matplotlib.pyplot as plt import seaborn as sns # In[3]: sns.set_style('darkgrid') # In[4]: numpyro.set_platform('cpu') numpyro.set_host_device_count(8) # In[5]: rng_key = jax.random.PRNGKey(42) # In[6]: ground_truth_params = { "slope" : 2.32, "intercept": 4.11, "noise_std": 0.5 } # # Create Dataset # In[7]: # Define the data rng_key, rng_key_ = jax.random.split(rng_key) np.random.seed(rng_key_) # Generate random data n = 51 # Number of samples # Underlying linear relation slope_true = ground_truth_params["slope"] intercept_true = ground_truth_params["intercept"] fn = lambda x_: x_ * slope_true + intercept_true # Noise err = ground_truth_params["noise_std"] * np.random.randn(n) # Noise # Features and output x_data = np.random.uniform(-1, 1, n) # Independent variable x y_data = fn(x_data) + err # Dependent variable # Show data plt.figure(figsize=(7, 4), dpi=100) plt.scatter(x_data, y_data, label='data: $(x,y)$', color='tab:blue') plt.plot( [-1, 1], [fn(-1), fn(1)], color='black', linestyle='-', label=f'$y = {intercept_true:.2f} + {slope_true:.2f} x$') plt.xlim((-1, 1)) plt.xlabel('$x$') plt.ylabel('$y$') plt.title('Noisy data samples from linear line') plt.legend() plt.show() # # # Define model and variational distribution # $$ # \mu_i = \text{intercept} + \text{slope} * x_i \\ # y_i \sim \mathcal{N}(\mu_i, \sigma) \quad (i = 1, \ldots, n) # $$ # In[8]: def model(x, y): slope = numpyro.sample('slope', dist.Normal(0., 10.)) intercept = numpyro.sample('intercept', dist.Normal(0., 10.)) noise_std = numpyro.sample('noise_std', dist.Exponential(1.)) with numpyro.plate('obs', x.shape[0]): y_loc = numpyro.deterministic('y_loc', intercept + slope * x) numpyro.sample('y', dist.Normal(y_loc, noise_std), obs=y) # In[9]: def guide(x, y): slope_loc = numpyro.param("slope_loc", 0.) slope_scale = numpyro.param("slope_scale", 0.01, constraint=dist.constraints.positive) slope = numpyro.sample('slope', dist.Normal(slope_loc, slope_scale)) intercept_loc = numpyro.param("intercept_loc", 0.) intercept_scale = numpyro.param("intercept_scale", 0.01, constraint=dist.constraints.positive) intercept = numpyro.sample('intercept', dist.Normal(intercept_loc, intercept_scale)) noise_std_rate = numpyro.param("noise_std_rate", 1., constraint=dist.constraints.positive) noise_std = numpyro.sample('noise_std', dist.Exponential(noise_std_rate)) # ## Fit SVI # In[10]: # Learning rate schedule def cosine_annealing(lr_min, lr_max, num_steps, i): return lr_min + 0.5 * (lr_max - lr_min) * (1 + jnp.cos(jnp.pi * i / num_steps)) num_steps = 5000 lr_max = 2e-3 lr_min = 1e-4 iterations = jnp.arange(num_steps) lr_steps = cosine_annealing(lr_min, lr_max, num_steps, iterations) def lr_schedule(idx): return lr_steps[idx] # In[11]: get_ipython().run_cell_magic('time', '', '\n# Use clipped Optimizer to deal with unstable gradients\n# http://num.pyro.ai/en/stable/optimizers.html#clippedadam\noptimizer = numpyro.optim.ClippedAdam(step_size=lr_schedule, clip_norm=1.0)\n\n# setup the inference algorithm\nsvi = SVI(\n model=model,\n guide=guide,\n optim=optimizer,\n loss=TraceMeanField_ELBO()\n)\n\n# Run\nsvi_result = svi.run(\n jax.random.PRNGKey(0),\n num_steps=5000,\n x=x_data,\n y=y_data\n)\n') # In[12]: fig, ax = plt.subplots(1, 1, figsize=(7, 3)) ax.plot(svi_result.losses) ax.set_title("losses") ax.set_yscale("symlog") plt.show() # In[13]: svi_predictive = Predictive( guide, params=svi_result.params, num_samples=2000 ) posterior_samples = svi_predictive( jax.random.PRNGKey(0), x=x_data, y=y_data ) # In[14]: fig, axs = plt.subplots(1, len(posterior_samples), figsize=(12, 4)) for ax, (param_name, param_samples) in zip(axs, posterior_samples.items()): d = sns.histplot(param_samples, kde=True, stat='probability', ax=ax) ax.set_xlabel(param_name) ax.set_title(f"Samples from {param_name!r}") ax.axvline(np.mean(param_samples), color="black", label="mean") ax.axvline( ground_truth_params[param_name], color="red", label="true") fig.legend(*ax.get_legend_handles_labels(), bbox_to_anchor=(0., 0.7, 1.13, -.0)) plt.show() # In[15]: for param_name, param_samples in posterior_samples.items(): param_gt = ground_truth_params[param_name] param_mean = np.mean(param_samples) param_std = np.std(param_samples) param_median = np.median(param_samples) param_quantile_low, param_quantile_high = np.quantile(param_samples, (.025, .975)) print(f"{param_name}: true={param_gt:+.2f} \tmedian={param_median:+.2f} \t95%-interval:{param_quantile_low:+.2f} - {param_quantile_high:+.2f} \t" f"mean:{param_mean:+.2f}±{param_std:.2f}") # In[16]: mean_slope = np.mean(posterior_samples["slope"]) mean_intercept = np.mean(posterior_samples["intercept"]) y_mean_pred = jnp.array([-1., 1]) * mean_slope + mean_intercept # Show mean fit vs data plt.figure(figsize=(7, 4), dpi=100) plt.scatter(x_data, y_data, label='data: $(x,y)$', color='tab:blue') plt.plot( [-1, 1], [fn(-1), fn(1)], color='black', linestyle='-', label='true' ) plt.plot( [-1, 1], y_mean_pred, color='red', linestyle='-', label=f'$pred = {mean_intercept:.2f} + {mean_slope:.2f} x$' ) # plt.scatter(x_data, y_mean_pred, label='predicted$', color='tab:red') plt.xlim((-1, 1)) plt.xlabel('$x$') plt.ylabel('$y$') plt.title('Mean fit vs ground-truth data') plt.legend() plt.show() #