import torch
from tqdm import tqdm
from torch_scatter import scatter
from utilities import *
dev = "cuda" if torch.cuda.is_available() else "cpu"
# load the graph
graph = torch.load("./data/3d_signal.pt")
graph = graph.to(dev)
#visualize the NOISY ponitcloud if required: #Open3d visualization works only locally!
displaySur(**dict(position=graph.x.cpu().numpy(), texture=graph.tex.cpu().numpy()))
We shall be implementing the $\mathcal{P}_{\tau B}(f)$ and iterative scheme, but we will leave the $\mathcal{P}_{\beta A^{*}}(g)$ for now.
class PD:
"""
class to implement primal dual algo.
"""
def __init__(self, graph, **kwargs):
self.edge_index = graph.edge_index
self.edge_attr = graph.edge_attr
self.dim_size = kwargs["dim_size"]
self.f0 = kwargs["noisy_sig"]
self.lamb = kwargs["lamb"]
self.tau = kwargs["tau"]
self.beta = kwargs["beta"]
self.theta = kwargs["theta"]
self.dev = kwargs["dev"]
def grad(self, Xb):
return torch.sqrt(self.edge_attr) * (Xb[self.edge_index[0]] - Xb[self.edge_index[1]])
def proj(self,T):
pass
def prox(self,U):
return (self.lamb * U + self.tau * self.f0)/(self.lamb + self.tau)
def grad_conj(self, P): # grad_conj = -divergence
out1 = scatter(torch.sqrt(self.edge_attr) * P, self.edge_index[1], dim=0, dim_size=self.dim_size, reduce="add")
out2= scatter(torch.sqrt(self.edge_attr) * P, self.edge_index[0], dim=0, dim_size=self.dim_size, reduce="add")
return (out2 - out1)
def run(self,itr, X):
Xb = X
P = self.grad(X)
for i in tqdm(range(itr)):
P = self.proj(P + self.beta* self.grad(Xb))
tmp_X = X.clone()
X = self.prox(X - self.tau * self.grad_conj(P))
Xb = X + self.theta*(X-tmp_X)
return X
The proximal operator $\mathcal{P}_{\beta A^{*}}(g)$ in this case is given as: $$ \mathcal{P}_{\beta A^{*}}(g)_{i,j,k} = \mathcal{Proj}_{\mathcal{B}_{\infty,\infty}}(g)_{i,j,k} = \frac{g_{i,j,k}}{\max\{1,|g_{i,j,k}|\}}$$ Recall that $g \in \mathbb{R}^{|V|x|V|x3}$
class PD_TV1(PD):
def proj(self,T):
return (T/torch.clamp(torch.abs(T),1))
lamb = 0.005
itr = 5000
theta = 1
# Calculates the \max_{i \in V}\sum_{j\in N(i)}\omega_{i,j}
k_prime = torch.max(scatter(graph.edge_attr, graph.edge_index[1],dim=0, dim_size=graph.x.shape[0], reduce="add")).item()
beta = 1/(2*k_prime)
tau = 1/(8*k_prime)
pd_tv1 = PD_TV1(graph, **dict(dim_size = graph.x.shape[0], noisy_sig=graph.x, tau=tau, lamb=lamb, beta = beta, theta=theta, dev=dev))
new_sig_tv1 = pd_tv1.run(itr,graph.x)
100%|██████████| 5000/5000 [00:01<00:00, 2994.38it/s]
displaySur(**dict(position=new_sig_tv1.cpu().numpy(), texture=graph.tex.cpu().detach())) # displays the processed pointcld.
The proximal operator $\mathcal{P}_{\beta A^{*}}(g)$ in this case is given as: $$ \mathcal{P}_{\beta A^{*}}(g)_{i,j,k} = \mathcal{Proj}_{\mathcal{B}_{\infty,2}}(g)_{i,j,k} = \frac{g_{i,j,k}}{\max\{1,\|g_{i,.}\|_{2}\}}$$ Recall that $g \in \mathbb{R}^{|V|x|V|x3}$, here $\|g_{i,.}\|_{2}$ is the $L_2$ norm corresonding to $i^{th}$ row tensor of size $|V|x3$
class PD_TV2(PD):
def proj(self,T):
aggr_out = scatter(torch.pow(T,2), self.edge_index[1], dim=0, dim_size=self.dim_size, reduce="add")
a = torch.sqrt(aggr_out)
b = torch.norm(a,dim=1)
b = b.view(-1,1)
return (T/torch.clamp(b[self.edge_index[1]],1))
lamb = 0.05
itr = 5000
theta = 1
# Calculates the \max_{i \in V}\sum_{j\in N(i)}\omega_{i,j}
k_prime = torch.max(scatter(graph.edge_attr, graph.edge_index[1],dim=0, dim_size=graph.x.shape[0], reduce="add")).item()
beta = 1/(2*k_prime)
tau = 1/(8*k_prime)
pd_tv2 = PD_TV2(graph, **dict(dim_size = graph.x.shape[0], noisy_sig=graph.x, tau=tau, lamb=lamb, beta = beta, theta=theta, dev=dev))
new_sig_tv2 = pd_tv2.run(itr,graph.x)
100%|██████████| 5000/5000 [00:02<00:00, 2249.51it/s]
displaySur(**dict(position=new_sig_tv2.cpu().cpu().numpy(), texture=graph.tex.cpu().numpy())) # displays the processed pointcld.
class PD_L2DF(PD):
def prox(self,U):
return self.f0 + (1 - ((self.tau/self.lamb)/ torch.where( torch.norm(U-self.f0, dim=1).view(-1,1) >= torch.tensor(self.tau/self.lamb).type(torch.float32).to(self.dev), torch.norm(U-self.f0, dim=1).view(-1,1) , torch.tensor(self.tau/self.lamb).type(torch.float32).to(self.dev)))) * (U-self.f0)
def proj(self,T):
return (T/torch.clamp(torch.abs(T),1))
lamb = 0.15
itr = 5000
theta = 1
# Calculates the \max_{i \in V}\sum_{j\in N(i)}\omega_{i,j}
k_prime = torch.max(scatter(graph.edge_attr, graph.edge_index[1],dim=0, dim_size=graph.x.shape[0], reduce="add")).item()
beta = 1/(2*k_prime)
tau = 1/(8*k_prime)
pd_l2df = PD_L2DF(graph, **dict(dim_size = graph.x.shape[0], noisy_sig=graph.x, tau=tau, lamb=lamb, beta = beta, theta=theta, dev=dev))
new_sig_l2df = pd_l2df.run(itr,graph.x)
100%|██████████| 5000/5000 [00:02<00:00, 1907.74it/s]
displaySur(**dict(position=new_sig_l2df.cpu().numpy(), texture=graph.tex.cpu().numpy())) # displays the processed pointcld.
# You should see something like this, in the following order:
#Original poincloud
#Case-1 (aniso TV)
#Case-2 (iso TV)
#Non-squared L2 DF term
from matplotlib import image as mpimg
from matplotlib import pyplot as plt
res = mpimg.imread("./data/out_shape.png")
f = plt.figure(figsize=(12,12))
plt.imshow(res, cmap="gray")
plt.show()
def prox()
method to solve for the non-squared L1 data-fidelity term in the following functional: