The goal of this practical session is to use sliced optimal transport in order to transfer the color distribution from one image to another. At the end of the notebook, regularization methods are also studied in order to reduce the artifacts that may be generated by the color transfer.
For some parts of the session (cells with commands written as todo_something
...), you are supposed to code by yourself.
Authors:
Below is a list of packages required in the notebook:
numpy
matplotlib
(display of images and graphics)os
(interactions with the operating system)Axes3D
(display of 3D graphics)import numpy as np
import os
import matplotlib.pyplot as plt
import matplotlib
from mpl_toolkits.mplot3d import Axes3D
%matplotlib inline
To import the solutions, execute the following cell. If you are using a Windows system, comment the os.system
line, download the file by hand, and place it in the same folder as the notebook.
#os.system("wget -nc https://raw.githubusercontent.com/storimaging/Notebooks/main/ContrastAndColor/Solutions/TP_color_transfer.py")
#from TP_color_transfer import *
A color image is made of three channels: red, green and blue. A color image in $\mathbb{R}^{N\times M}$ is stored as a $N\times M\times 3$ matrix.
Be careful with the functions plt.imread()
and plt.imshow()
of matplotlib
.
plt.imread()
reads png images as numpy arrays of floating points between 0 and 1, but it reads jpg or bmp images as numpy arrays of 8 bit integers.
In this practical session, we assume images are encoded as floating point values between 0 and 1, so if you load a jpg or bmp file you must convert the image to float type and normalize its values to $[0,1]$.
If 'im' is an image encoded as a float numpy array, plt.imshow(im)
will do a linear scaling, mapping the lowest value to 0 (black) and the highest to 1 (white). If the 'im' image is 8-bit encoded, plt.imshow(im)
will display 0 in black and 255 in white.
os.system("wget -c https://raw.githubusercontent.com/storimaging/Images/main/img/renoir.jpg")
os.system("wget -c https://raw.githubusercontent.com/storimaging/Images/main/img/gauguin.jpg")
imrgb1 = plt.imread("renoir.jpg")/255
imrgb2 = plt.imread("gauguin.jpg")/255
# useful if the image is a png with a transparency channel
imrgb1=imrgb1[:,:,0:3]
imrgb2=imrgb2[:,:,0:3]
#we display the images
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(15, 7))
axes[0].imshow(imrgb1)
axes[0].set_title('First image')
axes[0].axis('off')
axes[1].imshow(imrgb2)
axes[1].set_title('Second image')
axes[1].axis('off')
fig.tight_layout()
In the following cell, we write a function affine_transfer
which apply an affine transform to an image $u$ such that it has the same mean and the same standard deviation as $v$ on each channel.
def affine_transfer(u,v):
w = np.zeros(u.shape)
for i in range(0,3):
w[:,:,i] = (u[:,:,i] -np.mean(u[:,:,i]))/np.std(u[:,:,i])*np.std(v[:,:,i])+ np.mean(v[:,:,i])
return w
Display and comment the result.
w = affine_transfer(imrgb1,imrgb2)
w = (w>1)+(w<=1)*(w>0)*w # w should be in [0,1]
plt.figure(figsize=(7, 7))
plt.title('result of separable affine color transfer')
plt.axis('off')
plt.imshow(w)
<matplotlib.image.AxesImage at 0x7f1944beb670>
Another solution consists in applying histogram specifications separately on each channel of the images $u$ and $v$.
todo_specification_separate_channels(imrgb1,imrgb2)
which take two images $u$ and $v$ as input, and apply color specification separately on each channel (see the practical session on radiometry).w = todo_specification_separate_channels(imrgb1,imrgb2)
w = (w>1)+(w<=1)*(w>0)*w # w should be in [0,1]
Display and comment the result.
plt.figure(figsize=(7, 7))
plt.title('result of color separable transfer')
plt.axis('off')
plt.imshow(w)
<matplotlib.image.AxesImage at 0x7f1944bbe070>
In order to transport the full color distribution of $u$ on the one of $v$, we propose to write these distributions as 3D point clouds
$$X = \begin{pmatrix} X_1 \\ \vdots\\ X_n \end{pmatrix}, \;\;\text{ and }\;\; Y = \begin{pmatrix} Y_1 \\ \vdots\\ Y_n \end{pmatrix}, $$with $n$ the number of pixels in $u$ and $v$, $X_i$ the color $(R_i,G_i,B_i)$ of the $i$-th pixel in $u$ and $Y_i$ the color of the $i$-th pixel in $v$ ($X$ and $Y$ are $n\times 3$ matrices). We look for the permutation $\sigma$ of $\{1,\dots, n\}$ which minimizes the quantity
$$ \sum_{i=1}^n \|X_i - Y_{\sigma(i)}\|^2_2. $$The assignment $\sigma$ defines a mapping between the set of colors $\{X_i\}_{i=1,\dots n}$ and the set of colors $\{Y_j\}_{j=1,\dots n}$.
In one dimension, the assignment $\sigma$ which minimizes the previous cost is the one which preserves the ordering of the points. More precisely, if $\sigma_X$ and $\sigma_Y$ are the permutations such that $$X_{\sigma_X(1)} \leq X_{\sigma_X(2)} \leq \dots \leq X_{\sigma_X(n)},$$ $$Y_{\sigma_Y(1)} \leq Y_{\sigma_Y(2)} \leq \dots \leq Y_{\sigma_Y(n)},$$ then $\sigma = \sigma_Y \circ \sigma_X^{-1}$ minimizes the previous cost.
The following function transport1D
takes two 1D point clouds $X$ and $Y$ and return the permutations $\sigma_X$ and $\sigma_Y$.
def transport1D(X,Y):
sx = np.argsort(X) #argsort retourne les indices des valeurs s'ils étaient ordonnés par ordre croissant
sy = np.argsort(Y)
return((sx,sy))
We can visualize this 1D transport on random point sets. Let's first create two point sets in 1D
n = 200
X = 2 * np.random.rand(n)
Y = np.concatenate((
np.random.rand(n//4)-2,
.25*np.random.rand(3*n//4)+3))
plt.figure(figsize=(12, 5))
plt.plot(X,np.zeros(n),'.b')
plt.plot(Y,np.zeros(n),'+r')
plt.show()
# evolution of Z
sx, sy = transport1D(X,Y)
fig, axes = plt.subplots(nrows=4, ncols=4, figsize=(15, 4))
Z = np.zeros(n)
k = 0
for i in range(4):
for j in range(4):
Z[sx] = X[sx] + (k/15) * (Y[sy]-X[sx])
axes[i,j].plot(Z,np.zeros(n),'.')
axes[i,j].axis([-3,4,-1,1])
k += 1
fig.tight_layout()
# Another visualization of Z from X to Y:
plt.figure(figsize=(12, 5))
plt.plot(X,np.zeros(n),'.b')
plt.plot(Y,np.ones(n),'.r')
n_steps = 12
for k in range(1,n_steps):
Z[sx] = X[sx] + (k/n_steps) * (Y[sy]-X[sx])
plt.plot(Z,(k/n_steps)*np.ones(n),'.g')
fig.tight_layout()
Minimizing the quantity $$ \sum_{i=1}^n \|X_i - Y_{\sigma(i)}\|^2_2. $$ in more than one dimension is computationnaly demanding. Instead of computing the explicit solution, we make use of a stochastic gradient descent algorithm on a slighlty modified problem, called sliced optimal transport (see the following paper).
The sliced distance between the point clouds $X$ and $Y$ can be written as $$SW_2^2 ( X , Y) = \frac 1 {|S^2|} \int_{u\in S^{2}} \sum_{i=1}^n \left|<u,X_{\sigma_X^u(i)}> - <u,Y_{\sigma_Y^u(i)}> \right|^2 du,$$ where the integration is done on the sphere $S^2$, $<u,X_i>$ is the scalar product between $X_i$ and the unit vector $u$ of $\mathbb{R}^3$, and where $\sigma_X^u$ and $\sigma_Y^u$ are the permutations ordering the one-dimensionnal projected point clouds $\{<X_j,u_i>\}_{j=1\dots n}$ and $\{<Y_j,u_i>\}_{j=1\dots n}$.
$Y$ being fixed, the idea is to apply a gradient descent to $X \mapsto SW_2^2 ( X , Y)$, in order to compute an assignment between $X$ and $Y$. We make use of a stochastic gradient descent algorithm, where the integral is replaced at each step by a single direction $u$ drawn uniformly on the sphere.
The algorithm starts from the point cloud $Z=X$. At each iteration, it picks a random unit vector $u$ of $\mathbb{R}^3$ and it computes the gradient step $$ Z\circ{\sigma_Z} = Z\circ{\sigma_Z^u} + \varepsilon ( <Y,u>\circ{\sigma_Y^u} - <Z,u>\circ{\sigma_Z^u} ) u.$$
The permutations $\sigma_Z^u$ and $\sigma_Y^u$ at each iteration can be computed thanks to the function transport1D
defined above.
In practice, instead of drawing a single unit vector $u$ at each step, it is often more efficient to draw directly an
orthonormal basis of $\mathbb{R}^3$, denoted by $(u_1,u_2,u_3)$, and to apply the previous step for each $i=1,\dots 3$.
Write a function todo_transport3D(X,Y,N,e) implementing the previous iterations. The function should take as input the point clouds $X$ and $Y$, a number of iterations $N$ and a step $\epsilon$. The function should output a point cloud $Z$ and the permutations $\sigma_Z$ and $\sigma_X$.
Hints:
to initiate the point cloud $Z$ at $X$, you can use
Z = np.copy(X)
to pick an orthonormal basis, you can first draw a $\mathcal{N}(0,I)$ random vector in 3D and apply the QR algorithm to it,
u=np.random.randn(3,3)
q=np.linalg.qr(u)[0]
the scalar product between two 3D vectors $x$ and $y$ can be computed thanks to
np.dot(x,y)
Implement and test the previous code with random 3D point clouds.
X = np.vstack((
np.random.rand(50,3),
.25 * np.random.rand(150,3) + np.array([[0,2,2]])
))
Y = np.vstack((
np.random.rand(100,3) + np.array([[2,2,2]]),
.5 * np.random.rand(100,3) + np.array([[2,0,0]])
))
fig = plt.figure(figsize=(10, 10))
axis = fig.add_subplot(1, 1, 1, projection="3d")
axis.scatter(X[:,0],X[:,1],X[:,2],c='blue',s=80)
axis.scatter(Y[:,0],Y[:,1],Y[:,2],c='red',s=80)
plt.show()
Visualize the displacement of the point cloud X toward Y.
# correction
Z, sZ, sY = todo_transport3D(X,Y,1000,0.1)
fig = plt.figure(figsize=(16, 8))
Z = np.zeros(X.shape)
k = 0
for i in range(4):
for j in range(2):
Z[sZ,:] = X[sZ,:] + (k/7) * (Y[sY,:]-X[sZ,:])
ax = fig.add_subplot(2, 4, k+1, projection="3d")
ax.scatter(X[:,0],X[:,1],X[:,2],c='blue',s=80)
ax.scatter(Y[:,0],Y[:,1],Y[:,2],c='red',s=80)
ax.scatter(Z[:,0],Z[:,1],Z[:,2],c='black',s=30)
k += 1
fig.tight_layout()
First, we subsample both images to reduce the number of points and we create the color point clouds from the subsampled color images.
The color transfer can take a very long time if we use images with full resolution! In this practical session, we illustrate the color transfer on full images, but for experimentation, it is adviced to subsample both images by a factor $s$ large enough.
u = imrgb1
v = imrgb2
s = 1
# choose s=10 or larger in your session!!!
if s==1:
usubsample = u
vsubsample = v
else:
usubsample = u[1::s,1::s,0:3]
vsubsample = v[1::s,1::s,0:3]
X = usubsample.reshape((usubsample.shape[0]*usubsample.shape[1],3))
Y = vsubsample.reshape((vsubsample.shape[0]*vsubsample.shape[1],3))
We display (a subsample of) the corresponding point clouds.
nb = 3000
r = np.random.RandomState(42)
idX = r.randint(X.shape[0], size=(nb,))
idY = r.randint(Y.shape[0], size=(nb,))
Xs = X[idX, :]
Ys = Y[idY, :]
fig = plt.figure(2, figsize=(20, 10))
axis = fig.add_subplot(1, 2, 1, projection="3d")
axis.scatter(Xs[:, 0], Xs[:,1],Xs[:, 2], c=Xs,s=100)
axis.set_xlabel("Red"), axis.set_ylabel("Green"), axis.set_zlabel("Blue")
axis = fig.add_subplot(1, 2, 2, projection="3d")
axis.scatter(Ys[:, 0], Ys[:,1],Ys[:, 2], c=Ys,s=100)
axis.set_xlabel("Red"), axis.set_ylabel("Green"), axis.set_zlabel("Blue")
(Text(0.5, 0, 'Red'), Text(0.5, 0.5, 'Green'), Text(0.5, 0, 'Blue'))
We apply the previous sliced optimal transport algorithm to compute the optimal assignment between the point clouds $X$ and $Y$. For each $i$, the 3D point $X[i,:]$ is displaced at the position $X_{new}[i,:]$. Since $X$ is merely a reshaping of the color values of $u$, we obtain the new image $u$ after color transfer by reshaping the array $X_{new}$.
Don't forget to subsample your image (use s = 10 in the cells above) before computing color transfer !!! On full images, the following code may run for a few dozens of minutes !X_new, order_X_new, s_Y= todo_transport3D(X,Y,1000,0.2)
wsliced = X_new.reshape(usubsample.shape[0],usubsample.shape[1],3)
We can now display the result.
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(15, 7))
axes[0].imshow(usubsample)
axes[0].set_title('u')
axes[0].axis('off')
axes[1].imshow(wsliced)
axes[1].set_title('wsliced')
axes[1].axis('off')
axes[2].imshow(vsubsample)
axes[2].set_title('v')
axes[2].axis('off')
fig.tight_layout()
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
We display the corresponding color point clouds.
idZ = r.randint(X_new.shape[0], size=(nb,))
Zs = X_new[idZ, :]
Zs = (Zs>1)+(Zs<1)*(Zs>0)*Zs # to ensure that colors are in [0,1]^3
fig = plt.figure(2, figsize=(20, 10))
axis = fig.add_subplot(1, 2, 1, projection="3d")
axis.scatter(Ys[:, 0], Ys[:,1],Ys[:, 2], c=Ys,s=100)
axis.set_xlabel("Red")
axis.set_ylabel("Green")
axis.set_zlabel("Blue")
axis.set_title('Color distribution of the second image')
axis = fig.add_subplot(1, 2, 2, projection="3d")
axis.scatter(Zs[:, 0], Zs[:,1],Zs[:, 2], c=Zs,s=100)
axis.set_xlabel("Red")
axis.set_ylabel("Green")
axis.set_zlabel("Blue")
axis.set_title('Color distribution of the first image after color transfer')
Text(0.5, 0.92, 'Color distribution of the first image after color transfer')
Before starting this part, if you have worked with subsampled images in the previous part, you can load the results of the color transfer on full images.
usubsample = plt.imread('renoir.jpg')/255
#wsliced = plt.imread('renoir_by_gauguin.png')
wsliced = wsliced[:,:,0:3] # to remove transparency channel
A common drawback of classical methods aiming at color and contrast modifications is the revealing of artefacts (JPEG blocs, color inconsistancies, noise enhancement) or the attenuation of details and textures (see for instance the following web page). Let $u$ be an image and $g(u)$ the same image after color or contrast modification, we write $\mathcal{M}(u) = g(u) - u$. All artefacts observable in $g(u)$ can be seen as irregularities in these difference map $\mathcal{M}(u)$. In order to reduce these artefacts, we propose to filter this difference map thanks to an operator $Y_u$ and to reconstruct the image:
$$T(g(u)) = u + Y_u(g(u)-u).$$Let us begin with a very simple averaging filter for $Y_u$.
A simple averaging filter can be implemented with the function convolve2d
of scipy.signal
for instance. However, it is much more efficient to compute this average filter thanks to integral images, as follows.
def average_filter(u,r):
# uniform filter with a square (2*r+1)x(2*r+1) window
# u is a 2d image
# r is the radius for the filter
(nrow, ncol) = u.shape
big_uint = np.zeros((nrow+2*r+1,ncol+2*r+1))
big_uint[r+1:nrow+r+1,r+1:ncol+r+1] = u
big_uint = np.cumsum(np.cumsum(big_uint,0),1) # integral image
out = big_uint[2*r+1:nrow+2*r+1,2*r+1:ncol+2*r+1] + big_uint[0:nrow,0:ncol] - big_uint[0:nrow,2*r+1:ncol+2*r+1] - big_uint[2*r+1:nrow+2*r+1,0:ncol]
out = out/(2*r+1)**2
return out
diff = wsliced-usubsample
out = np.zeros_like(usubsample)
r=10
for i in range(3):
out[:,:,i] = average_filter(diff[:,:,i], r)
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(20, 10))
axes[0].imshow(wsliced)
axes[0].set_title('before regularization')
axes[0].axis('off')
axes[1].set_title('after regularization')
axes[1].axis('off')
axes[1].imshow(usubsample+out)
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
<matplotlib.image.AxesImage at 0x7f1943a18a30>
Let's zoom on the result.
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(20, 10))
axes[0].imshow(wsliced[400:700,600:900,:])
axes[0].set_title('before regularization')
axes[0].axis('off')
axes[1].set_title('after regularization')
axes[1].imshow(usubsample[400:700,600:900,:]+out[400:700,600:900,:])
axes[1].axis('off')
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
(-0.5, 299.5, 299.5, -0.5)
The result is not bad but some blur appears and is due to the fact that the edges in the image out
do not coincide anymore with the edges in the original image.
A more interesting result can be obtained by guiding the filtering of the difference map $\mathcal{M}(u)$ in a way such that it will follow the regularity of $u$. To this aim, we define the operator $Y_u$ as the guided filter described in the paper
Guided Image Filtering, Kaiming He1, Jian Sun2, and Xiaoou Tang, ECCV 2010.
The idea is the following: let $u$ be an image that we want to regularize, and $guide$ be a guidance image of the same size. We want to construct an output image $q$ which looks like $u$ but follows the regularity of $guide$. To this aim, we will try to ensure that the gradient of $q$ is (almost) proportional to the gradient of $guide$. The algorithm is as follows (see page 4 of the aforementioned paper):
where $\mathrm{mean}(guide)_k$ is the average value of the image $guide$ on the square $\omega_k$ and $\sigma(guide)_k^2$ is the empirical variance of $guide$ on $\omega_k$, i.e
$$\sigma(guide)_k^2 = \frac{1}{|\omega_k|}\left(\sum_{i\in \omega_k} guide(i)^2 - \mathrm{mean}(guide)_k\right). $$In order to avoid divisions by $0$, $\sigma(guide)_k^2$ is replaced in practice by $\sigma(guide)_k^2+\epsilon$, where $\epsilon$ is a small regularization parameter.
Write a function todo_guided_filter
which computes the filter described above. The function should take as input two gray level images $u$ and $guide$, an integer $r$ and a regularization parameter $\epsilon$, and should ouput a gray level image. Everything can be implemented without for
loops, with 2D convolutions (function convolve2d
of scipy.signal
) or much more efficiently with the $O(1)$ average_filter
function using integral images that we defined above.
Apply the previous filter to $\mathcal{M}(u)$ and display the result $T(g(u))$. You can use $\epsilon = 10^{-4}$ and $r=20$ for instance.
diff = wsliced-usubsample
out = np.zeros_like(usubsample)
for i in range(3):
out[:,:,i] = todo_guided_filter(diff[:,:,i], usubsample[:,:,i], 20,1e-4 )
We display the result.
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(20, 10))
axes[0].imshow(wsliced)
axes[0].set_title('before regularization')
axes[0].axis('off')
axes[1].set_title('after regularization')
axes[1].imshow(out+usubsample)
axes[1].axis('off')
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
(-0.5, 1023.5, 767.5, -0.5)
Zoom on the result.
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(20, 10))
axes[0].imshow(wsliced[400:700,600:900,:])
axes[0].set_title('before regularization')
axes[0].axis('off')
axes[1].set_title('after regularization')
axes[1].imshow(usubsample[400:700,600:900,:]+out[400:700,600:900,:])
axes[1].axis('off')
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
(-0.5, 299.5, 299.5, -0.5)
Knowing the optimal (or an approximation) assignment between the point clouds $X$ and $Y$, you can compute barycenters between these discrete distributions. Hence you can compute the midway color distributions between them.