@@ -123,7 +123,7 @@ def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1,numItermax = 10,numInnerIter
123123
124124 return transp
125125
126- def joint_OT_mapping_linear (xs ,xt ,mu = 1 ,eta = 0.001 ,bias = False ,verbose = False ,verbose2 = False ,numItermax = 100 ,numInnerItermax = 20 ,stopInnerThr = 1e-9 ,stopThr = 1e-6 ,log = False ,** kwargs ):
126+ def joint_OT_mapping_linear (xs ,xt ,mu = 1 ,eta = 0.001 ,bias = False ,verbose = False ,verbose2 = False ,numItermax = 100 ,numInnerItermax = 10 ,stopInnerThr = 1e-6 ,stopThr = 1e-5 ,log = False ,** kwargs ):
127127 """Joint Ot and mapping estimation (uniform weights and )
128128 """
129129
@@ -209,15 +209,15 @@ def df(G):
209209 if verbose :
210210 if it % 20 == 0 :
211211 print ('{:5s}|{:12s}|{:8s}' .format ('It.' ,'Loss' ,'Delta loss' )+ '\n ' + '-' * 32 )
212- print ('{:5d}|{:8e}|{:8e}' .format (it ,vloss [- 1 ],abs (vloss [- 1 ]- vloss [- 2 ])/ abs (vloss [- 2 ])))
212+ print ('{:5d}|{:8e}|{:8e}' .format (it ,vloss [- 1 ],(vloss [- 1 ]- vloss [- 2 ])/ abs (vloss [- 2 ])))
213213 if log :
214214 log ['loss' ]= vloss
215215 return G ,L ,log
216216 else :
217217 return G ,L
218218
219219
220- def joint_OT_mapping_kernel (xs ,xt ,mu = 1 ,eta = 0.001 ,kerneltype = 'gaussian' ,sigma = 1 ,bias = False ,verbose = False ,verbose2 = False ,numItermax = 100 ,numInnerItermax = 20 ,stopInnerThr = 1e-9 ,stopThr = 1e-6 ,log = False ,** kwargs ):
220+ def joint_OT_mapping_kernel (xs ,xt ,mu = 1 ,eta = 0.001 ,kerneltype = 'gaussian' ,sigma = 1 ,bias = False ,verbose = False ,verbose2 = False ,numItermax = 100 ,numInnerItermax = 10 ,stopInnerThr = 1e-6 ,stopThr = 1e-5 ,log = False ,** kwargs ):
221221 """Joint Ot and mapping estimation (uniform weights and )
222222 """
223223
@@ -228,15 +228,31 @@ def joint_OT_mapping_kernel(xs,xt,mu=1,eta=0.001,kerneltype='gaussian',sigma=1,b
228228 K1 = np .hstack ((K ,np .ones ((ns ,1 ))))
229229 I = np .eye (ns + 1 )
230230 I [- 1 ]= 0
231- K0 = K1 .T .dot (K1 )+ eta * I
232- Kreg = I
233- sel = lambda x : x [:- 1 ,:]
231+ Kp = np .eye (ns + 1 )
232+ Kp [:ns ,:ns ]= K
233+
234+ # ls regu
235+ #K0 = K1.T.dot(K1)+eta*I
236+ #Kreg=I
237+
238+ # RKHS regul
239+ K0 = K1 .T .dot (K1 )+ eta * Kp
240+ Kreg = Kp
241+
234242 else :
235243 K1 = K
236244 I = np .eye (ns )
245+
246+ # ls regul
247+ #K0 = K1.T.dot(K1)+eta*I
248+ #Kreg=I
249+
250+ # proper kernel ridge
237251 K0 = K + eta * I
238252 Kreg = K
239- sel = lambda x : x
253+
254+
255+
240256
241257 if log :
242258 log = {'err' :[]}
@@ -313,7 +329,7 @@ def df(G):
313329 if verbose :
314330 if it % 20 == 0 :
315331 print ('{:5s}|{:12s}|{:8s}' .format ('It.' ,'Loss' ,'Delta loss' )+ '\n ' + '-' * 32 )
316- print ('{:5d}|{:8e}|{:8e}' .format (it ,vloss [- 1 ],abs (vloss [- 1 ]- vloss [- 2 ])/ abs (vloss [- 2 ])))
332+ print ('{:5d}|{:8e}|{:8e}' .format (it ,vloss [- 1 ],(vloss [- 1 ]- vloss [- 2 ])/ abs (vloss [- 2 ])))
317333 if log :
318334 log ['loss' ]= vloss
319335 return G ,L ,log
0 commit comments