@@ -650,15 +650,16 @@ class OTDA(object):
650650
651651 """
652652
653- def __init__ (self , metric = 'sqeuclidean' ):
653+ def __init__ (self , metric = 'sqeuclidean' , norm = None ):
654654 """ Class initialization"""
655655 self .xs = 0
656656 self .xt = 0
657657 self .G = 0
658658 self .metric = metric
659+ self .norm = norm
659660 self .computed = False
660661
661- def fit (self , xs , xt , ws = None , wt = None , norm = None , max_iter = 100000 ):
662+ def fit (self , xs , xt , ws = None , wt = None , max_iter = 100000 ):
662663 """Fit domain adaptation between samples is xs and xt
663664 (with optional weights)"""
664665 self .xs = xs
@@ -673,7 +674,7 @@ def fit(self, xs, xt, ws=None, wt=None, norm=None, max_iter=100000):
673674 self .wt = wt
674675
675676 self .M = dist (xs , xt , metric = self .metric )
676- self .M = cost_normalization (self .M , norm )
677+ self .M = cost_normalization (self .M , self . norm )
677678 self .G = emd (ws , wt , self .M , max_iter )
678679 self .computed = True
679680
@@ -752,7 +753,7 @@ class OTDA_sinkhorn(OTDA):
752753
753754 """
754755
755- def fit (self , xs , xt , reg = 1 , ws = None , wt = None , norm = None , ** kwargs ):
756+ def fit (self , xs , xt , reg = 1 , ws = None , wt = None , ** kwargs ):
756757 """Fit regularized domain adaptation between samples is xs and xt
757758 (with optional weights)"""
758759 self .xs = xs
@@ -767,7 +768,7 @@ def fit(self, xs, xt, reg=1, ws=None, wt=None, norm=None, **kwargs):
767768 self .wt = wt
768769
769770 self .M = dist (xs , xt , metric = self .metric )
770- self .M = cost_normalization (self .M , norm )
771+ self .M = cost_normalization (self .M , self . norm )
771772 self .G = sinkhorn (ws , wt , self .M , reg , ** kwargs )
772773 self .computed = True
773774
@@ -779,8 +780,7 @@ class OTDA_lpl1(OTDA):
779780 """Class for domain adaptation with optimal transport with entropic and
780781 group regularization"""
781782
782- def fit (self , xs , ys , xt , reg = 1 , eta = 1 , ws = None , wt = None , norm = None ,
783- ** kwargs ):
783+ def fit (self , xs , ys , xt , reg = 1 , eta = 1 , ws = None , wt = None , ** kwargs ):
784784 """Fit regularized domain adaptation between samples is xs and xt
785785 (with optional weights), See ot.da.sinkhorn_lpl1_mm for fit
786786 parameters"""
@@ -796,7 +796,7 @@ def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, norm=None,
796796 self .wt = wt
797797
798798 self .M = dist (xs , xt , metric = self .metric )
799- self .M = cost_normalization (self .M , norm )
799+ self .M = cost_normalization (self .M , self . norm )
800800 self .G = sinkhorn_lpl1_mm (ws , ys , wt , self .M , reg , eta , ** kwargs )
801801 self .computed = True
802802
@@ -808,8 +808,7 @@ class OTDA_l1l2(OTDA):
808808 """Class for domain adaptation with optimal transport with entropic
809809 and group lasso regularization"""
810810
811- def fit (self , xs , ys , xt , reg = 1 , eta = 1 , ws = None , wt = None , norm = None ,
812- ** kwargs ):
811+ def fit (self , xs , ys , xt , reg = 1 , eta = 1 , ws = None , wt = None , ** kwargs ):
813812 """Fit regularized domain adaptation between samples is xs and xt
814813 (with optional weights), See ot.da.sinkhorn_lpl1_gl for fit
815814 parameters"""
@@ -825,7 +824,7 @@ def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, norm=None,
825824 self .wt = wt
826825
827826 self .M = dist (xs , xt , metric = self .metric )
828- self .M = cost_normalization (self .M , norm )
827+ self .M = cost_normalization (self .M , self . norm )
829828 self .G = sinkhorn_l1l2_gl (ws , ys , wt , self .M , reg , eta , ** kwargs )
830829 self .computed = True
831830
0 commit comments