where $\mathbf{R}$ is any generic additional linear operator
%load_ext autoreload
%autoreload 2
%matplotlib inline
import warnings
warnings.filterwarnings('ignore')
import numpy as np
from scipy.sparse.linalg import lsqr
from pylops.basicoperators import *
from pylops.utils.dottest import dottest
from pyproximal.proximal import *
from pyproximal import ProxOperator
from pyproximal.utils.bilinear import BilinearOperator
class LowRankFactorizedMatrix(BilinearOperator):
def __init__(self, X, Y, d, Op=None):
self.n, self.k = X.shape
self.m = Y.shape[1]
self.x = X
self.y = Y
self.d = d
self.Op = Op
self.shapex = (self.n * self.m, self.n * self.k)
self.shapey = (self.n * self.m, self.m * self.k)
def __call__(self, x, y=None):
if y is None:
x, y = x[:self.n * self.k], x[self.n * self.k:]
xold = self.x.copy()
self.updatex(x)
res = self.d - self._matvecy(y)
self.updatex(xold)
return np.linalg.norm(res)**2 / 2.
def _matvecx(self, x):
X = x.reshape(self.n, self.k)
X = X @ self.y.reshape(self.k, self.m)
if self.Op is not None:
X = self.Op @ X.ravel()
return X.ravel()
def _matvecy(self, y):
Y = y.reshape(self.k, self.m)
X = self.x.reshape(self.n, self.k) @ Y
if self.Op is not None:
X = self.Op @ X.ravel()
return X.ravel()
def matvec(self, x):
if x.size == self.shapex[1]:
y = self._matvecx(x)
else:
y = self._matvecy(x)
return y
def _rmatvecx(self, x):
if self.Op is not None:
x = self.Op.H @ x
X = x.reshape(self.n, self.m)
X = X @ np.conj(self.y.reshape(self.k, self.m).T)
return X.ravel()
def _rmatvecy(self, x):
if self.Op is not None:
x = self.Op.H @ x
Y = x.reshape(self.n, self.m)
X = (np.conj(Y.T) @ self.x.reshape(self.n, self.k)).T
return X.ravel()
def rmatvec(self, x, which="x"):
if which == "x":
y = self._rmatvecx(x)
else:
y = self._rmatvecy(x)
return y
# Restriction operator
n, m, k = 4, 5, 2
sub = 0.4
nsub = int(n*m*sub)
iava = np.random.permutation(np.arange(n*m))[:nsub]
Rop = Restriction(n*m, iava)
# model
U = np.random.normal(0., 1., (n, k))
V = np.random.normal(0., 1., (m, k))
X = U @ V.T
# data
y = Rop * X.ravel()
# Masked data
Y = (Rop.H * Rop * X.ravel()).reshape(n, m)
X = U @ V.T
X1 = (V @ U.T).T
Uop = MatrixMult(U, otherdims=(m,))
Top = Transpose((m,k), (1,0))
Uop1 = Uop * Top
print(Uop, Top)
X1 = Uop1 * V.ravel()
X1 = X1.reshape(n,m)
print(X-X1)
# data
Ruop = Rop * Uop * Top
y1 = Ruop * V.ravel()
print(y-y1)
<20x10 MatrixMult with dtype=float64> <10x10 Transpose with dtype=float64> [[0. 0. 0. 0. 0.] [0. 0. 0. 0. 0.] [0. 0. 0. 0. 0.] [0. 0. 0. 0. 0.]] [0. 0. 0. 0. 0. 0. 0. 0.]
v1 = Ruop.H @ y1
Vop = MatrixMult(V, otherdims=(n,))
Top = Transpose((n,k), (1,0))
T1op = Transpose((n,m), (1,0))
Vop1 = T1op.T * Vop * Top
X1 = Vop1 * U.ravel()
X1 = X1.reshape(n,m)
print(X-X1)
# data
Ruop = Rop * T1op.T * Vop * Top
y1 = Ruop * U.ravel()
print(y-y1)
[[0. 0. 0. 0. 0.] [0. 0. 0. 0. 0.] [0. 0. 0. 0. 0.] [0. 0. 0. 0. 0.]] [0. 0. 0. 0. 0. 0. 0. 0.]
u1 = Ruop.H @ y1
Let's now use our function
LOp = LowRankFactorizedMatrix(U, V.T, y, Op=Rop)
y-LOp._matvecx(U.ravel()), y-LOp._matvecy(V.T.ravel())
(array([0., 0., 0., 0., 0., 0., 0., 0.]), array([0., 0., 0., 0., 0., 0., 0., 0.]))
u1-LOp._rmatvecx(y).reshape(n, k)
array([[0., 0.], [0., 0.], [0., 0.], [0., 0.]])
v1.T-LOp._rmatvecy(y).reshape(k, m)
array([[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.]])
Fop = FunctionOperator(LOp._matvecx, LOp._rmatvecx, len(iava), n*k)
dottest(Fop)
True
Fop = FunctionOperator(LOp._matvecy, LOp._rmatvecy, len(iava), k*m)
dottest(Fop)
True
We do the same now but we assume a stack of matrices, where for each of them we have
$$ \mathbf{y}_i=\mathbf{U}_i\mathbf{V}_i^T = R_{u_i}(\mathbf{v}_i) $$and
$$ \mathbf{y}=\mathbf{R} [\mathbf{y}_1^T, \mathbf{y}_2^T, ..., \mathbf{y}_N^T]^T $$class LowRankFactorizedStackMatrix(BilinearOperator):
r"""Low-Rank Factorized Stack of Matrix operator.
Parameters
----------
X : :obj:`numpy.ndarray`
Left-matrix of size :math:`r \times n \times k`
Y : :obj:`numpy.ndarray`
Right-matrix of size :math:`r \times k \times m`
d : :obj:`numpy.ndarray`
Data vector
Op : :obj:`pylops.LinearOperator`, optional
Linear operator
"""
def __init__(self, X, Y, d, Op=None):
self.r, self.n, self.k = X.shape
self.m = Y.shape[2]
self.x = X
self.y = Y
self.d = d
self.Op = Op
self.shapex = (self.r * self.n * self.m, self.r * self.n * self.k)
self.shapey = (self.r * self.n * self.m, self.r * self.m * self.k)
def __call__(self, x, y=None):
if y is None:
x, y = x[:self.r * self.n * self.k], x[self.r * self.n * self.k:]
xold = self.x.copy()
self.updatex(x)
res = self.d - self._matvecy(y)
self.updatex(xold)
return np.linalg.norm(res)**2 / 2.
def _matvecx(self, x):
X = x.reshape(self.r, self.n, self.k)
X = np.matmul(X, self.y.reshape(self.r, self.k, self.m))
if self.Op is not None:
X = self.Op @ X.ravel()
return X.ravel()
def _matvecy(self, y):
Y = y.reshape(self.r, self.k, self.m)
X = np.matmul(self.x.reshape(self.r, self.n, self.k), Y)
if self.Op is not None:
X = self.Op @ X.ravel()
return X.ravel()
def matvec(self, x):
if x.size == self.shapex[1]:
y = self._matvecx(x)
else:
y = self._matvecy(x)
return y
def _rmatvecx(self, x):
if self.Op is not None:
x = self.Op.H @ x
X = x.reshape(self.r, self.n, self.m)
X = X @ np.conj(self.y.reshape(self.r, self.k, self.m).transpose(0, 2, 1))
return X.ravel()
def _rmatvecy(self, x):
if self.Op is not None:
x = self.Op.H @ x
Y = x.reshape(self.r, self.n, self.m)
X = (np.conj(Y.transpose(0, 2, 1) @ self.x.reshape(self.r, self.n, self.k)) ).transpose(0, 2, 1)
return X.ravel()
def rmatvec(self, x, which="x"):
if which == "x":
y = self._rmatvecx(x)
else:
y = self._rmatvecy(x)
return y
# Restriction operator
r, n, m, k = 10, 4, 5, 2
nsub = int(r*n*m*sub)
iava = np.random.permutation(np.arange(r*n*m))[:nsub]
Rop = Restriction(r*n*m, iava)
U = np.random.normal(0., 1., (r, n, k))
V = np.random.normal(0., 1., (r, m, k))
LOp = LowRankFactorizedStackMatrix(U, V.transpose(0,2,1), y, Op=Rop)
y = LOp._matvecx(U.ravel())
LOp._matvecx(U.ravel()) - LOp._matvecy(V.transpose(0,2,1).ravel())
array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
LOp._rmatvecx(y).reshape(r, n, k)
array([[[ -0.78054084, 2.8911588 ], [ 1.93563641, -0.72318821], [ 0.04292538, -0.28963407], [ -6.5718067 , 12.22588594]], [[ 0.1168023 , 0.06032668], [ 3.62884288, 4.30775936], [ 1.41275428, -4.99451938], [ -1.2500559 , 6.97063198]], [[ 0. , 0. ], [ -0.07611239, 0.26622162], [ 0.13927045, -0.15456169], [ 0.6466862 , -1.24214511]], [[ 1.198345 , -0.7805392 ], [ 4.75092671, -1.52422783], [ 1.47400335, -1.14077353], [ -3.1661599 , 2.05806897]], [[ -1.00902946, 4.92878232], [ 0.24899883, 0.32624332], [ -0.06286789, 0.04266047], [ -1.63890764, 1.39551526]], [[ 2.36845286, 6.60434453], [ -2.08853631, -6.76490457], [-14.68757653, -15.625856 ], [ -0.94519908, -0.99317364]], [[ -0.54407312, 1.28697607], [ -6.60844309, 4.67540914], [ 1.64741857, -2.02982978], [ 0.68914214, -1.63012914]], [[ 0.82424242, -1.02167535], [-11.59598402, -7.32947411], [ -0.81085562, 1.00508197], [ 0.78811781, -0.97689771]], [[ 0.8243525 , -2.10939496], [ 0.06842153, 1.06827078], [ -0.22793551, -0.69989879], [ 0.23410208, 0.57912629]], [[ 1.64003496, -0.93854556], [ 1.7113675 , -3.55769291], [ 0.26341859, -0.19187614], [ 0. , 0. ]]])
LOp._rmatvecy(y).reshape(r, k, m)
array([[[-1.53167819e+00, -5.41519764e-01, 1.72286170e+00, 3.52298628e+00, 2.09072694e+00], [ 3.04232231e+00, 1.13873913e+00, 3.11648146e+00, -4.93432359e+00, 9.35769112e-01]], [[ 1.21189421e+00, 7.41606280e+00, 3.42830783e-01, -2.43030485e-01, -2.09537060e-01], [ 2.04667300e+00, 6.80163112e+00, -4.33082797e-01, -5.23441190e+00, -1.71705541e-01]], [[-3.68759690e-01, 1.82084056e-01, 3.18172630e-03, 0.00000000e+00, 0.00000000e+00], [ 8.43423815e-01, -6.64181708e-01, -2.36685631e-02, 0.00000000e+00, 0.00000000e+00]], [[ 0.00000000e+00, -2.67766186e+00, 2.54528032e-01, 3.56676138e-02, 1.60006873e+00], [ 0.00000000e+00, 1.03960650e+00, 6.84012272e-02, -4.71108763e-02, -1.87663837e-01]], [[-1.20108514e-01, 5.31513126e-01, -9.37653478e-01, 2.90938798e-02, -1.19728385e+00], [-4.09950670e+00, 3.10669865e-01, 3.42816597e+00, -7.13296768e-03, 1.68613156e+00]], [[-2.16444553e+01, 7.02059677e+00, 4.57238230e+00, 5.88757160e-01, 2.18376847e+00], [-1.25470827e+01, 5.27350183e+00, 2.57870424e+00, 2.27289480e+00, 4.44479659e-01]], [[ 0.00000000e+00, 4.27009012e+00, 1.31059583e+00, -5.02808450e-01, -3.93352871e+00], [ 0.00000000e+00, -8.28073851e+00, -1.49855347e+00, 4.25243935e-01, 3.32673253e+00]], [[ 0.00000000e+00, 8.69892215e-01, 3.61350865e+00, -8.09336332e-01, -1.71329566e+00], [ 0.00000000e+00, 1.37063213e+00, 5.69356868e+00, -1.27521820e+00, 4.68299998e+00]], [[-4.38607107e-02, 6.48709824e-01, 4.43077157e-01, -3.21121611e-02, 0.00000000e+00], [-1.33536165e-02, 3.71728654e-01, 8.26053949e-01, -3.85739195e-02, 0.00000000e+00]], [[-3.41112674e-01, 2.20327313e+00, 1.08024599e+00, 2.01332704e+00, -1.40502167e+00], [-2.58495204e-01, -6.16489457e-01, 6.69923392e-02, -3.52710402e+00, 1.85899586e+00]]])
Fop = FunctionOperator(LOp._matvecx, LOp._rmatvecx, len(iava), r*n*k)
dottest(Fop)
True
Fop = FunctionOperator(LOp._matvecy, LOp._rmatvecy, len(iava), r*k*m)
dottest(Fop)
True