@@ -217,25 +217,23 @@ def df(G):
217217 return G ,L
218218
219219
220- def joint_OT_mapping_kernel (xs ,xt ,mu = 1 ,eta = 0.001 ,kernel = 'gaussian' ,sigma = 1 ,bias = False ,verbose = False ,verbose2 = False ,numItermax = 100 ,numInnerItermax = 20 ,stopInnerThr = 1e-9 ,stopThr = 1e-6 ,log = False ,** kwargs ):
220+ def joint_OT_mapping_kernel (xs ,xt ,mu = 1 ,eta = 0.001 ,kerneltype = 'gaussian' ,sigma = 1 ,bias = False ,verbose = False ,verbose2 = False ,numItermax = 100 ,numInnerItermax = 20 ,stopInnerThr = 1e-9 ,stopThr = 1e-6 ,log = False ,** kwargs ):
221221 """Joint Ot and mapping estimation (uniform weights and )
222222 """
223223
224224 ns ,nt ,d = xs .shape [0 ],xt .shape [0 ],xt .shape [1 ]
225225
226+ K = kernel (xs ,xs ,method = kerneltype ,sigma = sigma )
226227 if bias :
227- K =
228- xs1 = np .hstack ((xs ,np .ones ((ns ,1 ))))
229- xstxs = xs1 .T .dot (xs1 )
230- I = np .eye (d + 1 )
228+ K1 = np .hstack ((K ,np .ones ((ns ,1 ))))
229+ I = np .eye (ns + 1 )
231230 I [- 1 ]= 0
232- I0 = I [:,: - 1 ]
231+ K0 = K1 . T . dot ( K1 ) + eta * I
233232 sel = lambda x : x [:- 1 ,:]
234233 else :
235- xs1 = xs
236- xstxs = xs1 .T .dot (xs1 )
237- I = np .eye (d )
238- I0 = I
234+ K1 = K
235+ I = np .eye (ns )
236+ K0 = K + eta * I
239237 sel = lambda x : x
240238
241239 if log :
@@ -249,23 +247,32 @@ def joint_OT_mapping_kernel(xs,xt,mu=1,eta=0.001,kernel='gaussian',sigma=1,bias=
249247
250248 def loss (L ,G ):
251249 """Compute full loss"""
252- return np .sum ((xs1 .dot (L )- ns * G .dot (xt ))** 2 )+ mu * np .sum (G * M )+ eta * np .sum (sel (L - I0 )** 2 )
250+ return np .sum ((K1 .dot (L )- ns * G .dot (xt ))** 2 )+ mu * np .sum (G * M )+ eta * np .sum (sel (L )** 2 )
253251
254- def solve_L (G ):
252+ def solve_L_nobias (G ):
255253 """ solve L problem with fixed G (least square)"""
256254 xst = ns * G .dot (xt )
257- return np .linalg .solve (xstxs + eta * I ,xs1 .T .dot (xst )+ eta * I0 )
255+ return np .linalg .solve (K0 ,xst )
256+
257+ def solve_L_bias (G ):
258+ """ solve L problem with fixed G (least square)"""
259+ xst = ns * G .dot (xt )
260+ return np .linalg .solve (K0 ,K1 .T .dot (xst ))
258261
259262 def solve_G (L ,G0 ):
260263 """Update G with CG algorithm"""
261- xsi = xs1 .dot (L )
264+ xsi = K1 .dot (L )
262265 def f (G ):
263266 return np .sum ((xsi - ns * G .dot (xt ))** 2 )
264267 def df (G ):
265268 return - 2 * ns * (xsi - ns * G .dot (xt )).dot (xt .T )
266269 G = cg (a ,b ,M ,1.0 / mu ,f ,df ,G0 = G0 ,numItermax = numInnerItermax ,stopThr = stopInnerThr )
267270 return G
268271
272+ if bias :
273+ solve_L = solve_L_bias
274+ else :
275+ solve_L = solve_L_nobias
269276
270277 L = solve_L (G )
271278
0 commit comments