Florent Leclercq,
Institut d'Astrophysique de Paris,
florent.leclercq@iap.fr
import numpy as np
import os
import jax
import jax.numpy as jnp
import scipy.linalg as la
from cycler import cycler
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from matplotlib.colors import LogNorm, SymLogNorm
import matplotlib.patches as mpatches
import matplotlib.lines as mlines
from mpl_toolkits.axes_grid1 import make_axes_locatable
np.random.seed(123456)
plt.rcParams.update({'lines.linewidth': 2})
plt.rcParams.update({'text.usetex': True})
plt.rcParams.update({'text.latex.preamble': r"\usepackage{amsmath}\usepackage{upgreek}"})
plt.rcParams.update({'font.family': 'serif'})
plt.rcParams.update({'font.size': 15})
dir="./plots/BHM_field_parameter_sampling/"
os.makedirs(dir, exist_ok=True)
# Plotting utilities
colorsDict = {
# Match pygtc up to v0.2.4
'blues_old' : ('#4c72b0','#7fa5e3','#b2d8ff'),
'greens_old' : ('#55a868','#88db9b','#bbffce'),
'yellows_old' : ('#f5964f','#ffc982','#fffcb5'),
'reds_old' : ('#c44e52','#f78185','#ffb4b8'),
'purples_old' : ('#8172b2','#b4a5e5','#37d8ff'),
# New color scheme, dark colors match matplotlib v2
'blues' : ('#1f77b4','#52aae7','#85ddff'),
'oranges' : ('#ff7f0e','#ffb241','#ffe574'),
'greens' : ('#2ca02c','#5fd35f','#92ff92'),
'reds' : ('#d62728','#ff5a5b','#ff8d8e'),
'purples' : ('#9467bd','#c79af0','#facdff'),
'browns' : ('#8c564b','#bf897e','#f2bcb1'),
'pinks' : ('#e377c2','#ffaaf5','#ffddff'),
'grays' : ('#7f7f7f','#b2b2b2','#e5e5e5'),
'yellows' : ('#bcbd22','#eff055','#ffff88'),
'cyans' : ('#17becf','#4af1ff','#7dffff'),
}
defaultColorsOrder = ['blues', 'oranges','greens', 'reds', 'purples',
'browns', 'pinks', 'grays', 'yellows', 'cyans']
colors = [colorsDict[cs] for cs in defaultColorsOrder]
def get_contours(Z, nBins=30, confLevels=(.3173, .0455, .0027)):
Z /= Z.sum()
nContourLevels = len(confLevels)
chainLevels = np.ones(nContourLevels+1)
histOrdered = np.sort(Z.flat)
histCumulative = np.cumsum(histOrdered)
nBinsFlat = np.linspace(0., nBins**2, nBins**2)
for l in range(nContourLevels):
# Find location of contour level in 1d histCumulative
temp = np.interp(confLevels[l], histCumulative, nBinsFlat)
# Find "height" of contour level
chainLevels[nContourLevels-1-l] = np.interp(temp, nBinsFlat, histOrdered)
return chainLevels
def get_contours_from_samples(samples_x, samples_y, weights=None, nBins=30, confLevels=(.3173, .0455, .0027), smoothingKernel=1):
from scipy.ndimage import gaussian_filter
nContourLevels = len(confLevels)
chainLevels = np.ones(nContourLevels+1)
extents = np.empty(4)
# These are needed to compute the contour levels
nBinsFlat = np.linspace(0., nBins**2, nBins**2)
# Create 2d histogram
if weights is None:
weights = np.ones_like(samples_x)
hist2d, xedges, yedges = np.histogram2d(
samples_x, samples_y, weights=weights, bins=nBins)
# image extent, needed below for contour lines
extents = [xedges[0], xedges[-1], yedges[0], yedges[-1]]
# Normalize
hist2d = hist2d/np.sum(hist2d)
# Cumulative 1d distribution
histOrdered = np.sort(hist2d.flat)
histCumulative = np.cumsum(histOrdered)
# Compute contour levels (from low to high for technical reasons)
for l in range(nContourLevels):
# Find location of contour level in 1d histCumulative
temp = np.interp(confLevels[l], histCumulative, nBinsFlat)
# Find "height" of contour level
chainLevels[nContourLevels-1-l] = np.interp(temp, nBinsFlat, histOrdered)
# Apply Gaussian smoothing
contours = gaussian_filter(hist2d.T, sigma=smoothingKernel)
xbins = (xedges[1:]+xedges[:-1])/2
ybins = (yedges[1:]+yedges[:-1])/2
return xbins, ybins, contours, chainLevels
# Download and define the Planck color map
from matplotlib.colors import ListedColormap
import os
if not os.path.isfile("data/Planck_Parchment_RGB.txt"):
!wget https://raw.githubusercontent.com/zonca/paperplots/master/data/Planck_Parchment_RGB.txt --directory-prefix=data/
planck = ListedColormap(np.loadtxt("data/Planck_Parchment_RGB.txt")/255.)
planck.set_bad("C7") # color of missing pixels
N = 32
L = 1.0 # box size
A_s = 6e-9 # power spectrum normalisation, arbitrary units
n_s = 0.96 # scalar spectral index
f_NL = 2000.0 # non-linear coupling parameter
D1 = 1.732e7 # growth factor of fluctuations at z=0, arbitrary units
def build_power_spectrum(N, A_s=1.0, n_s=0.96):
# real‐space grid spacing
dx = L / N
# build k‐space grid for rfft2
kx = jnp.fft.fftfreq(N, d=dx) * 2*np.pi # length N
ky = jnp.fft.rfftfreq(N, d=dx) * 2*np.pi # length N//2+1
kx, kx = jnp.meshgrid(kx, ky, indexing='ij') # shape (N, N//2+1)
k = jnp.sqrt(kx**2 + ky**2)
k = jnp.where(k == 0, 1.0, k) # prevent division by zero
# build Pk safely (zero at k=0)
Pk = A_s * k**(n_s - 1.0)
Pk = jnp.where(k == 0, 0, Pk) # Set mean mode to zero power
return Pk
@jax.jit
def phiL_from_real_noise(white_noise, A_s, n_s):
"""
Given real-space white noise, construct a Gaussian random field with power spectrum
P(k) = A_s * k^(n_s - 1).
Args:
white_noise: real array of shape [N, N] (from jax.random.normal(key, (N, N)))
A_s, n_s: parameters of power spectrum
Returns:
field: real array [N, N]
"""
N = white_noise.shape[0]
# Build the power spectrum
Pkgrid = build_power_spectrum(N, A_s=A_s, n_s=n_s)
# FFT the white noise to k-space
noise_k = jnp.fft.rfft2(white_noise)
# Multiply by sqrt of power spectrum
filtered_k = noise_k * jnp.sqrt(Pkgrid)
# IFFT back to real space
field = jnp.fft.irfft2(filtered_k).real
return field
def build_transfer_function(N):
"""Transfer function from primordial potential to density contrast."""
# real‐space grid spacing
dx = L / N
# build k‐space grid for rfft2
kx = jnp.fft.fftfreq(N, d=dx) * 2*np.pi # length N
ky = jnp.fft.rfftfreq(N, d=dx) * 2*np.pi # length N//2+1
kx, ky = jnp.meshgrid(kx, ky, indexing='ij') # shape (N, N//2+1)
k = jnp.sqrt(kx**2 + ky**2)
k = jnp.where(k == 0, 1.0, k) # prevent division by zero
# build Tk safely (zero at k=0)
T = np.zeros_like(k)
# shape parameter, characterising the wavenumber at equality
Omega_b = 0.049
Omega_m = 0.315
h = 0.674
fb = Omega_b / Omega_m
shape = Omega_m * h * np.exp(-Omega_b - np.sqrt(2.0*h) * fb)
# power spectrum, BBKS style
q = k / shape
alpha = 2.34; beta = 3.89; gamma = 16.1; delta = 5.46; epsilon = 6.71
aux = 1.0 + beta * q + (gamma * q)**2 + (delta * q)**3 + (epsilon * q)**4
T = (jnp.log(1.0 + alpha * q) / (alpha * q) * aux**-0.25);
T *= q**(1./2.) # correct scaling for delta(k) = D1 * T(k) * Phi_NL(k) since we want P_delta(k) \propto k^n_s T(k)**2 \propto k * T(k)**2 * P_phi(k)
T = jnp.where(k == 0, 0, T) # Set mean mode to zero transfer function
return T
Tgrid = build_transfer_function(N)
@jax.jit
def delta_from_phi(phi, f_NL, D1=D1):
"""
Given real-space Phi, compute Phi_NL and then delta such that
delta(k) = D1 * T(k)**2 * Phi_NL(k), and inverse FFT to real-space delta.
Args:
phi: [N, N] real array, the field Phi in real space
f_NL: scalar
Returns:
delta_real: [N, N] real array, delta in real space
"""
N = phi.shape[0]
# 1. Construct Phi_NL
Phi_NL = phi + f_NL * phi**2
# 2. FFT to Fourier space
PhiNL_k = jnp.fft.rfft2(Phi_NL)
# 3. Multiply by transfer function T(k)
delta_k = D1 * Tgrid * PhiNL_k
delta_k = delta_k.at[0, 0].set(0.0) # Explicitly zero the mean mode
# 4. IFFT back to real space
delta_real = jnp.fft.irfft2(delta_k).real
return delta_real
@jax.jit
def data_model(white_noise, A_s, n_s, f_NL, D1=D1):
"""
Given real-space white noise, construct a Gaussian random field with power spectrum
P(k) = A_s * k^(n_s - 1), and compute the corresponding delta field.
Args:
white_noise: real array of shape [N, N] (from jax.random.normal(key, (N, N)))
f_NL: scalar
D1: scalar
Returns:
delta_real: [N, N] real array, delta in real space
"""
phi = phiL_from_real_noise(white_noise, A_s, n_s)
delta = delta_from_phi(phi, f_NL, D1=D1)
return delta
key = jax.random.PRNGKey(12)
white_noise = jax.random.normal(key, (N, N))
phi = phiL_from_real_noise(white_noise, A_s, n_s)
delta = delta_from_phi(phi, f_NL)
fig, (ax0, ax1, ax2) = plt.subplots(1, 3, figsize=(18, 6))
plt.subplots_adjust(wspace=0.3)
phiNL=phi+f_NL*phi**2
# visualize the phi field
vmin = -max(-phi.min(), phi.max(), phiNL.min(), phiNL.max())
vmax = max(-phi.min(), phi.max(), phiNL.min(), phiNL.max())
im0 = ax0.imshow(phi, vmin=vmin, vmax=vmax, origin='lower', cmap=planck)
ax0.set_title('$\\Phi_\mathrm{L}$')
divider = make_axes_locatable(ax0)
cax0 = divider.append_axes("right", size="5%", pad=0.1)
cbar0 = fig.colorbar(im0, cax=cax0)
# visualize the phiNL field
im1 = ax1.imshow(phiNL, vmin=vmin, vmax=vmax, origin='lower', cmap=planck)
ax1.set_title('$\\Phi_\mathrm{NL}=\\Phi_\mathrm{L}+f_\mathrm{NL}\\Phi_\mathrm{L}^2$')
divider = make_axes_locatable(ax1)
cax1 = divider.append_axes("right", size="5%", pad=0.1)
cbar1 = fig.colorbar(im1, cax=cax1)
# visualize the density contrast field
im2 = ax2.imshow(delta, vmin=-max(delta.min(),delta.max()), vmax=max(delta.min(),delta.max()), origin='lower', cmap=planck)
ax2.set_title('Density contrast $\\delta$')
divider = make_axes_locatable(ax2)
cax2 = divider.append_axes("right", size="5%", pad=0.1)
cbar2 = fig.colorbar(im2, cax=cax2)
plt.savefig(dir+'phi_phiNL_delta.pdf',dpi=300,bbox_inches="tight")
plt.savefig(dir+'phi_phiNL_delta.png',dpi=300,bbox_inches="tight",transparent=True)
plt.show()
def make_noise_field():
"""
Build a 32×32 array “field” of noise variances:
- Border of width 3 pixels everywhere: value = 1e0
- Central 26×26 block: two regions of different noise variances
- A low‐noise region (1e-5)
- A medium‐noise region (5e-5) in the lower left corner
- A 2×2 high‐noise patch (1e0) centered in the array
Returns
-------
field : ndarray, shape (32,32)
The resulting noise‐variance field.
"""
N = 32
field = np.zeros((N, N), dtype=float)
mask = np.zeros((N, N), dtype=bool)
# 1) High‐noise border
bw = 3
high = 1
field[:bw, :] = high
field[-bw:, :] = high
field[:, :bw] = high
field[:, -bw:] = high
mask[:bw, :] = True
mask[-bw:, :] = True
mask[:, :bw] = True
mask[:, -bw:] = True
# 2) Central block (size 26×26)
low = 1e-5
med = 5e-5
for i in range(bw,N-bw):
for j in range(bw,N-bw):
if i-j<-3:
field[i,j] = med
else:
field[i,j] = low
# 3) Sprinkle a 2×2 high‐noise patch
ci = 24
cj = 24
field[ci-1:ci+1, cj-1:cj+1] = high
mask[ci-1:ci+1, cj-1:cj+1] = True
return field, mask
noise_variance_field, mask = make_noise_field()
invN = np.diag(1.0 / noise_variance_field.flatten())
invN[np.where(invN <= 1)] = 0.0
Data model: $d=\delta(s)+n$
# noise: one realisation
noise = np.random.normal(size=(N, N)) * np.sqrt(noise_variance_field)
noise_v = np.ma.masked_where(mask, noise) # for visualization purposes only
# data = delta + noise
data = delta + noise
data_v = np.ma.masked_where(mask, data) # for visualization purposes only
fig, (ax0, ax1, ax2) = plt.subplots(1, 3, figsize=(18, 6))
plt.subplots_adjust(wspace=0.25)
# visualize the delta field
im0 = ax0.imshow(delta, vmin=-max(-delta.min(),delta.max()), vmax=max(-delta.min(),delta.max()), origin='lower', cmap=planck)
ax0.set_title('Groundtruth $\\delta$')
divider = make_axes_locatable(ax0)
cax0 = divider.append_axes("right", size="5%", pad=0.1)
cbar0 = fig.colorbar(im0, cax=cax0)
# visualize the noise field
vmin, vmax = np.min(noise_v), np.max(noise_v)
linthresh = 1e-5
linscale = 1.0
norm = SymLogNorm(linthresh=linthresh, linscale=linscale,
vmin=-max(-vmin,vmax), vmax=max(-vmin,vmax), base=10)
cmap = plt.get_cmap('PiYG')
cmap.set_bad('C7') # color of missing pixels
im1 = ax1.imshow(noise_v, cmap=cmap, norm=norm, origin='lower')
ax1.set_title('Noise')
divider = make_axes_locatable(ax1)
cax1 = divider.append_axes("right", size="5%", pad=0.1)
cbar1 = fig.colorbar(im1, cax=cax1)
# visualize the data field
im2 = ax2.imshow(data_v, vmin=-max(data_v.min(),data_v.max()), vmax=max(data_v.min(),data_v.max()), origin='lower', cmap=planck)
ax2.set_title('Data')
divider = make_axes_locatable(ax2)
cax2 = divider.append_axes("right", size="5%", pad=0.1)
cbar2 = fig.colorbar(im2, cax=cax2)
plt.savefig(dir+'delta_noise_data.pdf',dpi=300,bbox_inches="tight")
plt.savefig(dir+'delta_noise_data.png',dpi=300,bbox_inches="tight",transparent=True)
plt.show()
# Define hyperparameters for the prior
A_s_scale = 6e-9
f_NL_scale = 2e3
mu_A_s = 0.9 * A_s_scale
sigma_A_s = 0.1 * A_s_scale
mu_n_s = 0.95
sigma_n_s = 0.01
mu_f_NL = 0.75 * f_NL_scale
sigma_f_NL = 0.25 * f_NL_scale
def unpack_theta(theta, N):
"""Helper for unpacking the flat theta into params and field for debugging etc."""
A_s = theta[0]*A_s_scale
n_s = theta[1]
f_NL = theta[2]*f_NL_scale
field = theta[3:].reshape((N, N))
return A_s, n_s, f_NL, field
def pack_theta(A_s, n_s, f_NL, field):
"""Helper for packing the params and field into a flat theta."""
theta = jnp.concatenate((jnp.array([A_s / A_s_scale, n_s, f_NL / f_NL_scale]), field.flatten()))
return theta
def A_s_from_theta(theta):
"""Extract A_s from the theta vector."""
return theta[0] * A_s_scale
def n_s_from_theta(theta):
"""Extract n_s from the theta vector."""
return theta[1]
def f_NL_from_theta(theta):
"""Extract f_NL from the theta vector."""
return theta[2] * f_NL_scale
def field_from_theta(theta):
"""Extract the field from the theta vector."""
return theta[3:].reshape((N, N))
def log_prior(theta, mu_A_s=mu_A_s, sigma_A_s=sigma_A_s, mu_n_s=mu_n_s, sigma_n_s=sigma_n_s, mu_f_NL=mu_f_NL, sigma_f_NL=sigma_f_NL):
"""
Separable Gaussian prior for (A_s, n_s, f_NL) and white noise.
Args:
theta: [A_s, n_s, f_NL, signal...] array, where signal is the white noise field
mu_A_s, sigma_A: mean and standard deviation for A_s
mu_n, sigma_n: mean and standard deviation for n_s
mu_f, sigma_f: mean and standard deviation for f_NL
Returns:
log_prior: scalar, the log-prior value
"""
from jax.scipy.stats import norm
A_s, n_s, f_NL, signal = unpack_theta(theta, N)
# Compute the prior value for A_s, n_s, f_NL
logp_A_s = norm.logpdf(A_s, loc=mu_A_s, scale=sigma_A_s)
logp_n_s = norm.logpdf(n_s, loc=mu_n_s, scale=sigma_n_s)
logp_f_NL = norm.logpdf(f_NL, loc=mu_f_NL, scale=sigma_f_NL)
# Assuming a Gaussian prior with zero mean and unit variance
log_prior_signal = -0.5 * jnp.sum(signal**2)
return log_prior_signal + logp_A_s + logp_n_s + logp_f_NL
def log_likelihood(theta, data, noise_variance_field):
"""
Compute the (unnormalised) log-likelihood of the data given the signal and noise variance field.
Args:
theta: [A_s, n_s, f_NL, signal...] array, where signal is the white noise field
data: [N, N] array, the observed data
noise_variance_field: [N, N] array, the noise variance field
Returns:
log_likelihood: scalar, the log-likelihood value
"""
A_s, n_s, f_NL, signal = unpack_theta(theta, N)
delta = data_model(signal, A_s=A_s, n_s=n_s, f_NL=f_NL)
residual = data - delta
log_likelihood = -0.5 * jnp.sum(residual**2 / noise_variance_field)
return log_likelihood
def log_posterior(theta, data, noise_variance_field):
"""
Compute the (unnormalised) log-posterior of the signal given the data and noise variance field.
Args:
theta: [A_s, n_s, f_NL, signal...] array, where signal is the white noise field
signal: [N, N] array, the signal (white noise field)
noise_variance_field: [N, N] array, the noise variance field
Returns:
log_posterior: scalar, the log-posterior value
"""
log_likelihood_value = log_likelihood(theta, data, noise_variance_field)
log_prior_value = log_prior(theta)
return log_likelihood_value + log_prior_value
# Test the functions with the groundtruth white noise field
theta = pack_theta(A_s=A_s, n_s=n_s, f_NL=f_NL, field=white_noise)
log_prior(theta), log_likelihood(theta, data, noise_variance_field), log_posterior(theta, data, noise_variance_field)
(Array(-451.44205, dtype=float32), Array(-554.5358, dtype=float32), Array(-1005.9779, dtype=float32))
# Generate a new white noise field and compute the log-likelihood, log-prior, and log-posterior
white_noise_2 = jax.random.normal(jax.random.PRNGKey(2), (N, N)) # generate a new white noise field
theta_2 = pack_theta(A_s=A_s, n_s=n_s, f_NL=f_NL, field=white_noise_2)
log_likelihood(theta_2, data, noise_variance_field), log_prior(theta_2), log_posterior(theta_2, data, noise_variance_field)
(Array(-1487899.6, dtype=float32), Array(-494.03027, dtype=float32), Array(-1488393.6, dtype=float32))
# Change cosmological parameters and compute the log-likelihood, log-prior, and log-posterior
A_s_new = 7e-9
n_s_new = 0.98
f_NL_new = 2500.0
theta_3 = pack_theta(A_s=A_s_new, n_s=n_s_new, f_NL=f_NL_new, field=white_noise_2)
log_likelihood(theta_3, data, noise_variance_field), log_prior(theta_3), log_posterior(theta_3, data, noise_variance_field)
(Array(-1752963.9, dtype=float32), Array(-502.58582, dtype=float32), Array(-1753466.5, dtype=float32))
# Compute the gradients of log_prior, log_likelihood, and log_posterior w.r.t. white noise using JAX autodiff
def d_log_prior_d_params(theta):
"""
Compute the gradient of `log_prior` w.r.t. all parameters.
Args:
theta: [A_s, n_s, f_NL, signal...] array, where signal is the white noise field
Returns:
grad: [N, N] array, the gradient of log_prior w.r.t. white_noise
"""
A_s, n_s, f_NL, white_noise = unpack_theta(theta, N)
grad_A_s = (A_s - mu_A_s) / sigma_A_s**2
grad_n_s = (n_s - mu_n_s) / sigma_n_s**2
grad_f_NL = (f_NL - mu_f_NL) / sigma_f_NL**2
grad_field = -white_noise # Gradient of Gaussian prior with zero mean
grad = pack_theta(grad_A_s, grad_n_s, grad_f_NL, grad_field)
return grad
def d_log_likelihood_d_params_autodiff(theta, data, noise_variance_field):
"""
Compute the gradient of `log_likelihood` w.r.t. all parameter using JAX autodiff.
This function uses JAX's automatic differentiation to compute the gradient efficiently.
Args:
theta: [A_s, n_s, f_NL, signal...] array, where signal is the white noise field
data: [N, N] array
noise_variance_field: [N, N] array
Returns:
grad: [N, N] array, the gradient of log_likelihood w.r.t. white_noise
"""
grad = jax.grad(log_likelihood)(theta, data, noise_variance_field) # shape (N,N)
return grad
def d_log_posterior_d_params_autodiff(theta, data, noise_variance_field):
"""
Compute the gradient of `log_posterior` w.r.t. all parameters using JAX autodiff.
This function uses JAX's automatic differentiation to compute the gradient efficiently.
Args:
theta: [A_s, n_s, f_NL, signal...] array, where signal is the white noise field
data: [N, N] array
noise_variance_field: [N, N] array
Returns:
grad: [N, N] array, the gradient of log_posterior w.r.t. white_noise
"""
grad = jax.grad(log_posterior)(theta, data, noise_variance_field) # shape (N,N)
return grad
# Compute the gradients of log_prior, log_likelihood, and log_posterior w.r.t. white noise using finite differences
def d_log_prior_d_params_fd(theta, epsilon=1e-3):
"""
Compute the gradient of `log_prior` w.r.t. all parameters using finite differences.
Args:
theta: [A_s, n_s, f_NL, signal...] array, where signal is the white noise field
epsilon: finite difference step size
Returns:
grad_fd: [N, N] array, finite difference gradient
"""
grad_fd = jnp.zeros(3+N*N) # 3 parameters + N*N white noise field
# Parameter grads
for k in range(3):
theta_up = theta.at[k].add(+epsilon)
theta_down = theta.at[k].add(-epsilon)
f_up = log_prior(theta_up)
f_down = log_prior(theta_down)
grad_fd = grad_fd.at[k].set((f_up - f_down) / (2 * epsilon))
# White noise field grads
for i in range(N):
for j in range(N):
idx = 3 + i*N + j
white_noise_up = theta[3:].at[i*N + j].add(+epsilon)
white_noise_down = theta[3:].at[i*N + j].add(-epsilon)
theta_up = pack_theta(A_s_from_theta(theta), n_s_from_theta(theta), f_NL_from_theta(theta), white_noise_up)
theta_down = pack_theta(A_s_from_theta(theta), n_s_from_theta(theta), f_NL_from_theta(theta), white_noise_down)
f_up = log_prior(theta_up)
f_down = log_prior(theta_down)
grad_fd = grad_fd.at[idx].set((f_up - f_down) / (2 * epsilon))
return grad_fd
def d_log_likelihood_d_params_fd(theta, data, noise_variance_field, epsilon=1e-3):
"""
Compute the gradient of `log_likelihood` w.r.t. all parameters using finite differences.
Args:
theta: [A_s, n_s, f_NL, signal...] array, where signal is the white noise field
data: [N, N] array
noise_variance_field: [N, N] array
epsilon: finite difference step size
Returns:
grad_fd: [N, N] array, finite difference gradient
"""
grad_fd = jnp.zeros(3+N*N) # 3 parameters + N*N white noise field
# Parameter grads
for k in range(3):
theta_up = theta.at[k].add(+epsilon)
theta_down = theta.at[k].add(-epsilon)
f_up = log_likelihood(theta_up, data, noise_variance_field)
f_down = log_likelihood(theta_down, data, noise_variance_field)
grad_fd = grad_fd.at[k].set((f_up - f_down) / (2 * epsilon))
# White noise field grads
for i in range(N):
for j in range(N):
idx = 3 + i*N + j
white_noise_up = theta[3:].at[i*N + j].add(+epsilon)
white_noise_down = theta[3:].at[i*N + j].add(-epsilon)
theta_up = pack_theta(A_s_from_theta(theta), n_s_from_theta(theta), f_NL_from_theta(theta), white_noise_up)
theta_down = pack_theta(A_s_from_theta(theta), n_s_from_theta(theta), f_NL_from_theta(theta), white_noise_down)
f_up = log_likelihood(theta_up, data, noise_variance_field)
f_down = log_likelihood(theta_down, data, noise_variance_field)
grad_fd = grad_fd.at[idx].set((f_up - f_down) / (2 * epsilon))
return grad_fd
def d_log_posterior_d_params_fd(theta, data, noise_variance_field, epsilon=1e-3):
"""
Compute the gradient of `log_posterior` w.r.t. all parameters using finite differences.
Args:
theta: [A_s, n_s, f_NL, signal...] array, where signal is the white noise field
data: [N, N] array
noise_variance_field: [N, N] array
epsilon: finite difference step size
Returns:
grad_fd: [N, N] array, finite difference gradient
"""
grad_fd = jnp.zeros(3+N*N) # 3 parameters + N*N white noise field
# Parameter grads
for k in range(3):
theta_up = theta.at[k].add(+epsilon)
theta_down = theta.at[k].add(-epsilon)
f_up = log_posterior(theta_up, data, noise_variance_field)
f_down = log_posterior(theta_down, data, noise_variance_field)
grad_fd = grad_fd.at[k].set((f_up - f_down) / (2 * epsilon))
# White noise field grads
for i in range(N):
for j in range(N):
idx = 3 + i*N + j
theta_up = theta.at[idx].add(+epsilon)
theta_down = theta.at[idx].add(-epsilon)
f_up = log_posterior(theta_up, data, noise_variance_field)
f_down = log_posterior(theta_down, data, noise_variance_field)
grad_fd = grad_fd.at[idx].set((f_up - f_down) / (2 * epsilon))
return grad_fd
theta = pack_theta(A_s=A_s, n_s=n_s, f_NL=f_NL, field=white_noise)
grad_prior = d_log_prior_d_params(theta)
grad_prior_A_s, grad_prior_n_s, grad_prior_f_NL, grad_prior_wn = unpack_theta(grad_prior, N)
grad_prior_fd = d_log_prior_d_params_fd(theta)
grad_prior_fd_A_s, grad_prior_fd_n_s, grad_prior_fd_f_NL, grad_prior_fd_wn = unpack_theta(grad_prior_fd, N)
grad_lh = d_log_likelihood_d_params_autodiff(theta, data, noise_variance_field)
grad_lh_A_s, grad_lh_n_s, grad_lh_f_NL, grad_lh_wn = unpack_theta(grad_lh, N)
grad_lh_fd = d_log_likelihood_d_params_fd(theta, data, noise_variance_field)
grad_lh_fd_A_s, grad_lh_fd_n_s, grad_lh_fd_f_NL, grad_lh_fd_wn = unpack_theta(grad_lh_fd, N)
grad_post = d_log_posterior_d_params_autodiff(theta, data, noise_variance_field)
grad_post_A_s, grad_post_n_s, grad_post_f_NL, grad_post_wn = unpack_theta(grad_post, N)
grad_post_fd = d_log_posterior_d_params_fd(theta, data, noise_variance_field)
grad_post_fd_A_s, grad_post_fd_n_s, grad_post_fd_f_NL, grad_post_fd_wn = unpack_theta(grad_post_fd, N)
fig, (ax0, ax1) = plt.subplots(1, 2, figsize=(12,6))
ax0.plot(np.arange(3), [grad_post_A_s, grad_post_n_s, grad_post_f_NL], 'o', label='Autodiff', markersize=2, zorder=3)
ax0.plot(np.arange(3), [grad_post_fd_A_s, grad_post_fd_n_s, grad_post_fd_f_NL], 'o', label='Finite differencing', markersize=4, zorder=2)
ax0.set_xticks(np.arange(3))
ax0.set_xticklabels(['$A_\mathrm{s}$', '$n_s$', '$f_\mathrm{NL}$'])
ax0.set_xlabel('Cosmological parameters')
ax0.set_title('Gradient of log-posterior w.r.t. parameters')
ax0.legend()
ax1.plot(np.arange(0, N*N), jnp.reshape(grad_post_wn, (N*N)), 'o', label='Autodiff', markersize=2, zorder=3)
ax1.plot(np.arange(0, N*N), jnp.reshape(grad_post_fd_wn, (N*N)), 'o', label='Finite differencing', markersize=4, zorder=2)
ax1.set_xlabel('Index of white noise field $s$')
ax1.set_title('Gradient of log-posterior w.r.t. $s$')
ax1.legend()
plt.savefig(dir+'gradient_test.pdf',dpi=300,bbox_inches="tight")
plt.savefig(dir+'gradient_test.png',dpi=300,bbox_inches="tight",transparent=True)
plt.show()
import blackjax
from typing import Any, Tuple
def sample_theta_hmc(
theta_init: jnp.ndarray,
data: jnp.ndarray,
noise_variance_field: jnp.ndarray,
n_samples: int,
n_adapt: int = 1000, # Number of adaptation steps
rng_key: jax.random.PRNGKey = None,
) -> Tuple[jnp.ndarray, Any]:
if rng_key is None:
rng_key = jax.random.PRNGKey(0)
N = data.shape[0]
def logprob(flat_theta):
return log_posterior(flat_theta, data, noise_variance_field)
initial_position = theta_init
# Adaptation
adapt = blackjax.window_adaptation(
blackjax.nuts,
logprob,
num_steps=n_adapt,
target_acceptance_rate=0.8,
)
state, kernel, adaptation_state = adapt.run(rng_key, initial_position)
# Sampling loop
def inference_loop(rng_key, kernel, initial_state, num_samples):
@jax.jit
def one_step(state, rng_key):
state, _ = kernel(rng_key, state)
return state, state
keys = jax.random.split(rng_key, num_samples)
_, states = jax.lax.scan(one_step, initial_state, keys)
return states
states = inference_loop(rng_key, kernel, state, n_samples)
samples = states.position
return samples, states
n_samples = 10000 # Number of samples to draw
N_chains = 5 # Number of chains to run
initial_scaling = 0.001 # Initial scaling for the white noise field, start from an overdispersed state
def draw_cosmo_theta0_from_prior(rng_key):
from scipy.stats import norm
# Draw A_s, n_s, f_NL from corresponding normals
key1, key2, key3 = jax.random.split(rng_key, 3)
A_s = norm.rvs(loc=mu_A_s, scale=sigma_A_s, random_state=int(key1[0]))
n_s = norm.rvs(loc=mu_n_s, scale=sigma_n_s, random_state=int(key2[0]))
f_NL = norm.rvs(loc=mu_f_NL, scale=sigma_f_NL, random_state=int(key3[0]))
return [A_s, n_s, f_NL]
try:
samples_chain = np.load('data/BHM_field_parameter_sampling/samples_chain.npy', allow_pickle=True).item()
except FileNotFoundError:
samples_chain = {}
for c in range(N_chains):
rng_key = jax.random.PRNGKey(42 + c)
subkey1, subkey2 = jax.random.split(rng_key)
theta0_cosmo = draw_cosmo_theta0_from_prior(subkey1)
field_init = jax.random.normal(subkey2, (N, N)) * initial_scaling
theta_init = pack_theta(theta0_cosmo[0], theta0_cosmo[1], theta0_cosmo[2], field_init)
samples_chain[c], infos = sample_theta_hmc(
theta_init,
data,
noise_variance_field,
n_samples=n_samples,
n_adapt=2000,
rng_key=rng_key
)
os.makedirs('data/BHM_field_parameter_sampling/', exist_ok=True)
np.save('data/BHM_field_parameter_sampling/samples_chain.npy', samples_chain)
log_likelihoods_chain = {}
for c in range(N_chains):
log_likelihoods_chain[c] = jax.vmap(log_likelihood, in_axes=(0, None, None))(samples_chain[c], data, noise_variance_field)
Nburnin = 100
n_thin_plot = 50 # Plot every nth element for clarity
fig, ax = plt.subplots(figsize=(6, 5))
ax.set_xlim(0, n_samples)
ax.set_ylim(450, 600)
ax.set_prop_cycle(cycler('color', [plt.cm.Set2(i) for i in np.linspace(0, 1, 8)]))
for c in range(N_chains):
ax.plot(np.arange(0, n_samples, n_thin_plot), -log_likelihoods_chain[c][::n_thin_plot], label=f'Chain {c+1}')
ax.axvline(Nburnin, color='k', linestyle=':', label='Burnt-in')
ax.set_xlabel('Sample index')
ax.set_ylabel('$-\\log \mathcal{L}$')
ax.set_title('Log-likelihood vs sample index')
ax.grid()
ax.legend()
plt.savefig(dir+'log_likelihood.pdf',dpi=300,bbox_inches="tight")
plt.savefig(dir+'log_likelihood.png',dpi=300,bbox_inches="tight",transparent=True)
plt.show()
fig, (ax0, ax1) = plt.subplots(2, 1, figsize=(10,6), sharex=True)
ax0.set_prop_cycle(cycler('color', [plt.cm.Set2(i) for i in np.linspace(0, 1, 8)]))
ax0.set_ylabel("Pixel (10,20)")
ax0.set_title("Trace plots for different chains")
for c in range(N_chains):
ax0.plot(np.arange(0, n_samples, n_thin_plot),jax.vmap(field_from_theta)(samples_chain[c])[::n_thin_plot,10,20],marker='.')
ax0.axhline(white_noise[10,20],color='black',linestyle='--')
ax0.axvline(Nburnin,color='black',linestyle=':')
ax1.set_prop_cycle(cycler('color', [plt.cm.Set2(i) for i in np.linspace(0, 1, 8)]))
ax1.set_xlim(0,n_samples)
ax1.set_xlabel("Sample index")
ax1.set_ylabel("Pixel (25,10)")
for c in range(N_chains):
ax1.plot(np.arange(0, n_samples, n_thin_plot),jax.vmap(field_from_theta)(samples_chain[c])[::n_thin_plot,25,10],marker='.')
ax1.axhline(white_noise[25,10],color='black',linestyle='--')
ax1.axvline(Nburnin,color='black',linestyle=':')
plt.savefig(dir+'trace_plot_pixels.pdf',dpi=300,bbox_inches="tight")
plt.savefig(dir+'trace_plot_pixels.png',dpi=300,bbox_inches="tight",transparent=True)
plt.show()
fig, (ax0, ax1, ax2) = plt.subplots(3, 1, figsize=(10,9), sharex=True)
ax0.set_prop_cycle(cycler('color', [plt.cm.Set2(i) for i in np.linspace(0, 1, 8)]))
ax0.set_ylabel("$A_\\mathrm{s}$")
ax0.set_title("Trace plots for different chains")
for c in range(N_chains):
ax0.plot(np.arange(0, n_samples, n_thin_plot),jax.vmap(A_s_from_theta)(samples_chain[c])[::n_thin_plot],marker='.')
ax0.axhline(mu_A_s,color='black',linestyle=':')
ax0.fill_between(np.arange(0, n_samples), np.ones(n_samples)*(mu_A_s + 2*sigma_A_s), np.ones(n_samples)*(mu_A_s - 2*sigma_A_s), color='black', alpha=0.1, label='Prior $2\\sigma$')
ax0.axhline(A_s,color='black',linestyle='--',label='Groundtruth')
ax0.axvline(Nburnin,color='black',linestyle=':')
ax0.legend(loc='lower left')
ax1.set_prop_cycle(cycler('color', [plt.cm.Set2(i) for i in np.linspace(0, 1, 8)]))
ax1.set_xlim(0,n_samples)
ax1.set_ylabel("$n_\\mathrm{s}$")
for c in range(N_chains):
ax1.plot(np.arange(0, n_samples, n_thin_plot),jax.vmap(n_s_from_theta)(samples_chain[c])[::n_thin_plot],marker='.')
ax1.axhline(mu_n_s,color='black',linestyle=':')
ax1.fill_between(np.arange(0, n_samples), np.ones(n_samples)*(mu_n_s + 2*sigma_n_s), np.ones(n_samples)*(mu_n_s - 2*sigma_n_s), color='black', alpha=0.1, label='Prior $2\\sigma$')
ax1.axhline(n_s,color='black',linestyle='--',label='Groundtruth')
ax1.axvline(Nburnin,color='black',linestyle=':')
ax2.set_prop_cycle(cycler('color', [plt.cm.Set2(i) for i in np.linspace(0, 1, 8)]))
ax2.set_xlim(0,n_samples)
ax2.set_ylabel("$f_\\mathrm{NL}$")
for c in range(N_chains):
ax2.plot(np.arange(0, n_samples, n_thin_plot),jax.vmap(f_NL_from_theta)(samples_chain[c])[::n_thin_plot],marker='.')
ax2.axhline(mu_f_NL,color='black',linestyle=':')
ax2.fill_between(np.arange(0, n_samples), np.ones(n_samples)*(mu_f_NL + 2*sigma_f_NL), np.ones(n_samples)*(mu_f_NL - 2*sigma_f_NL), color='black', alpha=0.1, label='Prior $2\\sigma$')
ax2.axhline(f_NL,color='black',linestyle='--',label='Groundtruth')
ax2.axvline(Nburnin,color='black',linestyle=':')
ax2.set_xlabel("Sample index")
plt.savefig(dir+'trace_plot_cosmology.pdf',dpi=300,bbox_inches="tight")
plt.savefig(dir+'trace_plot_cosmology.png',dpi=300,bbox_inches="tight",transparent=True)
plt.show()
# Define a function to compute the power spectrum of the sampled delta fields
@jax.jit
def power_spectrum_2d_jitted(field, nbins=200, L=1.0):
# real‐space grid spacing
N = field.shape[0]
dx = L / N
# build k‐space grid for rfft2
kx = jnp.fft.fftfreq(N, d=dx) * 2 * np.pi # length N
ky = jnp.fft.rfftfreq(N, d=dx) * 2 * np.pi # length N//2+1
kx, ky = jnp.meshgrid(kx, ky, indexing='ij') # shape (N, N//2+1)
k = jnp.sqrt(kx**2 + ky**2)
kmag = k.flatten()
# Compute rfft2
ft = jnp.fft.rfft2(field)
power2d = (jnp.abs(ft) ** 2) * (dx ** 2) / (N ** 2)
power_flat = power2d.flatten()
# Bin edges (exclude k=0 when setting min)
k_nonzero = kmag[1:] if kmag.size > 1 else kmag # Remove k=0 for k_min
k_min = k_nonzero.min()
k_max = kmag.max()
bins = jnp.logspace(jnp.log10(k_min), jnp.log10(k_max), nbins + 1)
bin_idx = jnp.digitize(kmag, bins) - 1 # bin assignments
# Bin means
Pk_sum = jnp.bincount(bin_idx, weights=power_flat, length=nbins)
counts = jnp.bincount(bin_idx, length=nbins)
Pk = jnp.where(counts > 0, Pk_sum / counts, 0.0)
k_sum = jnp.bincount(bin_idx, weights=kmag, length=nbins)
k_bin = jnp.where(counts > 0, k_sum / counts, 0.0)
mask = counts > 0
return k_bin, Pk, counts
def power_spectrum_2d(field, nbins=200, L=1.0):
"""
Compute the 2D power spectrum of a 2D field using FFT and binning.
Args:
field: [N, N] array, the 2D field
nbins: int, number of bins for the power spectrum
L: float, box size
Returns:
k_bin: [nbins] array, the binned wavenumbers
Pk: [nbins] array, the binned power spectrum values
counts: [nbins] array, number of modes in each bin
"""
k_bin, Pk, counts = power_spectrum_2d_jitted(field)
mask = counts > 0
return k_bin[mask], Pk[mask]
k_vals, Pk_signal_gt = power_spectrum_2d(white_noise, nbins=200)
k_vals, Pk_delta_gt = power_spectrum_2d(delta, nbins=200)
import blackjax
from typing import Any, Tuple
# Helper for univariate slice updates
def univariate_slice_sampler(logprob, x_init, rng_key, w=1.0, m=10, max_steps=50):
logprob0 = logprob(x_init)
e_key, l_key, r_key, i_key, _ = jax.random.split(rng_key, 5)
y = logprob0 - jnp.abs(jax.random.exponential(e_key))
u = jax.random.uniform(l_key)
left = x_init - u * w
right = left + w
def step_out_left_fn(val):
_left, _i = val
return (logprob(_left) > y) & (_i < m)
def step_out_left_body(val):
_left, _i = val
return (_left - w, _i + 1)
left, _ = jax.lax.while_loop(step_out_left_fn, step_out_left_body, (left, 0))
def step_out_right_fn(val):
_right, _i = val
return (logprob(_right) > y) & (_i < m)
def step_out_right_body(val):
_right, _i = val
return (_right + w, _i + 1)
right, _ = jax.lax.while_loop(step_out_right_fn, step_out_right_body, (right, 0))
def body_fn(val):
left, right, x, i = val
new_key = jax.random.fold_in(i_key, i)
x_new = jax.random.uniform(new_key) * (right - left) + left
lp = logprob(x_new)
cond = lp >= y
left = jnp.where((~cond) & (x_new < x_init), x_new, left)
right = jnp.where((~cond) & (x_new >= x_init), x_new, right)
x = jnp.where(cond, x_new, x)
i += 1
return (left, right, x, i)
def cond_fn(val):
left, right, x, i = val
return (logprob(x) < y) & (i < max_steps)
left, right, x, i = jax.lax.while_loop(cond_fn, body_fn, (left, right, x_init, 0))
return x
def hmc_white_noise(theta, data, noise_variance_field, n_adapt, rng_key):
Npix = (len(theta) - 3)
N = int(jnp.sqrt(Npix))
# logprob for field (fixed params)
def logprob(flat_field):
thet = jnp.concatenate([theta[:3], flat_field])
return log_posterior(thet, data, noise_variance_field)
initial_position = theta[3:]
adapt = blackjax.window_adaptation(
blackjax.hmc,
logprob,
num_steps=n_adapt, # adaptation steps
num_integration_steps=10,
target_acceptance_rate=0.8,
)
state, kernel, adaptation_state = adapt.run(rng_key, initial_position)
# 1 HMC draw, return flat field
key, subkey = jax.random.split(rng_key)
state, _ = kernel(subkey, state)
return state.position
def gibbs_sampler(
theta_init,
data,
noise_variance_field,
n_samples,
n_adapt=20, # Number of adaptation steps
rng_key=jax.random.PRNGKey(0)
):
Npix = len(theta_init) - 3
def logprob_white_noise(field_flat, cosmo_params):
theta = jnp.concatenate([cosmo_params, field_flat])
return log_posterior(theta, data, noise_variance_field)
# Adapt HMC step-size _once_
field_init = theta_init[3:]
cosmo_init = theta_init[:3]
def logprob(field_flat):
return logprob_white_noise(field_flat, cosmo_init)
hmc_adapt = blackjax.window_adaptation(
blackjax.hmc,
logprob,
num_steps=n_adapt,
num_integration_steps=10,
target_acceptance_rate=0.8,
)
state, kernel, adaptation_state = hmc_adapt.run(rng_key, field_init)
hmc_kernel = kernel # Use this kernel for all field updates!
# Helper: update each cosmo param by slice sampling
def slice_update_var(idx, theta, rng_key):
def logprob1d(x):
return log_posterior(theta.at[idx].set(x), data, noise_variance_field)
x_new = univariate_slice_sampler(logprob1d, theta[idx], rng_key)
return theta.at[idx].set(x_new)
# One step of Gibbs (scan body function)
def gibbs_step(theta, rng_key):
keys = jax.random.split(rng_key, 5)
# Slice for each cosmological param
theta = slice_update_var(0, theta, keys[0])
theta = slice_update_var(1, theta, keys[1])
theta = slice_update_var(2, theta, keys[2])
# HMC for the field (using fixed hmc_kernel)
field_flat = theta[3:]
state = blackjax.hmc.init(field_flat, lambda x: logprob_white_noise(x, theta[:3]))
state, _ = hmc_kernel(keys[3], state)
theta = theta.at[3:].set(state.position)
return theta, theta
# Run scan
key_seq = jax.random.split(rng_key, n_samples+1)
init_theta = theta_init
_, samples = jax.lax.scan(gibbs_step, init_theta, key_seq[:-1])
return samples # shape: [n_samples, num_params+N**2]
# Compute the power spectrum of the signal (white noise)
Nburnin = 1001
n_thin_PS = 50 # Plot every nth element for clarity
Pk = np.zeros((Nburnin, len(k_vals)))
samples_chain1 = jax.vmap(field_from_theta)(samples_chain[0])
for i in range(0, Nburnin, n_thin_PS):
k_vals, Pk[i] = power_spectrum_2d(samples_chain1[i])
fig, (ax0, ax1) = plt.subplots(2, 1, figsize=(8, 6), sharex=True, gridspec_kw={'height_ratios': [3, 1], 'hspace': 0.})
ax0.loglog(k_vals, Pk_signal_gt, label='Groundtruth $s$', color='black', ls='--', zorder=5)
cmap = plt.colormaps.get_cmap('winter')
norm = plt.Normalize(0, Nburnin-1)
for i in range(0, Nburnin, n_thin_PS):
color = cmap(norm(i))
ax0.loglog(k_vals, Pk[i], color=color, alpha=0.6)
# Colorbar inside the top panel
sm = cm.ScalarMappable(cmap=cmap, norm=norm)
sm.set_array([]) # Only needed for showing the colorbar
cax = ax0.inset_axes([0.05, 0.25, 0.4, 0.03]) # [x0, y0, width, height]
cbar = plt.colorbar(sm, cax=cax, orientation='horizontal')
cbar.set_label("Sample index", labelpad=6, loc='center')
cax.xaxis.set_ticks_position('bottom') # Ticks at the bottom (default for horizontal)
ax0.legend(loc=[0.05, 0.40])
ax0.set_ylabel(r"$P(k)$ [arbitrary units]")
ax0.set_title('Sequential posterior power spectrum of reconstructed signal')
ax0.grid()
ax1.loglog(k_vals, np.ones_like(k_vals), color='black', ls='--', zorder=5)
for i in range(0, Nburnin, n_thin_PS):
color = cmap(norm(i))
ax1.loglog(k_vals, Pk[i]/Pk_signal_gt, color=color, alpha=0.6)
ax1.set_xlabel(r"$k$ [2$\pi$/L]")
ax1.set_ylabel("$P(k)/P_\\mathrm{truth}(k)$")
ax1.grid()
plt.savefig(dir+'power_spectrum_white_noise.pdf',dpi=300,bbox_inches="tight")
plt.savefig(dir+'power_spectrum_white_noise.png',dpi=300,bbox_inches="tight",transparent=True)
plt.show()
# Reconstruct the delta field from the sampled white noise fields
def data_model_from_theta(theta):
A_s, n_s, f_NL, signal = unpack_theta(theta, N)
return data_model(signal, A_s, n_s, f_NL)
delta_samples_chain1 = jax.vmap(data_model_from_theta)(samples_chain[0])
# Compute the power spectrum of the delta field samples
Pk_delta = np.zeros((Nburnin, len(k_vals)))
for i in range(0, Nburnin, n_thin_PS):
k_vals, Pk_delta[i] = power_spectrum_2d(delta_samples_chain1[i])
fig, (ax0, ax1) = plt.subplots(2, 1, figsize=(8, 6), sharex=True, gridspec_kw={'height_ratios': [3, 1], 'hspace': 0.})
ax0.loglog(k_vals, Pk_delta_gt, label='Groundtruth $\\delta$', color='black', ls='--', zorder=5)
cmap = plt.colormaps.get_cmap('winter')
norm = plt.Normalize(0, Nburnin-1)
Pk = np.zeros((Nburnin, len(k_vals)))
for i in range(0, Nburnin, n_thin_PS):
color = cmap(norm(i))
ax0.loglog(k_vals, Pk_delta[i], color=color, alpha=0.6)
# Colorbar inside the top panel
sm = cm.ScalarMappable(cmap=cmap, norm=norm)
sm.set_array([]) # Only needed for showing the colorbar
cax = ax0.inset_axes([0.05, 0.25, 0.4, 0.03]) # [x0, y0, width, height]
cbar = plt.colorbar(sm, cax=cax, orientation='horizontal')
cbar.set_label("Sample index", labelpad=6, loc='center')
cax.xaxis.set_ticks_position('bottom') # Ticks at the bottom (default for horizontal)
ax0.legend(loc=[0.05, 0.40])
ax0.set_ylabel("$P_\\delta(k)$ [arbitrary units]")
ax0.set_title('Sequential posterior power spectrum of $\\delta$ fields')
ax0.grid()
ax1.loglog(k_vals, np.ones_like(k_vals), color='black', ls='--', zorder=5)
for i in range(0, Nburnin, n_thin_PS):
color = cmap(norm(i))
ax1.loglog(k_vals, Pk_delta[i]/Pk_delta_gt, color=color, alpha=0.6)
ax1.set_xlabel(r"$k$ [2$\pi$/L]")
ax1.set_ylabel("$P_\\delta(k)/P_{\\delta,\\mathrm{truth}}(k)$")
ax1.grid()
plt.savefig(dir+'power_spectrum_delta.pdf',dpi=300,bbox_inches="tight")
plt.savefig(dir+'power_spectrum_delta.png',dpi=300,bbox_inches="tight",transparent=True)
plt.show()
# Based on python code from the emcee tutorials, https://emcee.readthedocs.io/en/stable/tutorials/autocorr/
def next_power_of_2(n: int) -> int:
"""Smallest power of two ≥ n."""
return 1 << (n - 1).bit_length() if n > 0 else 1
# 1D autocorrelation function
def autocorr_func_1d(x, norm=True):
"""
Compute the 1D autocorrelation via FFT in O(N log N).
If norm=True, normalize so acf[0] = 1.
"""
x = np.asarray(x, dtype=float)
n = x.size
nfft = 2 * next_power_of_2(n)
# real FFT
f = np.fft.rfft(x - np.mean(x), n=nfft)
ps = (f * f.conjugate()).real # power spectrum
acf = np.fft.irfft(ps, n=nfft)[:n]
acf /= 2 * nfft
# normalise
if norm:
if acf[0] <= 0:
return 0 # or raise ValueError("Autocorrelation function is zero or negative at lag 0.")
else:
acf /= acf[0]
return acf
# Automated windowing procedure following Sokal (1989)
def auto_window(taus, c):
"""
Return the first lag k for which k < c * tau_k fails.
If none fail, return len(taus)-1.
"""
k = np.arange(len(taus))
mask = k < c * taus
# find first index where mask is False
idx = np.argmax(~mask)
return idx if mask[idx] == False else len(taus) - 1
# Following the suggestion from Goodman & Weare (2010)
def autocorr_gw2010(x, c: float = 5.0) -> float:
"""
Estimate the integrated autocorrelation time following
Goodman & Weare (2010), with window parameter c.
"""
acf = autocorr_func_1d(x, norm=True)
taus = 2.0 * np.cumsum(acf) - 1.0
window = auto_window(taus, c)
return taus[window]
def N_eff(x) -> float:
"""
Effective number of independent samples in x.
Accepts input as either a list or a 1D numpy array.
"""
x = np.asarray(x)
tau = autocorr_gw2010(x)
if tau <= 0:
return 0 # or raise ValueError("Autocorrelation time is zero or negative.")
return x.size / tau
n_thin_corr = 200
n_steps = ((n_samples - 1) // n_thin_corr) + 1
# Preallocate arrays for Neff results
N_eff_10_20 = np.zeros((N_chains, n_steps))
N_eff_25_10 = np.zeros((N_chains, n_steps))
N_eff_A_s = np.zeros((N_chains, n_steps))
N_eff_n_s = np.zeros((N_chains, n_steps))
N_eff_f_NL = np.zeros((N_chains, n_steps))
# Vectorized extractors
field_10_20_vmap = jax.vmap(lambda theta: field_from_theta(theta)[10, 20])
field_25_10_vmap = jax.vmap(lambda theta: field_from_theta(theta)[25, 10])
A_s_vmap = jax.vmap(A_s_from_theta)
n_s_vmap = jax.vmap(n_s_from_theta)
f_NL_vmap = jax.vmap(f_NL_from_theta)
# For each chain, process all thinning windows
for c in range(N_chains):
cur_chain = samples_chain[c]
for i, samp in enumerate(np.arange(n_thin_corr, n_samples + 1, n_thin_corr)):
this_slice = cur_chain[:samp]
N_eff_10_20[c, i] = N_eff(field_10_20_vmap(this_slice))
N_eff_25_10[c, i] = N_eff(field_25_10_vmap(this_slice))
N_eff_A_s[c, i] = N_eff(A_s_vmap(this_slice))
N_eff_n_s[c, i] = N_eff(n_s_vmap(this_slice))
N_eff_f_NL[c, i] = N_eff(f_NL_vmap(this_slice))
fig, (ax0, ax1) = plt.subplots(2, 1, figsize=(10,6), sharex=True)
ax0.set_prop_cycle(cycler('color', [plt.cm.Set2(i) for i in np.linspace(0, 1, 8)]))
ax0.set_ylabel('Pixel (10,20)')
ax0.set_title("ESS for different chains")
for c in range(N_chains):
ax0.plot(np.arange(0, n_samples, n_thin_corr), N_eff_10_20[c], label=f'Chain {c+1}', marker='.')
ax0.axvline(Nburnin,color='black',linestyle=':')
ax1.set_prop_cycle(cycler('color', [plt.cm.Set2(i) for i in np.linspace(0, 1, 8)]))
ax1.set_xlim(0,n_samples)
ax1.set_xlabel("Sample index")
ax1.set_ylabel('Pixel (25,10)')
for c in range(N_chains):
ax1.plot(np.arange(0, n_samples, n_thin_corr), N_eff_25_10[c], label=f'Chain {c+1}', marker='.')
ax1.axvline(Nburnin,color='black',linestyle=':')
ax1.legend(loc='best')
plt.savefig(dir+'ESS_pixels.pdf',dpi=300,bbox_inches="tight")
plt.savefig(dir+'ESS_pixels.png',dpi=300,bbox_inches="tight",transparent=True)
plt.show()
fig, (ax0, ax1, ax2) = plt.subplots(3, 1, figsize=(10,9), sharex=True)
ax0.set_prop_cycle(cycler('color', [plt.cm.Set2(i) for i in np.linspace(0, 1, 8)]))
ax0.set_ylabel('$A_\\mathrm{s}$')
ax0.set_title("ESS for different chains")
for c in range(N_chains):
ax0.plot(np.arange(0, n_samples, n_thin_corr), N_eff_A_s[c], label=f'Chain {c+1}', marker='.')
ax0.axvline(Nburnin,color='black',linestyle=':')
ax1.set_prop_cycle(cycler('color', [plt.cm.Set2(i) for i in np.linspace(0, 1, 8)]))
ax1.set_xlim(0,n_samples)
ax1.set_ylabel('$n_\\mathrm{s}$')
for c in range(N_chains):
ax1.plot(np.arange(0, n_samples, n_thin_corr), N_eff_n_s[c], label=f'Chain {c+1}', marker='.')
ax1.axvline(Nburnin,color='black',linestyle=':')
ax2.set_prop_cycle(cycler('color', [plt.cm.Set2(i) for i in np.linspace(0, 1, 8)]))
ax2.set_xlim(0,n_samples)
ax2.set_ylabel('$f_\\mathrm{NL}$')
for c in range(N_chains):
ax2.plot(np.arange(0, n_samples, n_thin_corr), N_eff_f_NL[c], label=f'Chain {c+1}', marker='.')
ax2.axvline(Nburnin,color='black',linestyle=':')
ax2.set_xlabel("Sample index")
ax2.legend(loc='best')
plt.savefig(dir+'ESS_cosmology.pdf',dpi=300,bbox_inches="tight")
plt.savefig(dir+'ESS_cosmology.png',dpi=300,bbox_inches="tight",transparent=True)
plt.show()
Gelman et al., "Bayesian Data Analysis" (third edition), p. 284-285
Parameters:
Definitions:
Estimators: Estimators of the marginal posterior variance of the estimand:
Test:
def gelman_rubin(chain):
# between chains variance
Psi_dotj = np.mean(chain, axis=1)
Psi_dotdot = np.mean(Psi_dotj, axis=0)
m = chain.shape[0]
n = chain.shape[1]
B = n / (m - 1.) * np.sum((Psi_dotj - Psi_dotdot)**2, axis=0)
# within chains variance
sj2 = np.var(chain, axis=1, ddof=1)
W = np.mean(sj2, axis=0)
# estimators
var_minus = W
var_plus = (n - 1.) / n * W + 1. / n * B
R_hat = np.sqrt(var_plus / var_minus)
return R_hat
# The input array must have dimensions (nchains, nsamp, npars) = (m, n, npars).
# The Gelman-Rubin function expects (n_chains, n_samples, n_variates)
Rhat = gelman_rubin(jnp.array([samples_chain[c] for c in range(N_chains)]))
# Run Gelman-Rubin
Rhat_A_s = Rhat[0]
Rhat_n_s = Rhat[1]
Rhat_fNL = Rhat[2]
Rhat_field_flat = Rhat[3:] # shape: (N*N,)
Rhat_field = Rhat_field_flat.reshape((N, N))
print("Gelman-Rubin stats:")
print("A_s:", Rhat_A_s)
print("n_s:", Rhat_n_s)
print("fNL:", Rhat_fNL)
Gelman-Rubin stats: A_s: 1.0006064 n_s: 1.0000048 fNL: 1.000764
# Visualize the Gelman-Rubin statistic
fig, ax = plt.subplots(figsize=(6,6))
im = ax.imshow(Rhat_field, vmin=1, vmax=Rhat_field.max(), origin='lower', cmap='bone_r')
ax.set_title('Gelman-Rubin statistic $\hat{R}$')
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.1)
cbar = fig.colorbar(im, cax=cax)
plt.tight_layout()
plt.savefig(dir+'gelman_rubin_statistic.pdf',dpi=300,bbox_inches="tight")
plt.savefig(dir+'gelman_rubin_statistic.png',dpi=300,bbox_inches="tight",transparent=True)
plt.show()
# Join the different chains into one array for the posterior samples
samples = jnp.concatenate([samples_chain[c][Nburnin:] for c in range(N_chains)], axis=0)
white_noise_samples = jax.vmap(field_from_theta)(samples)
A_s_samples = jax.vmap(A_s_from_theta)(samples)
n_s_samples = jax.vmap(n_s_from_theta)(samples)
f_NL_samples = jax.vmap(f_NL_from_theta)(samples)
PhiL = phiL_from_real_noise(white_noise, A_s, n_s)
PhiL_samples = jax.vmap(phiL_from_real_noise)(white_noise_samples, A_s_samples, n_s_samples)
PhiNL = PhiL + f_NL * PhiL**2
PhiNL_samples = PhiL_samples + f_NL * PhiL_samples**2
delta_samples = jax.vmap(data_model)(white_noise_samples, A_s_samples, n_s_samples, f_NL_samples)
@jax.jit
def compute_summaries(samples):
empirical_mean = jnp.mean(samples, axis=0)
empirical_var = jnp.var(samples, axis=0)
return empirical_mean, empirical_var
empirical_mean, empirical_var = compute_summaries(white_noise_samples)
PhiL_mean, PhiL_var = compute_summaries(PhiL_samples)
PhiNL_mean, PhiNL_var = compute_summaries(PhiNL_samples)
delta_mean, delta_var = compute_summaries(delta_samples)
fig, (ax0, ax1, ax2) = plt.subplots(1, 3, figsize=(18, 6))
plt.subplots_adjust(wspace=0.25)
# visualize the signal field
im0 = ax0.imshow(white_noise, vmin=-max(-white_noise.min(),white_noise.max()), vmax=max(-white_noise.min(),white_noise.max()), origin='lower', cmap=planck)
ax0.set_title('Signal')
divider = make_axes_locatable(ax0)
cax0 = divider.append_axes("right", size="5%", pad=0.1)
cbar0 = fig.colorbar(im0, cax=cax0)
# visualize the empirical mean of Wiener filter samples
im1 = ax1.imshow(empirical_mean, vmin=-max(-white_noise.min(),white_noise.max()), vmax=max(-white_noise.min(),white_noise.max()), origin='lower', cmap=planck)
ax1.set_title('Empirical mean of samples')
divider = make_axes_locatable(ax1)
cax1 = divider.append_axes("right", size="5%", pad=0.1)
cbar1 = fig.colorbar(im1, cax=cax1)
# visualize the empirical variance of Wiener filter samples
im2 = ax2.imshow(empirical_var,
norm=LogNorm(vmin=empirical_var.min(), vmax=empirical_var.max()), origin='lower', cmap="Greys")
ax2.set_title('Empirical variance of samples')
divider = make_axes_locatable(ax2)
cax2 = divider.append_axes("right", size="5%", pad=0.1)
cbar2 = fig.colorbar(im2, cax=cax2)
plt.savefig(dir+'signal_mean_variance_samples.pdf',dpi=300,bbox_inches="tight")
plt.savefig(dir+'signal_mean_variance_samples.png',dpi=300,bbox_inches="tight",transparent=True)
plt.show()
fig, ((ax0a, ax1a, ax2a), (ax0b, ax1b, ax2b), (ax0c, ax1c, ax2c)) = plt.subplots(3, 3, figsize=(18, 16))
plt.subplots_adjust(wspace=0.25)
# visualize the groundtruth PhiL field
vmin = -max(-phi.min(), phi.max(), phiNL.min(), phiNL.max())
vmax = max(-phi.min(), phi.max(), phiNL.min(), phiNL.max())
im0a = ax0a.imshow(PhiL, vmin=vmin, vmax=vmax, origin='lower', cmap=planck)
ax0a.set_title('Groundtruth $\\Phi_\mathrm{L}$')
divider = make_axes_locatable(ax0a)
cax0a = divider.append_axes("right", size="5%", pad=0.1)
cbar0a = fig.colorbar(im0a, cax=cax0a)
# visualize the empirical mean of Wiener filter PhiL samples
im1a = ax1a.imshow(PhiL_mean, vmin=vmin, vmax=vmax, origin='lower', cmap=planck)
ax1a.set_title('Empirical mean of $\\Phi_\mathrm{L}$ samples')
divider = make_axes_locatable(ax1a)
cax1a = divider.append_axes("right", size="5%", pad=0.1)
cbar1a = fig.colorbar(im1a, cax=cax1a)
# visualize the empirical variance of Wiener filter PhiL samples
im2a = ax2a.imshow(PhiL_var,
norm=LogNorm(vmin=PhiL_var.min(), vmax=PhiL_var.max()), origin='lower', cmap="Greys")
ax2a.set_title('Empirical variance of $\\Phi_\mathrm{L}$ samples')
divider = make_axes_locatable(ax2a)
cax2a = divider.append_axes("right", size="5%", pad=0.1)
cbar2a = fig.colorbar(im2a, cax=cax2a)
# visualize the groundtruth PhiNL field
im0b = ax0b.imshow(PhiNL, vmin=vmin, vmax=vmax, origin='lower', cmap=planck)
ax0b.set_title('Groundtruth $\\Phi_\\mathrm{NL}$')
divider = make_axes_locatable(ax0b)
cax0b = divider.append_axes("right", size="5%", pad=0.1)
cbar0b = fig.colorbar(im0b, cax=cax0b)
# visualize the empirical mean of Wiener filter PhiNL samples
im1b = ax1b.imshow(PhiNL_mean, vmin=vmin, vmax=vmax, origin='lower', cmap=planck)
ax1b.set_title('Empirical mean of $\\Phi_\\mathrm{NL}$ samples')
divider = make_axes_locatable(ax1b)
cax1b = divider.append_axes("right", size="5%", pad=0.1)
cbar1b = fig.colorbar(im1b, cax=cax1b)
# visualize the empirical variance of Wiener filter PhiNL samples
im2b = ax2b.imshow(PhiNL_var,
norm=LogNorm(vmin=PhiNL_var.min(), vmax=PhiNL_var.max()), origin='lower', cmap="Greys")
ax2b.set_title('Empirical variance of $\\Phi_\\mathrm{NL}$ samples')
divider = make_axes_locatable(ax2b)
cax2b = divider.append_axes("right", size="5%", pad=0.1)
cbar2b = fig.colorbar(im2b, cax=cax2b)
# visualize the groundtruth delta field
im0c = ax0c.imshow(delta, vmin=-max(-delta.min(),delta.max()), vmax=max(-delta.min(),delta.max()), origin='lower', cmap=planck)
ax0c.set_title('Groundtruth $\\delta$')
divider = make_axes_locatable(ax0c)
cax0c = divider.append_axes("right", size="5%", pad=0.1)
cbar0c = fig.colorbar(im0c, cax=cax0c)
# visualize the empirical mean of Wiener filter delta samples
im1c = ax1c.imshow(delta_mean, vmin=-max(-delta.min(),delta.max()), vmax=max(-delta.min(),delta.max()), origin='lower', cmap=planck)
ax1c.set_title('Empirical mean of $\\delta$ samples')
divider = make_axes_locatable(ax1c)
cax1c = divider.append_axes("right", size="5%", pad=0.1)
cbar1c = fig.colorbar(im1c, cax=cax1c)
# visualize the empirical variance of Wiener filter delta samples
im2c = ax2c.imshow(delta_var,
norm=LogNorm(vmin=delta_var.min(), vmax=delta_var.max()), origin='lower', cmap="Greys")
ax2c.set_title('Empirical variance of $\\delta$ samples')
divider = make_axes_locatable(ax2c)
cax2c = divider.append_axes("right", size="5%", pad=0.1)
cbar2c = fig.colorbar(im2c, cax=cax2c)
plt.savefig(dir+'PhiL_PhiNL_delta_mean_variance_samples.pdf',dpi=300,bbox_inches="tight")
plt.savefig(dir+'PhiL_PhiNL_delta_mean_variance_samples.png',dpi=300,bbox_inches="tight",transparent=True)
plt.show()
def gaussian2d_grid(x, y, mu_x, mu_y, sigma_x, sigma_y, rho=0.0):
"""
Returns a normalized Gaussian density grid evaluated on (x, y).
x, y = meshgrid arrays.
"""
X = x - mu_x
Y = y - mu_y
z = (X**2 / sigma_x**2 + Y**2 / sigma_y**2 - 2*rho*X*Y/(sigma_x*sigma_y)) / (2*(1 - rho**2))
norm = 1.0 / (2*np.pi*sigma_x*sigma_y*np.sqrt(1 - rho**2))
return norm * np.exp(-z)
nBins = 50
A_s_grid = np.linspace(mu_A_s - 4*sigma_A_s, mu_A_s + 4*sigma_A_s, nBins)
n_s_grid = np.linspace(mu_n_s - 4*sigma_n_s, mu_n_s + 4*sigma_n_s, nBins)
f_NL_grid = np.linspace(mu_f_NL - 4*sigma_f_NL, mu_f_NL + 4*sigma_f_NL, nBins)
A_s_mesh, n_s_mesh = np.meshgrid(A_s_grid, n_s_grid, indexing='ij')
prior_density_A_s_n_s = gaussian2d_grid(A_s_mesh, n_s_mesh, mu_A_s, mu_n_s, sigma_A_s, sigma_n_s)
prior_chainLevels_A_s_n_s = get_contours(prior_density_A_s_n_s, nBins=nBins)
A_s_mesh, f_NL_mesh = np.meshgrid(A_s_grid, f_NL_grid, indexing='ij')
prior_density_A_s_f_NL = gaussian2d_grid(A_s_mesh, f_NL_grid, mu_A_s, mu_f_NL, sigma_A_s, sigma_f_NL)
prior_chainLevels_A_s_f_NL = get_contours(prior_density_A_s_f_NL, nBins=nBins)
n_s_mesh, f_NL_mesh = np.meshgrid(n_s_grid, f_NL_grid, indexing='ij')
prior_density_n_s_f_NL = gaussian2d_grid(n_s_mesh, f_NL_grid, mu_n_s, mu_f_NL, sigma_n_s, sigma_f_NL)
prior_chainLevels_n_s_f_NL = get_contours(prior_density_n_s_f_NL, nBins=nBins)
fig = plt.figure(figsize=(10,10))
plt.subplots_adjust(hspace=0.25, wspace=0.25)
n_thin_scatter = 5
ax0 = fig.add_subplot(2,2,1)
xbins, ybins, contours, chainLevels = get_contours_from_samples(A_s_samples, n_s_samples)
nContourLevels = len(chainLevels)
ax0.contourf(A_s_grid, n_s_grid, prior_density_A_s_n_s, levels=prior_chainLevels_A_s_n_s,
colors=colors[0][:nContourLevels][::-1])
ax0.contour(A_s_grid, n_s_grid, prior_density_A_s_n_s, levels=prior_chainLevels_A_s_n_s,
colors=colors[0][:nContourLevels][::-1])
ax0.scatter(A_s_samples[::n_thin_scatter], n_s_samples[::n_thin_scatter], color="black", s=2)
ax0.contourf(xbins, ybins, contours, levels=chainLevels,
colors=colors[2][:nContourLevels][::-1], alpha=0.5)
ax0.contour(xbins, ybins, contours, levels=chainLevels,
colors=colors[2][:nContourLevels][::-1])
ax0.axhline(n_s, color='black', ls='--', lw=1)
ax0.axvline(A_s, color='black', ls='--', lw=1)
ax0.set_xlabel("$A_\\mathrm{s}$")
ax0.set_ylabel("$n_\\mathrm{s}$")
prior_color = colors[0][1]
post_color = colors[2][1]
legend_elements = [
mlines.Line2D([], [], color=prior_color, linestyle='solid', linewidth=2, label='Prior'),
mlines.Line2D([], [], color=post_color, linestyle='solid', linewidth=2, label='Posterior'),
mlines.Line2D([], [], color="black", marker="o", linestyle='None', markersize=2, label='Samples'),
mlines.Line2D([], [], color="black", linestyle='--', linewidth=1, label='Groundtruth')
]
ax0.legend(handles=legend_elements, bbox_to_anchor=(1.12, 0.36), loc="upper left", borderaxespad=0)
ax1 = fig.add_subplot(2,2,3)
xbins, ybins, contours, chainLevels = get_contours_from_samples(A_s_samples, f_NL_samples)
nContourLevels = len(chainLevels)
ax1.contourf(A_s_grid, f_NL_grid, prior_density_A_s_f_NL, levels=prior_chainLevels_A_s_f_NL,
colors=colors[0][:nContourLevels][::-1])
ax1.contour(A_s_grid, f_NL_grid, prior_density_A_s_f_NL, levels=prior_chainLevels_A_s_f_NL,
colors=colors[0][:nContourLevels][::-1])
ax1.scatter(A_s_samples, f_NL_samples, color="black", s=2)
ax1.contourf(xbins, ybins, contours, levels=chainLevels,
colors=colors[2][:nContourLevels][::-1], alpha=0.5)
ax1.contour(xbins, ybins, contours, levels=chainLevels,
colors=colors[2][:nContourLevels][::-1])
ax1.axhline(f_NL, color='black', ls='--', lw=1)
ax1.axvline(A_s, color='black', ls='--', lw=1)
ax1.set_xlabel("$A_\\mathrm{s}$")
ax1.set_ylabel("$f_\\mathrm{NL}$")
ax2 = fig.add_subplot(2,2,4)
xbins, ybins, contours, chainLevels = get_contours_from_samples(n_s_samples, f_NL_samples)
nContourLevels = len(chainLevels)
ax2.contourf(n_s_grid, f_NL_grid, prior_density_n_s_f_NL, levels=prior_chainLevels_n_s_f_NL,
colors=colors[0][:nContourLevels][::-1])
ax2.contour(n_s_grid, f_NL_grid, prior_density_n_s_f_NL, levels=prior_chainLevels_n_s_f_NL,
colors=colors[0][:nContourLevels][::-1])
ax2.scatter(n_s_samples, f_NL_samples, color="black", s=2)
ax2.contourf(xbins, ybins, contours, levels=chainLevels,
colors=colors[2][:nContourLevels][::-1], alpha=0.5)
ax2.contour(xbins, ybins, contours, levels=chainLevels,
colors=colors[2][:nContourLevels][::-1])
ax2.axhline(f_NL, color='black', ls='--', lw=1)
ax2.axvline(n_s, color='black', ls='--', lw=1)
ax2.set_xlabel("$n_\\mathrm{s}$")
ax2.set_ylabel("$f_\\mathrm{NL}$")
plt.savefig(dir+'triangle_plot_cosmology.pdf',dpi=300,bbox_inches="tight")
plt.savefig(dir+'triangle_plot_cosmology.png',dpi=300,bbox_inches="tight",transparent=True)
plt.show()
Generate an animation of samples (constrained realizations and cosmological parameters).
A_s_samples_chain1 = jax.vmap(A_s_from_theta)(samples_chain[0])
n_s_samples_chain1 = jax.vmap(n_s_from_theta)(samples_chain[0])
f_NL_samples_chain1 = jax.vmap(f_NL_from_theta)(samples_chain[0])
def create_cr_animation(A_s_samples_chain1, n_s_samples_chain1, f_NL_samples_chain1, delta_samples_chain1, fname=dir+"constrained_realizations.mp4", fps=5, n_thin_plot=100, frame_dir=dir+"frames/"):
import imageio.v2 as imageio
import os
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
# Make output folder if doesn't exist
os.makedirs(frame_dir, exist_ok=True)
frame_files = []
thin_indices = np.arange(0, len(delta_samples_chain1), n_thin_plot)
num_frames = len(thin_indices)
for frame_num, frame in enumerate(thin_indices):
frame_file = os.path.join(frame_dir, f"frame_{frame_num:03d}.png")
if not os.path.exists(frame_file):
# Plot the frame
fig, ((ax0, axf), (ax1, ax2)) = plt.subplots(2, 2, figsize=(10, 10), dpi=320)
plt.subplots_adjust(left=0.09, right=0.93, top=0.95, bottom=0.09, wspace=0.30)
current_A_s = A_s_samples[frame]
current_n_s = n_s_samples[frame]
current_f_NL = f_NL_samples[frame]
realization = delta_samples[frame]
ax0.contourf(A_s_grid, n_s_grid, prior_density_A_s_n_s, levels=prior_chainLevels_A_s_n_s,
colors=colors[0][:nContourLevels][::-1])
ax0.contour(A_s_grid, n_s_grid, prior_density_A_s_n_s, levels=prior_chainLevels_A_s_n_s,
colors=colors[0][:nContourLevels][::-1])
ax0.scatter(A_s_samples[:frame: n_thin_plot], n_s_samples[:frame: n_thin_plot], color="black", s=2, zorder=4)
ax0.scatter([current_A_s], [current_n_s], color="C3", s=10, zorder=5)
ax0.axhline(n_s, color='black', ls='--', lw=1)
ax0.axvline(A_s, color='black', ls='--', lw=1)
ax0.set_xlabel("$A_\\mathrm{s}$")
ax0.set_ylabel("$n_\\mathrm{s}$")
# # visualize the constrained realization
im2 = axf.imshow(realization, vmin=-max(-delta.min(),delta.max()), vmax=max(-delta.min(),delta.max()), origin='lower', cmap=planck)
axf.set_title('Constrained realizations')
divider = make_axes_locatable(axf)
cax2 = divider.append_axes("right", size="5%", pad=0.05)
cbar2 = fig.colorbar(im2, cax=cax2, extend='both')
ax1.contourf(A_s_grid, f_NL_grid, prior_density_A_s_f_NL, levels=prior_chainLevels_A_s_f_NL,
colors=colors[0][:nContourLevels][::-1])
ax1.contour(A_s_grid, f_NL_grid, prior_density_A_s_f_NL, levels=prior_chainLevels_A_s_f_NL,
colors=colors[0][:nContourLevels][::-1])
ax1.scatter(A_s_samples[:frame: n_thin_plot], f_NL_samples[:frame: n_thin_plot], color="black", s=2, zorder=4)
ax1.scatter([current_A_s], [current_f_NL], color="C3", s=10, zorder=5)
ax1.axhline(f_NL, color='black', ls='--', lw=1)
ax1.axvline(A_s, color='black', ls='--', lw=1)
ax1.set_xlabel("$A_\\mathrm{s}$")
ax1.set_ylabel("$f_\\mathrm{NL}$")
ax2.contourf(n_s_grid, f_NL_grid, prior_density_n_s_f_NL, levels=prior_chainLevels_n_s_f_NL,
colors=colors[0][:nContourLevels][::-1])
ax2.contour(n_s_grid, f_NL_grid, prior_density_n_s_f_NL, levels=prior_chainLevels_n_s_f_NL,
colors=colors[0][:nContourLevels][::-1])
ax2.scatter(n_s_samples[:frame: n_thin_plot], f_NL_samples[:frame: n_thin_plot], color="black", s=2, zorder=4)
ax2.scatter([current_n_s], [current_f_NL], color="C3", s=10, zorder=5)
ax2.axhline(f_NL, color='black', ls='--', lw=1)
ax2.axvline(n_s, color='black', ls='--', lw=1)
ax2.set_xlabel("$n_\\mathrm{s}$")
ax2.set_ylabel("$f_\\mathrm{NL}$")
# Save frame
fig.savefig(frame_file,dpi=320)
plt.close(fig)
plt.clf()
frame_files.append(frame_file)
# Write animated gif
with imageio.get_writer(fname, fps=fps) as writer:
for filename in frame_files:
image = imageio.imread(filename)
writer.append_data(image)
print(f"Animation saved as {fname}")
create_cr_animation(A_s_samples_chain1, n_s_samples_chain1, f_NL_samples_chain1, delta_samples_chain1, fps=5, n_thin_plot=10)
/home/leclercq/.local/apps/anaconda/python3.10_2023.03/install/lib/python3.10/subprocess.py:1780: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock. self.pid = _posixsubprocess.fork_exec(
Animation saved as ./plots/BHM_field_parameter_sampling/constrained_realizations.mp4
<Figure size 640x480 with 0 Axes>