@@ -1088,26 +1088,23 @@ def fit(self, Xs=None, ys=None, Xt=None, yt=None):
10881088 # pairwise distance
10891089 self .Cost = dist (Xs , Xt , metric = self .metric )
10901090
1091- if self .mode == "semisupervised" :
1092-
1093- if (ys is not None ) and (yt is not None ):
1094-
1095- # assumes labeled source samples occupy the first rows
1096- # and labeled target samples occupy the first columns
1097- classes = np .unique (ys )
1098- for c in classes :
1099- ids = np .where (ys == c )
1100- idt = np .where (yt == c )
1101-
1102- # all the coefficients corresponding to a source sample
1103- # and a target sample with the same label gets a 0
1104- # transport cost
1105- for j in idt [0 ]:
1106- self .Cost [ids [0 ], j ] = 0
1107- else :
1108- print ("Warning: using unsupervised mode\
1109- \n to use semisupervised mode, please provide ys and yt" )
1110- pass
1091+ if (ys is not None ) and (yt is not None ):
1092+
1093+ if self .limit_max != np .infty :
1094+ self .limit_max = self .limit_max * np .max (self .Cost )
1095+
1096+ # assumes labeled source samples occupy the first rows
1097+ # and labeled target samples occupy the first columns
1098+ classes = np .unique (ys )
1099+ for c in classes :
1100+ idx_s = np .where ((ys != c ) & (ys != - 1 ))
1101+ idx_t = np .where (yt == c )
1102+
1103+ # all the coefficients corresponding to a source sample
1104+ # and a target sample :
1105+ # with different labels get a infinite
1106+ for j in idx_t [0 ]:
1107+ self .Cost [idx_s [0 ], j ] = self .limit_max
11111108
11121109 # distribution estimation
11131110 self .mu_s = self .distribution_estimation (Xs )
@@ -1243,6 +1240,9 @@ class SinkhornTransport(BaseTransport):
12431240 Controls the verbosity of the optimization algorithm
12441241 log : int, optional (default=0)
12451242 Controls the logs of the optimization algorithm
1243+ limit_max: float, optional (defaul=np.infty)
1244+ Controls the semi supervised mode. Transport between labeled source
1245+ and target samples of different classes will exhibit an infinite cost
12461246 Attributes
12471247 ----------
12481248 Coupling_ : the optimal coupling
@@ -1257,19 +1257,19 @@ class SinkhornTransport(BaseTransport):
12571257 26, 2013
12581258 """
12591259
1260- def __init__ (self , reg_e = 1. , mode = "unsupervised" , max_iter = 1000 ,
1260+ def __init__ (self , reg_e = 1. , max_iter = 1000 ,
12611261 tol = 10e-9 , verbose = False , log = False ,
12621262 metric = "sqeuclidean" ,
12631263 distribution_estimation = distribution_estimation_uniform ,
1264- out_of_sample_map = 'ferradans' ):
1264+ out_of_sample_map = 'ferradans' , limit_max = np . infty ):
12651265
12661266 self .reg_e = reg_e
1267- self .mode = mode
12681267 self .max_iter = max_iter
12691268 self .tol = tol
12701269 self .verbose = verbose
12711270 self .log = log
12721271 self .metric = metric
1272+ self .limit_max = limit_max
12731273 self .distribution_estimation = distribution_estimation
12741274 self .out_of_sample_map = out_of_sample_map
12751275
@@ -1326,6 +1326,10 @@ class EMDTransport(BaseTransport):
13261326 Controls the verbosity of the optimization algorithm
13271327 log : int, optional (default=0)
13281328 Controls the logs of the optimization algorithm
1329+ limit_max: float, optional (default=10)
1330+ Controls the semi supervised mode. Transport between labeled source
1331+ and target samples of different classes will exhibit an infinite cost
1332+ (10 times the maximum value of the cost matrix)
13291333 Attributes
13301334 ----------
13311335 Coupling_ : the optimal coupling
@@ -1337,15 +1341,15 @@ class EMDTransport(BaseTransport):
13371341 on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
13381342 """
13391343
1340- def __init__ (self , mode = "unsupervised" , verbose = False ,
1344+ def __init__ (self , verbose = False ,
13411345 log = False , metric = "sqeuclidean" ,
13421346 distribution_estimation = distribution_estimation_uniform ,
1343- out_of_sample_map = 'ferradans' ):
1347+ out_of_sample_map = 'ferradans' , limit_max = 10 ):
13441348
1345- self .mode = mode
13461349 self .verbose = verbose
13471350 self .log = log
13481351 self .metric = metric
1352+ self .limit_max = limit_max
13491353 self .distribution_estimation = distribution_estimation
13501354 self .out_of_sample_map = out_of_sample_map
13511355
@@ -1414,6 +1418,10 @@ class SinkhornLpl1Transport(BaseTransport):
14141418 Controls the verbosity of the optimization algorithm
14151419 log : int, optional (default=0)
14161420 Controls the logs of the optimization algorithm
1421+ limit_max: float, optional (defaul=np.infty)
1422+ Controls the semi supervised mode. Transport between labeled source
1423+ and target samples of different classes will exhibit an infinite cost
1424+
14171425 Attributes
14181426 ----------
14191427 Coupling_ : the optimal coupling
@@ -1431,16 +1439,15 @@ class SinkhornLpl1Transport(BaseTransport):
14311439
14321440 """
14331441
1434- def __init__ (self , reg_e = 1. , reg_cl = 0.1 , mode = "unsupervised" ,
1442+ def __init__ (self , reg_e = 1. , reg_cl = 0.1 ,
14351443 max_iter = 10 , max_inner_iter = 200 ,
14361444 tol = 10e-9 , verbose = False , log = False ,
14371445 metric = "sqeuclidean" ,
14381446 distribution_estimation = distribution_estimation_uniform ,
1439- out_of_sample_map = 'ferradans' ):
1447+ out_of_sample_map = 'ferradans' , limit_max = np . infty ):
14401448
14411449 self .reg_e = reg_e
14421450 self .reg_cl = reg_cl
1443- self .mode = mode
14441451 self .max_iter = max_iter
14451452 self .max_inner_iter = max_inner_iter
14461453 self .tol = tol
@@ -1449,6 +1456,7 @@ def __init__(self, reg_e=1., reg_cl=0.1, mode="unsupervised",
14491456 self .metric = metric
14501457 self .distribution_estimation = distribution_estimation
14511458 self .out_of_sample_map = out_of_sample_map
1459+ self .limit_max = limit_max
14521460
14531461 def fit (self , Xs , ys = None , Xt = None , yt = None ):
14541462 """Build a coupling matrix from source and target sets of samples
@@ -1514,6 +1522,11 @@ class SinkhornL1l2Transport(BaseTransport):
15141522 Controls the verbosity of the optimization algorithm
15151523 log : int, optional (default=0)
15161524 Controls the logs of the optimization algorithm
1525+ limit_max: float, optional (default=10)
1526+ Controls the semi supervised mode. Transport between labeled source
1527+ and target samples of different classes will exhibit an infinite cost
1528+ (10 times the maximum value of the cost matrix)
1529+
15171530 Attributes
15181531 ----------
15191532 Coupling_ : the optimal coupling
@@ -1531,16 +1544,15 @@ class SinkhornL1l2Transport(BaseTransport):
15311544
15321545 """
15331546
1534- def __init__ (self , reg_e = 1. , reg_cl = 0.1 , mode = "unsupervised" ,
1547+ def __init__ (self , reg_e = 1. , reg_cl = 0.1 ,
15351548 max_iter = 10 , max_inner_iter = 200 ,
15361549 tol = 10e-9 , verbose = False , log = False ,
15371550 metric = "sqeuclidean" ,
15381551 distribution_estimation = distribution_estimation_uniform ,
1539- out_of_sample_map = 'ferradans' ):
1552+ out_of_sample_map = 'ferradans' , limit_max = 10 ):
15401553
15411554 self .reg_e = reg_e
15421555 self .reg_cl = reg_cl
1543- self .mode = mode
15441556 self .max_iter = max_iter
15451557 self .max_inner_iter = max_inner_iter
15461558 self .tol = tol
@@ -1549,6 +1561,7 @@ def __init__(self, reg_e=1., reg_cl=0.1, mode="unsupervised",
15491561 self .metric = metric
15501562 self .distribution_estimation = distribution_estimation
15511563 self .out_of_sample_map = out_of_sample_map
1564+ self .limit_max = limit_max
15521565
15531566 def fit (self , Xs , ys = None , Xt = None , yt = None ):
15541567 """Build a coupling matrix from source and target sets of samples
0 commit comments