@@ -247,7 +247,7 @@ def joint_OT_mapping_kernel(xs,xt,mu=1,eta=0.001,kerneltype='gaussian',sigma=1,b
247247
248248 def loss (L ,G ):
249249 """Compute full loss"""
250- return np .sum ((K1 .dot (L )- ns * G .dot (xt ))** 2 )+ mu * np .sum (G * M )+ eta * np .sum ( sel ( L ) ** 2 )
250+ return np .sum ((K1 .dot (L )- ns * G .dot (xt ))** 2 )+ mu * np .sum (G * M )+ eta * np .trace ( L . T . dot ( K0 ). dot ( L ) )
251251
252252 def solve_L_nobias (G ):
253253 """ solve L problem with fixed G (least square)"""
@@ -450,11 +450,11 @@ def fit(self,xs,ys,xt,reg=1,eta=1,ws=None,wt=None,**kwargs):
450450 self .G = sinkhorn_lpl1_mm (ws ,ys ,wt ,self .M ,reg ,eta ,** kwargs )
451451 self .computed = True
452452
453- class OTDA_mapping (OTDA ):
453+ class OTDA_mapping_linear (OTDA ):
454454 """Class for optimal transport with joint linear mapping estimation"""
455455
456456
457- def __init__ (self , metric = 'sqeuclidean' ):
457+ def __init__ (self ):
458458 """ Class initialization"""
459459
460460
@@ -463,8 +463,8 @@ def __init__(self,metric='sqeuclidean'):
463463 self .G = 0
464464 self .L = 0
465465 self .bias = False
466- self .metric = metric
467466 self .computed = False
467+ self .metric = 'sqeuclidean'
468468
469469 def fit (self ,xs ,xt ,mu = 1 ,eta = 1 ,bias = False ,** kwargs ):
470470 """ Fit domain adaptation between samples is xs and xt (with optional
@@ -473,6 +473,7 @@ def fit(self,xs,xt,mu=1,eta=1,bias=False,**kwargs):
473473 self .xt = xt
474474 self .bias = bias
475475
476+
476477 self .ws = unif (xs .shape [0 ])
477478 self .wt = unif (xt .shape [0 ])
478479
@@ -498,3 +499,42 @@ def predict(self,x):
498499 print ("Warning, model not fitted yet, returning None" )
499500 return None
500501
502+ class OTDA_mapping_kernel (OTDA_mapping_linear ):
503+ """Class for optimal transport with joint linear mapping estimation"""
504+
505+
506+
507+ def fit (self ,xs ,xt ,mu = 1 ,eta = 1 ,bias = False ,kerneltype = 'gaussian' ,sigma = 1 ,** kwargs ):
508+ """ Fit domain adaptation between samples is xs and xt (with optional
509+ weights)"""
510+ self .xs = xs
511+ self .xt = xt
512+ self .bias = bias
513+
514+ self .ws = unif (xs .shape [0 ])
515+ self .wt = unif (xt .shape [0 ])
516+ self .kernel = kerneltype
517+ self .sigma = sigma
518+ self .kwargs = kwargs
519+
520+
521+ self .G ,self .L = joint_OT_mapping_kernel (xs ,xt ,mu = mu ,eta = eta ,bias = bias ,** kwargs )
522+ self .computed = True
523+
524+
525+ def predict (self ,x ):
526+ """ Out of sample mapping using the formulation from Ferradans
527+
528+ It basically find the source sample the nearset to the nex sample and
529+ apply the difference to the displaced source sample.
530+
531+ """
532+
533+ if self .computed :
534+ K = kernel (x ,self .xs ,method = self .kernel ,sigma = self .sigma ,** self .kwargs )
535+ if self .bias :
536+ K = np .hstack ((K ,np .ones ((x .shape [0 ],1 ))))
537+ return K .dot (self .L )
538+ else :
539+ print ("Warning, model not fitted yet, returning None" )
540+ return None
0 commit comments