@@ -658,7 +658,7 @@ def __init__(self, metric='sqeuclidean'):
658658 self .metric = metric
659659 self .computed = False
660660
661- def fit (self , xs , xt , ws = None , wt = None , norm = None , numItermax = 10000 ):
661+ def fit (self , xs , xt , ws = None , wt = None , norm = None , max_iter = 100000 ):
662662 """Fit domain adaptation between samples is xs and xt
663663 (with optional weights)"""
664664 self .xs = xs
@@ -674,7 +674,7 @@ def fit(self, xs, xt, ws=None, wt=None, norm=None, numItermax=10000):
674674
675675 self .M = dist (xs , xt , metric = self .metric )
676676 self .normalizeM (norm )
677- self .G = emd (ws , wt , self .M , numItermax )
677+ self .G = emd (ws , wt , self .M , max_iter )
678678 self .computed = True
679679
680680 def interp (self , direction = 1 ):
@@ -1001,6 +1001,7 @@ def fit(self, Xs=None, ys=None, Xt=None, yt=None):
10011001
10021002 # pairwise distance
10031003 self .cost_ = dist (Xs , Xt , metric = self .metric )
1004+ self .normalizeCost_ (self .norm )
10041005
10051006 if (ys is not None ) and (yt is not None ):
10061007
@@ -1182,6 +1183,26 @@ def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None,
11821183
11831184 return transp_Xt
11841185
1186+ def normalizeCost_ (self , norm ):
1187+ """ Apply normalization to the loss matrix
1188+
1189+
1190+ Parameters
1191+ ----------
1192+ norm : str
1193+ type of normalization from 'median','max','log','loglog'
1194+
1195+ """
1196+
1197+ if norm == "median" :
1198+ self .cost_ /= float (np .median (self .cost_ ))
1199+ elif norm == "max" :
1200+ self .cost_ /= float (np .max (self .cost_ ))
1201+ elif norm == "log" :
1202+ self .cost_ = np .log (1 + self .cost_ )
1203+ elif norm == "loglog" :
1204+ self .cost_ = np .log (1 + np .log (1 + self .cost_ ))
1205+
11851206
11861207class SinkhornTransport (BaseTransport ):
11871208 """Domain Adapatation OT method based on Sinkhorn Algorithm
@@ -1202,6 +1223,9 @@ class SinkhornTransport(BaseTransport):
12021223 be transported from a domain to another one.
12031224 metric : string, optional (default="sqeuclidean")
12041225 The ground metric for the Wasserstein problem
1226+ norm : string, optional (default=None)
1227+ If given, normalize the ground metric to avoid numerical errors that
1228+ can occur with large metric values.
12051229 distribution : string, optional (default="uniform")
12061230 The kind of distribution estimation to employ
12071231 verbose : int, optional (default=0)
@@ -1231,7 +1255,7 @@ class SinkhornTransport(BaseTransport):
12311255
12321256 def __init__ (self , reg_e = 1. , max_iter = 1000 ,
12331257 tol = 10e-9 , verbose = False , log = False ,
1234- metric = "sqeuclidean" ,
1258+ metric = "sqeuclidean" , norm = None ,
12351259 distribution_estimation = distribution_estimation_uniform ,
12361260 out_of_sample_map = 'ferradans' , limit_max = np .infty ):
12371261
@@ -1241,6 +1265,7 @@ def __init__(self, reg_e=1., max_iter=1000,
12411265 self .verbose = verbose
12421266 self .log = log
12431267 self .metric = metric
1268+ self .norm = norm
12441269 self .limit_max = limit_max
12451270 self .distribution_estimation = distribution_estimation
12461271 self .out_of_sample_map = out_of_sample_map
@@ -1296,6 +1321,9 @@ class EMDTransport(BaseTransport):
12961321 be transported from a domain to another one.
12971322 metric : string, optional (default="sqeuclidean")
12981323 The ground metric for the Wasserstein problem
1324+ norm : string, optional (default=None)
1325+ If given, normalize the ground metric to avoid numerical errors that
1326+ can occur with large metric values.
12991327 distribution : string, optional (default="uniform")
13001328 The kind of distribution estimation to employ
13011329 verbose : int, optional (default=0)
@@ -1306,6 +1334,9 @@ class EMDTransport(BaseTransport):
13061334 Controls the semi supervised mode. Transport between labeled source
13071335 and target samples of different classes will exhibit an infinite cost
13081336 (10 times the maximum value of the cost matrix)
1337+ max_iter : int, optional (default=100000)
1338+ The maximum number of iterations before stopping the optimization
1339+ algorithm if it has not converged.
13091340
13101341 Attributes
13111342 ----------
@@ -1319,14 +1350,17 @@ class EMDTransport(BaseTransport):
13191350 on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
13201351 """
13211352
1322- def __init__ (self , metric = "sqeuclidean" ,
1353+ def __init__ (self , metric = "sqeuclidean" , norm = None ,
13231354 distribution_estimation = distribution_estimation_uniform ,
1324- out_of_sample_map = 'ferradans' , limit_max = 10 ):
1355+ out_of_sample_map = 'ferradans' , limit_max = 10 ,
1356+ max_iter = 100000 ):
13251357
13261358 self .metric = metric
1359+ self .norm = norm
13271360 self .limit_max = limit_max
13281361 self .distribution_estimation = distribution_estimation
13291362 self .out_of_sample_map = out_of_sample_map
1363+ self .max_iter = max_iter
13301364
13311365 def fit (self , Xs , ys = None , Xt = None , yt = None ):
13321366 """Build a coupling matrix from source and target sets of samples
@@ -1353,7 +1387,7 @@ def fit(self, Xs, ys=None, Xt=None, yt=None):
13531387
13541388 # coupling estimation
13551389 self .coupling_ = emd (
1356- a = self .mu_s , b = self .mu_t , M = self .cost_ ,
1390+ a = self .mu_s , b = self .mu_t , M = self .cost_ , max_iter = self . max_iter
13571391 )
13581392
13591393 return self
@@ -1376,6 +1410,9 @@ class SinkhornLpl1Transport(BaseTransport):
13761410 be transported from a domain to another one.
13771411 metric : string, optional (default="sqeuclidean")
13781412 The ground metric for the Wasserstein problem
1413+ norm : string, optional (default=None)
1414+ If given, normalize the ground metric to avoid numerical errors that
1415+ can occur with large metric values.
13791416 distribution : string, optional (default="uniform")
13801417 The kind of distribution estimation to employ
13811418 max_iter : int, float, optional (default=10)
@@ -1410,7 +1447,7 @@ class SinkhornLpl1Transport(BaseTransport):
14101447 def __init__ (self , reg_e = 1. , reg_cl = 0.1 ,
14111448 max_iter = 10 , max_inner_iter = 200 ,
14121449 tol = 10e-9 , verbose = False ,
1413- metric = "sqeuclidean" ,
1450+ metric = "sqeuclidean" , norm = None ,
14141451 distribution_estimation = distribution_estimation_uniform ,
14151452 out_of_sample_map = 'ferradans' , limit_max = np .infty ):
14161453
@@ -1421,6 +1458,7 @@ def __init__(self, reg_e=1., reg_cl=0.1,
14211458 self .tol = tol
14221459 self .verbose = verbose
14231460 self .metric = metric
1461+ self .norm = norm
14241462 self .distribution_estimation = distribution_estimation
14251463 self .out_of_sample_map = out_of_sample_map
14261464 self .limit_max = limit_max
@@ -1477,6 +1515,9 @@ class SinkhornL1l2Transport(BaseTransport):
14771515 be transported from a domain to another one.
14781516 metric : string, optional (default="sqeuclidean")
14791517 The ground metric for the Wasserstein problem
1518+ norm : string, optional (default=None)
1519+ If given, normalize the ground metric to avoid numerical errors that
1520+ can occur with large metric values.
14801521 distribution : string, optional (default="uniform")
14811522 The kind of distribution estimation to employ
14821523 max_iter : int, float, optional (default=10)
@@ -1516,7 +1557,7 @@ class SinkhornL1l2Transport(BaseTransport):
15161557 def __init__ (self , reg_e = 1. , reg_cl = 0.1 ,
15171558 max_iter = 10 , max_inner_iter = 200 ,
15181559 tol = 10e-9 , verbose = False , log = False ,
1519- metric = "sqeuclidean" ,
1560+ metric = "sqeuclidean" , norm = None ,
15201561 distribution_estimation = distribution_estimation_uniform ,
15211562 out_of_sample_map = 'ferradans' , limit_max = 10 ):
15221563
@@ -1528,6 +1569,7 @@ def __init__(self, reg_e=1., reg_cl=0.1,
15281569 self .verbose = verbose
15291570 self .log = log
15301571 self .metric = metric
1572+ self .norm = norm
15311573 self .distribution_estimation = distribution_estimation
15321574 self .out_of_sample_map = out_of_sample_map
15331575 self .limit_max = limit_max
@@ -1588,6 +1630,9 @@ class MappingTransport(BaseEstimator):
15881630 Estimate linear mapping with constant bias
15891631 metric : string, optional (default="sqeuclidean")
15901632 The ground metric for the Wasserstein problem
1633+ norm : string, optional (default=None)
1634+ If given, normalize the ground metric to avoid numerical errors that
1635+ can occur with large metric values.
15911636 kernel : string, optional (default="linear")
15921637 The kernel to use either linear or gaussian
15931638 sigma : float, optional (default=1)
@@ -1627,11 +1672,12 @@ class MappingTransport(BaseEstimator):
16271672 """
16281673
16291674 def __init__ (self , mu = 1 , eta = 0.001 , bias = False , metric = "sqeuclidean" ,
1630- kernel = "linear" , sigma = 1 , max_iter = 100 , tol = 1e-5 ,
1675+ norm = None , kernel = "linear" , sigma = 1 , max_iter = 100 , tol = 1e-5 ,
16311676 max_inner_iter = 10 , inner_tol = 1e-6 , log = False , verbose = False ,
16321677 verbose2 = False ):
16331678
16341679 self .metric = metric
1680+ self .norm = norm
16351681 self .mu = mu
16361682 self .eta = eta
16371683 self .bias = bias
0 commit comments