This is a modified version of the Pyro tutorial for Gaussian Mixture Model.
import os
from collections import defaultdict
import torch
from matplotlib import pyplot
%matplotlib inline
import pyro
import pyro.distributions as dist
from pyro import poutine
from pyro.infer.autoguide import AutoDelta, init_to_uniform
from pyro.infer import SVI, TraceEnum_ELBO, config_enumerate
smoke_test = ('CI' in os.environ)
assert pyro.__version__.startswith('1.8.0')
data = torch.tensor([
[0.2784, -0.3830, 0.8980],
[0.2784, -0.3830, 0.8980],
[0.2784, -0.3830, 0.8980],
[0.2784, -0.3830, 0.8980],
[0.2784, -0.3830, 0.8980],
[0.2581, 0.4620, -1.2788],
[0.2581, 0.4620, -1.2788],
[0.2581, 0.4620, -1.2788],
[0.2581, 0.4620, -1.2788],
[0.2581, 0.4620, -1.2788],
[1.0734, 0.4766, 0.5579],
[1.0734, 0.4766, 0.5579],
[1.0734, 0.4766, 0.5579],
[1.0734, 0.4766, 0.5579],
[1.0734, 0.4766, 0.5579]
]
)
Here is how the data points look like
ax = pyplot.axes(projection='3d')
# Data for three-dimensional scattered points
ax.scatter3D(data[:, 0], data[:, 1], data[:, 2], cmap='viridis');
K = 3 # Fixed number of components.
@config_enumerate
def model(data):
# Global variables.
weights = pyro.sample('weights', dist.Dirichlet(1/K * torch.ones(K)))
with pyro.plate('components', K):
locs = pyro.sample('locs', dist.MultivariateNormal(torch.zeros(3), 100 * torch.eye(3)))
with pyro.plate('data', len(data)):
# Local variables.
assignment = pyro.sample('assignment', dist.Categorical(weights))
pyro.sample('obs', dist.MultivariateNormal(locs[assignment], 0.01 * torch.eye(3)), obs=data)
optim = pyro.optim.Adam({'lr': 0.05, 'betas': [0.95, 0.99]})
elbo = TraceEnum_ELBO(max_plate_nesting=1)
def initialize(seed):
global global_guide, svi
pyro.set_rng_seed(seed)
pyro.clear_param_store()
global_guide = AutoDelta(poutine.block(model, expose=['weights', 'locs']),
init_loc_fn=init_to_uniform)
svi = SVI(model, global_guide, optim, loss=elbo)
return svi.loss(model, global_guide, data)
# Choose the best among 100 random initializations.
loss, seed = min((initialize(seed), seed) for seed in range(100))
initialize(seed)
print('seed = {}, initial_loss = {}'.format(seed, loss))
seed = 78, initial_loss = 500.1433410644531
# Register hooks to monitor gradient norms.
gradient_norms = defaultdict(list)
for name, value in pyro.get_param_store().named_parameters():
value.register_hook(lambda g, name=name: gradient_norms[name].append(g.norm().item()))
losses = []
for i in range(400 if not smoke_test else 2):
loss = svi.step(data)
losses.append(loss)
print('.' if i % 100 else '\n', end='')
................................................................................................... ................................................................................................... ................................................................................................... ...................................................................................................
pyplot.figure(figsize=(10,3), dpi=100).set_facecolor('white')
pyplot.plot(losses)
pyplot.xlabel('iters')
pyplot.ylabel('loss')
pyplot.yscale('log')
pyplot.title('Convergence of SVI');
pyplot.figure(figsize=(10,4), dpi=100).set_facecolor('white')
for name, grad_norms in gradient_norms.items():
pyplot.plot(grad_norms, label=name)
pyplot.xlabel('iters')
pyplot.ylabel('gradient norm')
pyplot.yscale('log')
pyplot.legend(loc='best')
pyplot.title('Gradient norms during SVI');
Here are the learned parameters:
map_estimates = global_guide(data)
weights = map_estimates['weights']
locs = map_estimates['locs']
print('weights = {}'.format(weights.data.numpy()))
print('locs = {}'.format(locs.data.numpy()))
weights = [0.3333379 0.33333713 0.33332494] locs = [[ 0.2581018 0.46195865 -1.2787634 ] [ 1.0733963 0.47656873 0.55789876] [ 0.27835828 -0.38301384 0.89797974]]