1313
1414from .bregman import sinkhorn
1515from .lp import emd
16- from .utils import unif , dist , kernel
16+ from .utils import unif , dist , kernel , cost_normalization
1717from .utils import check_params , deprecated , BaseEstimator
1818from .optim import cg
1919from .optim import gcg
@@ -650,15 +650,16 @@ class OTDA(object):
650650
651651 """
652652
653- def __init__ (self , metric = 'sqeuclidean' ):
653+ def __init__ (self , metric = 'sqeuclidean' , norm = None ):
654654 """ Class initialization"""
655655 self .xs = 0
656656 self .xt = 0
657657 self .G = 0
658658 self .metric = metric
659+ self .norm = norm
659660 self .computed = False
660661
661- def fit (self , xs , xt , ws = None , wt = None , norm = None ):
662+ def fit (self , xs , xt , ws = None , wt = None , max_iter = 100000 ):
662663 """Fit domain adaptation between samples is xs and xt
663664 (with optional weights)"""
664665 self .xs = xs
@@ -673,8 +674,8 @@ def fit(self, xs, xt, ws=None, wt=None, norm=None):
673674 self .wt = wt
674675
675676 self .M = dist (xs , xt , metric = self .metric )
676- self .normalizeM ( norm )
677- self .G = emd (ws , wt , self .M )
677+ self .M = cost_normalization ( self . M , self . norm )
678+ self .G = emd (ws , wt , self .M , max_iter )
678679 self .computed = True
679680
680681 def interp (self , direction = 1 ):
@@ -741,26 +742,6 @@ def predict(self, x, direction=1):
741742 # aply the delta to the interpolation
742743 return xf [idx , :] + x - x0 [idx , :]
743744
744- def normalizeM (self , norm ):
745- """ Apply normalization to the loss matrix
746-
747-
748- Parameters
749- ----------
750- norm : str
751- type of normalization from 'median','max','log','loglog'
752-
753- """
754-
755- if norm == "median" :
756- self .M /= float (np .median (self .M ))
757- elif norm == "max" :
758- self .M /= float (np .max (self .M ))
759- elif norm == "log" :
760- self .M = np .log (1 + self .M )
761- elif norm == "loglog" :
762- self .M = np .log (1 + np .log (1 + self .M ))
763-
764745
765746@deprecated ("The class OTDA_sinkhorn is deprecated in 0.3.1 and will be"
766747 " removed in 0.5 \n Use class SinkhornTransport instead." )
@@ -772,7 +753,7 @@ class OTDA_sinkhorn(OTDA):
772753
773754 """
774755
775- def fit (self , xs , xt , reg = 1 , ws = None , wt = None , norm = None , ** kwargs ):
756+ def fit (self , xs , xt , reg = 1 , ws = None , wt = None , ** kwargs ):
776757 """Fit regularized domain adaptation between samples is xs and xt
777758 (with optional weights)"""
778759 self .xs = xs
@@ -787,7 +768,7 @@ def fit(self, xs, xt, reg=1, ws=None, wt=None, norm=None, **kwargs):
787768 self .wt = wt
788769
789770 self .M = dist (xs , xt , metric = self .metric )
790- self .normalizeM ( norm )
771+ self .M = cost_normalization ( self . M , self . norm )
791772 self .G = sinkhorn (ws , wt , self .M , reg , ** kwargs )
792773 self .computed = True
793774
@@ -799,8 +780,7 @@ class OTDA_lpl1(OTDA):
799780 """Class for domain adaptation with optimal transport with entropic and
800781 group regularization"""
801782
802- def fit (self , xs , ys , xt , reg = 1 , eta = 1 , ws = None , wt = None , norm = None ,
803- ** kwargs ):
783+ def fit (self , xs , ys , xt , reg = 1 , eta = 1 , ws = None , wt = None , ** kwargs ):
804784 """Fit regularized domain adaptation between samples is xs and xt
805785 (with optional weights), See ot.da.sinkhorn_lpl1_mm for fit
806786 parameters"""
@@ -816,7 +796,7 @@ def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, norm=None,
816796 self .wt = wt
817797
818798 self .M = dist (xs , xt , metric = self .metric )
819- self .normalizeM ( norm )
799+ self .M = cost_normalization ( self . M , self . norm )
820800 self .G = sinkhorn_lpl1_mm (ws , ys , wt , self .M , reg , eta , ** kwargs )
821801 self .computed = True
822802
@@ -828,8 +808,7 @@ class OTDA_l1l2(OTDA):
828808 """Class for domain adaptation with optimal transport with entropic
829809 and group lasso regularization"""
830810
831- def fit (self , xs , ys , xt , reg = 1 , eta = 1 , ws = None , wt = None , norm = None ,
832- ** kwargs ):
811+ def fit (self , xs , ys , xt , reg = 1 , eta = 1 , ws = None , wt = None , ** kwargs ):
833812 """Fit regularized domain adaptation between samples is xs and xt
834813 (with optional weights), See ot.da.sinkhorn_lpl1_gl for fit
835814 parameters"""
@@ -845,7 +824,7 @@ def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, norm=None,
845824 self .wt = wt
846825
847826 self .M = dist (xs , xt , metric = self .metric )
848- self .normalizeM ( norm )
827+ self .M = cost_normalization ( self . M , self . norm )
849828 self .G = sinkhorn_l1l2_gl (ws , ys , wt , self .M , reg , eta , ** kwargs )
850829 self .computed = True
851830
@@ -1001,6 +980,7 @@ def fit(self, Xs=None, ys=None, Xt=None, yt=None):
1001980
1002981 # pairwise distance
1003982 self .cost_ = dist (Xs , Xt , metric = self .metric )
983+ self .cost_ = cost_normalization (self .cost_ , self .norm )
1004984
1005985 if (ys is not None ) and (yt is not None ):
1006986
@@ -1202,6 +1182,9 @@ class SinkhornTransport(BaseTransport):
12021182 be transported from a domain to another one.
12031183 metric : string, optional (default="sqeuclidean")
12041184 The ground metric for the Wasserstein problem
1185+ norm : string, optional (default=None)
1186+ If given, normalize the ground metric to avoid numerical errors that
1187+ can occur with large metric values.
12051188 distribution : string, optional (default="uniform")
12061189 The kind of distribution estimation to employ
12071190 verbose : int, optional (default=0)
@@ -1231,7 +1214,7 @@ class SinkhornTransport(BaseTransport):
12311214
12321215 def __init__ (self , reg_e = 1. , max_iter = 1000 ,
12331216 tol = 10e-9 , verbose = False , log = False ,
1234- metric = "sqeuclidean" ,
1217+ metric = "sqeuclidean" , norm = None ,
12351218 distribution_estimation = distribution_estimation_uniform ,
12361219 out_of_sample_map = 'ferradans' , limit_max = np .infty ):
12371220
@@ -1241,6 +1224,7 @@ def __init__(self, reg_e=1., max_iter=1000,
12411224 self .verbose = verbose
12421225 self .log = log
12431226 self .metric = metric
1227+ self .norm = norm
12441228 self .limit_max = limit_max
12451229 self .distribution_estimation = distribution_estimation
12461230 self .out_of_sample_map = out_of_sample_map
@@ -1296,6 +1280,9 @@ class EMDTransport(BaseTransport):
12961280 be transported from a domain to another one.
12971281 metric : string, optional (default="sqeuclidean")
12981282 The ground metric for the Wasserstein problem
1283+ norm : string, optional (default=None)
1284+ If given, normalize the ground metric to avoid numerical errors that
1285+ can occur with large metric values.
12991286 distribution : string, optional (default="uniform")
13001287 The kind of distribution estimation to employ
13011288 verbose : int, optional (default=0)
@@ -1306,6 +1293,9 @@ class EMDTransport(BaseTransport):
13061293 Controls the semi supervised mode. Transport between labeled source
13071294 and target samples of different classes will exhibit an infinite cost
13081295 (10 times the maximum value of the cost matrix)
1296+ max_iter : int, optional (default=100000)
1297+ The maximum number of iterations before stopping the optimization
1298+ algorithm if it has not converged.
13091299
13101300 Attributes
13111301 ----------
@@ -1319,14 +1309,17 @@ class EMDTransport(BaseTransport):
13191309 on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
13201310 """
13211311
1322- def __init__ (self , metric = "sqeuclidean" ,
1312+ def __init__ (self , metric = "sqeuclidean" , norm = None ,
13231313 distribution_estimation = distribution_estimation_uniform ,
1324- out_of_sample_map = 'ferradans' , limit_max = 10 ):
1314+ out_of_sample_map = 'ferradans' , limit_max = 10 ,
1315+ max_iter = 100000 ):
13251316
13261317 self .metric = metric
1318+ self .norm = norm
13271319 self .limit_max = limit_max
13281320 self .distribution_estimation = distribution_estimation
13291321 self .out_of_sample_map = out_of_sample_map
1322+ self .max_iter = max_iter
13301323
13311324 def fit (self , Xs , ys = None , Xt = None , yt = None ):
13321325 """Build a coupling matrix from source and target sets of samples
@@ -1353,7 +1346,7 @@ def fit(self, Xs, ys=None, Xt=None, yt=None):
13531346
13541347 # coupling estimation
13551348 self .coupling_ = emd (
1356- a = self .mu_s , b = self .mu_t , M = self .cost_ ,
1349+ a = self .mu_s , b = self .mu_t , M = self .cost_ , numItermax = self . max_iter
13571350 )
13581351
13591352 return self
@@ -1376,6 +1369,9 @@ class SinkhornLpl1Transport(BaseTransport):
13761369 be transported from a domain to another one.
13771370 metric : string, optional (default="sqeuclidean")
13781371 The ground metric for the Wasserstein problem
1372+ norm : string, optional (default=None)
1373+ If given, normalize the ground metric to avoid numerical errors that
1374+ can occur with large metric values.
13791375 distribution : string, optional (default="uniform")
13801376 The kind of distribution estimation to employ
13811377 max_iter : int, float, optional (default=10)
@@ -1410,7 +1406,7 @@ class SinkhornLpl1Transport(BaseTransport):
14101406 def __init__ (self , reg_e = 1. , reg_cl = 0.1 ,
14111407 max_iter = 10 , max_inner_iter = 200 ,
14121408 tol = 10e-9 , verbose = False ,
1413- metric = "sqeuclidean" ,
1409+ metric = "sqeuclidean" , norm = None ,
14141410 distribution_estimation = distribution_estimation_uniform ,
14151411 out_of_sample_map = 'ferradans' , limit_max = np .infty ):
14161412
@@ -1421,6 +1417,7 @@ def __init__(self, reg_e=1., reg_cl=0.1,
14211417 self .tol = tol
14221418 self .verbose = verbose
14231419 self .metric = metric
1420+ self .norm = norm
14241421 self .distribution_estimation = distribution_estimation
14251422 self .out_of_sample_map = out_of_sample_map
14261423 self .limit_max = limit_max
@@ -1477,6 +1474,9 @@ class SinkhornL1l2Transport(BaseTransport):
14771474 be transported from a domain to another one.
14781475 metric : string, optional (default="sqeuclidean")
14791476 The ground metric for the Wasserstein problem
1477+ norm : string, optional (default=None)
1478+ If given, normalize the ground metric to avoid numerical errors that
1479+ can occur with large metric values.
14801480 distribution : string, optional (default="uniform")
14811481 The kind of distribution estimation to employ
14821482 max_iter : int, float, optional (default=10)
@@ -1516,7 +1516,7 @@ class SinkhornL1l2Transport(BaseTransport):
15161516 def __init__ (self , reg_e = 1. , reg_cl = 0.1 ,
15171517 max_iter = 10 , max_inner_iter = 200 ,
15181518 tol = 10e-9 , verbose = False , log = False ,
1519- metric = "sqeuclidean" ,
1519+ metric = "sqeuclidean" , norm = None ,
15201520 distribution_estimation = distribution_estimation_uniform ,
15211521 out_of_sample_map = 'ferradans' , limit_max = 10 ):
15221522
@@ -1528,6 +1528,7 @@ def __init__(self, reg_e=1., reg_cl=0.1,
15281528 self .verbose = verbose
15291529 self .log = log
15301530 self .metric = metric
1531+ self .norm = norm
15311532 self .distribution_estimation = distribution_estimation
15321533 self .out_of_sample_map = out_of_sample_map
15331534 self .limit_max = limit_max
@@ -1588,6 +1589,9 @@ class MappingTransport(BaseEstimator):
15881589 Estimate linear mapping with constant bias
15891590 metric : string, optional (default="sqeuclidean")
15901591 The ground metric for the Wasserstein problem
1592+ norm : string, optional (default=None)
1593+ If given, normalize the ground metric to avoid numerical errors that
1594+ can occur with large metric values.
15911595 kernel : string, optional (default="linear")
15921596 The kernel to use either linear or gaussian
15931597 sigma : float, optional (default=1)
@@ -1627,11 +1631,12 @@ class MappingTransport(BaseEstimator):
16271631 """
16281632
16291633 def __init__ (self , mu = 1 , eta = 0.001 , bias = False , metric = "sqeuclidean" ,
1630- kernel = "linear" , sigma = 1 , max_iter = 100 , tol = 1e-5 ,
1634+ norm = None , kernel = "linear" , sigma = 1 , max_iter = 100 , tol = 1e-5 ,
16311635 max_inner_iter = 10 , inner_tol = 1e-6 , log = False , verbose = False ,
16321636 verbose2 = False ):
16331637
16341638 self .metric = metric
1639+ self .norm = norm
16351640 self .mu = mu
16361641 self .eta = eta
16371642 self .bias = bias
0 commit comments