This notebook illustrates the use of the distance $MW_2$ for color transfer, as described in
Delon, Desolneux, A Wasserstein-type distance in the space of Gaussian Mixture Models, 2019.
Authors:
Below is a list of packages required in the notebook:
numpy
matplotlib
(display of images and graphics)scipy.linalg
(algebra functions)scipy.stats
(probability density functions)sklearn.mixture
(energy models)sklearn.cluster
(KMeans algorithm)ot
(Optimal Transport library https://github.com/rflamary/POT)os
(interactions with the operating system)In order to use the Optimal Transport library we must first install it using PIP.
!pip install POT
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/ Collecting POT Downloading POT-0.9.0-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (709 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 709.2/709.2 kB 11.4 MB/s eta 0:00:00 Requirement already satisfied: scipy>=1.0 in /usr/local/lib/python3.9/dist-packages (from POT) (1.10.1) Requirement already satisfied: numpy>=1.16 in /usr/local/lib/python3.9/dist-packages (from POT) (1.22.4) Installing collected packages: POT Successfully installed POT-0.9.0
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image
from mpl_toolkits.mplot3d import Axes3D
import scipy.linalg as spl
import scipy.stats as sps
import os
import sklearn.mixture # for EM
from sklearn.cluster import KMeans # for kmeans
import ot
To import the solutions, execute the following cell. If you are using a Windows system, comment the os.system
line, download the file by hand, and place it in the same folder as the notebook.
#os.system("wget -nc https://raw.githubusercontent.com/storimaging/Notebooks/main/ContrastAndColor/Solutions/GMM_OT_color_transfer.py")
#from GMM_OT_color_transfer import *
We read two color images into numpy arrays. Be careful that jpg images must be cast to double and normalized to [0,1].
os.system("wget -c https://raw.githubusercontent.com/storimaging/Images/main/img/renoir.jpg")
os.system("wget -c https://raw.githubusercontent.com/storimaging/Images/main/img/gauguin.jpg")
#ATTENTION : if images are of type jpg or bmp, they should be normalized to [0,1]
u = plt.imread('renoir.jpg')/255
v = plt.imread('gauguin.jpg')/255
nru,ncu,nch = u.shape
nrv,ncv,nch = v.shape
# image display thanks to the function imshow of the pyplot library of matplotlib
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(20, 20))
#we display the images
axes[0].imshow(u)
axes[0].set_title('First image')
axes[0].axis('off')
axes[1].imshow(v)
axes[1].set_title('Second image')
axes[1].axis('off')
fig.tight_layout()
We display their 3D color scatter plots.
X = u.reshape((nru*ncu,3))
Y = v.reshape((nrv*ncv,3))
nb = 3000
r = np.random.RandomState(42)
idX = r.randint(X.shape[0], size=(nb,))
idY = r.randint(Y.shape[0], size=(nb,))
Xs = X[idX, :]
Ys = Y[idY, :]
fig = plt.figure(2, figsize=(20, 10))
axis = fig.add_subplot(1, 2, 1, projection="3d")
axis.scatter(Xs[:, 0], Xs[:,1],Xs[:, 2], c=Xs,s=100)
axis.set_xlabel("Red"), axis.set_ylabel("Green"), axis.set_zlabel("Blue");
axis = fig.add_subplot(1, 2, 2, projection="3d")
axis.scatter(Ys[:, 0], Ys[:,1],Ys[:, 2], c=Ys,s=100)
axis.set_xlabel("Red"), axis.set_ylabel("Green"), axis.set_zlabel("Blue");
First, we show the result of separable optimal transport (implemented here for two images of the same size).
# BE CAREFUL: this implementation works only if u and v have the same size
uout = np.copy(u)
for k in range(3):
uk = u[:,:,k]
vk = v[:,:,k]
uk_sort,index_u=np.sort(uk,axis=None),np.argsort(uk,axis=None)
vk_sort,index_v=np.sort(vk,axis=None),np.argsort(vk,axis=None)
uspecifv= np.zeros(nru*ncu)
uspecifv[index_u] = vk_sort
uspecifv = uspecifv.reshape(nru,ncu)
uout[:,:,k] = uspecifv
#Display images
plt.figure(figsize=(7, 7))
plt.axis('off')
plt.imshow(uout)
<matplotlib.image.AxesImage at 0x7f00536d1e80>
We use scikit-learn to compute the GMM from the two color 3d point clouds.
X = u.reshape((nru*ncu,3))
Y = v.reshape((nrv*ncv,3))
k = 10 # number of classes
ninit = 1
K0,K1 = k,k
gmmX = sklearn.mixture.GaussianMixture(n_components=K0, covariance_type='full',n_init=ninit).fit(X) # spherical or full
pi0,m0,S0 = gmmX.weights_, gmmX.means_, gmmX.covariances_
ClassesX = gmmX.predict(X)
ProbaClassesX = gmmX.predict_proba(X)
gmmY = sklearn.mixture.GaussianMixture(n_components=K1, covariance_type='full',n_init=ninit).fit(Y) # spherical or full
pi1,m1,S1 = gmmY.weights_, gmmY.means_, gmmY.covariances_
ClassesY = gmmY.predict(Y)
ProbaClassesY = gmmY.predict_proba(Y)
We can display the following classes on the scatter plots.
nb = 3000
r = np.random.RandomState(42)
idX = r.randint(X.shape[0], size=(nb,))
idY = r.randint(Y.shape[0], size=(nb,))
Xs = X[idX, :]
Ys = Y[idY, :]
ClassesXsubsample = ClassesX[idX]
ClassesYsubsample = ClassesY[idY]
fig = plt.figure(2, figsize=(20, 10))
axis = fig.add_subplot(1, 2, 1, projection="3d")
axis.scatter(Xs[:, 0], Xs[:,1],Xs[:, 2], c=m0[ClassesXsubsample,:],s=100)
axis.set_xlabel("Red"), axis.set_ylabel("Green"), axis.set_zlabel("Blue");
axis = fig.add_subplot(1, 2, 2, projection="3d")
axis.scatter(Ys[:, 0], Ys[:,1],Ys[:, 2], c=m1[ClassesYsubsample,:],s=100)
axis.set_xlabel("Red"), axis.set_ylabel("Green"), axis.set_zlabel("Blue");
We can also display the segmentations of $u$ and $v$ provided by the GMM (the color of pixels of the same class is the mean of the class)
# display the corresponding image segmentations
useg = m0[ClassesX]
useg = useg.reshape((nru,ncu,3))
vseg = m1[ClassesY]
vseg = vseg.reshape((nrv,ncv,3))
# display
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(14, 7))
axes[0].imshow(useg)
axes[0].set_title('segmentation of u')
axes[0].axis('off')
axes[1].imshow(vseg)
axes[1].set_title('segmentation of v')
axes[1].axis('off')
fig.tight_layout()
Then, we can use the $MW_2$ transport map between these two GMM for color transfer.
# Compute the K0xK1 OT matrix between the members of the mixtures
wstar,dist = MW2(pi0/np.sum(pi0),pi1/np.sum(pi1),m0,m1,S0,S1)
# Compute all Tkl maps at all points of u
T = np.zeros((K0,K1,3,nru*ncu))
for k in range(K0):
for l in range(K1):
T[k,l,:,:] = GaussianMap(m0[k,:],m1[l,:],S0[k,:,:],S1[l,:,:],X).T
# Compute mean color transfer on all points
Tmeanx = np.zeros((3,nru*ncu))
for k in range(K0):
for l in range(K1):
Tmeanx += wstar[k,l]/pi0[k]*ProbaClassesX[:,k].T*T[k,l,:,:]
# Compute random color transfer on all points (random sample with posterior distribution)
Trandx = np.zeros((3,nru*ncu))
tmp = np.zeros((K0*K1,nru*ncu))
for k in range(K0):
for l in range(K1):
tmp[k+K0*l,:]= wstar[k,l]/pi0[k]*ProbaClassesX[:,k]
for i in range(nru*ncu):
mr = np.random.choice(K0*K1,p=tmp[:,i])
l = mr//K0
k = mr - K0*l
Trandx[:,i] = T[k,l,:,i]
# Compute best color transfer on all points (only best class for each point)
Tmaxx = np.zeros((3,nru*ncu))
normalisation = np.zeros((nru*ncu))
for k in range(K0):
for l in range(K1):
Tmaxx += wstar[k,l]*T[k,l,:,:]*(ClassesX==k).T
normalisation +=wstar[k,l]*(ClassesX==k).T
Tmaxx = Tmaxx/normalisation
# Display result as an image
w=Tmeanx.T.reshape((nru,ncu,3))
wmax=Tmaxx.T.reshape((nru,ncu,3))
wrand=Trandx.T.reshape((nru,ncu,3))
#we display the images
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(20, 20))
axes[0].imshow(w)
axes[0].set_title('w')
axes[0].axis('off')
axes[1].imshow(wrand)
axes[1].set_title('wrand')
axes[1].axis('off')
axes[2].imshow(wmax)
axes[2].set_title('wmax')
axes[2].axis('off')
fig.tight_layout()
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
A common drawback of classical methods aiming at color and contrast modifications is the revealing of artefacts (JPEG blocs, color inconsistancies, noise enhancement) or the attenuation of details and textures (see for instance the following web page). Let $u$ be an image and $g(u)$ the same image after color or contrast modification, we write $\mathcal{M}(u) = g(u) - u$. All artefacts observable in $g(u)$ can be seen as irregularities in these difference map $\mathcal{M}(u)$. In order to reduce these artefacts, we propose to filter this difference map thanks to an operator $Y_u$ and to reconstruct the image:
$$T(g(u)) = u + Y_u(g(u)-u).$$We will use for $Y_u$ the guided filter described in the paper
Guided Image Filtering, Kaiming He1, Jian Sun2, and Xiaoou Tang, ECCV 2010.
diff = w-u
out = np.zeros_like(u)
for i in range(3):
out[:,:,i] = guided_filter(diff[:,:,i],u[:,:,i], 10,1e-4 )
#we display the result
plt.figure(figsize=(15, 15))
plt.axis('off')
plt.imshow(out+u)
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
<matplotlib.image.AxesImage at 0x7f004f702460>