@@ -1315,7 +1315,10 @@ class SinkhornTransport(BaseTransport):
13151315
13161316 Attributes
13171317 ----------
1318- coupling_ : the optimal coupling
1318+ coupling_ : array-like, shape (n_source_samples, n_target_samples)
1319+ The optimal coupling
1320+ log_ : dictionary
1321+ The dictionary of log, empty dic if parameter log is not True
13191322
13201323 References
13211324 ----------
@@ -1367,11 +1370,18 @@ def fit(self, Xs=None, ys=None, Xt=None, yt=None):
13671370 super (SinkhornTransport , self ).fit (Xs , ys , Xt , yt )
13681371
13691372 # coupling estimation
1370- self . coupling_ = sinkhorn (
1373+ returned_ = sinkhorn (
13711374 a = self .mu_s , b = self .mu_t , M = self .cost_ , reg = self .reg_e ,
13721375 numItermax = self .max_iter , stopThr = self .tol ,
13731376 verbose = self .verbose , log = self .log )
13741377
1378+ # deal with the value of log
1379+ if self .log :
1380+ self .coupling_ , self .log_ = returned_
1381+ else :
1382+ self .coupling_ = returned_
1383+ self .log_ = dict ()
1384+
13751385 return self
13761386
13771387
@@ -1400,7 +1410,8 @@ class EMDTransport(BaseTransport):
14001410
14011411 Attributes
14021412 ----------
1403- coupling_ : the optimal coupling
1413+ coupling_ : array-like, shape (n_source_samples, n_target_samples)
1414+ The optimal coupling
14041415
14051416 References
14061417 ----------
@@ -1475,15 +1486,14 @@ class SinkhornLpl1Transport(BaseTransport):
14751486 The number of iteration in the inner loop
14761487 verbose : int, optional (default=0)
14771488 Controls the verbosity of the optimization algorithm
1478- log : int, optional (default=0)
1479- Controls the logs of the optimization algorithm
14801489 limit_max: float, optional (defaul=np.infty)
14811490 Controls the semi supervised mode. Transport between labeled source
14821491 and target samples of different classes will exhibit an infinite cost
14831492
14841493 Attributes
14851494 ----------
1486- coupling_ : the optimal coupling
1495+ coupling_ : array-like, shape (n_source_samples, n_target_samples)
1496+ The optimal coupling
14871497
14881498 References
14891499 ----------
@@ -1500,7 +1510,7 @@ class SinkhornLpl1Transport(BaseTransport):
15001510
15011511 def __init__ (self , reg_e = 1. , reg_cl = 0.1 ,
15021512 max_iter = 10 , max_inner_iter = 200 ,
1503- tol = 10e-9 , verbose = False , log = False ,
1513+ tol = 10e-9 , verbose = False ,
15041514 metric = "sqeuclidean" ,
15051515 distribution_estimation = distribution_estimation_uniform ,
15061516 out_of_sample_map = 'ferradans' , limit_max = np .infty ):
@@ -1511,7 +1521,6 @@ def __init__(self, reg_e=1., reg_cl=0.1,
15111521 self .max_inner_iter = max_inner_iter
15121522 self .tol = tol
15131523 self .verbose = verbose
1514- self .log = log
15151524 self .metric = metric
15161525 self .distribution_estimation = distribution_estimation
15171526 self .out_of_sample_map = out_of_sample_map
@@ -1544,7 +1553,7 @@ def fit(self, Xs, ys=None, Xt=None, yt=None):
15441553 a = self .mu_s , labels_a = ys , b = self .mu_t , M = self .cost_ ,
15451554 reg = self .reg_e , eta = self .reg_cl , numItermax = self .max_iter ,
15461555 numInnerItermax = self .max_inner_iter , stopInnerThr = self .tol ,
1547- verbose = self .verbose , log = self . log )
1556+ verbose = self .verbose )
15481557
15491558 return self
15501559
@@ -1584,7 +1593,10 @@ class SinkhornL1l2Transport(BaseTransport):
15841593
15851594 Attributes
15861595 ----------
1587- coupling_ : the optimal coupling
1596+ coupling_ : array-like, shape (n_source_samples, n_target_samples)
1597+ The optimal coupling
1598+ log_ : dictionary
1599+ The dictionary of log, empty dic if parameter log is not True
15881600
15891601 References
15901602 ----------
@@ -1641,12 +1653,19 @@ def fit(self, Xs, ys=None, Xt=None, yt=None):
16411653
16421654 super (SinkhornL1l2Transport , self ).fit (Xs , ys , Xt , yt )
16431655
1644- self . coupling_ = sinkhorn_l1l2_gl (
1656+ returned_ = sinkhorn_l1l2_gl (
16451657 a = self .mu_s , labels_a = ys , b = self .mu_t , M = self .cost_ ,
16461658 reg = self .reg_e , eta = self .reg_cl , numItermax = self .max_iter ,
16471659 numInnerItermax = self .max_inner_iter , stopInnerThr = self .tol ,
16481660 verbose = self .verbose , log = self .log )
16491661
1662+ # deal with the value of log
1663+ if self .log :
1664+ self .coupling_ , self .log_ = returned_
1665+ else :
1666+ self .coupling_ = returned_
1667+ self .log_ = dict ()
1668+
16501669 return self
16511670
16521671
@@ -1683,14 +1702,15 @@ class MappingTransport(BaseEstimator):
16831702
16841703 Attributes
16851704 ----------
1686- coupling_ : array-like, shape (n_source_samples, n_features )
1705+ coupling_ : array-like, shape (n_source_samples, n_target_samples )
16871706 The optimal coupling
16881707 mapping_ : array-like, shape (n_features (+ 1), n_features)
16891708 (if bias) for kernel == linear
16901709 The associated mapping
1691-
16921710 array-like, shape (n_source_samples (+ 1), n_features)
16931711 (if bias) for kernel == gaussian
1712+ log_ : dictionary
1713+ The dictionary of log, empty dic if parameter log is not True
16941714
16951715 References
16961716 ----------
@@ -1745,19 +1765,26 @@ def fit(self, Xs=None, ys=None, Xt=None, yt=None):
17451765 self .Xt = Xt
17461766
17471767 if self .kernel == "linear" :
1748- self . coupling_ , self . mapping_ = joint_OT_mapping_linear (
1768+ returned_ = joint_OT_mapping_linear (
17491769 Xs , Xt , mu = self .mu , eta = self .eta , bias = self .bias ,
17501770 verbose = self .verbose , verbose2 = self .verbose2 ,
17511771 numItermax = self .max_iter , numInnerItermax = self .max_inner_iter ,
17521772 stopThr = self .tol , stopInnerThr = self .inner_tol , log = self .log )
17531773
17541774 elif self .kernel == "gaussian" :
1755- self . coupling_ , self . mapping_ = joint_OT_mapping_kernel (
1775+ returned_ = joint_OT_mapping_kernel (
17561776 Xs , Xt , mu = self .mu , eta = self .eta , bias = self .bias ,
17571777 sigma = self .sigma , verbose = self .verbose , verbose2 = self .verbose ,
17581778 numItermax = self .max_iter , numInnerItermax = self .max_inner_iter ,
17591779 stopInnerThr = self .inner_tol , stopThr = self .tol , log = self .log )
17601780
1781+ # deal with the value of log
1782+ if self .log :
1783+ self .coupling_ , self .mapping_ , self .log_ = returned_
1784+ else :
1785+ self .coupling_ , self .mapping_ = returned_
1786+ self .log_ = dict ()
1787+
17611788 return self
17621789
17631790 def transform (self , Xs ):
0 commit comments