#!/usr/bin/env python # coding: utf-8 # # Gaussian Mixture Model # This is a modified version of the [Pyro tutorial for Gaussian Mixture Model](https://pyro.ai/examples/gmm.html). # In[1]: import os from collections import defaultdict import torch from matplotlib import pyplot get_ipython().run_line_magic('matplotlib', 'inline') import pyro import pyro.distributions as dist from pyro import poutine from pyro.infer.autoguide import AutoDelta, init_to_uniform from pyro.infer import SVI, TraceEnum_ELBO, config_enumerate smoke_test = ('CI' in os.environ) assert pyro.__version__.startswith('1.8.0') # ## Overview # In[2]: data = torch.tensor([ [0.2784, -0.3830, 0.8980], [0.2784, -0.3830, 0.8980], [0.2784, -0.3830, 0.8980], [0.2784, -0.3830, 0.8980], [0.2784, -0.3830, 0.8980], [0.2581, 0.4620, -1.2788], [0.2581, 0.4620, -1.2788], [0.2581, 0.4620, -1.2788], [0.2581, 0.4620, -1.2788], [0.2581, 0.4620, -1.2788], [1.0734, 0.4766, 0.5579], [1.0734, 0.4766, 0.5579], [1.0734, 0.4766, 0.5579], [1.0734, 0.4766, 0.5579], [1.0734, 0.4766, 0.5579] ] ) # Here is how the data points look like # In[3]: ax = pyplot.axes(projection='3d') # Data for three-dimensional scattered points ax.scatter3D(data[:, 0], data[:, 1], data[:, 2], cmap='viridis'); # ## Training a MAP estimator # In[4]: K = 3 # Fixed number of components. @config_enumerate def model(data): # Global variables. weights = pyro.sample('weights', dist.Dirichlet(1/K * torch.ones(K))) with pyro.plate('components', K): locs = pyro.sample('locs', dist.MultivariateNormal(torch.zeros(3), 100 * torch.eye(3))) with pyro.plate('data', len(data)): # Local variables. assignment = pyro.sample('assignment', dist.Categorical(weights)) pyro.sample('obs', dist.MultivariateNormal(locs[assignment], 0.01 * torch.eye(3)), obs=data) # In[5]: optim = pyro.optim.Adam({'lr': 0.05, 'betas': [0.95, 0.99]}) elbo = TraceEnum_ELBO(max_plate_nesting=1) # In[6]: def initialize(seed): global global_guide, svi pyro.set_rng_seed(seed) pyro.clear_param_store() global_guide = AutoDelta(poutine.block(model, expose=['weights', 'locs']), init_loc_fn=init_to_uniform) svi = SVI(model, global_guide, optim, loss=elbo) return svi.loss(model, global_guide, data) # Choose the best among 100 random initializations. loss, seed = min((initialize(seed), seed) for seed in range(100)) initialize(seed) print('seed = {}, initial_loss = {}'.format(seed, loss)) # In[7]: # Register hooks to monitor gradient norms. gradient_norms = defaultdict(list) for name, value in pyro.get_param_store().named_parameters(): value.register_hook(lambda g, name=name: gradient_norms[name].append(g.norm().item())) losses = [] for i in range(400 if not smoke_test else 2): loss = svi.step(data) losses.append(loss) print('.' if i % 100 else '\n', end='') # In[8]: pyplot.figure(figsize=(10,3), dpi=100).set_facecolor('white') pyplot.plot(losses) pyplot.xlabel('iters') pyplot.ylabel('loss') pyplot.yscale('log') pyplot.title('Convergence of SVI'); # In[9]: pyplot.figure(figsize=(10,4), dpi=100).set_facecolor('white') for name, grad_norms in gradient_norms.items(): pyplot.plot(grad_norms, label=name) pyplot.xlabel('iters') pyplot.ylabel('gradient norm') pyplot.yscale('log') pyplot.legend(loc='best') pyplot.title('Gradient norms during SVI'); # Here are the learned parameters: # In[10]: map_estimates = global_guide(data) weights = map_estimates['weights'] locs = map_estimates['locs'] print('weights = {}'.format(weights.data.numpy())) print('locs = {}'.format(locs.data.numpy()))