import numpy as np import matplotlib.pyplot as plt import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from time import time import progressbar device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(device) n = 10000 # number of points m = 10100 # number of points on the second cloud d = 2 # dimension X = torch.rand(n,d) Y = torch.rand(m,d) print( X.is_cuda ) X = X.to(device); # put it on gpu Y = Y.to(device); # put it on gpu print( X.is_cuda ) boundary = 'no' # no boundary condition boundary = 'per' # periodic if boundary=='no': # change this for usual BC print('No boundary.') def bc_pos(X): return X def bc_diff(D): return D else: print('Periodic boundary.') def bc_pos(X): return torch.remainder(X,1.0) def bc_diff(D): return torch.remainder(D-.5,1.0)-.5 t = torch.tensor(np.linspace(-1.5,1.5,1000)) plt.plot( t, bc_pos( t ) ) plt.plot( t, bc_diff( t ), '--' ) def distmat_square(X,Y): return torch.sum( bc_diff(X[:,None,:] - Y[None,:,:])**2, axis=2 ) plt.imshow( distmat_square(t[:,None],t[:,None]) ) def distmat_square2(X, Y): X_sq = (X ** 2).sum(axis=-1) Y_sq = (Y ** 2).sum(axis=-1) cross_term = X.matmul(Y.T) return X_sq[:, None] + Y_sq[None, :] - 2 * cross_term t0 = time() distmat_square(X, Y) print(time() - t0) t0 = time() distmat_square2(X, Y) print(time() - t0) def kernel(X,Y): return -torch.sqrt( distmat_square(X,Y) ) def MMD(X,Y): n = X.shape[0] m = Y.shape[0] a = torch.sum( kernel(X,X) )/n**2 + \ torch.sum( kernel(Y,Y) )/m**2 - \ 2*torch.sum( kernel(X,Y) )/(n*m) return a.item() print( MMD(X,X) ) # should be 0 print( MMD(X,Y) ) # should be >0 sigma = .1; def psi(r): return torch.exp( -r/(2*sigma**2) ) def Speed(X): return 2/X.shape[0] * 1/sigma**2 * torch.sum( psi(distmat_square(X,X))[:,:,None] * bc_diff( X[:,None,:] - X[None,:,:] ), axis=1 ) if boundary=='no': tau = 1/500 # time step else: tau = 1/200 niter = 200 save_per = 10 # periodicity of saving Zsvg = torch.zeros((n,2,niter//save_per)) # to store all the intermediate time Z = X for it in progressbar.progressbar(range(niter)): if np.mod(it,save_per)==0: Zsvg[:,:,it//save_per] = Z.clone().detach() # for later display Z = bc_pos( Z - tau*Speed(Z) ) import ipywidgets as widgets @widgets.interact(t=(0,niter//save_per-1)) def display_frame(t=0): s = t/(niter//save_per-1) plt.scatter(Zsvg[:,0,t], Zsvg[:,1,t], color=[s,0,1-s]) plt.axis('equal') plt.axis([0,1,0,1]) X.requires_grad = True L = -1/X.shape[0] * torch.sum( psi(distmat_square(X,X)), axis=(0,1) ) [g] = torch.autograd.grad(L, [X]) # compare with the "by hand" computation print( 'Difference "hand" vs. pytorch" : ' + str( torch.norm( g-Speed(X) ).item() / torch.norm( g ).item() ) ) !pip install pykeops[colab] > install.log import pykeops import pykeops.torch as keops pykeops.clean_pykeops() # just in case old build files are still present pykeops.test_torch_bindings() # perform the compilation X.requires_grad = True D = keops.Vi(X) - keops.Vj(X) if boundary=='per': D1 = (D-.5).mod(1.0)-.5 # for periodic BC else: D1 = D D2 = ( D1 ** 2 ).sum( dim=2 ) K = ( -D2 / (2*sigma**2) ).exp() L = -1/X.shape[0] * (K.sum(dim=1)**1).sum() [g] = torch.autograd.grad(L, [X]) # compare with the "by hand" computation print( 'Difference "hand" vs. "keops+pytorch" : ' + str( torch.norm( g-Speed(X) ).item() / torch.norm( g ).item() ) ) Zsvg = torch.zeros((n,2,niter//save_per)) # to store all the intermediate time Z = X Z.requires_grad = True for it in progressbar.progressbar(range(niter)): if np.mod(it,save_per)==0: Zsvg[:,:,it//save_per] = Z.clone().detach() # for later display D = keops.Vi(Z) - keops.Vj(Z) if boundary=='per': D1 = (D-.5).mod(1.0)-.5 # for periodic BC else: D1 = D D2 = ( D1 ** 2 ).sum( dim=2 ) K = ( -D2 / (2*sigma**2) ).exp() L = 1/X.shape[0] * (K.sum(dim=1)**1).sum() # There is a bug, I needed to add **1 here !! [g] = torch.autograd.grad(L, [Z]) Z = bc_pos( Z + tau*g ) import ipywidgets as widgets @widgets.interact(t=(0,niter//save_per-1)) def display_frame(t=0): s = t/(niter//save_per-1) plt.scatter(Zsvg[:,0,t], Zsvg[:,1,t], color=[s,0,1-s]) plt.axis('equal') plt.axis([0,1,0,1])