# import the required libs
import numpy as np
import faiss
import torch
from utilities import *
from torch_geometric.data import Data
from torch_geometric.utils import is_undirected
#load the pcd/mesh
(pos, tex, fac) = getPositionTexture("data/noisy_girl_skate.ply") # Color Processing.
#(pos, tex, fac) = getPositionTexture("data/noisy_3d_signal.ply")# Shape Processing.
pos = (pos - np.min(pos))/(np.max(pos)-np.min(pos)) #a good practice.
displaySur(**dict(position=pos, texture=tex)) # Open3d visualization works only locally!
[Open3D WARNING] geometry::TriangleMesh appears to be a geometry::PointCloud (only contains vertices, but no triangles).
"""
- Create a knn graph using the position coords of the pointcloud.
- Assign scalar weigths to the edges.
"""
#Faiss graph construction
res = faiss.StandardGpuResources()
index = faiss.IndexFlatL2(pos.shape[1])
gpu_index_flat = faiss.index_cpu_to_gpu(res,0,index)
gpu_index_flat.add(pos.astype(np.float32))
k = 8
D, I = gpu_index_flat.search(pos.astype(np.float32),k+1)
#Convert to torch_geometric Data class
edge_index = np.vstack((I[:,1:].flatten(), np.repeat(I[:,0].flatten(),k)))
### Shape Processing ###
#edge_attr = np.ones(edge_index.shape[1]) # Lets keep the weights equal to 1 !
### Shape Processing ###
### Color Processing ###
# RBF kernel
edge_attr = np.exp(-np.sum(((tex[I]-tex[I][:,0,None])**2), axis=2)/(0.2)**2)[:,1:].flatten()
### Color Porcessing ###
edge_index = torch.from_numpy(edge_index).type(torch.long) # it is important to convert to torch.long
edge_attr = torch.from_numpy(edge_attr).type(torch.float32)
edge_attr = edge_attr.view(-1,1)
#getWgStats(edge_attr.numpy())
#Check if the graph is symmetric and create a temporary graph.
print(is_undirected(edge_index))
tmp_graph = Data(edge_index = edge_index, edge_attr=edge_attr, num_nodes=len(pos))
# graph symmetrization by converting to a sparse tensor.
tst = ToSparseTensor()
nG = tst(tmp_graph).adj_t.to_symmetric().to_torch_sparse_coo_tensor()
new_edge_index = torch.stack((nG.coalesce().indices()[1], nG.coalesce().indices()[0]))
new_edge_attr = nG.coalesce().values()
# Create a new graph
graph = Data(edge_index = new_edge_index.type(torch.long), edge_attr=new_edge_attr/2,
x=torch.from_numpy(pos).type(torch.float32), tex=torch.from_numpy(tex).type(torch.float32))
print(is_undirected(graph.edge_index))
False True
# save it
torch.save(graph, "./data/girl_skate.pt") # Color Processing
#torch.save(graph, "./data/3d_signal.pt") # Shape Processing