import numpy as np import scipy.sparse as ssp import matplotlib.pyplot as plt u = np.linspace(-2, 2, 60) v = np.linspace(-2, 2, 60) U, V = np.meshgrid(u, v) X_original = np.sin(-U+V-1)**2*np.cos(U+V+1) + 1 xmin = np.min(X_original) xmax = np.max(X_original) m,n = X_original.shape B = np.random.rand(m, n) > 0.50 X_obscured = B*X_original plt.figure(figsize=(12,6)) plt.subplot(121) plt.imshow(X_original, cmap='gray') plt.title('Original image') plt.axis('off') plt.subplot(122) plt.imshow(X_obscured, cmap='gray') plt.title('Obscured image') plt.axis('off') plt.show() mn = m*n p = np.sum(B) q = mn - p B_k = B.flatten('F') B_u = 1-B.flatten('F') Z_k = ssp.coo_matrix( ( np.ones(p), ( np.where(B_k>0)[0], range(p) ) ), shape=(mn, p), dtype=np.int8) Z_u = ssp.coo_matrix( ( np.ones(q), ( np.where(B_u>0)[0], range(q) ) ), shape=(mn, q), dtype=np.int8) x_k = Z_k.T@X_original.flatten('F') Dx = ssp.hstack( [ np.zeros( (m*(n-1),m) ), ssp.eye( m*(n-1) ) ] ) \ - ssp.hstack( [ ssp.eye( m*(n-1) ), np.zeros( (m*(n-1),m) ) ] ) Dy = ssp.kron( ssp.eye(n), np.diff(np.eye(m), axis=0) ) plt.figure(figsize=(12,6)) plt.subplot(121) plt.spy(Z_k, markersize=0.2) plt.title(r'$Z_{k}$') plt.xticks(np.arange(0,p,500)) plt.subplot(122) plt.spy(Z_u, markersize=0.2) plt.title(r'$Z_{u}$') plt.xticks(np.arange(0,q,500)) plt.show() plt.figure(figsize=(12,6)) plt.subplot(121) plt.spy(Dx, markersize=0.2) plt.title(r'$D_x$') plt.subplot(122) plt.spy(Dy, markersize=0.2) plt.title(r'$D_y$') plt.show() import scipy.sparse.linalg as sla DD = ssp.vstack( [ Dx, Dy ] ) A_MATRIX = DD@Z_u B_MATRIX = DD@Z_k@x_k x_u = sla.lsqr(A_MATRIX, -B_MATRIX)[0] X_recon_L2 = (Z_k@x_k + Z_u@x_u).reshape(m, n, order='F') plt.figure(figsize=(12,12)) plt.subplot(221) plt.imshow(X_original, cmap='gray') plt.title('Original image') plt.axis('off') plt.subplot(222) plt.imshow(X_obscured, cmap='gray') plt.title('Obscured image') plt.axis('off') plt.subplot(223) plt.imshow(X_recon_L2, cmap='gray') plt.title('Reconstructed image') plt.axis('off') plt.subplot(224) plt.imshow(np.abs(X_recon_L2-X_original), cmap='gray', vmin=xmin, vmax=xmax) plt.title('Difference image') plt.axis('off') plt.show()