This example demonstrates the use of the ADMM Plug and Play Priors (PPP) algorithm , with the BM3D denoiser, for solving a raw image demosaicing problem.
# This scico project Jupyter notebook has been automatically modified
# to install the dependencies required for running it on Google Colab.
# If you encounter any problems in running it, please open an issue at
# https://github.com/lanl/scico-data/issues
!pip install 'scico[examples] @ git+https://github.com/lanl/scico'
import numpy as np
from bm3d import bm3d_rgb
from colour_demosaicing import demosaicing_CFA_Bayer_Menon2007
import scico
import scico.numpy as snp
import scico.random
from scico import functional, linop, loss, metric, plot
from scico.data import kodim23
from scico.optimize.admm import ADMM, LinearSubproblemSolver
from scico.util import device_info
plot.config_notebook_plotting()
Read a ground truth image.
img = snp.array(kodim23(asfloat=True)[160:416, 60:316])
Define demosaicing forward operator and its transpose.
def Afn(x):
"""Map an RGB image to a single channel image with each pixel
representing a single colour according to the colour filter array.
"""
y = snp.zeros(x.shape[0:2])
y = y.at[1::2, 1::2].set(x[1::2, 1::2, 0])
y = y.at[0::2, 1::2].set(x[0::2, 1::2, 1])
y = y.at[1::2, 0::2].set(x[1::2, 0::2, 1])
y = y.at[0::2, 0::2].set(x[0::2, 0::2, 2])
return y
def ATfn(x):
"""Back project a single channel raw image to an RGB image with zeros
at the locations of undefined samples.
"""
y = snp.zeros(x.shape + (3,))
y = y.at[1::2, 1::2, 0].set(x[1::2, 1::2])
y = y.at[0::2, 1::2, 1].set(x[0::2, 1::2])
y = y.at[1::2, 0::2, 1].set(x[1::2, 0::2])
y = y.at[0::2, 0::2, 2].set(x[0::2, 0::2])
return y
Define a baseline demosaicing function based on the demosaicing algorithm of from package colour_demosaicing.
def demosaic(cfaimg):
"""Apply baseline demosaicing."""
return demosaicing_CFA_Bayer_Menon2007(cfaimg, pattern="BGGR").astype(np.float32)
Create a test image by color filter array sampling and adding Gaussian white noise.
s = Afn(img)
rgbshp = s.shape + (3,) # shape of reconstructed RGB image
σ = 2e-2 # noise standard deviation
noise, key = scico.random.randn(s.shape, seed=0)
sn = s + σ * noise
Compute a baseline demosaicing solution.
imgb = snp.array(bm3d_rgb(demosaic(sn), 3 * σ).astype(np.float32))
Set up an ADMM solver object. Note the use of the baseline solution as an initializer. We use BM3D as the denoiser, using the code released with .
A = linop.LinearOperator(input_shape=rgbshp, output_shape=s.shape, eval_fn=Afn, adj_fn=ATfn)
f = loss.SquaredL2Loss(y=sn, A=A)
C = linop.Identity(input_shape=rgbshp)
g = 1.8e-1 * 6.1e-2 * functional.BM3D(is_rgb=True)
ρ = 1.8e-1 # ADMM penalty parameter
maxiter = 12 # number of ADMM iterations
solver = ADMM(
f=f,
g_list=[g],
C_list=[C],
rho_list=[ρ],
x0=imgb,
maxiter=maxiter,
subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-3, "maxiter": 100}),
itstat_options={"display": True},
)
Run the solver.
print(f"Solving on {device_info()}\n")
x = solver.solve()
hist = solver.itstat_object.history(transpose=True)
Solving on GPU (NVIDIA GeForce RTX 2080 Ti) Iter Time Prml Rsdl Dual Rsdl CG It CG Res ------------------------------------------------------ 0 7.66e+00 5.788e+00 2.298e+00 1 1.971e-09 1 1.39e+01 4.773e+00 8.708e-01 2 2.083e-09 2 2.02e+01 3.597e+00 1.122e+00 2 9.364e-10 3 2.59e+01 2.751e+00 1.480e+00 2 1.387e-09 4 3.24e+01 2.231e+00 1.549e+00 2 1.406e-09 5 3.86e+01 1.961e+00 1.283e+00 2 5.103e-10 6 4.42e+01 1.758e+00 9.265e-01 2 6.186e-10 7 5.01e+01 1.400e+00 4.759e-01 1 8.766e-04 8 5.63e+01 1.216e+00 7.355e-01 2 3.677e-10 9 6.24e+01 9.984e-01 7.076e-01 2 2.950e-10 10 6.80e+01 8.883e-01 6.705e-01 2 3.903e-10 11 7.34e+01 6.458e-01 4.391e-01 1 9.664e-04
Show reference and demosaiced images.
fig, ax = plot.subplots(nrows=1, ncols=3, sharex=True, sharey=True, figsize=(21, 7))
plot.imview(img, title="Reference", fig=fig, ax=ax[0])
plot.imview(imgb, title="Baseline demoisac: %.2f (dB)" % metric.psnr(img, imgb), fig=fig, ax=ax[1])
plot.imview(x, title="PPP demoisac: %.2f (dB)" % metric.psnr(img, x), fig=fig, ax=ax[2])
fig.show()
Plot convergence statistics.
plot.plot(
snp.vstack((hist.Prml_Rsdl, hist.Dual_Rsdl)).T,
ptyp="semilogy",
title="Residuals",
xlbl="Iteration",
lgnd=("Primal", "Dual"),
)