Metropolis Sampling

Marcel Lüthi, Departement of Mathematics and Computer Science, University of Basel

In this notebook we will present an implementation of the Metropolis algorithm and experiment with it. As a target distribution we will use a bivariate normal distribution. This simple setup is ideally suited for exploring the properties of the algorithm: We know how the true target distribution looks like, can change its shape, and we can visualize the samples.

The intuition that we develop while experimenting with this simple examples, will help us to understand the more complicated cases that we will explore later in this course.

Preparation

Before we start, we need to download the plotting library EvilPlot and make it available in the Jupyter-Notebook. It may take some time when you execute the following cells for the first time.

In [23]:
import $ivy.`io.github.cibotech::evilplot:0.8.1`
import com.cibo.evilplot.plot._
import com.cibo.evilplot.plot.renderers.PointRenderer

import com.cibo.evilplot.plot.aesthetics.DefaultTheme._
import com.cibo.evilplot.numeric.Point
import com.cibo.evilplot.colors._

import breeze.linalg.{DenseVector, DenseMatrix}
import breeze.stats.distributions.MultivariateGaussian 

implicit class ShowJupyter(plot: com.cibo.evilplot.geometry.Drawable) {
  def show() : Image = {
    Image.fromRenderedImage(plot.asBufferedImage, Image.PNG)  
  }
}
Out[23]:
import $ivy.$                                   

import com.cibo.evilplot.plot._

import com.cibo.evilplot.plot.renderers.PointRenderer


import com.cibo.evilplot.plot.aesthetics.DefaultTheme._

import com.cibo.evilplot.numeric.Point

import com.cibo.evilplot.colors._


import breeze.linalg.{DenseVector, DenseMatrix}

import breeze.stats.distributions.MultivariateGaussian 


defined class ShowJupyter

We also initialize a global random number generator, which we use whenever we need a new random number in our algorithm.

In [24]:
// We keep random number generator around as global state
val rng = new scala.util.Random()
Out[24]:
rng: scala.util.Random = [email protected]

The Metropolis algorithm

The Metropolis algorithm works by simulating a random path through the states on which the target distribution is defined. Starting from a given state, a possible new state is proposed using a proposal function. The new state is then evaluated given the target distribution and rejected or accepted depending on the likelihood ratio between the old and new state.

This motivates the following definitions:

In [25]:
type State = DenseVector[Double]  // The state is represented as a vector
Out[25]:
defined type State
In [26]:
type Proposal = State => State    // The proposal function produces a new state from a given state
Out[26]:
defined type Proposal
In [27]:
type DistributionEvaluator = State => Double // The distribution evaluator evaluates a probability of each state
Out[27]:
defined type DistributionEvaluator

In order to understand and be able to visualize how the Metropolis algorithm walks through the state-space, we introduce a logger. The logger keeps track of the sequence of states, together with the information, which new state was proposed in each step:

In [28]:
class Logger {
    case class StepInfo(state : State, proposedState : State) {
        def accepted : Boolean = state == proposedState
    }

    private val stepSeq = collection.mutable.Buffer[StepInfo]()
    
    def steps : Seq[StepInfo] = stepSeq.toSeq
    
    def logStep(state : State, proposedState : State) : Unit = stepSeq.append(StepInfo(state, proposedState))
}
Out[28]:
defined class Logger

With these definitions, we are ready to implement the Metropolis sampler.

In [29]:
def metropolisSampler(p : DistributionEvaluator, q : Proposal, initialState : State, logger : Logger) : Iterator[State] = {

    // Simulates one step
    def nextStep(currentState : State) : State = {
        
        // propose a new state from the given state
        val proposedState = q(currentState) 
        
        // accept based on the ratio of probabilities between 
        // the new and the old state
        val r = rng.nextDouble()
        val alpha  = scala.math.min(1.0, p(proposedState) / p(currentState));
        val nextState = if (r < alpha) proposedState else currentState
        logger.logStep(nextState, proposedState)
        nextState
    }
    
    // create an iterator starting from the initial state
    Iterator.iterate(initialState)(nextStep)
}
Out[29]:
defined function metropolisSampler

This is it - this simple code is a complete implementation of the celebrated Metropolis algorithm!

Toy example: sampling from a bivariate normal

We will now run the algorithm on a toy example, where we sample from a bivariate normal distribution. We start by defining the target probability distribution.

In [30]:
val bivariateNormal = MultivariateGaussian(
    mean = DenseVector(9.0, 10.0), 
    covariance = DenseMatrix((2.0, 0.5), (0.5, 1.0))
)
Out[30]:
bivariateNormal: MultivariateGaussian = MultivariateGaussian(
  DenseVector(9.0, 10.0),
  2.0  0.5  
0.5  1.0  
)

Let's plot some samples from the distribution:

In [32]:
val samples = for (_ <- 0 until 100) yield bivariateNormal.sample()
ScatterPlot(    
    samples.map(s => Point(s(0), s(1)))    
).xAxis()
 .yAxis()
 .frame()
 .xLabel("x")
 .xbounds(0, 20)
 .yLabel("y")
 .ybounds(0, 20)
 .rightLegend()
 .render()
 .show()
Out[32]:
samples: IndexedSeq[DenseVector[Double]] = Vector(
  DenseVector(7.950063059653292, 9.665357409794472),
  DenseVector(7.897476790412558, 9.197888519410425),
  DenseVector(6.210966685420566, 9.93020822662213),
  DenseVector(8.192650396042627, 9.397569656684123),
  DenseVector(8.336046271557342, 10.28007826687203),
  DenseVector(8.590456087934829, 10.604253910618507),
  DenseVector(10.742780009679185, 9.97268152768115),
  DenseVector(8.243260612290735, 8.345027389250161),
  DenseVector(8.772700290567586, 10.779154820458052),
  DenseVector(8.89060983679405, 9.141380544285607),
  DenseVector(9.575247444854442, 11.176331926542087),
  DenseVector(9.420488498784927, 9.768752730729066),
  DenseVector(7.450690764697665, 10.747945203145385),
  DenseVector(9.130535550533367, 10.848343191504156),
  DenseVector(9.248312306711691, 10.833523108927844),
  DenseVector(10.398547752687266, 11.19834270465996),
  DenseVector(7.977392156846392, 10.180300442365315),
  DenseVector(8.87026699526358, 9.505827787246515),
  DenseVector(8.391855425238063, 9.1013800778294),
  DenseVector(8.718448241569766, 10.059654354269696),
  DenseVector(8.182811836545667, 9.127419062841977),
  DenseVector(7.522046764414963, 8.837648197788488),
  DenseVector(8.366007965721682, 8.641553647723578),
  DenseVector(7.582405824003544, 8.21372724419766),
  DenseVector(6.466502909744587, 10.519600882120297),
  DenseVector(10.89851720652511, 8.460125068073468),
  DenseVector(11.370071765385482, 11.545188548738976),
  DenseVector(8.73120207808771, 8.849972901815843),
  DenseVector(7.321254703338562, 10.534840852820686),
  DenseVector(9.67655185728469, 9.372375052841612),
  DenseVector(7.067027164189062, 9.346475828194185),
  DenseVector(9.256796610174717, 8.89612716519381),
  DenseVector(10.16640003320734, 10.759403325636672),
  DenseVector(11.404955599592181, 10.433078521829183),
  DenseVector(7.232789375007407, 10.530078177368562),
  DenseVector(8.576056217192264, 10.640261281061594),
  DenseVector(8.838894356703719, 11.097272025529975),
  DenseVector(8.612431988952201, 9.225312915579183),
...

We also need to define a proposal generator. We define a simple random walk proposal, which chooses the new direction and step length randomly.

In [33]:
def randomWalkProposal(x : State) : State = {
    val stepLength = 0.5
    
    val step = DenseVector(rng.nextGaussian() * stepLength, rng.nextGaussian() * stepLength)
    x + step
}
Out[33]:
defined function randomWalkProposal

Now we can draw samples using our sampler.

In [34]:
val logger = new Logger()
val samples = metropolisSampler(bivariateNormal.pdf, randomWalkProposal, DenseVector(10, 10), logger).take(500).toSeq
Out[34]:
logger: Logger = [email protected]
samples: Seq[State] = List(
  DenseVector(10.0, 10.0),
  DenseVector(9.720643015666251, 9.708472943256385),
  DenseVector(9.998134353362762, 10.006841187977123),
  DenseVector(10.237948537418532, 10.971645154560111),
  DenseVector(9.819643666968798, 10.819400876162973),
  DenseVector(9.840398761481561, 11.027326740439163),
  DenseVector(9.342556894605615, 10.850817105130197),
  DenseVector(10.17022244894715, 10.630751992854817),
  DenseVector(10.332602351809195, 9.93399996773307),
  DenseVector(10.310662447412549, 10.342190308832237),
  DenseVector(10.319540619957818, 10.120953241811405),
  DenseVector(10.80636188910151, 10.46997008396131),
  DenseVector(10.201127157611145, 11.031051385564199),
  DenseVector(10.340306593391059, 9.295506509111483),
  DenseVector(10.12415641544559, 9.306995653846224),
  DenseVector(10.12415641544559, 9.306995653846224),
  DenseVector(10.12415641544559, 9.306995653846224),
  DenseVector(10.432626266146578, 10.168951370556062),
  DenseVector(10.432626266146578, 10.168951370556062),
  DenseVector(9.464697815618106, 9.668210122512415),
  DenseVector(9.560270122912694, 10.092292781541644),
  DenseVector(9.529050861129745, 9.597895628050676),
  DenseVector(9.46772933619005, 9.520881377074707),
  DenseVector(10.276135235393053, 9.456376788567347),
  DenseVector(10.27364467318171, 9.616746008306068),
  DenseVector(9.87479780392965, 10.547434187671335),
  DenseVector(9.561350665632311, 10.769140952872622),
  DenseVector(9.561350665632311, 10.769140952872622),
  DenseVector(8.887118238654468, 10.618118072661842),
  DenseVector(8.375465599708871, 9.97640655127039),
  DenseVector(8.241176034456126, 10.168522847702429),
  DenseVector(8.630096038976866, 10.067092689786765),
  DenseVector(8.85288470445321, 10.202495488649314),
  DenseVector(8.251616569722923, 10.286334057878408),
  DenseVector(7.8124533328398655, 10.31764175044807),
  DenseVector(8.025697559028835, 9.839905086650246),
  DenseVector(7.952883255778813, 9.647423746082575),
  DenseVector(8.454397672701985, 10.541076878847173),
...

Computing the mean and the covariance from the samples shows, that the samples approximate the target distribution rather well.

In [35]:
val mean = samples.reduce(_ + _) * (1.0 / samples.length)
val cov = samples.map(x => (x - mean) * (x - mean).t).reduce(_ + _) * (1.0 / samples.length)
Out[35]:
mean: DenseVector[Double] = DenseVector(8.677396841080336, 10.262956957766043)
cov: DenseMatrix[Double] = 1.3069347024111155  0.4019547020016448  
0.4019547020016448  1.0351467562971002  

We can also check this also visually by plotting the samples:

In [36]:
def plotSamples(samples : Seq[State]) : Image = {

    ScatterPlot(
        samples.map(s => Point(s(0), s(1)))
    )
    .xAxis()
    .xbounds(0, 20)
    .yAxis()
    .ybounds(0, 20)
    .frame()
    .xLabel("x")
    .yLabel("y")
    .rightLegend()
    .render()
    .show()
}
Out[36]:
defined function plotSamples
In [37]:
plotSamples(samples)

To get a deeper understanding of how the metropolis algorithm works, it is interesting to visualize not only the accepted samples, but also those which are rejected, and possibly the path that was taken. This is achieved by the following plot function, which visualizes the accepted and rejected samples in different colors and also can plot lines to show the path that was explored:

In [38]:
def plotLoggedSamples(logger : Logger, plotLines : Boolean = false) : Image = {

    val acceptedPlot = ScatterPlot.series(
        logger.steps.filter(s => s.accepted).map(s => Point(s.state(0), s.state(1))),
        "accepted", 
        HTMLNamedColors.blue, 
        pointSize = Some(3)
    )
    val rejectedPlot = ScatterPlot.series(
        logger.steps.filterNot(s => s.accepted).map(s => Point(s.proposedState(0), s.proposedState(1))),
        "rejected", 
        HTMLNamedColors.red ,
        pointSize = Some(3)  
    )
    val linePlot =  LinePlot(
        logger.steps.flatMap(s => 
                        if (!s.accepted) {        
                            Seq(Point(s.proposedState(0), s.proposedState(1)), Point(s.state(0), s.state(1)))
                        } else {
                            Seq(Point(s.state(0), s.state(1)))
                        })
    )
    
    var plots = Seq(acceptedPlot)
    plots = plots :+ rejectedPlot
    if (plotLines) plots = plots :+ linePlot
    
    Overlay.fromSeq(plots)
    .xAxis()
    .xbounds(0, 20)
    .yAxis()
    .ybounds(0, 20)
    .frame()
    .xLabel("x")
    .yLabel("y")
    .rightLegend()
    .render()
    .show()
}
Out[38]:
defined function plotLoggedSamples

We can now plot our samples and start experimenting:

In [39]:
plotLoggedSamples(logger, plotLines = false)

Exercises

  • We have started sampling from close to the mean of the target distribution. What happens when we start at the Point (0,0)?
    • Visualize the resulting samples using the plot functions
    • How could we mitigate the situation?
  • Play with different step-length in the proposal.
    • What happens to the acceptance and rejection rate?
    • How well is the target distribution approximated after a fixed number of samples
  • Experiment with different target distributions.
    • Make the variance larger and smaller
    • Change the correlation
    • Assign to all samples whose x value is larger then 5 the probability 0.
      • Do you need to normalize the pdf? why, why not?
  • Define a function that computes the exected value of a given function, using the samples
  • How could you get samples from the marginal distributions?
In [ ]: