To take advantage of modern high performance computing facilities such as clusters with hundreds of CPUs we recommend to use MPI
instead of multiprocessing
.
To do this we will use the ChainManager
included in zeus
.
In order to run this example, copy and paste the following script into a file called 'test_mpi.py' and run the following command in the terminal:
mpiexec -n 8 python3 test_mpi.py
This will spawn 8 MPI
processes and divide them into 2 independent chains of 10 walkers each. Unfortunately MPI
is not compatible with Jupyter
notebooks.
import numpy as np
import zeus
from zeus import ChainManager
ndim = 5
nwalkers = 2 * ndim
nsteps = 100
nchains = 2
def log_prob(x):
return -0.5 * np.sum(x**2.0)
start = np.random.randn(nwalkers, ndim)
with ChainManager(nchains) as cm:
rank = cm.get_rank
sampler = zeus.EnsembleSampler(nwalkers, ndim, log_prob, pool=cm.get_pool)
sampler.run_mcmc(start, nsteps)
chain = sampler.get_chain(flat=True, discard=0.5)
np.save('chain_'+str(rank)+'.npy', chain)