Plug-and-play (PnP) is an optimization framework that integrates modern denoising priors into gradient or proximal descent schemes.
The goal of this notebook is to use PnP methods to solve image restoration problems (interpolation of missing pixels, denoising, or deblurring for instance). Typically, these inverse problems can be written $$y = Mx+\eta$$ where $x$ is an unknown image that we wish to reconstruct, $y$ is the observed image, $M$ is a linear degradation operator and $\eta$ is a realization of a known noise distribution (for instance Gaussian). The discrete images $x$ and $y$ can be seen as vectors of size $n$ and $m$ (for instance by reading them columnwise), and $M$ is a $m\times n$ matrix. For instance, for an inverse problem where pixel values are missing, $M$ is a diagonal matrix with $0$ (hidden pixels) and $1$ (known pixels) on the diagonal.
To estimate $x$, when the noise is Gaussian, a classical approach consists in minimizing an energy of the form
$$ \frac{1}{2\sigma^2}\|Mx - y \|^2 + U(x)$$
where the second term is a regularization term on $x$.
In a Bayesian framework, this energy can be interpreted as
$$-\log p(x|y) = -\log p(y|x)-\log p(x),$$ where $p(y|x) \propto e^{-\frac{1}{2\sigma^2}\|Mx - y \|^2}$ is the likelihood of $y$ knowing $x$ and a prior distribution $p(x) \propto e^{-U(x)}$ is assumed on the unknown $x$.
Typicall optimization schemes used to minimize the previous expression involve the gradient or proximal operator of $U$. The principle of PnP approaches is to derive an approximation of $\nabla U$ or $prox_U$ using an image denoiser, for instance a denoising deep neural network. This approximation can then be used with any scheme using gradient or proximal descent on $U$ for optimization or sampling.
This practical session explains and shows inpainting et deblurring experiments based on the code of the following paper:
References:
Authors of the notebook:
Below is a list of packages needed. PyTorch version used to run this notebook is 1.11.0+cu113 (to check the installed version, use torch.__version__
)
numpy
matplotlib.pyplot
(display of images and graphics)torch
(use cuda with PyTorch)time
(measure time)os
(interactions with the operating system)import numpy as np
import torch
import time
import matplotlib.pyplot as plt
import os
%matplotlib inline
To import the solutions, execute the following cell. If you are using a Windows system, comment the wget
line, download the file by hand, and place it in the same folder as the notebook.
#!wget -nc https://raw.githubusercontent.com/storimaging/Notebooks/main/Restoration/Solutions/PnP.py
#from PnP import *
Next, we need to choose which device to run the algorithm on. Running the algorithm on large images takes longer and will go much faster when running on a GPU. We can use torch.cuda.is_available()
to detect if there is a GPU available. Next, we set the torch.device
. The .to(device)
method is used to move tensors or modules to a desired device, we will use it in next sections.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device is", device)
Device is cuda
In this notebook, for the sake of simplicity (and computing time!), we will use a pre-trained denoising networks. The load_denoiser
function helps to load the model available:
In AuxiliarFunctions_PnP
we define the network, taken from the original code.
#@title
!wget -nc https://raw.githubusercontent.com/storimaging/Notebooks/main/Restoration/AuxiliarFunctions/AuxiliarFunctions_PnP.py
!wget -nc https://raw.githubusercontent.com/storimaging/Notebooks/main/Restoration/AuxiliarFunctions/network_unet.py
!wget -nc https://raw.githubusercontent.com/storimaging/Notebooks/main/Restoration/AuxiliarFunctions/basicblock.py
from AuxiliarFunctions_PnP import *
model_type = "DnCNN"
model_name = "RealSN_DnCNN_noise5.pth"
model = load_denoiser(model_type, model_name, device)
In the following, we will compare several algorithms, in terms of computing time and performance. We will use the PSNR (defined below) to measure restoration performance. The higher the value, the better the result.
def PSNR(image_u, image_denoised, peak=1):
return 10*np.log10(peak**2/np.mean((image_u-image_denoised)**2))
In the next cell we will define some helper functions for displaying and creating degraded images. We will also define the Denoiser function to be used by all PnP implementations.
#@title
# Function to display two images
def printImages(a, b, title_a, title_b, size1, size2):
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(size1, size1))
axes[0].set_title(title_a)
axes[0].imshow(a,cmap='gray',vmin=0,vmax=1)
axes[1].set_title(title_b)
axes[1].imshow(b,cmap='gray',vmin=0,vmax=1)
fig.tight_layout()
plt.show()
# Display time, PSNR and images
def displayResults(image_denoised, v, title_a, title_b, size1, size2, time, original):
print("time: ", time)
print("PSNR: ", PSNR(original, image_denoised))
printImages(image_denoised, v, title_a, title_b, 10, 10)
# Create blurred image with periodic kernel
def CreateBlurredImage(u_orig, sigma,s):
n_Rows, n_Cols = u_orig.shape
b = sigma*np.random.randn(n_Rows,n_Cols)
# Definition of the blurring kernel l_h
k_uniform = np.zeros((n_Rows,n_Cols))
k_uniform[0:2*s+1,0:2*s+1] = np.ones((2*s+1, 2*s+1))/(2*s+1)**2
image = np.real(np.fft.ifft2(np.fft.fft2(u_orig)*np.fft.fft2(k_uniform))) + b
return image, k_uniform
# Create image with missing pixels and noise
def CreateImageWithMissingPixels(u_orig, sigma,p):
nrow,ncol = u_orig.shape
# Create noise
b = sigma*np.random.randn(nrow,ncol)
# Create mask
mask = np.random.rand(nrow,ncol)>p
image = mask*u_orig + b
return image, mask
As explained above, we wish to solve $$ \arg\min_x f(x) + \alpha U(x)$$
where,
For convenience we can rewrite the minimization problem as: $$\arg\min_x \frac{1}{\alpha} f(x) + U(x).$$
Recall the alternating directions method of multipliers (ADMM) optimization to solve the previous problem:
$$ x^{k+1} =Prox_{\frac \epsilon \alpha f}(y^k −u^k) \\ y^{k+1} = Prox_{\epsilon U} (x^{k+1} + u^k ) \\ u^{k+1} = u^k + x^{k+1} − y^{k+1} \\ $$PnP-ADMM simply replaces the proximal operator $Prox_{\epsilon U}$ with the denoiser $D_{\epsilon}$, where $\epsilon$ is the noise parameter (variance) of the denoiser:
$$ x^{k+1} =Prox_{\frac \epsilon \alpha f}(y^k −u^k) \\ y^{k+1} = D_{\epsilon} (x^{k+1} + u^k ) \\ u^{k+1} = u^k + x^{k+1} − y^{k+1}. \\ $$Write a function pnp_admm
to implement the algorithm, with the following parameters:
def pnp_admm(noisy, denoiser, proximal_step, **opts):
"""
Parameters:
:noisy - the noisy observation.
:denoiser - the Gaussian denoiser used in Plug-and-Play ADMM.
:proximal_step - the function which implements the proximal step of the ADMM algorithm.
:opts - the kwargs for hyperparameters in Plug-and-Play ADMM.
"""
The forward-backward splitting (FBS) optimization scheme can be written:
$ x^{k+1} = Prox_{\epsilon U}(I - \frac \epsilon \alpha \nabla f)(x^k) $
PNP-FBS simply replaces the proximal operator $Prox_{\epsilon U}$ with the denoiser $D_{\epsilon}$:
$ x^{k+1} = D_{\epsilon}(I - \frac \epsilon \alpha \nabla f)(x^k) $
Write a function pnp_fbs
to implement the algorithm, with the following parameters:
def pnp_fbs(noisy, denoiser, proximal_step, **opts):
"""
Parameters:
:noisy - the noisy observation.
:denoiser - the Gaussian denoiser used in Plug-and-Play FBS.
:gradient_step - the function which implements the gradient step: x- alpha*grad(f)
:opts - the kwargs for hyperparameters in Plug-and-Play FBS.
"""
A fully proximal version of the algorithm, called backward-backward splitting (BBS) writes:
$ x^{k+1} = Prox_{\epsilon U}(Prox_{\frac \epsilon \alpha f})(x^k) $
PNP-BBS replaces the proximal operator $Prox_{\epsilon U}$ with the denoiser $ D_{\epsilon}$:
$ x^{k+1} = D_{\epsilon}(Prox_{\frac \epsilon \alpha f})(x^k) $
Write a function pnp_fbs
to implement the algorithm, with the following parameters:
def pnp_bbs(noisy, denoiser, proximal_step, **opts):
"""
Parameters:
:noisy - the noisy observation.
:denoiser - the Gaussian denoiser used in Plug-and-Play BBS.
:proximal_step - the function which implements the proximal step of the FBS algorithm.
:opts - the kwargs for hyperparameters in Plug-and-Play BBS.
"""
Remark: it can be shown that PNP-FBS and PNP-ADMM are different methods for finding the same set of fixed points. Sometimes, PNP-FBS is easier to implement since it only requires the computation of $\nabla f$ rather than $Prox_{ f}$ . On the other hand, PNP-ADMM is generally more robust with better convergence properties. BBS aims at solving a slightly modified version of the original problem where $f$ is replaced by its Moreau envelope.
We will now use the previous algorithms to solve two inverse problems in imaging: deblurring and inpainting.
In this section, experiments will be carried out to recover images with blur and additive noise. In the next cell we will define the algorithm's hyperparameters, load the denoiser model to use, and create a blurred image. In all cases we will use the denoising model developed by the authors (RealSN_DnCNN) with sigma level of 5.
# Read image
os.system("wget -nc https://raw.githubusercontent.com/storimaging/Images/main/img/simpson_nb512.png")
u_orig = plt.imread("simpson_nb512.png")
u_orig = u_orig[128:256,128:256]
# start by working with a small image to experiment, before using the whole image
# Define hyperparameters
sigma = 1/255 #additive noise
# Create blurred image
blurredImage, H = CreateBlurredImage(u_orig, sigma,s=4)
# Display the degraded image
printImages(u_orig, blurredImage, 'original image', 'blurred image', 10, 10)
To use the previous algorithms, we need to define the gradient and proximal steps for the data term of the deblurring problem. The data term can be written $$\frac 1 \alpha f(x) = \frac{|| Mx - y ||^2}{2 \alpha\sigma^2}$$ with,
Clearly, $$\nabla f(x) = M^*\frac{Mx -y}{\sigma^2}$$ and $$\mathrm{prox}_{\frac \epsilon \alpha f}(x) = (Id+\frac \epsilon {\alpha\sigma^2} M^* M)^{-1}\left(x +\frac \epsilon {\alpha\sigma^2} M^* y\right).$$
The linear operator $M$ represents a convolution by a kernel $h$. It can be applied directly in the Fourier domain as a multiplication. The adjoint $M^*$ is the convolution with the conjugate of $h$.
# Define gradient step
def grad_deblurring(x, im_b, **opts):
"""
Gradient Operator for Gaussian deblurring:
f(x) = || A.x - im_b ||^2 / (2 sigma^2)
with A.x = h*x
Parameters:
:x - the argument to the proximal operator.
:im_b - the noisy observation.
:h - blurring kernel.
:sigma - the standard deviation of the gaussian noise in im_b.
:alpha - the regularization parameter.
"""
# Process Parameters
alpha = opts.get('alpha', 1.)
sigma = opts.get('sigma', 1/255)
sigma_model = opts.get('sigma_model', 5)
h = opts.get('h', np.ones_like(im_b))
n_row = opts.get('n_row', 1)
n_col = opts.get('n_col', 1)
epsilon = (sigma_model/255)**2
# Reshapes
x_shapeOrig = np.reshape((x) , (n_row, n_col))
im_b_shapeOrig = np.reshape((im_b) , (n_row, n_col))
# Computes gradient step
a = epsilon/(alpha*sigma**2)
h_fft = np.fft.fft2(h)
hc_fft = np.conj(h_fft)
x_fft = np.fft.fft2(x_shapeOrig)
grad = x_shapeOrig -a*np.real(np.fft.ifft2(hc_fft *(h_fft*x_fft - np.fft.fft2(im_b_shapeOrig))))
# Reshape to flat to return
return np.reshape(grad, -1)
def prox_deblurring(x, im_b, **opts):
"""
Proximal Operator for Gaussian deblurring:
f(x) = || A.x - im_b ||^2 / (2 sigma^2)
avec A.x = h*x
prox_{alpha f} (x[i]) = 1/(1+alpha/sigma^2*h_fft*hc_fft)*(alpha/sigma^2 A.T im_b[i] +x[i])
Parameters:
:x - the argument to the proximal operator.
:im_b - the noisy observation.
:h - blurring kernel.
:sigma - the standard deviation of the gaussian noise in im_b.
:alpha - the regularization parameter.
"""
# Process parameters
alpha = opts.get('alpha', 1.)
sigma = opts.get('sigma', 1/255)
sigma_model = opts.get('sigma_model', 5)
h = opts.get('h', np.ones_like(im_b))
n_row = opts.get('n_row', 1)
n_col = opts.get('n_col', 1)
epsilon = (sigma_model/255)**2
# Reshapes
x_shapeOrig = np.reshape((x) , (n_row, n_col))
im_b_shapeOrig = np.reshape((im_b) , (n_row, n_col))
# Computes proximal step
a = epsilon/(alpha*sigma**2)
h_fft = np.fft.fft2(h)
hc_fft = np.conj(h_fft)
X = a*np.real(np.fft.ifft2(hc_fft*np.fft.fft2(im_b_shapeOrig))) + x_shapeOrig
prox = np.real(np.fft.ifft2(np.fft.fft2(X)/(a*h_fft*hc_fft + 1)))
# Reshape to flat to return
return np.reshape(prox, -1)
In the following cell we will define the function that implements the proximal step needed by PnP ADMM. Next, we will run the algorithm and display the result.
# Define hyperparameters
alpha = 0.6
maxitr = 100
sigma_model = 5 #(corresponds to sqrt(epsilon)*255)
# Kwargs for PnP algorithm
opts = dict(alpha=alpha, maxitr=maxitr, sigma_model=sigma_model, sigma=sigma, h=H, n_row=u_orig.shape[0], n_col=u_orig.shape[1], noise_level_map = (model_type == "DRUnet"), device = device)
# Run the algorithm
with torch.no_grad():
start = time.time()
out = pnp_admm(blurredImage, model, prox_deblurring, **opts)
end = time.time()
# Display results
displayResults(out, blurredImage, 'restored image with PnP-ADMM', 'blurred image', 10, 10, end - start, u_orig)
time: 6.403290033340454 PSNR: 33.85985801464729
# Run the algorithm
with torch.no_grad():
start = time.time()
out = pnp_fbs(blurredImage, model, prox_deblurring, **opts)
end = time.time()
# Display results
displayResults(out, blurredImage, 'restored image with PnP-FBS', 'blurred image', 10, 10, end - start, u_orig)
time: 0.6459488868713379 PSNR: 33.52611018590379
# Run the algorithm
with torch.no_grad():
start = time.time()
out = pnp_bbs(blurredImage, model, prox_deblurring, **opts)
end = time.time()
# Display results
displayResults(out, blurredImage, 'restored image with PnP-BBS', 'blurred image', 10, 10, end - start, u_orig)
time: 0.6535851955413818 PSNR: 33.52611018590379
Everything works fine in the deblurring experiments above and all algorithms converge in a few dozens of iterations. We will now deal with an inverse problem where pixels are missing, and we will see that the same algorithms require much more iterations to converge.
We create the degraded image.
#Create degraded image with 80% of missing pixels
sigma = 1/255 # noise level of the considered restoration problem
degraded, mask = CreateImageWithMissingPixels(u_orig, sigma,0.8)
# Display the degraded image
printImages(u_orig, degraded, 'original image', 'degraded image', 10, 10)
The following cell contains the gradient and proximal operators for the inpainting problem.
#@title Gradient and proximal operators for inpainting
# Gradient step for inpainting
def grad_inpainting(x, im_b, **opts):
"""
- Gradient Operator for Gaussian inpainting:
f(x) = || M*x - im_b ||^2 / (2 sigma^2)
Parameters:
:x - the argument to the proximal operator.
:im_b - the noisy observation.
:mask - binary image of the same size as x
:sigma - the standard deviation of the gaussian noise in im_b.
:alpha - the regularization parameter.
"""
alpha = opts.get('alpha', 1)
sigma = opts.get('sigma', 1/255)
mask_orig = opts.get('mask', np.ones_like(im_b))
mask = np.reshape(mask_orig, -1)
sigma_model = opts.get('sigma_model', 5)
epsilon = (sigma_model/255)**2
a = epsilon/(alpha*sigma**2)
return x -a*(mask==1)*(x - im_b)
# Proximal step for inpainting
def prox_inpainting(x, im_b, **opts):
"""
Proximal Operator for Gaussian inpainting:
f(x) = || M*x - im_b ||^2 / (2 sigma^2)
prox_{alpha f} (x[i]) = (x[i] + im_b[i]*alpha/sigma^2)/(1+alpha/sigma^2) if M[i]==1
= x[i] if M[i]==0
Parameters:
:x - the argument to the proximal operator.
:im_b - the noisy observation.
:mask - binary image of the same size as x
:sigma - the standard deviation of the gaussian noise in im_b.
:alpha - the regularization parameter.
"""
alpha = opts.get('alpha', 1)
sigma = opts.get('sigma', 1/255)
mask_orig = opts.get('mask', np.ones_like(im_b))
mask = np.reshape(mask_orig, -1)
sigma_model = opts.get('sigma_model', 5)
epsilon = (sigma_model/255)**2
if sigma != 0:
a = epsilon/(alpha*sigma**2)
out = (a*im_b + x)/(a+1)*mask + (1-mask)*x
elif sigma == 0:
out = mask*im_b + (1-mask)*x
return np.copy(out)
# Define the hyperparameters
alpha = 1 # weight of the data term
maxitr = 1500 # number of iterations
# Kwargs for PnP algorithm
opts = dict(alpha=alpha, maxitr=maxitr, sigma_model=sigma_model, sigma=sigma, mask=mask, noise_level_map = (model_type == "DRUnet"), device = device)
3.Display the graph of the PSNR values along the iterations for each algorithm and comment. 4. How can we deal with the case $\sigma = 0$?
# Run the ADMM algorithm
with torch.no_grad():
start = time.time()
out = pnp_admm(degraded, model, prox_inpainting, **opts)
end = time.time()
# Display results
displayResults(out, degraded, 'restored image with PnP-ADMM', 'degraded image', 10, 10, end-start, u_orig)
time: 7.5262439250946045 PSNR: 25.616163385248832
# Run the algorithm
with torch.no_grad():
start = time.time()
out = pnp_fbs(degraded, model, prox_inpainting, **opts)
end = time.time()
# Display results
displayResults(out, degraded, 'restored image with PnP-FBS', 'degraded image', 10, 10, end-start, u_orig)
time: 7.54533052444458 PSNR: 26.745591088677774
# Run the algorithm
with torch.no_grad():
start = time.time()
out = pnp_bbs(degraded, model, prox_inpainting, **opts)
end = time.time()
# Display results
displayResults(out, degraded, 'restored image with PnP-BBS', 'degraded image', 10, 10, end-start, u_orig)
time: 7.617080450057983 PSNR: 26.745591088677774