66import numpy as np
77from .bregman import sinkhorn
88from .lp import emd
9- from .utils import unif ,dist
9+ from .utils import unif ,dist , kernel
1010from .optim import cg
1111
1212
1313def indices (a , func ):
1414 return [i for (i , val ) in enumerate (a ) if func (val )]
1515
16+
17+
1618def sinkhorn_lpl1_mm (a ,labels_a , b , M , reg , eta = 0.1 ,numItermax = 10 ,numInnerItermax = 200 ,stopInnerThr = 1e-9 ,verbose = False ,log = False ):
1719 """
1820 Solve the entropic regularization optimal transport problem with nonconvex group lasso regularization
@@ -129,34 +131,38 @@ def joint_OT_mapping_linear(xs,xt,mu=1,eta=0.001,bias=False,verbose=False,verbos
129131
130132 if bias :
131133 xs1 = np .hstack ((xs ,np .ones ((ns ,1 ))))
132- I = eta * np .eye (d + 1 )
134+ xstxs = xs1 .T .dot (xs1 )
135+ I = np .eye (d + 1 )
133136 I [- 1 ]= 0
134137 I0 = I [:,:- 1 ]
135138 sel = lambda x : x [:- 1 ,:]
136139 else :
137140 xs1 = xs
138- I = eta * np .eye (d )
141+ xstxs = xs1 .T .dot (xs1 )
142+ I = np .eye (d )
139143 I0 = I
140144 sel = lambda x : x
141145
142146 if log :
143147 log = {'err' :[]}
144148
145149 a ,b = unif (ns ),unif (nt )
146- M = dist (xs ,xt )
150+ M = dist (xs ,xt )* ns
147151 G = emd (a ,b ,M )
148152
149153 vloss = []
150154
151155 def loss (L ,G ):
156+ """Compute full loss"""
152157 return np .sum ((xs1 .dot (L )- ns * G .dot (xt ))** 2 )+ mu * np .sum (G * M )+ eta * np .sum (sel (L - I0 )** 2 )
153158
154159 def solve_L (G ):
155- """ solve problem with fixed G"""
160+ """ solve L problem with fixed G (least square) """
156161 xst = ns * G .dot (xt )
157- return np .linalg .solve (xs1 . T . dot ( xs1 ) + I ,xs1 .T .dot (xst )+ I0 )
162+ return np .linalg .solve (xstxs + eta * I ,xs1 .T .dot (xst )+ eta * I0 )
158163
159164 def solve_G (L ,G0 ):
165+ """Update G with CG algorithm"""
160166 xsi = xs1 .dot (L )
161167 def f (G ):
162168 return np .sum ((xsi - ns * G .dot (xt ))** 2 )
@@ -175,8 +181,11 @@ def df(G):
175181 print ('{:5d}|{:8e}|{:8e}' .format (0 ,vloss [- 1 ],0 ))
176182
177183
178- # regul matrix
179- loop = 1
184+ # init loop
185+ if numItermax > 0 :
186+ loop = 1
187+ else :
188+ loop = 0
180189 it = 0
181190
182191 while loop :
@@ -191,18 +200,116 @@ def df(G):
191200
192201 vloss .append (loss (L ,G ))
193202
203+ if it >= numItermax :
204+ loop = 0
205+
194206 if abs (vloss [- 1 ]- vloss [- 2 ])< stopThr :
195207 loop = 0
196208
197209 if verbose :
198210 if it % 20 == 0 :
199211 print ('{:5s}|{:12s}|{:8s}' .format ('It.' ,'Loss' ,'Delta loss' )+ '\n ' + '-' * 32 )
200212 print ('{:5d}|{:8e}|{:8e}' .format (it ,vloss [- 1 ],abs (vloss [- 1 ]- vloss [- 2 ])/ abs (vloss [- 2 ])))
213+ if log :
214+ log ['loss' ]= vloss
215+ return G ,L ,log
216+ else :
217+ return G ,L
201218
202- return G ,L
203219
220+ def joint_OT_mapping_kernel (xs ,xt ,mu = 1 ,eta = 0.001 ,kernel = 'gaussian' ,sigma = 1 ,bias = False ,verbose = False ,verbose2 = False ,numItermax = 100 ,numInnerItermax = 20 ,stopInnerThr = 1e-9 ,stopThr = 1e-6 ,log = False ,** kwargs ):
221+ """Joint Ot and mapping estimation (uniform weights and )
222+ """
204223
224+ ns ,nt ,d = xs .shape [0 ],xt .shape [0 ],xt .shape [1 ]
205225
226+ if bias :
227+ K =
228+ xs1 = np .hstack ((xs ,np .ones ((ns ,1 ))))
229+ xstxs = xs1 .T .dot (xs1 )
230+ I = np .eye (d + 1 )
231+ I [- 1 ]= 0
232+ I0 = I [:,:- 1 ]
233+ sel = lambda x : x [:- 1 ,:]
234+ else :
235+ xs1 = xs
236+ xstxs = xs1 .T .dot (xs1 )
237+ I = np .eye (d )
238+ I0 = I
239+ sel = lambda x : x
240+
241+ if log :
242+ log = {'err' :[]}
243+
244+ a ,b = unif (ns ),unif (nt )
245+ M = dist (xs ,xt )* ns
246+ G = emd (a ,b ,M )
247+
248+ vloss = []
249+
250+ def loss (L ,G ):
251+ """Compute full loss"""
252+ return np .sum ((xs1 .dot (L )- ns * G .dot (xt ))** 2 )+ mu * np .sum (G * M )+ eta * np .sum (sel (L - I0 )** 2 )
253+
254+ def solve_L (G ):
255+ """ solve L problem with fixed G (least square)"""
256+ xst = ns * G .dot (xt )
257+ return np .linalg .solve (xstxs + eta * I ,xs1 .T .dot (xst )+ eta * I0 )
258+
259+ def solve_G (L ,G0 ):
260+ """Update G with CG algorithm"""
261+ xsi = xs1 .dot (L )
262+ def f (G ):
263+ return np .sum ((xsi - ns * G .dot (xt ))** 2 )
264+ def df (G ):
265+ return - 2 * ns * (xsi - ns * G .dot (xt )).dot (xt .T )
266+ G = cg (a ,b ,M ,1.0 / mu ,f ,df ,G0 = G0 ,numItermax = numInnerItermax ,stopThr = stopInnerThr )
267+ return G
268+
269+
270+ L = solve_L (G )
271+
272+ vloss .append (loss (L ,G ))
273+
274+ if verbose :
275+ print ('{:5s}|{:12s}|{:8s}' .format ('It.' ,'Loss' ,'Delta loss' )+ '\n ' + '-' * 32 )
276+ print ('{:5d}|{:8e}|{:8e}' .format (0 ,vloss [- 1 ],0 ))
277+
278+
279+ # init loop
280+ if numItermax > 0 :
281+ loop = 1
282+ else :
283+ loop = 0
284+ it = 0
285+
286+ while loop :
287+
288+ it += 1
289+
290+ # update G
291+ G = solve_G (L ,G )
292+
293+ #update L
294+ L = solve_L (G )
295+
296+ vloss .append (loss (L ,G ))
297+
298+ if it >= numItermax :
299+ loop = 0
300+
301+ if abs (vloss [- 1 ]- vloss [- 2 ])< stopThr :
302+ loop = 0
303+
304+ if verbose :
305+ if it % 20 == 0 :
306+ print ('{:5s}|{:12s}|{:8s}' .format ('It.' ,'Loss' ,'Delta loss' )+ '\n ' + '-' * 32 )
307+ print ('{:5d}|{:8e}|{:8e}' .format (it ,vloss [- 1 ],abs (vloss [- 1 ]- vloss [- 2 ])/ abs (vloss [- 2 ])))
308+ if log :
309+ log ['loss' ]= vloss
310+ return G ,L ,log
311+ else :
312+ return G ,L
206313
207314
208315class OTDA (object ):
@@ -294,6 +401,7 @@ def predict(self,x,direction=1):
294401
295402class OTDA_sinkhorn (OTDA ):
296403 """Class for domain adaptation with optimal transport with entropic regularization"""
404+
297405 def fit (self ,xs ,xt ,reg = 1 ,ws = None ,wt = None ,** kwargs ):
298406 """ Fit domain adaptation between samples is xs and xt (with optional
299407 weights)"""
@@ -335,3 +443,51 @@ def fit(self,xs,ys,xt,reg=1,eta=1,ws=None,wt=None,**kwargs):
335443 self .G = sinkhorn_lpl1_mm (ws ,ys ,wt ,self .M ,reg ,eta ,** kwargs )
336444 self .computed = True
337445
446+ class OTDA_mapping (OTDA ):
447+ """Class for optimal transport with joint linear mapping estimation"""
448+
449+
450+ def __init__ (self ,metric = 'sqeuclidean' ):
451+ """ Class initialization"""
452+
453+
454+ self .xs = 0
455+ self .xt = 0
456+ self .G = 0
457+ self .L = 0
458+ self .bias = False
459+ self .metric = metric
460+ self .computed = False
461+
462+ def fit (self ,xs ,xt ,mu = 1 ,eta = 1 ,bias = False ,** kwargs ):
463+ """ Fit domain adaptation between samples is xs and xt (with optional
464+ weights)"""
465+ self .xs = xs
466+ self .xt = xt
467+ self .bias = bias
468+
469+ self .ws = unif (xs .shape [0 ])
470+ self .wt = unif (xt .shape [0 ])
471+
472+ self .G ,self .L = joint_OT_mapping_linear (xs ,xt ,mu = mu ,eta = eta ,bias = bias ,** kwargs )
473+ self .computed = True
474+
475+ def mapping (self ):
476+ return lambda x : self .predict (x )
477+
478+
479+ def predict (self ,x ):
480+ """ Out of sample mapping using the formulation from Ferradans
481+
482+ It basically find the source sample the nearset to the nex sample and
483+ apply the difference to the displaced source sample.
484+
485+ """
486+ if self .computed :
487+ if self .bias :
488+ x = np .hstack ((x ,np .ones ((x .shape [0 ],1 ))))
489+ return x .dot (self .L ) # aply the delta to the interpolation
490+ else :
491+ print ("Warning, model not fitted yet, returning None" )
492+ return None
493+
0 commit comments