We try to find the projection of a vector into the intersection of simple sets (whose projection can be computed easily) - See Appendix A of A Convex Approach to Minimal Partitions Antonin Chambolle, Daniel Cremers, Thomas Pock
$$ proj_K(x) = \bigcap_{1 \leq i_1 < i_2 \leq k} K_{i_1,i_2} \quad K_{i_1,i_2}= \{ x: |x_{i_2} - x_{i_1}| \leq \sigma_{i1, i2} \quad \forall i1<i2 \} $$%load_ext autoreload
%autoreload 2
%matplotlib inline
import warnings
warnings.filterwarnings('ignore')
import numpy as np
import matplotlib.pyplot as plt
from pyproximal.projection import *
from pyproximal.proximal import *
The autoreload extension is already loaded. To reload it, use: %reload_ext autoreload
Create vector x
k = 3
xtrue = np.random.normal(0, 50, k)
x = xtrue.copy()
x1 = xtrue.copy()
sigma = np.array([[3,2,1], [2,2,1], [3,4,1]])
sigma = sigma.T @ sigma
print(sigma)
for i1 in range(k-1):
for i2 in range(i1+1, k):
print(i1, i2, sigma[i1, i2], np.linalg.norm(x[i1] - x[i2]))
[[22 22 8] [22 24 8] [ 8 8 3]] 0 1 22 53.16650719444234 0 2 8 30.493562683547587 1 2 8 22.67294451089475
Projection of vector x
niter = 40
tol = 1e-20
x12 = np.zeros((k,k))
for iiter in range(niter):
xold = x.copy()
for i1 in range(k-1):
for i2 in range(i1+1, k):
xtilde = x[i2] - x[i1] + x12[i1, i2]
xtildeabs = np.abs(xtilde)
xdtilde = np.maximum(0, xtildeabs - sigma[i1, i2]) * xtilde / xtildeabs
x[i1] = x[i1] + 0.5 * (xdtilde - x12[i1, i2])
x[i2] = x[i2] - 0.5 * (xdtilde - x12[i1, i2])
x12[i1, i2] = xdtilde
if max(np.abs(x - xold)) < tol:
break
print(iiter)
33
Check projected vector satisfy the condition
for i1 in range(k-1):
for i2 in range(i1+1, k):
print(i1, i2, sigma[i1, i2], np.abs(xtrue[i1] - xtrue[i2]))
for i1 in range(k-1):
for i2 in range(i1+1, k):
print(i1, i2, sigma[i1, i2], np.abs(x[i1] - x[i2]))
0 1 22 53.16650719444234 0 2 8 30.493562683547587 1 2 8 22.67294451089475 0 1 22 16.0 0 2 8 8.0 1 2 8 8.0
ic = IntersectionProj(k, 1, sigma, niter, tol)
x1 = ic(x1)
x1 - xtrue
array([ 19.88668996, -17.27981724, -2.60687272])
ic = Intersection(k, 1, sigma, niter, tol)
print(ic(xtrue))
x = ic.prox(xtrue, 1)
print(ic(x))
False True
Repeat the same, now with a matrix with n columns (algorithm works on each column indipendently)
k = 3
n = 5
xtrue = np.random.normal(0, 50, (k, n))
x = xtrue.copy()
x1 = xtrue.copy()
sigma = np.array([[3,2,1], [2,2,1], [3,4,1]])
sigma = sigma.T @ sigma
print(sigma)
for i1 in range(k-1):
for i2 in range(i1+1, k):
print(i1, i2, sigma[i1, i2], np.abs(x[i1] - x[i2]))
[[22 22 8] [22 24 8] [ 8 8 3]] 0 1 22 [23.59298472 77.50994045 24.97714966 26.83330122 76.43167362] 0 2 8 [20.48895437 44.75477282 19.23383174 50.63779612 31.27031417] 1 2 8 [ 3.10403035 32.75516763 44.2109814 23.80449489 107.70198779]
niter = 50
tol = 1e-20
x12 = np.zeros((k,k,n))
for iiter in range(niter):
xold = x.copy()
for i1 in range(k-1):
for i2 in range(i1+1, k):
xtilde = x[i2] - x[i1] + x12[i1, i2]
xtildeabs = np.abs(xtilde)
xdtilde = np.maximum(0, xtildeabs - sigma[i1, i2]) * xtilde / xtildeabs
x[i1] = x[i1] + 0.5 * (xdtilde - x12[i1, i2])
x[i2] = x[i2] - 0.5 * (xdtilde - x12[i1, i2])
x12[i1, i2] = xdtilde
if max(np.sum(np.abs(x-xold), axis=0)) < tol:
break
print(iiter)
40
for i1 in range(k-1):
for i2 in range(i1+1, k):
print(i1, i2, sigma[i1, i2], np.abs(x[i1] - x[i2]))
0 1 22 [16. 16. 6.87165896 5.51440317 16. ] 0 2 8 [8. 8. 1.12834104 8. 8. ] 1 2 8 [8. 8. 8. 2.48559683 8. ]
Same using the projection operator in PyProximal
ic = IntersectionProj(k, n, sigma, niter, tol)
x1 = ic(x1)
x1 = x1.reshape(k,n)
for i1 in range(k-1):
for i2 in range(i1+1, k):
print(i1, i2, sigma[i1, i2], np.abs(x1[i1] - x1[i2]))
0 1 22 [16. 16. 6.87165896 5.51440317 16. ] 0 2 8 [8. 8. 1.12834104 8. 8. ] 1 2 8 [8. 8. 8. 2.48559683 8. ]
and the proximal operator in PyProximal
ic = Intersection(k, n, sigma, niter, tol)
print(ic(xtrue, 1e-3))
x = ic.prox(xtrue, 1)
print(ic(x, 1e-3))
x = x.reshape(k,n)
for i1 in range(k-1):
for i2 in range(i1+1, k):
print(i1, i2, sigma[i1, i2], np.abs(x[i1] - x[i2]))
False True 0 1 22 [16. 16. 6.87165896 5.51440317 16. ] 0 2 8 [8. 8. 1.12834104 8. 8. ] 1 2 8 [8. 8. 8. 2.48559683 8. ]