Important: Please read the installation page for details about how to install the toolboxes. $\newcommand{\dotp}[2]{\langle #1, #2 \rangle}$ $\newcommand{\enscond}[2]{\lbrace #1, #2 \rbrace}$ $\newcommand{\pd}[2]{ \frac{ \partial #1}{\partial #2} }$ $\newcommand{\umin}[1]{\underset{#1}{\min}\;}$ $\newcommand{\umax}[1]{\underset{#1}{\max}\;}$ $\newcommand{\umin}[1]{\underset{#1}{\min}\;}$ $\newcommand{\uargmin}[1]{\underset{#1}{argmin}\;}$ $\newcommand{\norm}[1]{\|#1\|}$ $\newcommand{\abs}[1]{\left|#1\right|}$ $\newcommand{\choice}[1]{ \left\{ \begin{array}{l} #1 \end{array} \right. }$ $\newcommand{\pa}[1]{\left(#1\right)}$ $\newcommand{\diag}[1]{{diag}\left( #1 \right)}$ $\newcommand{\qandq}{\quad\text{and}\quad}$ $\newcommand{\qwhereq}{\quad\text{where}\quad}$ $\newcommand{\qifq}{ \quad \text{if} \quad }$ $\newcommand{\qarrq}{ \quad \Longrightarrow \quad }$ $\newcommand{\ZZ}{\mathbb{Z}}$ $\newcommand{\CC}{\mathbb{C}}$ $\newcommand{\RR}{\mathbb{R}}$ $\newcommand{\EE}{\mathbb{E}}$ $\newcommand{\Zz}{\mathcal{Z}}$ $\newcommand{\Ww}{\mathcal{W}}$ $\newcommand{\Vv}{\mathcal{V}}$ $\newcommand{\Nn}{\mathcal{N}}$ $\newcommand{\NN}{\mathcal{N}}$ $\newcommand{\Hh}{\mathcal{H}}$ $\newcommand{\Bb}{\mathcal{B}}$ $\newcommand{\Ee}{\mathcal{E}}$ $\newcommand{\Cc}{\mathcal{C}}$ $\newcommand{\Gg}{\mathcal{G}}$ $\newcommand{\Ss}{\mathcal{S}}$ $\newcommand{\Pp}{\mathcal{P}}$ $\newcommand{\Ff}{\mathcal{F}}$ $\newcommand{\Xx}{\mathcal{X}}$ $\newcommand{\Mm}{\mathcal{M}}$ $\newcommand{\Ii}{\mathcal{I}}$ $\newcommand{\Dd}{\mathcal{D}}$ $\newcommand{\Ll}{\mathcal{L}}$ $\newcommand{\Tt}{\mathcal{T}}$ $\newcommand{\si}{\sigma}$ $\newcommand{\al}{\alpha}$ $\newcommand{\la}{\lambda}$ $\newcommand{\ga}{\gamma}$ $\newcommand{\Ga}{\Gamma}$ $\newcommand{\La}{\Lambda}$ $\newcommand{\si}{\sigma}$ $\newcommand{\Si}{\Sigma}$ $\newcommand{\be}{\beta}$ $\newcommand{\de}{\delta}$ $\newcommand{\De}{\Delta}$ $\newcommand{\phi}{\varphi}$ $\newcommand{\th}{\theta}$ $\newcommand{\om}{\omega}$ $\newcommand{\Om}{\Omega}$ $\newcommand{\eqdef}{\equiv}$
This numerical tours exposes the general methodology of regularizing the optimal transport (OT) linear program using entropy. This allows to derive fast computation algorithm based on iterative projections according to a Kulback-Leiber divergence. $$ \DeclareMathOperator{\KL}{KL} \newcommand{\KLdiv}[2]{\KL\pa{#1 | #2}} \newcommand{\KLproj}{\text{Proj}^{\tiny\KL}} \renewcommand{\epsilon}{\varepsilon} \def\ones{\mathbb{I}} $$
from __future__ import division
import numpy as np
import matplotlib.pyplot as plt
import scipy as scp
import pylab as pyl
import warnings
warnings.filterwarnings('ignore')
%matplotlib inline
%load_ext autoreload
%autoreload 2
We consider two input histograms $a,b \in \Si_n$, where we denote the simplex in $\RR^n$ $$ \Si_n \eqdef \enscond{ a \in \RR_+^n }{ \sum_i a_i = 1 }. $$ We consider the following discrete regularized transport $$ W_\epsilon(a,b) \eqdef \umin{P \in U(a,b)} \dotp{C}{P} - \epsilon E(P). $$ where the polytope of coupling is defined as $$ U(a,b) \eqdef \enscond{P \in (\RR^+)^{n \times m}}{ P \ones_m = a, P^\top \ones_n = b }, $$ where $\ones_n \eqdef (1,\ldots,1)^\top \in \RR^n $, and for $P \in \RR_+^{n \times m}$, we define its entropy as $$ E(P) \eqdef -\sum_{i,j} P_{i,j} ( \log(P_{i,j}) - 1). $$
When $\epsilon=0$ one recovers the classical (discrete) optimal transport. We refer to the monograph Villani for more details about OT. The idea of regularizing transport to allows for faster computation is introduced in Cuturi.
Here the matrix $C \in (\RR^+)^{n \times m} $ defines the ground cost, i.e. $C_{i,j}$ is the cost of moving mass from a bin indexed by $i$ to a bin indexed by $j$.
The regularized transportation problem can be re-written as a projection $$ W_\epsilon(a,b) = \epsilon \umin{P \in U(a,b)} \KLdiv{P}{K} \qwhereq K_{i,j} \eqdef e^{ -\frac{C_{i,j}}{\epsilon} } $$ of the Gibbs kernel $K$ according to the Kullback-Leibler divergence. The Kullback-Leibler divergence between $P, K \in \RR_+^{n \times m}$ is $$ \KLdiv{P}{K} \eqdef \sum_{i,j} P_{i,j} \pa{ \log\pa{ \frac{P_{i,j}}{K_{i,j}} } - 1}. $$
This interpretation of regularized transport as a KL projection and its numerical applications are detailed in BenamouEtAl.
Given a convex set $\Cc \subset \RR^N$, the projection according to the Kullback-Leiber divergence is defined as $$ \KLproj_\Cc(\xi) = \uargmin{ \pi \in \Cc } \KLdiv{\pi}{\xi}. $$
Given affine constraint sets $ (\Cc_1,\Cc_2) $, we aim at computing $$ \KLproj_\Cc(K) \qwhereq \Cc = \Cc_1 \cap \Cc_2 $$ (this description can of course be extended to more than 2 sets).
This can be achieved, starting by $P_0=K$, by iterating $\forall \ell \geq 0$, $$ P_{2\ell+1} = \KLproj_{\Cc_1}(P_{2\ell}) \qandq P_{2\ell+2} = \KLproj_{\Cc_2}(P_{2\ell+1}). $$
One can indeed show that $P_\ell \rightarrow \KLproj_\Cc(K)$. We refer to BauschkeLewis for more details about this algorithm and its extension to compute the projection on the intersection of convex sets (Dikstra algorithm).
A fundamental remark is that the optimality condition of the entropic regularized problem shows that the optimal coupling $P_\epsilon$ necessarily has the form $$P_\epsilon = \diag{u} K \diag{v}$$ where the Gibbs kernel is defined as $$K \eqdef e^{-\frac{C}{\epsilon}}.$$
One thus needs to find two positive scaling vectors $u \in \RR_+^n$ and $v \in \RR_+^m$ such that the two following equality holds $$P \ones = u \odot (K v) = a \qandq P^\top \ones = v \odot (K^\top u) = b.$$
Sinkhorn's algorithm alternate between the resolution of these two equations, and reads $$u \longleftarrow \frac{a}{K v} \qandq v \longleftarrow \frac{b}{K^\top u}.$$ This algorithm was shown to converge to a solution of the entropic regularized problem by Sinkhorn.
We first test the method for two input measures that are uniform measures (i.e. constant histograms) supported on two point clouds (that do not necessarily have the same size).
We thus first load two points clouds $x=(x_i)_{i=1}^{n}, y=(y_i)_{i=1}^{m}, $ where $x_i, y_i \in \RR^2$.
Number of points in each cloud, $N=(n,m)$.
N = [300,200]
Dimension of the clouds.
d = 2
Point cloud $x$, of $n$ points inside a square.
x = np.random.rand(2,N[0])-.5
Point cloud $y$, of $m$ points inside an anulus.
theta = 2*np.pi*np.random.rand(1,N[1])
r = .8 + .2*np.random.rand(1,N[1])
y = np.vstack((np.cos(theta)*r,np.sin(theta)*r))
Shortcut for displaying point clouds.
plotp = lambda x,col: plt.scatter(x[0,:], x[1,:], s=200, edgecolors="k", c=col, linewidths=2)
Display of the two clouds.
plt.figure(figsize=(10,10))
plotp(x, 'b')
plotp(y, 'r')
plt.axis("off")
plt.xlim(np.min(y[0,:])-.1,np.max(y[0,:])+.1)
plt.ylim(np.min(y[1,:])-.1,np.max(y[1,:])+.1)
plt.show()
Cost matrix $C_{i,j} = \norm{x_i-y_j}^2$.
x2 = np.sum(x**2,0)
y2 = np.sum(y**2,0)
C = np.tile(y2,(N[0],1)) + np.tile(x2[:,np.newaxis],(1,N[1])) - 2*np.dot(np.transpose(x),y)
Target histograms $(a,b)$, here uniform histograms.
a = np.ones(N[0])/N[0]
b = np.ones(N[1])/N[1]
Regularization strength $\epsilon>0$.
epsilon = .01;
Gibbs Kernel $K$.
K = np.exp(-C/epsilon)
Initialization of $v=\ones_{m}$ ($u$ does not need to be initialized).
v = np.ones(N[1])
One sinkhorn iterations.
u = a / (np.dot(K,v))
v = b / (np.dot(np.transpose(K),u))
Exercise 1
Implement Sinkhorn algorithm. Display the evolution of the constraints satisfaction errors $$ \norm{ P \ones - a }_1 \qandq \norm{ P^\top \ones - b } $$ (you need to think about how to compute these residuals from $(u,v)$ alone). isplay the violation of constraint error in log-plot.
run -i nt_solutions/optimaltransp_5_entropic/exo1
Compute the final matrix $P$.
P = np.dot(np.dot(np.diag(u),K),np.diag(v))
Display it.
plt.imshow(P);
Exercise 2
Display the regularized transport solution for various values of $\epsilon$. For a too small value of $\epsilon$, what do you observe ?
run -i nt_solutions/optimaltransp_5_entropic/exo2
Compute the obtained optimal $P$.
P = np.dot(np.dot(np.diag(u),K),np.diag(v))
Keep only the highest entries of the coupling matrix, and use them to draw a map between the two clouds. First we draw "strong" connexions, i.e. linkds $(i,j)$ corresponding to large values of $P_{i,j}$. We then draw weaker connexions.
plt.figure(figsize=(10,10))
plotp(x, 'b')
plotp(y, 'r')
A = P * (P > np.max(P)*.8)
i,j = np.where(A != 0)
plt.plot([x[0,i],y[0,j]],[x[1,i],y[1,j]],'k',lw = 2)
A = P * (P > np.max(P)*.2)
i,j = np.where(A != 0)
plt.plot([x[0,i],y[0,j]],[x[1,i],y[1,j]],'k:',lw = 1)
plt.axis("off")
plt.xlim(np.min(y[0,:])-.1,np.max(y[0,:])+.1)
plt.ylim(np.min(y[1,:])-.1,np.max(y[1,:])+.1)
plt.show()
We now consider a different setup, where the histogram values $a,b$ are not uniform, but the measures are defined on a uniform grid $x_i=y_i=i/n$. They are thue often refered to as "histograms".
Size $n$ of the histograms.
N = 200
We use here a 1-D square Euclidean metric.
t = np.arange(0,N)/N
Define the histogram $a,b$ as translated Gaussians.
Gaussian = lambda t0,sigma: np.exp(-(t-t0)**2/(2*sigma**2))
normalize = lambda p: p/np.sum(p)
sigma = .06;
a = Gaussian(.25,sigma)
b = Gaussian(.8,sigma)
Add some minimal mass and normalize.
vmin = .02;
a = normalize( a+np.max(a)*vmin)
b = normalize( b+np.max(b)*vmin)
Display the histograms.
plt.figure(figsize = (10,7))
plt.subplot(2, 1, 1)
plt.bar(t, a, width = 1/len(t), color = "darkblue")
plt.subplot(2, 1, 2)
plt.bar(t, b, width = 1/len(t), color = "darkblue")
plt.show()
Regularization strength $\ga$.
epsilon = (.03)**2
The Gibbs kernel is a Gaussian convolution, $$ K_{i,j} \eqdef e^{ -(i/N-j/N)^2/\epsilon }. $$
[Y,X] = np.meshgrid(t,t)
K = np.exp(-(X-Y)**2/epsilon)
Initialization of $v=\ones_{N}$.
v = np.ones(N)
One sinkhorn iteration.
u = a / (np.dot(K,v))
v = b / (np.dot(np.transpose(K),u))
Exercise 3
Implement Sinkhorn algorithm. Display the evolution of the constraints satisfaction errors $ \norm{ P \ones - a }_1, \norm{ P^\top \ones - b }_1$. You need to think how to compute it from $(u,v)$. Display the violation of constraint error in log-plot.
run -i nt_solutions/optimaltransp_5_entropic/exo3
Display the coupling. Use a log domain plot to better vizualize it.
P = np.dot(np.dot(np.diag(u),K),np.diag(v))
plt.figure(figsize=(5,5))
plt.imshow(np.log(P+1e-5))
plt.axis('off');
One can compute an approximation of the transport plan between the two measure by computing the so-called barycentric projection map $$ t_i \in [0,1] \longmapsto s_j \eqdef \frac{\sum_{j} P_{i,j} t_j }{ \sum_{j} P_{i,j} } = \frac{ [u \odot K(v \odot t)]_j }{ a_i }. $$ where $\odot$ and $\frac{\cdot}{\cdot}$ are the enry-wise multiplication and division.
This computation can thus be done using only multiplication with the kernel $K$.
s = np.dot(K,v*t)*u/a
Display the transport map, super-imposed over the coupling.
plt.figure(figsize=(5,5))
plt.imshow(np.log(P+1e-5))
plt.plot(s*N,t*N, 'r', linewidth=3);
plt.axis('off');
Exercise (bonus)
Try different regularization strength $\epsilon$.
We will use here Pytorch to implement Sinkhorn on the GPU. If you are running the code on Google Colab, this means you need to switch on in the preferences the use of a GPU.
import torch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print( device )
cpu
Since CUDA uses float number on 32 bits, one needs to use a quite large value for $\epsilon$ to avoid overflow.
epsilon = (.06)**2
K = np.exp(-(X-Y)**2/epsilon)
Convert Sinkohrn variables and host them on GPU (if available).
u = torch.ones(N);
v = torch.ones(N);
K1 = torch.from_numpy(K).type(torch.FloatTensor);
a1 = torch.from_numpy(a).type(torch.FloatTensor);
b1 = torch.from_numpy(b).type(torch.FloatTensor);
# send them to the GPU
K1 = K1.to(device);
u = u.to(device);
v = v.to(device);
a1 = a1.to(device);
b1 = b1.to(device);
When using Pytorch, it is good practice to implement matrix operation as summation and dummy variables. We show here how to implement one iteration of Sinkhorn this way.
u = a1 / (K1 * v[None,:]).sum(1)
v = b1 / (K1 * u[:,None]).sum(0)
Exercise:
Implement the full algorithm.
v = torch.ones(N)
niter = 2000
Err_p = torch.zeros(niter)
Err_q = torch.zeros(niter)
for i in range(niter):
# sinkhorn step 1
u = a1 / (K1 * v[None,:]).sum(1)
# error computation
r = v*(K1 * u[:,None]).sum(0)
Err_q[i] = torch.norm(r - b1, p=1)
# sinkhorn step 2
v = b1 / (K1 * u[:,None]).sum(0)
s = u*(K1 * v[None,:]).sum(1)
Err_p[i] = torch.norm(s - a1,p=1)
plt.figure(figsize = (10,7))
plt.subplot(2,1,1)
plt.title("$||P1 -a||_1$")
plt.plot(np.log(np.asarray(Err_p)), linewidth = 2)
plt.subplot(2,1,2)
plt.title("$||P^T 1 -b||_1$")
plt.plot(np.log(np.asarray(Err_q)), linewidth = 2)
plt.show()
Exercice To avoid underflow, replace the matrix/vector multiplication in a log-sum-exp style, and use the log-sum-exp stabilization trick.
Instead of computing transport, we now turn to the problem of computing barycenter of $R$ input measures $(a_k)_{k=1}^R$. A barycenter $b$ solves $$ \umin{b} \sum_{k=1}^R W_\ga(a_k,b) $$ where $\la_k$ are positive weights with $\sum_k \la_k=1$. This follows the definition of barycenters proposed in AguehCarlier.
Dimension (width of the images) $N$ of the histograms.
N = 70
You need to install imageio, for instance using
conda install -c conda-forge imageio
If you need to rescale the image size, you can use
skimage.transform.resize
Load input histograms $(a_k)_{k=1}^R$, store them in a tensor $A$.
import imageio
rescale = lambda x: (x-x.min())/(x.max()-x.min())
names = ['disk','twodisks','letter-x','letter-z']
vmin = .01
A = np.zeros([N,N,len(names)])
for i in range(len(names)):
a = imageio.imread("nt_toolbox/data/" + names[i] + ".bmp") # ,N)
a = normalize(rescale(a)+vmin)
A[:,:,i] = a
R = len(names)
Display the input histograms.
plt.figure(figsize=(5,5))
for i in range(R):
plt.subplot(2,2,i+1)
plt.imshow(A[:,:,i])
plt.axis('off');
In this specific case, the kernel $K$ associated with the squared Euclidean norm is a convolution with a Gaussian filter $$ K_{i,j} = e^{ -\norm{i/N-j/N}^2/\epsilon } $$ where here $(i,j)$ are 2-D indexes.
The multiplication against the kernel, i.e. $K(a)$, can now be computed efficiently, using fast convolution methods. This crucial points was exploited and generalized in SolomonEtAl to design fast optimal transport algorithm.
Regularization strength $\epsilon>0$.
epsilon = (.04)**2
Define the $K$ kernel. We use here the fact that the convolution is separable to implement it using only 1-D convolution, which further speeds up computations.
t = np.linspace(0,1,N)
[Y,X] = np.meshgrid(t,t)
K1 = np.exp(-(X-Y)**2/epsilon)
K = lambda x: np.dot(np.dot(K1,x),K1)
Display the application of the $K$ kernel on one of the input histogram.
plt.figure(figsize=(10,5))
plt.subplot(1,2,1)
plt.imshow(A[:,:,0])
plt.title("$a$")
plt.axis('off');
plt.subplot(1,2,2)
plt.imshow(K(A[:,:,0]))
plt.title("$K(a)$")
plt.axis('off');
Weights $\la_k$ for isobarycenter.
lambd = np.ones(R)/R
It is shown in BenamouEtAl that the problem of Barycenter computation boilds down to optimizing over couplings $(P_k)_{k=1}^R$, and that this can be achieved using iterative a Sinkhorn-like algorithm, since the optimal coupling has the scaling form $$P_k = \diag{u_k} K \diag{v_k}$$ for some unknown positive weights $(u_k,v_k)$.
Initialize the scaling factors $(u_k,v_k)_k$, store them in matrices.
v = np.ones([N,N,R])
u = np.copy(v)
The first step of the Bregman projection method corresponds to the projection on the fixed marginals constraints $P^k \ones = a_k$. This is achieved by updating $$ \forall k=1,\ldots,R, \quad u_k \longleftarrow \frac{a_k}{ K( v_k ) }. $$
for k in range(R):
u[:,:,k] = A[:,:,k]/K(v[:,:,k])
The second step of the Bregman projection method corresponds to the projection on the equal marginals constraints $\forall k, P_k^\top \ones=b$ for a common barycenter target $b$. This is achieved by first computing the target barycenter $b$ using a geometric means $$ \log(b) \eqdef \sum_k \lambda_k \log( u_{k} \odot K ( v_{k} ) ). $$
b = np.zeros(N)
for k in range(R):
b = b + lambd[k] * np.log(np.maximum(1e-19*np.ones(len(v[:,:,k])), v[:,:,k]*K(u[:,:,k])))
b = np.exp(b)
Display $b$.
plt.imshow(b);
plt.axis('off');
And then one can update the scaling by a Sinkhorn step using this newly computed histogram $b$ as follow (note that $K=K^\top$ here): $$ \forall k=1,\ldots,R, \quad v_{k} \longleftarrow \frac{b}{ K(u_{k}) }. $$
for k in range(R):
v[:,:,k] = b/K(u[:,:,k])
Exercise 4
Implement the iterative algorithm to compute the iso-barycenter of the measures. Plot the decay of the error $\sum_k \norm{P_k \ones - a_k} $.
run -i nt_solutions/optimaltransp_5_entropic/exo4
Display the barycenter.
plt.imshow(b)
plt.axis('off');
Exercise 5
Compute barycenters for varying weights $\la$ corresponding to a bilinear interpolation inside a square.
run -i nt_solutions/optimaltransp_5_entropic/exo5