This example demonstrates solution of a tomographic reconstruction problem using the Plug-and-Play Priors framework , using BM3D as a denoiser and SVMBIR for tomographic projection.
There are two versions of this example, solving the same problem in two
different ways. This version uses the data fidelity term as one of the
ADMM $g$ functionals so that the optimization with respect to the data
fidelity is able to exploit the internal prox of the SVMBIRExtendedLoss
and SVMBIRSquaredL2Loss
functionals. The
other version solves the ADMM subproblem
corresponding to the data fidelity term via CG.
Two ways of exploiting the SVMBIR internal prox are explored in this example:
SVMBIRSquaredL2Loss
together with the BM3D pseudo-functional
and a non-negative indicator function, andSVMBIRExtendedLoss
, which includes a non-negativity
constraint, together with the BM3D pseudo-functional.import numpy as np
import matplotlib.pyplot as plt
import svmbir
from matplotlib.ticker import MaxNLocator
from xdesign import Foam, discrete_phantom
import scico.numpy as snp
from scico import metric, plot
from scico.functional import BM3D, NonNegativeIndicator
from scico.linop import Diagonal, Identity
from scico.linop.xray.svmbir import (
SVMBIRExtendedLoss,
SVMBIRSquaredL2Loss,
XRayTransform,
)
from scico.optimize.admm import ADMM, LinearSubproblemSolver
from scico.util import device_info
plot.config_notebook_plotting()
Generate a ground truth image.
N = 256 # image size
density = 0.025 # attenuation density of the image
np.random.seed(1234)
x_gt = discrete_phantom(Foam(size_range=[0.05, 0.02], gap=0.02, porosity=0.3), size=N - 10)
x_gt = x_gt / np.max(x_gt) * density
x_gt = np.pad(x_gt, 5)
x_gt[x_gt < 0] = 0
Generate tomographic projector and sinogram.
num_angles = int(N / 2)
num_channels = N
angles = snp.linspace(0, snp.pi, num_angles, endpoint=False, dtype=snp.float32)
A = XRayTransform(x_gt.shape, angles, num_channels)
sino = A @ x_gt
Impose Poisson noise on sinogram. Higher max_intensity means less noise.
max_intensity = 2000
expected_counts = max_intensity * np.exp(-sino)
noisy_counts = np.random.poisson(expected_counts).astype(np.float32)
noisy_counts[noisy_counts == 0] = 1 # deal with 0s
y = -np.log(noisy_counts / max_intensity)
Reconstruct using default prior of SVMBIR .
weights = svmbir.calc_weights(y, weight_type="transmission")
x_mrf = svmbir.recon(
np.array(y[:, np.newaxis]),
np.array(angles),
weights=weights[:, np.newaxis],
num_rows=N,
num_cols=N,
positivity=True,
verbose=0,
)[0]
Convert numpy arrays to jax arrays.
y = snp.array(y)
x0 = snp.array(x_mrf)
weights = snp.array(weights)
Set problem parameters and BM3D pseudo-functional.
ρ = 10 # ADMM penalty parameter
σ = density * 0.26 # denoiser sigma
g0 = σ * ρ * BM3D()
Set up problem using SVMBIRSquaredL2Loss
and NonNegativeIndicator
.
f_l2loss = SVMBIRSquaredL2Loss(
y=y, A=A, W=Diagonal(weights), scale=0.5, prox_kwargs={"maxiter": 5, "ctol": 0.0}
)
g1 = NonNegativeIndicator()
solver_l2loss = ADMM(
f=None,
g_list=[f_l2loss, g0, g1],
C_list=[Identity(x_mrf.shape), Identity(x_mrf.shape), Identity(x_mrf.shape)],
rho_list=[ρ, ρ, ρ],
x0=x0,
maxiter=20,
subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-3, "maxiter": 100}),
itstat_options={"display": True, "period": 5},
)
Run the ADMM solver.
print(f"Solving on {device_info()}\n")
x_l2loss = solver_l2loss.solve()
hist_l2loss = solver_l2loss.itstat_object.history(transpose=True)
Solving on GPU (NVIDIA GeForce RTX 2080 Ti) Iter Time Prml Rsdl Dual Rsdl CG It CG Res ------------------------------------------------------ 0 3.67e+00 5.375e-01 5.375e-01 0 0.000e+00 1 7.05e+00 3.687e-01 3.230e-01 1 3.913e-09 2 1.07e+01 2.416e-01 2.150e-01 1 2.928e-09 3 1.32e+01 1.573e-01 1.827e-01 1 2.418e-09 4 1.66e+01 1.189e-01 1.367e-01 1 1.673e-09 5 2.02e+01 1.015e-01 9.357e-02 1 2.187e-09 6 2.40e+01 8.631e-02 6.378e-02 1 6.341e-10 7 2.77e+01 6.855e-02 5.047e-02 1 5.304e-10 8 3.09e+01 5.318e-02 4.421e-02 1 5.568e-10 9 3.40e+01 4.378e-02 3.729e-02 1 3.857e-10 10 3.75e+01 3.905e-02 2.961e-02 1 3.466e-10 11 4.10e+01 3.554e-02 2.390e-02 1 2.640e-10 12 4.45e+01 3.225e-02 2.117e-02 1 4.761e-10 13 4.84e+01 2.893e-02 1.961e-02 1 1.834e-10 14 5.15e+01 2.672e-02 1.795e-02 1 1.972e-10 15 5.53e+01 2.457e-02 1.615e-02 1 3.801e-10 16 5.83e+01 2.318e-02 1.501e-02 1 1.520e-10 17 6.19e+01 2.172e-02 1.407e-02 1 1.273e-10 18 6.52e+01 2.067e-02 1.363e-02 1 1.397e-10 19 6.89e+01 1.981e-02 1.360e-02 1 1.330e-10
Set up problem using SVMBIRExtendedLoss
, without need for NonNegativeIndicator
.
f_extloss = SVMBIRExtendedLoss(
y=y,
A=A,
W=Diagonal(weights),
scale=0.5,
positivity=True,
prox_kwargs={"maxiter": 5, "ctol": 0.0},
)
solver_extloss = ADMM(
f=None,
g_list=[f_extloss, g0],
C_list=[Identity(x_mrf.shape), Identity(x_mrf.shape)],
rho_list=[ρ, ρ],
x0=x0,
maxiter=20,
subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-3, "maxiter": 100}),
itstat_options={"display": True, "period": 5},
)
Run the ADMM solver.
print()
x_extloss = solver_extloss.solve()
hist_extloss = solver_extloss.itstat_object.history(transpose=True)
Iter Time Prml Rsdl Dual Rsdl CG It CG Res ------------------------------------------------------ 0 3.01e+00 4.720e-01 4.720e-01 0 0.000e+00 1 6.68e+00 3.471e-01 2.464e-01 1 5.184e-09 2 9.61e+00 1.689e-01 2.116e-01 1 9.142e-10 3 1.36e+01 1.023e-01 1.473e-01 1 4.484e-10 4 1.73e+01 8.577e-02 8.309e-02 1 2.231e-10 5 2.10e+01 6.649e-02 4.819e-02 1 1.205e-10 6 2.46e+01 4.644e-02 3.784e-02 1 6.679e-10 7 2.74e+01 3.438e-02 3.069e-02 1 8.120e-11 8 3.09e+01 2.915e-02 2.295e-02 1 3.865e-10 9 3.37e+01 2.598e-02 1.785e-02 1 4.350e-11 10 3.71e+01 2.280e-02 1.480e-02 1 2.567e-10 11 4.03e+01 2.039e-02 1.298e-02 1 2.823e-11 12 4.36e+01 1.890e-02 1.154e-02 1 1.974e-10 13 4.71e+01 1.743e-02 1.021e-02 1 1.803e-10 14 5.01e+01 1.644e-02 9.463e-03 1 1.599e-10 15 5.33e+01 1.556e-02 8.859e-03 1 1.540e-10 16 5.61e+01 1.488e-02 8.616e-03 1 1.457e-10 17 5.93e+01 1.425e-02 8.511e-03 1 1.452e-10 18 6.21e+01 1.373e-02 8.126e-03 1 1.990e-11 19 6.53e+01 1.316e-02 7.722e-03 1 1.384e-10
Show the recovered images.
norm = plot.matplotlib.colors.Normalize(vmin=-0.1 * density, vmax=1.2 * density)
fig, ax = plt.subplots(2, 2, figsize=(15, 15))
plot.imview(img=x_gt, title="Ground Truth Image", cbar=True, fig=fig, ax=ax[0, 0], norm=norm)
plot.imview(
img=x_mrf,
title=f"MRF (PSNR: {metric.psnr(x_gt, x_mrf):.2f} dB)",
cbar=True,
fig=fig,
ax=ax[0, 1],
norm=norm,
)
plot.imview(
img=x_l2loss,
title=f"SquaredL2Loss + non-negativity (PSNR: {metric.psnr(x_gt, x_l2loss):.2f} dB)",
cbar=True,
fig=fig,
ax=ax[1, 0],
norm=norm,
)
plot.imview(
img=x_extloss,
title=f"ExtendedLoss (PSNR: {metric.psnr(x_gt, x_extloss):.2f} dB)",
cbar=True,
fig=fig,
ax=ax[1, 1],
norm=norm,
)
fig.show()
Plot convergence statistics.
fig, ax = plt.subplots(1, 2, figsize=(15, 5))
plot.plot(
snp.vstack((hist_l2loss.Prml_Rsdl, hist_l2loss.Dual_Rsdl)).T,
ptyp="semilogy",
title="Residuals (SquaredL2Loss + non-negativity)",
xlbl="Iteration",
lgnd=("Primal", "Dual"),
fig=fig,
ax=ax[0],
)
ax[0].set_ylim([5e-3, 5e0])
ax[0].xaxis.set_major_locator(MaxNLocator(integer=True))
plot.plot(
snp.vstack((hist_extloss.Prml_Rsdl, hist_extloss.Dual_Rsdl)).T,
ptyp="semilogy",
title="Residuals (ExtendedLoss)",
xlbl="Iteration",
lgnd=("Primal", "Dual"),
fig=fig,
ax=ax[1],
)
ax[1].set_ylim([5e-3, 5e0])
ax[1].xaxis.set_major_locator(MaxNLocator(integer=True))
fig.show()