Florent Leclercq,
Institut d'Astrophysique de Paris,
florent.leclercq@iap.fr
import numpy as np
import scipy.linalg as la
import jax
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm, SymLogNorm
from mpl_toolkits.axes_grid1 import make_axes_locatable
np.random.seed(123456)
plt.rcParams.update({'lines.linewidth': 2})
plt.rcParams.update({'text.usetex': True})
plt.rcParams.update({'text.latex.preamble': r"\usepackage{amsmath}\usepackage{upgreek}"})
plt.rcParams.update({'font.family': 'serif'})
plt.rcParams.update({'font.size': 15})
dir="./plots/Wiener_filter_denoising/"
os.makedirs(dir, exist_ok=True)
# Download and define the Planck color map
from matplotlib.colors import ListedColormap
import os
if not os.path.isfile("data/Planck_Parchment_RGB.txt"):
!wget https://raw.githubusercontent.com/zonca/paperplots/master/data/Planck_Parchment_RGB.txt --directory-prefix=data/
planck = ListedColormap(np.loadtxt("data/Planck_Parchment_RGB.txt")/255.)
planck.set_bad("C7") # color of missing pixels
N = 32
L = 1.0 # box size
A_s = 6e5 # power spectrum normalisation, arbitrary units
n_s = 0.96 # scalar spectral index
def P_of_k(k, A_s=A_s, n_s=n_s):
"""Power spectrum of a Gaussian random field."""
k = np.atleast_1d(k)
P = np.empty_like(k)
# shape parameter, characterising the wavenumber at equality
Omega_b = 0.049
Omega_m = 0.315
h = 0.674
fb = Omega_b / Omega_m
shape = Omega_m * h * np.exp(-Omega_b - np.sqrt(2.0*h) * fb)
# power spectrum, BBKS style
q = k / shape
alpha = 2.34; beta = 3.89; gamma = 16.1; delta = 5.46; epsilon = 6.71
aux = 1.0 + beta * q + (gamma * q)**2 + (delta * q)**3 + (epsilon * q)**4
T = np.log(1.0 + alpha * q) / (alpha * q) * aux**-0.25;
P = A_s * q**n_s * T**2
# if input was scalar, return scalar
return P[0] if np.isscalar(k) else P
k_modes = np.logspace(-4, 0, 100)
Pk = np.array([P_of_k(k) for k in k_modes])
plt.loglog(k_modes, Pk, color="C0")
plt.xlabel(r"$k$ [2$\pi$/L]")
plt.ylabel(r"$P(k)$ [arbitrary units]")
plt.grid()
plt.savefig(dir+'power_spectrum.pdf',dpi=300,bbox_inches="tight")
plt.savefig(dir+'power_spectrum.png',dpi=300,bbox_inches="tight",transparent=True)
plt.show()
def signal_covariance_matrix(A_s, n_s, N=32, L=1.0):
"""
Compute the exact real‐space covariance matrix C_ij = <δ(x_i) δ(x_j)>
for a 2D periodic GRF on an N×N grid of side L with power spectrum
given by the function above.
Returns
-------
S : ndarray, shape (N*N, N*N)
Covariance matrix in pixel‐index basis (row‐major flattening).
"""
# real‐space grid spacing
dx = L / N
# build k‐space grid for rfft2
kx = np.fft.fftfreq(N, d=dx) * 2*np.pi # length N
ky = np.fft.rfftfreq(N, d=dx) * 2*np.pi # length N//2+1
KX, KY = np.meshgrid(kx, ky, indexing='ij') # shape (N, N//2+1)
K = np.sqrt(KX**2 + KY**2)
# build Pk safely (zero at K=0)
Pk = np.zeros_like(K)
mask = (K > 0)
Pk[mask] = P_of_k(K[mask], A_s=A_s, n_s=n_s)
# compute the 2D covariance as inverse‐rfft
# cov_map[m,n] = (1/N²) ∑_k Pk e^{i k·(Δx,Δy)}
cov_map = np.fft.irfft2(Pk) # shape (N, N)
# now build the full N²×N² matrix
# flatten index → (ix,iy)
idx = np.arange(N*N)
ix = idx // N
iy = idx % N
# compute separation indices mod N (periodic)
dx_idx = (ix[:,None] - ix[None,:]) % N # shape (N², N²)
dy_idx = (iy[:,None] - iy[None,:]) % N # shape (N², N²)
# lookup into cov_map
S = cov_map[dx_idx, dy_idx] # shape (N², N²)
return S
S = signal_covariance_matrix(A_s, n_s, N, L)
sqrtS = la.sqrtm(S).real
# visualize the signal covariance matrix
vmin, vmax = np.min(S), np.max(S)
linthresh = 1e-5
linscale = 1.0
norm = SymLogNorm(linthresh=linthresh, linscale=linscale,
vmin=-max(-vmin,vmax), vmax=max(-vmin,vmax), base=10)
fig, ax = plt.subplots(figsize=(6,6))
im = ax.imshow(S, cmap='coolwarm', norm=norm)
ax.set_title('Signal covariance matrix')
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.1)
cbar = fig.colorbar(im, cax=cax)
plt.tight_layout()
plt.savefig(dir+'signal_covariance_matrix.pdf',dpi=300,bbox_inches="tight")
plt.savefig(dir+'signal_covariance_matrix.png',dpi=300,bbox_inches="tight",transparent=True)
plt.show()
def make_noise_field():
"""
Build a 32×32 array “field” of noise variances:
- Border of width 3 pixels everywhere: value = 1e0
- Central 26×26 block: two regions of different noise variances
- A low‐noise region (1e-5)
- A medium‐noise region (5e-5) in the lower left corner
- A 2×2 high‐noise patch (1e0) centered in the array
Returns
-------
field : ndarray, shape (32,32)
The resulting noise‐variance field.
"""
N = 32
field = np.zeros((N, N), dtype=float)
mask = np.zeros((N, N), dtype=bool)
# 1) High‐noise border
bw = 3
high = 1
field[:bw, :] = high
field[-bw:, :] = high
field[:, :bw] = high
field[:, -bw:] = high
mask[:bw, :] = True
mask[-bw:, :] = True
mask[:, :bw] = True
mask[:, -bw:] = True
# 2) Central block (size 26×26)
low = 1e-5
med = 5e-5
for i in range(bw,N-bw):
for j in range(bw,N-bw):
if i-j<-3:
field[i,j] = med
else:
field[i,j] = low
# 3) Sprinkle a 2×2 high‐noise patch
ci = 24
cj = 24
field[ci-1:ci+1, cj-1:cj+1] = high
mask[ci-1:ci+1, cj-1:cj+1] = True
return field, mask
noise_variance_field, mask = make_noise_field()
invN = np.diag(1.0 / noise_variance_field.flatten())
invN[np.where(invN <= 1)] = 0.0
# visualize the noise variance field
fig, ax = plt.subplots(figsize=(6,6))
im = ax.imshow(noise_variance_field, cmap='Reds',
norm=LogNorm(vmin=noise_variance_field.min(), vmax=noise_variance_field.max()), origin='lower')
ax.set_title('Noise variance field')
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.1)
cbar = fig.colorbar(im, cax=cax, format="%.0e")
plt.tight_layout()
plt.savefig(dir+'noise_variance_field.pdf',dpi=300,bbox_inches="tight")
plt.savefig(dir+'noise_variance_field.png',dpi=300,bbox_inches="tight",transparent=True)
plt.show()
# visualize the inverse of the noise covariance matrix
fig, ax = plt.subplots(figsize=(6,6))
im = ax.imshow(invN, cmap='Reds',
norm=LogNorm(vmin=1e-5, vmax=invN.max()))
ax.set_title('Inverse of the noise covariance matrix')
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.1)
cbar = fig.colorbar(im, cax=cax, format="%.0e")
plt.tight_layout()
plt.savefig(dir+'inverse_noise_covariance_matrix.pdf',dpi=300,bbox_inches="tight")
plt.savefig(dir+'inverse_noise_covariance_matrix.png',dpi=300,bbox_inches="tight",transparent=True)
plt.show()
Data model: $d=s+n$
# signal: one single realization
key = jax.random.PRNGKey(12)
white_noise = jax.random.normal(key, (N, N)).reshape(N*N)
signal = (sqrtS.dot(white_noise)).reshape(N, N)
# visualize the signal field
fig, ax = plt.subplots(figsize=(6,6))
im = ax.imshow(signal, vmin=-max(-signal.min(),signal.max()), vmax=max(-signal.min(),signal.max()), origin='lower', cmap=planck)
ax.set_title('Signal')
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.1)
cbar = fig.colorbar(im, cax=cax)
plt.tight_layout()
plt.savefig(dir+'signal.pdf',dpi=300,bbox_inches="tight")
plt.savefig(dir+'signal.png',dpi=300,bbox_inches="tight",transparent=True)
plt.show()
# noise: one realisation
noise = np.random.normal(size=(N, N)) * np.sqrt(noise_variance_field)
noise_v = np.ma.masked_where(mask, noise) # for visualization purposes only
# visualize the noise field
vmin, vmax = np.min(noise_v), np.max(noise_v)
linthresh = 1e-5
linscale = 1.0
norm = SymLogNorm(linthresh=linthresh, linscale=linscale,
vmin=-max(-vmin,vmax), vmax=max(-vmin,vmax), base=10)
fig, ax = plt.subplots(figsize=(6,6))
cmap = plt.get_cmap('PiYG')
cmap.set_bad('C7') # color of missing pixels
im = ax.imshow(noise_v, cmap=cmap, norm=norm, origin='lower')
ax.set_title('Noise')
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.1)
cbar = fig.colorbar(im, cax=cax)
plt.tight_layout()
plt.savefig(dir+'noise.pdf',dpi=300,bbox_inches="tight")
plt.savefig(dir+'noise.png',dpi=300,bbox_inches="tight",transparent=True)
plt.show()
# data = signal + noise
data = signal + noise
data_v = signal + noise_v # for visualization purposes only
# visualize the data field
fig, ax = plt.subplots(figsize=(6,6))
im = ax.imshow(data_v, vmin=-max(-data_v.min(),data_v.max()), vmax=max(-data_v.min(),data_v.max()), origin='lower', cmap=planck)
ax.set_title('Data')
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.1)
cbar = fig.colorbar(im, cax=cax)
plt.tight_layout()
plt.savefig(dir+'data.pdf',dpi=300,bbox_inches="tight")
plt.savefig(dir+'data.png',dpi=300,bbox_inches="tight",transparent=True)
plt.show()
fig, (ax0, ax1, ax2) = plt.subplots(1, 3, figsize=(18, 6))
plt.subplots_adjust(wspace=0.25)
# visualize the signal field
im0 = ax0.imshow(signal, vmin=-max(-signal.min(),signal.max()), vmax=max(-signal.min(),signal.max()), origin='lower', cmap=planck)
ax0.set_title('Signal')
divider = make_axes_locatable(ax0)
cax0 = divider.append_axes("right", size="5%", pad=0.1)
cbar0 = fig.colorbar(im0, cax=cax0)
# visualize the noise field
vmin, vmax = np.min(noise_v), np.max(noise_v)
linthresh = 1e-5
linscale = 1.0
norm = SymLogNorm(linthresh=linthresh, linscale=linscale,
vmin=-max(-vmin,vmax), vmax=max(-vmin,vmax), base=10)
cmap = plt.get_cmap('PiYG')
cmap.set_bad('C7') # color of missing pixels
im1 = ax1.imshow(noise_v, cmap=cmap, norm=norm, origin='lower')
ax1.set_title('Noise')
divider = make_axes_locatable(ax1)
cax1 = divider.append_axes("right", size="5%", pad=0.1)
cbar1 = fig.colorbar(im1, cax=cax1)
# visualize the data field
im2 = ax2.imshow(data_v, vmin=-max(data_v.min(),data_v.max()), vmax=max(data_v.min(),data_v.max()), origin='lower', cmap=planck)
ax2.set_title('Data')
divider = make_axes_locatable(ax2)
cax2 = divider.append_axes("right", size="5%", pad=0.1)
cbar2 = fig.colorbar(im2, cax=cax2)
plt.savefig(dir+'signal_noise_data.pdf',dpi=300,bbox_inches="tight")
plt.savefig(dir+'signal_noise_data.png',dpi=300,bbox_inches="tight",transparent=True)
plt.show()
Covariance of the Wiener Filter: \begin{equation} \mathrm{Cov}_\mathrm{WF} = (N^{-1}+S^{-1})^{-1} = S^{1/2}(I+S^{1/2}N^{-1}S^{1/2})^{-1}S^{1/2} \end{equation}
M=np.identity(N*N)+sqrtS.dot(invN).dot(sqrtS)
CovWF=sqrtS.dot(np.linalg.inv(M)).dot(sqrtS)
CovWF=(CovWF+CovWF.T)/2 # for numerical stability reasons
CovWF=CovWF.real # get rid of the imaginary part, which is only numerical noise
sqrtCovWF=la.sqrtm(CovWF)
sqrtCovWF=np.real(sqrtCovWF)
# visualize the covariance matrix of the Wiener filter
vmin, vmax = np.min(CovWF), np.max(CovWF)
fig, ax = plt.subplots(figsize=(6,6))
im = ax.imshow(CovWF, cmap='BrBG', vmin=-max(-vmin,vmax), vmax=max(-vmin,vmax))
ax.set_title('Covariance matrix of the Wiener filter')
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.1)
cbar = fig.colorbar(im, cax=cax, format="%.0e")
plt.tight_layout()
plt.savefig(dir+'covariance_matrix_wiener_filter.pdf',dpi=300,bbox_inches="tight")
plt.savefig(dir+'covariance_matrix_wiener_filter.png',dpi=300,bbox_inches="tight",transparent=True)
plt.show()
Mean of the Wiener posterior: \begin{equation} s_\mathrm{WF} = \mathrm{Cov}_\mathrm{WF} N^{-1} d \end{equation}
# compute the posterior mean of the Wiener filter
sWF = CovWF.dot(invN).dot(data.reshape(N*N)).reshape(N,N)
# compute the residual field
residual = np.abs(signal - sWF)
residual_v = np.ma.masked_where(mask, residual) # for visualization purposes only
# visualize the posterior mean field of the Wiener filter
fig, ax = plt.subplots(figsize=(6,6))
im = ax.imshow(sWF, vmin=-max(-sWF.min(),sWF.max()), vmax=max(-sWF.min(),sWF.max()), origin='lower', cmap=planck)
ax.set_title('Signal reconstruction')
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.1)
cbar = fig.colorbar(im, cax=cax)
plt.tight_layout()
plt.savefig(dir+'signal_reconstruction.pdf',dpi=300,bbox_inches="tight")
plt.savefig(dir+'signal_reconstruction.png',dpi=300,bbox_inches="tight",transparent=True)
plt.show()
# visualize the residual field
fig, ax = plt.subplots(figsize=(6,6))
im = ax.imshow(residual,
norm=LogNorm(vmin=noise_variance_field.min(), vmax=noise_variance_field.max()), origin='lower', cmap="Greys")
ax.set_title('Residual')
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.1)
cbar = fig.colorbar(im, cax=cax)
plt.tight_layout()
plt.savefig(dir+'residual.pdf',dpi=300,bbox_inches="tight")
plt.savefig(dir+'residual.png',dpi=300,bbox_inches="tight",transparent=True)
plt.show()
fig, (ax0, ax1, ax2) = plt.subplots(1, 3, figsize=(18, 6))
plt.subplots_adjust(wspace=0.25)
# visualize the signal field
im0 = ax0.imshow(signal, vmin=-max(-signal.min(),signal.max()), vmax=max(-signal.min(),signal.max()), origin='lower', cmap=planck)
ax0.set_title('Signal')
divider = make_axes_locatable(ax0)
cax0 = divider.append_axes("right", size="5%", pad=0.1)
cbar0 = fig.colorbar(im0, cax=cax0)
# visualize the posterior mean field of the Wiener filter
im1 = ax1.imshow(sWF, vmin=-max(-signal.min(),signal.max()), vmax=max(-signal.min(),signal.max()), origin='lower', cmap=planck)
ax1.set_title('Signal reconstruction')
divider = make_axes_locatable(ax1)
cax1 = divider.append_axes("right", size="5%", pad=0.1)
cbar1 = fig.colorbar(im1, cax=cax1)
# visualize the residual field
im2 = ax2.imshow(residual,
norm=LogNorm(vmin=noise_variance_field.min(), vmax=noise_variance_field.max()), origin='lower', cmap="Greys")
ax2.set_title('Residual')
divider = make_axes_locatable(ax2)
cax2 = divider.append_axes("right", size="5%", pad=0.1)
cbar2 = fig.colorbar(im2, cax=cax2)
plt.savefig(dir+'signal_reconstruction_residual.pdf',dpi=300,bbox_inches="tight")
plt.savefig(dir+'signal_reconstruction_residual.png',dpi=300,bbox_inches="tight",transparent=True)
plt.show()
Samples of the Wiener posterior: \begin{equation} s=s_\mathrm{WF}+\sqrt{C_\mathrm{WF}} \, G(0,1) \end{equation} so that $\left\langle s \right\rangle = s_\mathrm{WF}$ and $\mathrm{Cov}(s) = C_\mathrm{WF}$
def generate_constrained_realization(sWF, sqrtCovWF):
return (sqrtCovWF.dot(np.random.normal(size=(N*N)))).real.reshape((N,N))+sWF
cr1=generate_constrained_realization(sWF, sqrtCovWF)
cr2=generate_constrained_realization(sWF, sqrtCovWF)
cr3=generate_constrained_realization(sWF, sqrtCovWF)
fig, ((ax0, ax1, ax2), (ax3, ax4, ax5)) = plt.subplots(2, 3, figsize=(18, 12))
plt.subplots_adjust(hspace=0., wspace=0.25)
vmin=-max(-cr1.min(),-cr2.min(),-cr3.min(),cr1.max(),cr2.max(),cr3.max())
vmax=max(-cr1.min(),-cr2.min(),-cr3.min(),cr1.max(),cr2.max(),cr3.max())
# visualize a constrained realization
im0 = ax0.imshow(cr1, vmin=vmin, vmax=vmax, origin='lower', cmap=planck)
ax0.set_title('Constrained realization')
divider = make_axes_locatable(ax0)
cax0 = divider.append_axes("right", size="5%", pad=0.1)
cbar0 = fig.colorbar(im0, cax=cax0)
# visualize a constrained realization
im1 = ax1.imshow(cr2, vmin=vmin, vmax=vmax, origin='lower', cmap=planck)
ax1.set_title('Constrained realization')
divider = make_axes_locatable(ax1)
cax1 = divider.append_axes("right", size="5%", pad=0.1)
cbar1 = fig.colorbar(im1, cax=cax1)
# visualize a constrained realization
im2 = ax2.imshow(cr3, vmin=vmin, vmax=vmax, origin='lower', cmap=planck)
ax2.set_title('Constrained realization')
divider = make_axes_locatable(ax2)
cax2 = divider.append_axes("right", size="5%", pad=0.1)
cbar2 = fig.colorbar(im2, cax=cax2)
# visualize a residual field
im3 = ax3.imshow(np.abs(signal - cr1),
norm=LogNorm(vmin=noise_variance_field.min(), vmax=noise_variance_field.max()), origin='lower', cmap="Greys")
ax3.set_title('Residual')
divider = make_axes_locatable(ax3)
cax3 = divider.append_axes("right", size="5%", pad=0.1)
cbar3 = fig.colorbar(im3, cax=cax3)
# visualize a residual field
im4 = ax4.imshow(np.abs(signal - cr2),
norm=LogNorm(vmin=noise_variance_field.min(), vmax=noise_variance_field.max()), origin='lower', cmap="Greys")
ax4.set_title('Residual')
divider = make_axes_locatable(ax4)
cax4 = divider.append_axes("right", size="5%", pad=0.1)
cbar4 = fig.colorbar(im4, cax=cax4)
# visualize a residual field
im5 = ax5.imshow(np.abs(signal - cr3),
norm=LogNorm(vmin=noise_variance_field.min(), vmax=noise_variance_field.max()), origin='lower', cmap="Greys")
ax5.set_title('Residual')
divider = make_axes_locatable(ax5)
cax5 = divider.append_axes("right", size="5%", pad=0.1)
cbar5 = fig.colorbar(im5, cax=cax5)
plt.savefig(dir+'constrained_realizations.pdf',dpi=300,bbox_inches="tight")
plt.savefig(dir+'constrained_realizations.png',dpi=300,bbox_inches="tight",transparent=True)
plt.show()
def generate_constrained_realizations(sWF, sqrtCovWF, n):
return np.array([generate_constrained_realization(sWF, sqrtCovWF) for _ in range(n)])
samples = generate_constrained_realizations(sWF, sqrtCovWF, n=1000)
empirical_mean = np.mean(samples, axis=0)
empirical_var = np.var(samples, axis=0)
fig, (ax0, ax1, ax2) = plt.subplots(1, 3, figsize=(18, 6))
plt.subplots_adjust(wspace=0.25)
# visualize the signal field
im0 = ax0.imshow(signal, vmin=-max(-signal.min(),signal.max()), vmax=max(-signal.min(),signal.max()), origin='lower', cmap=planck)
ax0.set_title('Signal')
divider = make_axes_locatable(ax0)
cax0 = divider.append_axes("right", size="5%", pad=0.1)
cbar0 = fig.colorbar(im0, cax=cax0)
# visualize the empirical mean of Wiener filter samples
im1 = ax1.imshow(empirical_mean, vmin=-max(-signal.min(),signal.max()), vmax=max(-signal.min(),signal.max()), origin='lower', cmap=planck)
ax1.set_title('Empirical mean of samples')
divider = make_axes_locatable(ax1)
cax1 = divider.append_axes("right", size="5%", pad=0.1)
cbar1 = fig.colorbar(im1, cax=cax1)
# visualize the empirical variance of Wiener filter samples
im2 = ax2.imshow(empirical_var,
norm=LogNorm(vmin=empirical_var.min(), vmax=empirical_var.max()), origin='lower', cmap="Greys")
ax2.set_title('Empirical variance of samples')
divider = make_axes_locatable(ax2)
cax2 = divider.append_axes("right", size="5%", pad=0.1)
cbar2 = fig.colorbar(im2, cax=cax2)
plt.savefig(dir+'signal_mean_variance_samples.pdf',dpi=300,bbox_inches="tight")
plt.savefig(dir+'signal_mean_variance_samples.png',dpi=300,bbox_inches="tight",transparent=True)
plt.show()
Generate an animation of constrained realizations.
def create_cr_animation(samples, fname=dir+"constrained_realizations.mp4", fps=5, frame_dir=dir+"frames/"):
import imageio.v2 as imageio
import os
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
# Make output folder if doesn't exist
os.makedirs(frame_dir, exist_ok=True)
frame_files = []
num_frames = len(samples)
for i in range(num_frames):
# Generate constrained realization
realization = samples[i]
frame_file = os.path.join(frame_dir, f"frame_{i:03d}.png")
if not os.path.exists(frame_file):
fig, axs = plt.subplots(1, 3, figsize=(19, 6), dpi=320)
plt.subplots_adjust(left=0.02, right=0.97, top=0.97, bottom=0.03, wspace=0.30)
# visualize the signal field
signal_range = max(-signal.min(), signal.max())
im0 = axs[0].imshow(signal, vmin=-signal_range, vmax=signal_range, origin='lower', cmap=planck)
axs[0].set_title('Signal')
divider = make_axes_locatable(axs[0])
cax0 = divider.append_axes("right", size="5%", pad=0.05)
cbar0 = fig.colorbar(im0, cax=cax0)
# visualize the data field
data_range = max(data_v.min(), data_v.max())
im1 = axs[1].imshow(data_v, vmin=-data_range, vmax=data_range, origin='lower', cmap=planck)
axs[1].set_title('Data')
divider = make_axes_locatable(axs[1])
cax1 = divider.append_axes("right", size="5%", pad=0.05)
cbar1 = fig.colorbar(im1, cax=cax1)
# visualize the constrained realization
im2 = axs[2].imshow(realization, vmin=-signal_range, vmax=signal_range, origin='lower', cmap=planck)
axs[2].set_title('Constrained realizations')
divider = make_axeps_locatable(axs[2])
cax2 = divider.append_axes("right", size="5%", pad=0.05)
cbar2 = fig.colorbar(im2, cax=cax2, extend='both')
# Save frame
fig.savefig(frame_file,dpi=320)
plt.close(fig)
frame_files.append(frame_file)
# Write animated gif
with imageio.get_writer(fname, fps=fps) as writer:
for filename in frame_files:
image = imageio.imread(filename)
writer.append_data(image)
print(f"Animation saved as {fname}")
create_cr_animation(samples, fps=5)
/home/leclercq/.local/apps/anaconda/python3.10_2023.03/install/lib/python3.10/subprocess.py:1780: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock. self.pid = _posixsubprocess.fork_exec(
Animation saved as ./plots/Wiener_filter_denoising/constrained_realizations.mp4