This example demonstrates the solution of a non-negative sparse coding problem
$$\mathrm{argmin}_{\mathbf{x}} ; (1/2) | \mathbf{y} - D \mathbf{x} |_2^2
where $D$ the dictionary, $\mathbf{y}$ the signal to be represented, $\mathbf{x}$ is the sparse representation, and $I(\mathbf{x} \geq 0)$ is the non-negative indicator.
# This scico project Jupyter notebook has been automatically modified
# to install the dependencies required for running it on Google Colab.
# If you encounter any problems in running it, please open an issue at
# https://github.com/lanl/scico-data/issues
!pip install 'scico[examples] @ git+https://github.com/lanl/scico'
# This scico project Jupyter notebook has been automatically modified
# to install the dependencies required for running it on Google Colab.
# If you encounter any problems in running it, please open an issue at
# https://github.com/lanl/scico-data/issues
!pip install 'scico[examples] @ git+https://github.com/lanl/scico
'
# This scico project Jupyter notebook has been automatically modified
# to install the dependencies required for running it on Google Colab.
# If you encounter any problems in running it, please open an issue at
# https://github.com/lanl/scico-data/issues
!pip install git+https://github.com/lanl/scico
# This scico project Jupyter notebook has been automatically modified
# to install the dependencies required for running it on Google Colab.
# If you encounter any problems in running it, please open an issue at
# https://github.com/lanl/scico-data/issues
!pip install git+https://github.com/lanl/scico
# This scico project Jupyter notebook has been automatically modified
# to install the dependencies required for running it on Google Colab.
# If you encounter any problems in running it, please open an issue at
# https://github.com/lanl/scico-data/issues
!pip install git+https://github.com/lanl/scico
# This scico project Jupyter notebook has been automatically modified
# to install the dependencies required for running it on Google Colab.
# If you encounter any problems in running it, please open an issue at
# https://github.com/lanl/scico-data/issues
!pip install git+https://github.com/lanl/scico
# This scico project Jupyter notebook has been automatically modified
# to install the dependencies required for running it on Google Colab.
# If you encounter any problems in running it, please open an issue at
# https://github.com/lanl/scico-data/issues
!pip install git+https://github.com/lanl/scico
import numpy as np
import jax
from scico import functional, linop, loss, plot
from scico.optimize.admm import ADMM, MatrixSubproblemSolver
from scico.util import device_info
plot.config_notebook_plotting()
Create random dictionary, reference random sparse representation, and test signal consisting of the synthesis of the reference sparse representation.
m = 32 # signal size
n = 128 # dictionary size
s = 10 # sparsity level
np.random.seed(1)
D = np.random.randn(m, n)
D = D / np.linalg.norm(D, axis=0, keepdims=True) # normalize dictionary
xt = np.zeros(n) # true signal
idx = np.random.randint(low=0, high=n, size=s) # support of xt
xt[idx] = np.random.rand(s)
y = D @ xt + 5e-2 * np.random.randn(m) # synthetic signal
xt = jax.device_put(xt) # convert to jax array, push to GPU
y = jax.device_put(y) # convert to jax array, push to GPU
Set up the forward operator and ADMM solver object.
lmbda = 1e-1
A = linop.MatrixOperator(D)
f = loss.SquaredL2Loss(y=y, A=A)
g_list = [lmbda * functional.L1Norm(), functional.NonNegativeIndicator()]
C_list = [linop.Identity((n)), linop.Identity((n))]
rho_list = [1.0, 1.0]
maxiter = 100 # number of ADMM iterations
solver = ADMM(
f=f,
g_list=g_list,
C_list=C_list,
rho_list=rho_list,
x0=A.adj(y),
maxiter=maxiter,
subproblem_solver=MatrixSubproblemSolver(),
itstat_options={"display": True, "period": 10},
)
Run the solver.
print(f"Solving on {device_info()}\n")
x = solver.solve()
Solving on GPU (NVIDIA GeForce RTX 2080 Ti) Iter Time Objective Prml Rsdl Dual Rsdl CG It CG Res ----------------------------------------------------------------- 0 1.47e+00 2.810e+00 1.435e+00 4.750e+00 7 5.959e-05 10 2.48e+00 4.879e-01 3.590e-02 6.160e-02 6 3.352e-05 20 2.75e+00 4.753e-01 1.017e-02 2.036e-02 5 3.832e-05 30 3.02e+00 4.737e-01 3.005e-03 7.686e-03 4 4.805e-05 40 3.28e+00 4.732e-01 1.731e-03 3.707e-03 3 5.594e-05 50 3.53e+00 4.731e-01 5.812e-04 8.044e-04 2 7.877e-05 60 3.75e+00 4.730e-01 9.734e-05 0.000e+00 0 7.698e-05 70 4.00e+00 4.730e-01 5.895e-05 6.213e-05 1 4.818e-05 80 4.19e+00 4.730e-01 3.793e-05 0.000e+00 0 8.337e-05 90 4.42e+00 4.730e-01 2.943e-05 0.000e+00 0 5.663e-05 99 4.62e+00 4.730e-01 6.007e-05 0.000e+00 0 5.957e-05
Plot the recovered coefficients and signal.
fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5))
plot.plot(
np.vstack((xt, solver.x)).T,
title="Coefficients",
lgnd=("Ground Truth", "Recovered"),
fig=fig,
ax=ax[0],
)
plot.plot(
np.vstack((D @ xt, y, D @ solver.x)).T,
title="Signal",
lgnd=("Ground Truth", "Noisy", "Recovered"),
fig=fig,
ax=ax[1],
)
fig.show()