@@ -1417,12 +1417,13 @@ class SinkhornTransport(BaseTransport):
14171417 Sciences, 7(3), 1853-1882.
14181418 """
14191419
1420- def __init__ (self , reg_e = 1. , max_iter = 1000 ,
1420+ def __init__ (self , reg_e = 1. , method = "sinkhorn" , max_iter = 1000 ,
14211421 tol = 10e-9 , verbose = False , log = False ,
14221422 metric = "sqeuclidean" , norm = None ,
14231423 distribution_estimation = distribution_estimation_uniform ,
14241424 out_of_sample_map = 'ferradans' , limit_max = np .infty ):
14251425 self .reg_e = reg_e
1426+ self .method = method
14261427 self .max_iter = max_iter
14271428 self .tol = tol
14281429 self .verbose = verbose
@@ -1463,7 +1464,7 @@ class label
14631464 # coupling estimation
14641465 returned_ = sinkhorn (
14651466 a = self .mu_s , b = self .mu_t , M = self .cost_ , reg = self .reg_e ,
1466- numItermax = self .max_iter , stopThr = self .tol ,
1467+ method = self . method , numItermax = self .max_iter , stopThr = self .tol ,
14671468 verbose = self .verbose , log = self .log )
14681469
14691470 # deal with the value of log
0 commit comments