Florent Leclercq,
Institut d'Astrophysique de Paris,
florent.leclercq@iap.fr
import numpy as np
import os
import jax
import jax.numpy as jnp
from cycler import cycler
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from matplotlib.colors import LogNorm, SymLogNorm
from mpl_toolkits.axes_grid1 import make_axes_locatable
np.random.seed(123456)
plt.rcParams.update({'lines.linewidth': 2})
plt.rcParams.update({'text.usetex': True})
plt.rcParams.update({'text.latex.preamble': r"\usepackage{amsmath}\usepackage{upgreek}"})
plt.rcParams.update({'font.family': 'serif'})
plt.rcParams.update({'font.size': 15})
dir="./plots/HMC_nonlinear_model/"
os.makedirs(dir, exist_ok=True)
# Download and define the Planck color map
from matplotlib.colors import ListedColormap
import os
if not os.path.isfile("data/Planck_Parchment_RGB.txt"):
!wget https://raw.githubusercontent.com/zonca/paperplots/master/data/Planck_Parchment_RGB.txt --directory-prefix=data/
planck = ListedColormap(np.loadtxt("data/Planck_Parchment_RGB.txt")/255.)
planck.set_bad("C7") # color of missing pixels
N = 32
L = 1.0 # box size
A_s = 6e-9 # power spectrum normalisation, arbitrary units
n_s = 0.96 # scalar spectral index
f_NL = 2000.0 # non-linear coupling parameter
D1 = 1.732e7 # growth factor of fluctuations at z=0, arbitrary units
def build_power_spectrum(N, A_s=1.0, n_s=0.96):
# real‐space grid spacing
dx = L / N
# build k‐space grid for rfft2
kx = jnp.fft.fftfreq(N, d=dx) * 2*np.pi # length N
ky = jnp.fft.rfftfreq(N, d=dx) * 2*np.pi # length N//2+1
kx, kx = jnp.meshgrid(kx, ky, indexing='ij') # shape (N, N//2+1)
k = jnp.sqrt(kx**2 + ky**2)
k = jnp.where(k == 0, 1.0, k) # prevent division by zero
# build Pk safely (zero at k=0)
Pk = A_s * k**(n_s - 1.0)
Pk = jnp.where(k == 0, 0, Pk) # Set mean mode to zero power
return Pk
Pkgrid = build_power_spectrum(N, A_s, n_s)
@jax.jit
def phiL_from_real_noise(white_noise):
"""
Given real-space white noise, construct a Gaussian random field with power spectrum
P(k) = A_s * k^(n_s - 1).
Args:
white_noise: real array of shape [N, N] (from jax.random.normal(key, (N, N)))
A_s, n_s: parameters of power spectrum
Returns:
field: real array [N, N]
"""
N = white_noise.shape[0]
noise_k = jnp.fft.rfft2(white_noise)
# Multiply by sqrt of power spectrum
filtered_k = noise_k * jnp.sqrt(Pkgrid)
# IFFT back to real space
field = jnp.fft.irfft2(filtered_k).real
return field
def build_transfer_function(N):
"""Transfer function from primordial potential to density contrast."""
# real‐space grid spacing
dx = L / N
# build k‐space grid for rfft2
kx = jnp.fft.fftfreq(N, d=dx) * 2*np.pi # length N
ky = jnp.fft.rfftfreq(N, d=dx) * 2*np.pi # length N//2+1
kx, ky = jnp.meshgrid(kx, ky, indexing='ij') # shape (N, N//2+1)
k = jnp.sqrt(kx**2 + ky**2)
k = jnp.where(k == 0, 1.0, k) # prevent division by zero
# build Tk safely (zero at k=0)
T = np.zeros_like(k)
# shape parameter, characterising the wavenumber at equality
Omega_b = 0.049
Omega_m = 0.315
h = 0.674
fb = Omega_b / Omega_m
shape = Omega_m * h * np.exp(-Omega_b - np.sqrt(2.0*h) * fb)
# power spectrum, BBKS style
q = k / shape
alpha = 2.34; beta = 3.89; gamma = 16.1; delta = 5.46; epsilon = 6.71
aux = 1.0 + beta * q + (gamma * q)**2 + (delta * q)**3 + (epsilon * q)**4
T = (jnp.log(1.0 + alpha * q) / (alpha * q) * aux**-0.25);
T *= q**(1./2.) # correct scaling for delta(k) = D1 * T(k) * Phi_NL(k) since we want P_delta(k) \propto k^n_s T(k)**2 \propto k * T(k)**2 * P_phi(k)
T = jnp.where(k == 0, 0, T) # Set mean mode to zero transfer function
return T
Tgrid = build_transfer_function(N)
@jax.jit
def delta_from_phi(phi, f_NL=f_NL, D1=D1):
"""
Given real-space Phi, compute Phi_NL and then delta such that
delta(k) = D1 * T(k)**2 * Phi_NL(k), and inverse FFT to real-space delta.
Args:
phi: [N, N] real array, the field Phi in real space
f_NL: scalar
Returns:
delta_real: [N, N] real array, delta in real space
"""
N = phi.shape[0]
# 1. Construct Phi_NL
Phi_NL = phi + f_NL * phi**2
# 2. FFT to Fourier space
PhiNL_k = jnp.fft.rfft2(Phi_NL)
# 3. Multiply by transfer function T(k)
delta_k = D1 * Tgrid * PhiNL_k
delta_k = delta_k.at[0, 0].set(0.0) # Explicitly zero the mean mode
# 4. IFFT back to real space
delta_real = jnp.fft.irfft2(delta_k).real
return delta_real
@jax.jit
def data_model(white_noise, f_NL=f_NL, D1=D1):
"""
Given real-space white noise, construct a Gaussian random field with power spectrum
P(k) = A_s * k^(n_s - 1), and compute the corresponding delta field.
Args:
white_noise: real array of shape [N, N] (from jax.random.normal(key, (N, N)))
f_NL: scalar
D1: scalar
Returns:
delta_real: [N, N] real array, delta in real space
"""
phi = phiL_from_real_noise(white_noise)
delta = delta_from_phi(phi, f_NL=f_NL, D1=D1)
return delta
def P_of_k(k, A_s=A_s, n_s=n_s):
"""
Power spectrum P(k) = A_s * k^(n_s - 1).
Args:
k: scalar or array of wavenumbers
A_s: scalar, amplitude of the power spectrum
n_s: scalar, spectral index
Returns:
Pk: scalar or array of power spectrum values at k
"""
k = np.atleast_1d(k)
Pk = A_s * k**(n_s - 1.0)
Pk[k == 0] = 0.0 # Set mean mode to zero power
return Pk
def T_of_k(k):
"""
Transfer function T(k) from primordial potential to density contrast.
Args:
k: scalar or array of wavenumbers
Returns:
T: scalar or array of transfer function values at k
"""
k = np.atleast_1d(k)
# shape parameter, characterising the wavenumber at equality
Omega_b = 0.049
Omega_m = 0.315
h = 0.674
fb = Omega_b / Omega_m
shape = Omega_m * h * np.exp(-Omega_b - np.sqrt(2.0*h) * fb)
# power spectrum, BBKS style
q = k / shape
alpha = 2.34; beta = 3.89; gamma = 16.1; delta = 5.46; epsilon = 6.71
aux = 1.0 + beta * q + (gamma * q)**2 + (delta * q)**3 + (epsilon * q)**4
T = (np.log(1.0 + alpha * q) / (alpha * q) * aux**-0.25)
T *= q**(1./2.) # correct scaling for delta(k) = D1 * T(k) * Phi_NL(k), P_delta(k) \propto k^n_s T(k)**2 \propto k * T(k)**2 * P_phi(k)
T[k == 0] = 0.0 # Set mean mode to zero transfer function
return T
k_modes = np.logspace(-4, 0, 100)
Pk = np.array([P_of_k(k) for k in k_modes])
Tk = np.array([T_of_k(k) for k in k_modes])
plt.loglog(k_modes, 1e8 * Pk, color="C0", label="$P(k)$")
plt.loglog(k_modes, 1e2 * Tk**2, color="C1", linestyle="--", label=r"$D_1^2 \, T(k)^2$")
plt.xlabel(r"$k$ [2$\pi$/L]")
plt.ylabel(r"$P(k)$ [arbitrary units]")
plt.legend(loc='best')
plt.grid()
plt.savefig(dir+'power_spectrum.pdf',dpi=300,bbox_inches="tight")
plt.savefig(dir+'power_spectrum.png',dpi=300,bbox_inches="tight",transparent=True)
plt.show()
key = jax.random.PRNGKey(12)
white_noise = jax.random.normal(key, (N, N))
phi = phiL_from_real_noise(white_noise)
delta = delta_from_phi(phi)
fig, (ax0, ax1, ax2) = plt.subplots(1, 3, figsize=(18, 6))
plt.subplots_adjust(wspace=0.3)
phiNL=phi+f_NL*phi**2
# visualize the phi field
vmin = -max(-phi.min(), phi.max(), phiNL.min(), phiNL.max())
vmax = max(-phi.min(), phi.max(), phiNL.min(), phiNL.max())
im0 = ax0.imshow(phi, vmin=vmin, vmax=vmax, origin='lower', cmap=planck)
ax0.set_title('$\\Phi_\mathrm{L}$')
divider = make_axes_locatable(ax0)
cax0 = divider.append_axes("right", size="5%", pad=0.1)
cbar0 = fig.colorbar(im0, cax=cax0)
# visualize the phiNL field
im1 = ax1.imshow(phiNL, vmin=vmin, vmax=vmax, origin='lower', cmap=planck)
ax1.set_title('$\\Phi_\mathrm{NL}=\\Phi_\mathrm{L}+f_\mathrm{NL}\\Phi_\mathrm{L}^2$')
divider = make_axes_locatable(ax1)
cax1 = divider.append_axes("right", size="5%", pad=0.1)
cbar1 = fig.colorbar(im1, cax=cax1)
# visualize the density contrast field
im2 = ax2.imshow(delta, vmin=-max(delta.min(),delta.max()), vmax=max(delta.min(),delta.max()), origin='lower', cmap=planck)
ax2.set_title('Density contrast $\\delta$')
divider = make_axes_locatable(ax2)
cax2 = divider.append_axes("right", size="5%", pad=0.1)
cbar2 = fig.colorbar(im2, cax=cax2)
plt.savefig(dir+'phi_phiNL_delta.pdf',dpi=300,bbox_inches="tight")
plt.savefig(dir+'phi_phiNL_delta.png',dpi=300,bbox_inches="tight",transparent=True)
plt.show()
def make_noise_field():
"""
Build a 32×32 array “field” of noise variances:
- Border of width 3 pixels everywhere: value = 1e0
- Central 26×26 block: two regions of different noise variances
- A low‐noise region (1e-5)
- A medium‐noise region (5e-5) in the lower left corner
- A 2×2 high‐noise patch (1e0) centered in the array
Returns
-------
field : ndarray, shape (32,32)
The resulting noise‐variance field.
"""
N = 32
field = np.zeros((N, N), dtype=float)
mask = np.zeros((N, N), dtype=bool)
# 1) High‐noise border
bw = 3
high = 1
field[:bw, :] = high
field[-bw:, :] = high
field[:, :bw] = high
field[:, -bw:] = high
mask[:bw, :] = True
mask[-bw:, :] = True
mask[:, :bw] = True
mask[:, -bw:] = True
# 2) Central block (size 26×26)
low = 1e-5
med = 5e-5
for i in range(bw,N-bw):
for j in range(bw,N-bw):
if i-j<-3:
field[i,j] = med
else:
field[i,j] = low
# 3) Sprinkle a 2×2 high‐noise patch
ci = 24
cj = 24
field[ci-1:ci+1, cj-1:cj+1] = high
mask[ci-1:ci+1, cj-1:cj+1] = True
return field, mask
noise_variance_field, mask = make_noise_field()
invN = np.diag(1.0 / noise_variance_field.flatten())
invN[np.where(invN <= 1)] = 0.0
Data model: $d=\delta(s)+n$
# noise: one realisation
noise = np.random.normal(size=(N, N)) * np.sqrt(noise_variance_field)
noise_v = np.ma.masked_where(mask, noise) # for visualization purposes only
# data = delta + noise
data = delta + noise
data_v = np.ma.masked_where(mask, data) # for visualization purposes only
fig, (ax0, ax1, ax2) = plt.subplots(1, 3, figsize=(18, 6))
plt.subplots_adjust(wspace=0.25)
# visualize the delta field
im0 = ax0.imshow(delta, vmin=-max(-delta.min(),delta.max()), vmax=max(-delta.min(),delta.max()), origin='lower', cmap=planck)
ax0.set_title('Groundtruth $\\delta$')
divider = make_axes_locatable(ax0)
cax0 = divider.append_axes("right", size="5%", pad=0.1)
cbar0 = fig.colorbar(im0, cax=cax0)
# visualize the noise field
vmin, vmax = np.min(noise_v), np.max(noise_v)
linthresh = 1e-5
linscale = 1.0
norm = SymLogNorm(linthresh=linthresh, linscale=linscale,
vmin=-max(-vmin,vmax), vmax=max(-vmin,vmax), base=10)
cmap = plt.get_cmap('PiYG')
cmap.set_bad('C7') # color of missing pixels
im1 = ax1.imshow(noise_v, cmap=cmap, norm=norm, origin='lower')
ax1.set_title('Noise')
divider = make_axes_locatable(ax1)
cax1 = divider.append_axes("right", size="5%", pad=0.1)
cbar1 = fig.colorbar(im1, cax=cax1)
# visualize the data field
im2 = ax2.imshow(data_v, vmin=-max(data_v.min(),data_v.max()), vmax=max(data_v.min(),data_v.max()), origin='lower', cmap=planck)
ax2.set_title('Data')
divider = make_axes_locatable(ax2)
cax2 = divider.append_axes("right", size="5%", pad=0.1)
cbar2 = fig.colorbar(im2, cax=cax2)
plt.savefig(dir+'delta_noise_data.pdf',dpi=300,bbox_inches="tight")
plt.savefig(dir+'delta_noise_data.png',dpi=300,bbox_inches="tight",transparent=True)
plt.show()
def log_prior(signal):
"""
Compute the log-prior of the signal (white noise field).
Args:
signal: [N, N] array, the signal (white noise field)
Returns:
log_prior: scalar, the log-prior value
"""
# Assuming a Gaussian prior with zero mean and unit variance
log_prior = -0.5 * jnp.sum(signal**2)
return log_prior
def log_likelihood(signal, data, noise_variance_field):
"""
Compute the (unnormalised) log-likelihood of the data given the signal and noise variance field.
Args:
data: [N, N] array, the observed data
signal: [N, N] array, the signal (white noise field)
noise_variance_field: [N, N] array, the noise variance field
Returns:
log_likelihood: scalar, the log-likelihood value
"""
delta = data_model(signal)
residual = data - delta
log_likelihood = -0.5 * jnp.sum(residual**2 / noise_variance_field)
return log_likelihood
def log_posterior(signal, data, noise_variance_field):
"""
Compute the (unnormalised) log-posterior of the signal given the data and noise variance field.
Args:
data: [N, N] array, the observed data
signal: [N, N] array, the signal (white noise field)
noise_variance_field: [N, N] array, the noise variance field
Returns:
log_posterior: scalar, the log-posterior value
"""
log_likelihood_value = log_likelihood(signal, data, noise_variance_field)
log_prior_value = log_prior(signal)
return log_likelihood_value + log_prior_value
# Test the functions with the groundtruth white noise field
log_prior(white_noise), log_likelihood(white_noise, data, noise_variance_field), log_posterior(white_noise, data, noise_variance_field)
(Array(-466.80988, dtype=float32), Array(-554.5358, dtype=float32), Array(-1021.3457, dtype=float32))
# Generate a new white noise field and compute the log-likelihood, log-prior, and log-posterior
white_noise_2 = jax.random.normal(jax.random.PRNGKey(2), (N, N)) # generate a new white noise field
log_likelihood(white_noise_2, data, noise_variance_field), log_prior(white_noise_2), log_posterior(white_noise_2, data, noise_variance_field)
(Array(-1487899.9, dtype=float32), Array(-509.3981, dtype=float32), Array(-1488409.2, dtype=float32))
# Compute the gradients of log_prior, log_likelihood, and log_posterior w.r.t. white noise using JAX autodiff
def d_log_prior_d_white_noise(white_noise):
"""
Compute the gradient of `log_prior` w.r.t. all elements of `white_noise`.
Args:
white_noise: [N, N] array
Returns:
grad: [N, N] array, the gradient of log_prior w.r.t. white_noise
"""
grad = -white_noise # Gradient of Gaussian prior with zero mean
return grad
def d_log_likelihood_d_white_noise_autodiff(white_noise, data, noise_variance_field):
"""
Compute the gradient of `log_likelihood` w.r.t. all elements of `white_noise` using JAX autodiff.
This function uses JAX's automatic differentiation to compute the gradient efficiently.
Args:
white_noise: [N, N] array
data: [N, N] array
noise_variance_field: [N, N] array
Returns:
grad: [N, N] array, the gradient of log_likelihood w.r.t. white_noise
"""
grad = jax.grad(log_likelihood)(white_noise, data, noise_variance_field) # shape (N,N)
return grad
def d_log_posterior_d_white_noise_autodiff(white_noise, data, noise_variance_field):
"""
Compute the gradient of `log_posterior` w.r.t. all elements of `white_noise` using JAX autodiff.
This function uses JAX's automatic differentiation to compute the gradient efficiently.
Args:
white_noise: [N, N] array
data: [N, N] array
noise_variance_field: [N, N] array
Returns:
grad: [N, N] array, the gradient of log_posterior w.r.t. white_noise
"""
grad = jax.grad(log_posterior)(white_noise, data, noise_variance_field) # shape (N,N)
return grad
# Compute the gradients of log_prior, log_likelihood, and log_posterior w.r.t. white noise using finite differences
def d_log_prior_d_white_noise_fd(white_noise, epsilon=1e-3):
"""
Compute the gradient of `log_prior` w.r.t. all elements of `white_noise` using finite differences.
Args:
white_noise: [N, N] array
epsilon: finite difference step size
Returns:
grad_fd: [N, N] array, finite difference gradient
"""
N = white_noise.shape[0]
grad_fd = jnp.zeros_like(white_noise)
for i in range(N):
for j in range(N):
white_noise_up = white_noise.at[i, j].add(+epsilon)
white_noise_down = white_noise.at[i, j].add(-epsilon)
f_up = log_prior(white_noise_up)
f_down = log_prior(white_noise_down)
grad_fd = grad_fd.at[i, j].set((f_up - f_down) / (2 * epsilon))
return grad_fd
def d_log_likelihood_d_white_noise_fd(white_noise, data, noise_variance_field, epsilon=1e-3):
"""
Compute the gradient of `log_likelihood` w.r.t. all elements of `white_noise` using finite differences.
Args:
white_noise: [N, N] array
data: [N, N] array
noise_variance_field: [N, N] array
epsilon: finite difference step size
Returns:
grad_fd: [N, N] array, finite difference gradient
"""
N = white_noise.shape[0]
grad_fd = jnp.zeros_like(white_noise)
for i in range(N):
for j in range(N):
white_noise_up = white_noise.at[i, j].add(+epsilon)
white_noise_down = white_noise.at[i, j].add(-epsilon)
f_up = log_likelihood(white_noise_up, data, noise_variance_field)
f_down = log_likelihood(white_noise_down, data, noise_variance_field)
grad_fd = grad_fd.at[i, j].set((f_up - f_down) / (2 * epsilon))
return grad_fd
def d_log_posterior_d_white_noise_fd(white_noise, data, noise_variance_field, epsilon=1e-3):
"""
Compute the gradient of `log_posterior` w.r.t. all elements of `white_noise` using finite differences.
Args:
white_noise: [N, N] array
data: [N, N] array
noise_variance_field: [N, N] array
epsilon: finite difference step size
Returns:
grad_fd: [N, N] array, finite difference gradient
"""
N = white_noise.shape[0]
grad_fd = jnp.zeros_like(white_noise)
for i in range(N):
for j in range(N):
white_noise_up = white_noise.at[i, j].add(+epsilon)
white_noise_down = white_noise.at[i, j].add(-epsilon)
f_up = log_posterior(white_noise_up, data, noise_variance_field)
f_down = log_posterior(white_noise_down, data, noise_variance_field)
grad_fd = grad_fd.at[i, j].set((f_up - f_down) / (2 * epsilon))
return grad_fd
grad_prior = d_log_prior_d_white_noise(white_noise)
grad_prior_fd = d_log_prior_d_white_noise_fd(white_noise)
grad_lh = d_log_likelihood_d_white_noise_autodiff(white_noise, data, noise_variance_field)
grad_lh_fd = d_log_likelihood_d_white_noise_fd(white_noise, data, noise_variance_field)
grad_post = d_log_posterior_d_white_noise_autodiff(white_noise, data, noise_variance_field)
grad_post_fd = d_log_posterior_d_white_noise_fd(white_noise, data, noise_variance_field)
fig, (ax0, ax1, ax2) = plt.subplots(1, 3, figsize=(18,6))
ax0.plot(np.arange(0, N*N), jnp.reshape(grad_prior, (N*N)), 'o', label='Autodiff', markersize=2, zorder=3)
ax0.plot(np.arange(0, N*N), jnp.reshape(grad_prior_fd, (N*N)), 'o', label='Finite differencing', markersize=4, zorder=2)
ax0.set_xlabel('Index of white noise field $s$')
ax0.set_ylabel('Gradient')
ax0.set_title('Gradient of log-prior w.r.t. $s$')
ax0.legend()
ax1.plot(np.arange(0, N*N), jnp.reshape(grad_lh, (N*N)), 'o', label='Autodiff', markersize=2, zorder=3)
ax1.plot(np.arange(0, N*N), jnp.reshape(grad_lh_fd, (N*N)), 'o', label='Finite differencing', markersize=4, zorder=2)
ax1.set_xlabel('Index of white noise field $s$')
ax1.set_title('Gradient of log-likelihood w.r.t. $s$')
ax1.legend()
ax2.plot(np.arange(0, N*N), jnp.reshape(grad_post, (N*N)), 'o', label='Autodiff', markersize=2, zorder=3)
ax2.plot(np.arange(0, N*N), jnp.reshape(grad_post_fd, (N*N)), 'o', label='Finite differencing', markersize=4, zorder=2)
ax2.set_xlabel('Index of white noise field $s$')
ax2.set_title('Gradient of log-posterior w.r.t. $s$')
ax2.legend()
plt.savefig(dir+'gradient_test.pdf',dpi=300,bbox_inches="tight")
plt.savefig(dir+'gradient_test.png',dpi=300,bbox_inches="tight",transparent=True)
plt.show()
import blackjax
from typing import Any, Tuple
def sample_white_noise_hmc(
white_noise_init: jnp.ndarray,
data: jnp.ndarray,
noise_variance_field: jnp.ndarray,
n_samples: int,
n_adapt: int = 1000, # Number of adaptation steps
rng_key: jax.random.PRNGKey = None,
) -> Tuple[jnp.ndarray, Any]:
if rng_key is None:
rng_key = jax.random.PRNGKey(0)
N = white_noise_init.shape[0]
def logprob(flat_white_noise):
white_noise = flat_white_noise.reshape((N, N))
return log_posterior(white_noise, data, noise_variance_field)
initial_position = jnp.ravel(white_noise_init)
# Set up the adaptation routine for HMC
adapt = blackjax.window_adaptation(
blackjax.hmc,
logprob,
num_steps=n_adapt, # Number of adaptation steps here
num_integration_steps=10, # Number of leapfrog steps
target_acceptance_rate=0.8, # Target acceptance rate
)
# Run the adaptation phase
state, kernel, adaptation_state = adapt.run(rng_key, initial_position)
# Sampling loop
def inference_loop(rng_key, kernel, initial_state, num_samples):
@jax.jit
def one_step(state, rng_key):
state, _ = kernel(rng_key, state)
return state, state
keys = jax.random.split(rng_key, num_samples)
_, states = jax.lax.scan(one_step, initial_state, keys)
return states
# Run the sampling
states = inference_loop(rng_key, kernel, state, n_samples)
# Reshape the samples to the original field shape
samples = states.position.reshape(n_samples, N, N)
return samples, states
n_samples = 10000 # Number of samples to draw
N_chains = 5 # Number of chains to run
initial_scaling = 0.001 # Initial scaling for the white noise field, start from an overdispersed state
try:
samples_chain = np.load('data/HMC_nonlinear_model/samples_chain.npy', allow_pickle=True).item()
except FileNotFoundError:
samples_chain = {}
for c in range(N_chains):
rng_key = jax.random.PRNGKey(42 + c)
field_init = jax.random.normal(rng_key, (N, N)) * initial_scaling
samples_chain[c], infos = sample_white_noise_hmc(
field_init,
data,
noise_variance_field,
n_samples=n_samples,
n_adapt=10,
rng_key=rng_key
)
os.makedirs('data/HMC_nonlinear_model/', exist_ok=True)
np.save('data/HMC_nonlinear_model/samples_chain.npy', samples_chain)
log_likelihoods_chain = {}
for c in range(N_chains):
log_likelihoods_chain[c] = jax.vmap(log_likelihood, in_axes=(0, None, None))(samples_chain[c], data, noise_variance_field)
fig, ax = plt.subplots(figsize=(6, 5))
ax.set_xlim(0, n_samples)
ax.set_ylim(450, 700)
n_thin_plot = 50 # Plot every nth element for clarity
ax.set_prop_cycle(cycler('color', [plt.cm.Set2(i) for i in np.linspace(0, 1, 8)]))
for c in range(N_chains):
ax.plot(np.arange(0, n_samples, n_thin_plot), -log_likelihoods_chain[c][::n_thin_plot], label=f'Chain {c+1}')
Nburnin = 1001
ax.axvline(Nburnin, color='k', linestyle=':', label='Burnt-in')
ax.set_xlabel('Sample index')
ax.set_ylabel('$-\\log \mathcal{L}$')
ax.set_title('Log-likelihood vs sample index')
ax.grid()
ax.legend()
plt.savefig(dir+'log_likelihood.pdf',dpi=300,bbox_inches="tight")
plt.savefig(dir+'log_likelihood.png',dpi=300,bbox_inches="tight",transparent=True)
plt.show()
fig, (ax0, ax1) = plt.subplots(2, 1, figsize=(10,6), sharex=True)
ax0.set_prop_cycle(cycler('color', [plt.cm.Set2(i) for i in np.linspace(0, 1, 8)]))
ax0.set_ylabel("Pixel (10,20)")
ax0.set_title("Trace plots for different chains")
for c in range(N_chains):
ax0.plot(np.arange(0, n_samples, n_thin_plot),samples_chain[c].T[10,20][::n_thin_plot],marker='.')
ax0.axhline(white_noise.T[10,20], color='black', linestyle='--', label='Groundtruth')
ax0.axvline(Nburnin,color='black',linestyle=':')
ax0.legend()
ax1.set_prop_cycle(cycler('color', [plt.cm.Set2(i) for i in np.linspace(0, 1, 8)]))
ax1.set_xlim(0,n_samples)
ax1.set_xlabel("Sample index")
ax1.set_ylabel("Pixel (25,10)")
for c in range(N_chains):
ax1.plot(np.arange(0, n_samples, n_thin_plot),samples_chain[c].T[25,10][::n_thin_plot],marker='.')
ax1.axhline(white_noise.T[25,10], color='black', linestyle='--', label='Groundtruth')
ax1.axvline(Nburnin,color='black',linestyle=':')
ax1.legend()
plt.savefig(dir+'trace_plot.pdf',dpi=300,bbox_inches="tight")
plt.savefig(dir+'trace_plot.png',dpi=300,bbox_inches="tight",transparent=True)
plt.show()
# Define a function to compute the power spectrum of the sampled delta fields
@jax.jit
def power_spectrum_2d_jitted(field, nbins=200, L=1.0):
# real‐space grid spacing
N = field.shape[0]
dx = L / N
# build k‐space grid for rfft2
kx = jnp.fft.fftfreq(N, d=dx) * 2 * np.pi # length N
ky = jnp.fft.rfftfreq(N, d=dx) * 2 * np.pi # length N//2+1
kx, ky = jnp.meshgrid(kx, ky, indexing='ij') # shape (N, N//2+1)
k = jnp.sqrt(kx**2 + ky**2)
kmag = k.flatten()
# Compute rfft2
ft = jnp.fft.rfft2(field)
power2d = (jnp.abs(ft) ** 2) * (dx ** 2) / (N ** 2)
power_flat = power2d.flatten()
# Bin edges (exclude k=0 when setting min)
k_nonzero = kmag[1:] if kmag.size > 1 else kmag # Remove k=0 for k_min
k_min = k_nonzero.min()
k_max = kmag.max()
bins = jnp.logspace(jnp.log10(k_min), jnp.log10(k_max), nbins + 1)
bin_idx = jnp.digitize(kmag, bins) - 1 # bin assignments
# Bin means
Pk_sum = jnp.bincount(bin_idx, weights=power_flat, length=nbins)
counts = jnp.bincount(bin_idx, length=nbins)
Pk = jnp.where(counts > 0, Pk_sum / counts, 0.0)
k_sum = jnp.bincount(bin_idx, weights=kmag, length=nbins)
k_bin = jnp.where(counts > 0, k_sum / counts, 0.0)
mask = counts > 0
return k_bin, Pk, counts
def power_spectrum_2d(field, nbins=200, L=1.0):
"""
Compute the 2D power spectrum of a 2D field using FFT and binning.
Args:
field: [N, N] array, the 2D field
nbins: int, number of bins for the power spectrum
L: float, box size
Returns:
k_bin: [nbins] array, the binned wavenumbers
Pk: [nbins] array, the binned power spectrum values
counts: [nbins] array, number of modes in each bin
"""
k_bin, Pk, counts = power_spectrum_2d_jitted(field)
mask = counts > 0
return k_bin[mask], Pk[mask]
k_vals, Pk_signal_gt = power_spectrum_2d(white_noise, nbins=200)
k_vals, Pk_delta_gt = power_spectrum_2d(delta, nbins=200)
# Compute the power spectrum of the signal (white noise)
Pk = np.zeros((Nburnin, len(k_vals)))
for i in range(0, Nburnin, 50):
k_vals, Pk[i] = power_spectrum_2d(samples_chain[0][i])
fig, (ax0, ax1) = plt.subplots(2, 1, figsize=(8, 6), sharex=True, gridspec_kw={'height_ratios': [3, 1], 'hspace': 0.})
ax0.loglog(k_vals, Pk_signal_gt, label='Groundtruth $s$', color='black', ls='--', zorder=5)
cmap = plt.colormaps.get_cmap('winter')
norm = plt.Normalize(0, Nburnin-1)
for i in range(0, Nburnin, 50):
color = cmap(norm(i))
ax0.loglog(k_vals, Pk[i], color=color, alpha=0.6)
# Colorbar inside the top panel
sm = cm.ScalarMappable(cmap=cmap, norm=norm)
sm.set_array([]) # Only needed for showing the colorbar
cax = ax0.inset_axes([0.05, 0.25, 0.4, 0.03]) # [x0, y0, width, height]
cbar = plt.colorbar(sm, cax=cax, orientation='horizontal')
cbar.set_label("Sample index", labelpad=6, loc='center')
cax.xaxis.set_ticks_position('bottom') # Ticks at the bottom (default for horizontal)
ax0.legend(loc=[0.05, 0.40])
ax0.set_ylabel(r"$P(k)$ [arbitrary units]")
ax0.set_title('Sequential posterior power spectrum of reconstructed signal')
ax0.grid()
ax1.loglog(k_vals, np.ones_like(k_vals), color='black', ls='--', zorder=5)
for i in range(0, Nburnin, 50):
color = cmap(norm(i))
ax1.loglog(k_vals, Pk[i]/Pk_signal_gt, color=color, alpha=0.6)
ax1.set_xlabel(r"$k$ [2$\pi$/L]")
ax1.set_ylabel("$P(k)/P_\\mathrm{truth}(k)$")
ax1.grid()
plt.savefig(dir+'power_spectrum_white_noise.pdf',dpi=300,bbox_inches="tight")
plt.savefig(dir+'power_spectrum_white_noise.png',dpi=300,bbox_inches="tight",transparent=True)
plt.show()
# Reconstruct the delta field from the sampled white noise fields
delta_samples_chain1 = jax.vmap(data_model)(samples_chain[0])
# Compute the power spectrum of the delta field samples
Pk_delta = np.zeros((Nburnin, len(k_vals)))
for i in range(0, Nburnin, 50):
k_vals, Pk_delta[i] = power_spectrum_2d(delta_samples_chain1[i])
fig, (ax0, ax1) = plt.subplots(2, 1, figsize=(8, 6), sharex=True, gridspec_kw={'height_ratios': [3, 1], 'hspace': 0.})
ax0.loglog(k_vals, Pk_delta_gt, label='Groundtruth $\\delta$', color='black', ls='--', zorder=5)
cmap = plt.colormaps.get_cmap('winter')
norm = plt.Normalize(0, Nburnin-1)
Pk = np.zeros((Nburnin, len(k_vals)))
for i in range(0, Nburnin, 50):
color = cmap(norm(i))
ax0.loglog(k_vals, Pk_delta[i], color=color, alpha=0.6)
# Colorbar inside the top panel
sm = cm.ScalarMappable(cmap=cmap, norm=norm)
sm.set_array([]) # Only needed for showing the colorbar
cax = ax0.inset_axes([0.05, 0.25, 0.4, 0.03]) # [x0, y0, width, height]
cbar = plt.colorbar(sm, cax=cax, orientation='horizontal')
cbar.set_label("Sample index", labelpad=6, loc='center')
cax.xaxis.set_ticks_position('bottom') # Ticks at the bottom (default for horizontal)
ax0.legend(loc=[0.05, 0.40])
ax0.set_ylabel("$P_\\delta(k)$ [arbitrary units]")
ax0.set_title('Sequential posterior power spectrum of $\\delta$ fields')
ax0.grid()
ax1.loglog(k_vals, np.ones_like(k_vals), color='black', ls='--', zorder=5)
for i in range(0, Nburnin, 50):
color = cmap(norm(i))
ax1.loglog(k_vals, Pk_delta[i]/Pk_delta_gt, color=color, alpha=0.6)
ax1.set_xlabel(r"$k$ [2$\pi$/L]")
ax1.set_ylabel("$P_\\delta(k)/P_{\\delta,\\mathrm{truth}}(k)$")
ax1.grid()
plt.savefig(dir+'power_spectrum_delta.pdf',dpi=300,bbox_inches="tight")
plt.savefig(dir+'power_spectrum_delta.png',dpi=300,bbox_inches="tight",transparent=True)
plt.show()
# Based on python code from the emcee tutorials, https://emcee.readthedocs.io/en/stable/tutorials/autocorr/
def next_power_of_2(n: int) -> int:
"""Smallest power of two ≥ n."""
return 1 << (n - 1).bit_length() if n > 0 else 1
# 1D autocorrelation function
def autocorr_func_1d(x, norm=True):
"""
Compute the 1D autocorrelation via FFT in O(N log N).
If norm=True, normalize so acf[0] = 1.
"""
x = np.asarray(x, dtype=float)
n = x.size
nfft = 2 * next_power_of_2(n)
# real FFT
f = np.fft.rfft(x - np.mean(x), n=nfft)
ps = (f * f.conjugate()).real # power spectrum
acf = np.fft.irfft(ps, n=nfft)[:n]
acf /= 2 * nfft
# normalise
if norm:
if acf[0] <= 0:
return 0 # or raise ValueError("Autocorrelation function is zero or negative at lag 0.")
else:
acf /= acf[0]
return acf
# Automated windowing procedure following Sokal (1989)
def auto_window(taus, c):
"""
Return the first lag k for which k < c * tau_k fails.
If none fail, return len(taus)-1.
"""
k = np.arange(len(taus))
mask = k < c * taus
# find first index where mask is False
idx = np.argmax(~mask)
return idx if mask[idx] == False else len(taus) - 1
# Following the suggestion from Goodman & Weare (2010)
def autocorr_gw2010(x, c: float = 5.0) -> float:
"""
Estimate the integrated autocorrelation time following
Goodman & Weare (2010), with window parameter c.
"""
acf = autocorr_func_1d(x, norm=True)
taus = 2.0 * np.cumsum(acf) - 1.0
window = auto_window(taus, c)
return taus[window]
def N_eff(x) -> float:
"""
Effective number of independent samples in x.
Accepts input as either a list or a 1D numpy array.
"""
x = np.asarray(x)
tau = autocorr_gw2010(x)
if tau <= 0:
return 0 # or raise ValueError("Autocorrelation time is zero or negative.")
return x.size / tau
n_thin_corr = 200
N_eff_10_20 = np.zeros((N_chains, n_samples // n_thin_corr), dtype=float)
N_eff_25_10 = np.zeros((N_chains, n_samples // n_thin_corr), dtype=float)
for c in range(N_chains):
for i, samp in enumerate(np.arange(n_thin_corr, n_samples + 1, n_thin_corr)):
# Use up-to samp (not samp:) and start with at least n_thin_corr samples
N_eff_10_20[c, i] = N_eff(samples_chain[c][:samp, 10, 20])
N_eff_25_10[c, i] = N_eff(samples_chain[c][:samp, 25, 10])
fig, (ax0, ax1) = plt.subplots(2, 1, figsize=(10,6), sharex=True)
ax0.set_prop_cycle(cycler('color', [plt.cm.Set2(i) for i in np.linspace(0, 1, 8)]))
ax0.set_ylabel('Pixel (10,20)')
ax0.set_title("ESS for different chains")
for c in range(N_chains):
ax0.plot(np.arange(0, n_samples, n_thin_corr), N_eff_10_20[c], label=f'Chain {c+1}', marker='.')
ax0.axvline(Nburnin,color='black',linestyle=':')
ax1.set_prop_cycle(cycler('color', [plt.cm.Set2(i) for i in np.linspace(0, 1, 8)]))
ax1.set_xlim(0,n_samples)
ax1.set_xlabel("Sample index")
ax1.set_ylabel('Pixel (25,10)')
for c in range(N_chains):
ax1.plot(np.arange(0, n_samples, n_thin_corr), N_eff_25_10[c], label=f'Chain {c+1}', marker='.')
ax1.axvline(Nburnin,color='black',linestyle=':')
ax1.legend(loc='best')
plt.savefig(dir+'ESS.pdf',dpi=300,bbox_inches="tight")
plt.savefig(dir+'ESS.png',dpi=300,bbox_inches="tight",transparent=True)
plt.show()
Gelman et al., "Bayesian Data Analysis" (third edition), p. 284-285
Parameters:
Definitions:
Estimators: Estimators of the marginal posterior variance of the estimand:
Test:
def gelman_rubin(chain):
# between chains variance
Psi_dotj = np.mean(chain, axis=1)
Psi_dotdot = np.mean(Psi_dotj, axis=0)
m = chain.shape[0]
n = chain.shape[1]
B = n / (m - 1.) * np.sum((Psi_dotj - Psi_dotdot)**2, axis=0)
# within chains variance
sj2 = np.var(chain, axis=1, ddof=1)
W = np.mean(sj2, axis=0)
# estimators
var_minus = W
var_plus = (n - 1.) / n * W + 1. / n * B
R_hat = np.sqrt(var_plus / var_minus)
return R_hat
# The input array must have dimensions (nchains, nsamp, npars) = (m, n, npars).
chain = jnp.array([samples_chain[c] for c in range(N_chains)])
chain = chain.reshape((N_chains, n_samples, N*N)) # Reshape to (nchains, nsamp, npars)
Rstat = gelman_rubin(chain).reshape((N, N))
# Visualize the Gelman-Rubin statistic
fig, ax = plt.subplots(figsize=(6,6))
im = ax.imshow(Rstat, vmin=1, vmax=2, origin='lower', cmap='bone_r')
ax.set_title('Gelman-Rubin statistic $\hat{R}$')
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.1)
cbar = fig.colorbar(im, cax=cax, extend='max')
plt.tight_layout()
plt.savefig(dir+'gelman_rubin_statistic.pdf',dpi=300,bbox_inches="tight")
plt.savefig(dir+'gelman_rubin_statistic.png',dpi=300,bbox_inches="tight",transparent=True)
plt.show()
# Join the different chains into one array for the posterior samples
samples = jnp.concatenate([samples_chain[c][Nburnin:] for c in range(N_chains)], axis=0)
PhiL = phiL_from_real_noise(white_noise)
PhiL_samples = jax.vmap(phiL_from_real_noise)(samples)
PhiNL = PhiL + f_NL * PhiL**2
PhiNL_samples = PhiL_samples + f_NL * PhiL_samples**2
delta_samples = jax.vmap(data_model)(samples)
@jax.jit
def compute_summaries(samples):
empirical_mean = jnp.mean(samples, axis=0)
empirical_var = jnp.var(samples, axis=0)
return empirical_mean, empirical_var
empirical_mean, empirical_var = compute_summaries(samples)
PhiL_mean, PhiL_var = compute_summaries(PhiL_samples)
PhiNL_mean, PhiNL_var = compute_summaries(PhiNL_samples)
delta_mean, delta_var = compute_summaries(delta_samples)
fig, (ax0, ax1, ax2) = plt.subplots(1, 3, figsize=(18, 6))
plt.subplots_adjust(wspace=0.25)
# visualize the signal field
im0 = ax0.imshow(white_noise, vmin=-max(-white_noise.min(),white_noise.max()), vmax=max(-white_noise.min(),white_noise.max()), origin='lower', cmap=planck)
ax0.set_title('Signal')
divider = make_axes_locatable(ax0)
cax0 = divider.append_axes("right", size="5%", pad=0.1)
cbar0 = fig.colorbar(im0, cax=cax0)
# visualize the empirical mean of Wiener filter samples
im1 = ax1.imshow(empirical_mean, vmin=-max(-white_noise.min(),white_noise.max()), vmax=max(-white_noise.min(),white_noise.max()), origin='lower', cmap=planck)
ax1.set_title('Empirical mean of samples')
divider = make_axes_locatable(ax1)
cax1 = divider.append_axes("right", size="5%", pad=0.1)
cbar1 = fig.colorbar(im1, cax=cax1)
# visualize the empirical variance of Wiener filter samples
im2 = ax2.imshow(empirical_var,
norm=LogNorm(vmin=empirical_var.min(), vmax=empirical_var.max()), origin='lower', cmap="Greys")
ax2.set_title('Empirical variance of samples')
divider = make_axes_locatable(ax2)
cax2 = divider.append_axes("right", size="5%", pad=0.1)
cbar2 = fig.colorbar(im2, cax=cax2)
plt.savefig(dir+'signal_mean_variance_samples.pdf',dpi=300,bbox_inches="tight")
plt.savefig(dir+'signal_mean_variance_samples.png',dpi=300,bbox_inches="tight",transparent=True)
plt.show()
fig, ((ax0a, ax1a, ax2a), (ax0b, ax1b, ax2b), (ax0c, ax1c, ax2c)) = plt.subplots(3, 3, figsize=(18, 16))
plt.subplots_adjust(wspace=0.25)
# visualize the groundtruth PhiL field
vmin = -max(-phi.min(), phi.max(), phiNL.min(), phiNL.max())
vmax = max(-phi.min(), phi.max(), phiNL.min(), phiNL.max())
im0a = ax0a.imshow(PhiL, vmin=vmin, vmax=vmax, origin='lower', cmap=planck)
ax0a.set_title('Groundtruth $\\Phi_\mathrm{L}$')
divider = make_axes_locatable(ax0a)
cax0a = divider.append_axes("right", size="5%", pad=0.1)
cbar0a = fig.colorbar(im0a, cax=cax0a)
# visualize the empirical mean of Wiener filter PhiL samples
im1a = ax1a.imshow(PhiL_mean, vmin=vmin, vmax=vmax, origin='lower', cmap=planck)
ax1a.set_title('Empirical mean of $\\Phi_\mathrm{L}$ samples')
divider = make_axes_locatable(ax1a)
cax1a = divider.append_axes("right", size="5%", pad=0.1)
cbar1a = fig.colorbar(im1a, cax=cax1a)
# visualize the empirical variance of Wiener filter PhiL samples
im2a = ax2a.imshow(PhiL_var,
norm=LogNorm(vmin=PhiL_var.min(), vmax=PhiL_var.max()), origin='lower', cmap="Greys")
ax2a.set_title('Empirical variance of $\\Phi_\mathrm{L}$ samples')
divider = make_axes_locatable(ax2a)
cax2a = divider.append_axes("right", size="5%", pad=0.1)
cbar2a = fig.colorbar(im2a, cax=cax2a)
# visualize the groundtruth PhiNL field
im0b = ax0b.imshow(PhiNL, vmin=vmin, vmax=vmax, origin='lower', cmap=planck)
ax0b.set_title('Groundtruth $\\Phi_\\mathrm{NL}$')
divider = make_axes_locatable(ax0b)
cax0b = divider.append_axes("right", size="5%", pad=0.1)
cbar0b = fig.colorbar(im0b, cax=cax0b)
# visualize the empirical mean of Wiener filter PhiNL samples
im1b = ax1b.imshow(PhiNL_mean, vmin=vmin, vmax=vmax, origin='lower', cmap=planck)
ax1b.set_title('Empirical mean of $\\Phi_\\mathrm{NL}$ samples')
divider = make_axes_locatable(ax1b)
cax1b = divider.append_axes("right", size="5%", pad=0.1)
cbar1b = fig.colorbar(im1b, cax=cax1b)
# visualize the empirical variance of Wiener filter PhiNL samples
im2b = ax2b.imshow(PhiNL_var,
norm=LogNorm(vmin=PhiNL_var.min(), vmax=PhiNL_var.max()), origin='lower', cmap="Greys")
ax2b.set_title('Empirical variance of $\\Phi_\\mathrm{NL}$ samples')
divider = make_axes_locatable(ax2b)
cax2b = divider.append_axes("right", size="5%", pad=0.1)
cbar2b = fig.colorbar(im2b, cax=cax2b)
# visualize the groundtruth delta field
im0c = ax0c.imshow(delta, vmin=-max(-delta.min(),delta.max()), vmax=max(-delta.min(),delta.max()), origin='lower', cmap=planck)
ax0c.set_title('Groundtruth $\\delta$')
divider = make_axes_locatable(ax0c)
cax0c = divider.append_axes("right", size="5%", pad=0.1)
cbar0c = fig.colorbar(im0c, cax=cax0c)
# visualize the empirical mean of Wiener filter delta samples
im1c = ax1c.imshow(delta_mean, vmin=-max(-delta.min(),delta.max()), vmax=max(-delta.min(),delta.max()), origin='lower', cmap=planck)
ax1c.set_title('Empirical mean of $\\delta$ samples')
divider = make_axes_locatable(ax1c)
cax1c = divider.append_axes("right", size="5%", pad=0.1)
cbar1c = fig.colorbar(im1c, cax=cax1c)
# visualize the empirical variance of Wiener filter delta samples
im2c = ax2c.imshow(delta_var,
norm=LogNorm(vmin=delta_var.min(), vmax=delta_var.max()), origin='lower', cmap="Greys")
ax2c.set_title('Empirical variance of $\\delta$ samples')
divider = make_axes_locatable(ax2c)
cax2c = divider.append_axes("right", size="5%", pad=0.1)
cbar2c = fig.colorbar(im2c, cax=cax2c)
plt.savefig(dir+'PhiL_PhiNL_delta_mean_variance_samples.pdf',dpi=300,bbox_inches="tight")
plt.savefig(dir+'PhiL_PhiNL_delta_mean_variance_samples.png',dpi=300,bbox_inches="tight",transparent=True)
plt.show()
Generate an animation of constrained realizations.
def create_cr_animation(samples, fname=dir+"constrained_realizations.mp4", fps=5, n_thin_plot=100, frame_dir=dir+"frames/"):
import imageio.v2 as imageio
import os
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
# Make output folder if doesn't exist
os.makedirs(frame_dir, exist_ok=True)
frame_files = []
thin_indices = np.arange(0, len(samples), n_thin_plot)
num_frames = len(thin_indices)
for frame_num, i in enumerate(thin_indices):
# Generate constrained realization
realization = samples[i]
frame_file = os.path.join(frame_dir, f"frame_{frame_num:03d}.png")
if not os.path.exists(frame_file):
fig, axs = plt.subplots(1, 3, figsize=(19, 6), dpi=320)
plt.subplots_adjust(left=0.02, right=0.97, top=0.97, bottom=0.03, wspace=0.30)
# visualize the signal field
signal = delta
im0 = axs[0].imshow(signal, vmin=-max(-delta.min(),delta.max()), vmax=max(-delta.min(),delta.max()), origin='lower', cmap=planck)
axs[0].set_title('Groundtruth $\\delta$')
divider = make_axes_locatable(axs[0])
cax0 = divider.append_axes("right", size="5%", pad=0.05)
cbar0 = fig.colorbar(im0, cax=cax0)
# visualize the data field
im1 = axs[1].imshow(data_v, vmin=-max(-delta.min(),delta.max()), vmax=max(-delta.min(),delta.max()), origin='lower', cmap=planck)
axs[1].set_title('Data')
divider = make_axes_locatable(axs[1])
cax1 = divider.append_axes("right", size="5%", pad=0.05)
cbar1 = fig.colorbar(im1, cax=cax1)
# visualize the constrained realization
im2 = axs[2].imshow(realization, vmin=-max(-delta.min(),delta.max()), vmax=max(-delta.min(),delta.max()), origin='lower', cmap=planck)
axs[2].set_title('Constrained realizations')
divider = make_axes_locatable(axs[2])
cax2 = divider.append_axes("right", size="5%", pad=0.05)
cbar2 = fig.colorbar(im2, cax=cax2, extend='both')
# Save frame
fig.savefig(frame_file,dpi=320)
plt.close(fig)
plt.clf()
frame_files.append(frame_file)
# Write animated gif
with imageio.get_writer(fname, fps=fps) as writer:
for filename in frame_files:
image = imageio.imread(filename)
writer.append_data(image)
print(f"Animation saved as {fname}")
create_cr_animation(delta_samples_chain1, fps=10, n_thin_plot=10)
/home/leclercq/.local/apps/anaconda/python3.10_2023.03/install/lib/python3.10/subprocess.py:1780: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock. self.pid = _posixsubprocess.fork_exec(
Animation saved as ./plots/HMC_nonlinear_model/constrained_realizations.mp4