1+ # -*- coding: utf-8 -*-
2+ """
3+ demo of Optimal transport for domain adaptation with image color adaptation as in [6]
4+
5+ [6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014). Regularized discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882.
6+ """
7+
8+ import numpy as np
9+ import scipy .ndimage as spi
10+ import matplotlib .pylab as pl
11+ import ot
12+
13+
14+ #%% Loading images
15+
16+ I1 = spi .imread ('../data/ocean_day.jpg' ).astype (np .float64 )/ 256
17+ I2 = spi .imread ('../data/ocean_sunset.jpg' ).astype (np .float64 )/ 256
18+
19+ #%% Plot images
20+
21+ pl .figure (1 )
22+
23+ pl .subplot (1 ,2 ,1 )
24+ pl .imshow (I1 )
25+ pl .title ('Image 1' )
26+
27+ pl .subplot (1 ,2 ,2 )
28+ pl .imshow (I2 )
29+ pl .title ('Image 2' )
30+
31+ pl .show ()
32+
33+ #%% Image conversion and dataset generation
34+
35+ def im2mat (I ):
36+ """Converts and image to matrix (one pixel per line)"""
37+ return I .reshape ((I .shape [0 ]* I .shape [1 ],I .shape [2 ]))
38+
39+ def mat2im (X ,shape ):
40+ """Converts back a matrix to an image"""
41+ return X .reshape (shape )
42+
43+ X1 = im2mat (I1 )
44+ X2 = im2mat (I2 )
45+
46+ # training samples
47+ nb = 1000
48+ idx1 = np .random .randint (X1 .shape [0 ],size = (nb ,))
49+ idx2 = np .random .randint (X2 .shape [0 ],size = (nb ,))
50+
51+ xs = X1 [idx1 ,:]
52+ xt = X2 [idx2 ,:]
53+
54+ #%% domain adaptation between images
55+
56+ # LP problem
57+ da_emd = ot .da .OTDA () # init class
58+ da_emd .fit (xs ,xt ) # fit distributions
59+
60+
61+ # sinkhorn regularization
62+ lambd = 1e-1
63+ da_entrop = ot .da .OTDA_sinkhorn ()
64+ da_entrop .fit (xs ,xt ,reg = lambd )
65+
66+
67+
68+ #%% prediction between images (using out of sample prediction as in [6])
69+
70+ X1t = da_emd .predict (X1 )
71+ X2t = da_emd .predict (X2 ,- 1 )
72+
73+
74+ X1te = da_entrop .predict (X1 )
75+ X2te = da_entrop .predict (X2 ,- 1 )
76+
77+
78+ def minmax (I ):
79+ return np .minimum (np .maximum (I ,0 ),1 )
80+
81+ I1t = minmax (mat2im (X1t ,I1 .shape ))
82+ I2t = minmax (mat2im (X2t ,I2 .shape ))
83+
84+ I1te = minmax (mat2im (X1te ,I1 .shape ))
85+ I2te = minmax (mat2im (X2te ,I2 .shape ))
86+
87+ #%% plot all images
88+
89+ pl .figure (2 ,(10 ,8 ))
90+
91+ pl .subplot (2 ,3 ,1 )
92+
93+ pl .imshow (I1 )
94+ pl .title ('Image 1' )
95+
96+ pl .subplot (2 ,3 ,2 )
97+ pl .imshow (I1t )
98+ pl .title ('Image 1 Adapt' )
99+
100+
101+ pl .subplot (2 ,3 ,3 )
102+ pl .imshow (I1te )
103+ pl .title ('Image 1 Adapt (reg)' )
104+
105+ pl .subplot (2 ,3 ,4 )
106+
107+ pl .imshow (I2 )
108+ pl .title ('Image 2' )
109+
110+ pl .subplot (2 ,3 ,5 )
111+ pl .imshow (I2t )
112+ pl .title ('Image 2 Adapt' )
113+
114+
115+ pl .subplot (2 ,3 ,6 )
116+ pl .imshow (I2te )
117+ pl .title ('Image 2 Adapt (reg)' )
118+
119+ pl .show ()
0 commit comments