1414import multiprocessing
1515
1616
17- def emd (a , b , M , max_iter = 100000 ):
17+ def emd (a , b , M , numItermax = 100000 ):
1818 """Solves the Earth Movers distance problem and returns the OT matrix
1919
2020
@@ -39,7 +39,7 @@ def emd(a, b, M, max_iter=100000):
3939 Target histogram (uniform weigth if empty list)
4040 M : (ns,nt) ndarray, float64
4141 loss matrix
42- max_iter : int, optional (default=100000)
42+ numItermax : int, optional (default=100000)
4343 The maximum number of iterations before stopping the optimization
4444 algorithm if it has not converged.
4545
@@ -86,10 +86,10 @@ def emd(a, b, M, max_iter=100000):
8686 if len (b ) == 0 :
8787 b = np .ones ((M .shape [1 ], ), dtype = np .float64 )/ M .shape [1 ]
8888
89- return emd_c (a , b , M , max_iter )
89+ return emd_c (a , b , M , numItermax )
9090
9191
92- def emd2 (a , b , M , processes = multiprocessing .cpu_count (), max_iter = 100000 ):
92+ def emd2 (a , b , M , processes = multiprocessing .cpu_count (), numItermax = 100000 ):
9393 """Solves the Earth Movers distance problem and returns the loss
9494
9595 .. math::
@@ -113,7 +113,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), max_iter=100000):
113113 Target histogram (uniform weigth if empty list)
114114 M : (ns,nt) ndarray, float64
115115 loss matrix
116- max_iter : int, optional (default=100000)
116+ numItermax : int, optional (default=100000)
117117 The maximum number of iterations before stopping the optimization
118118 algorithm if it has not converged.
119119
@@ -161,12 +161,12 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), max_iter=100000):
161161 b = np .ones ((M .shape [1 ], ), dtype = np .float64 )/ M .shape [1 ]
162162
163163 if len (b .shape ) == 1 :
164- return emd2_c (a , b , M , max_iter )
164+ return emd2_c (a , b , M , numItermax )
165165 else :
166166 nb = b .shape [1 ]
167- # res = [emd2_c(a, b[:, i].copy(), M, max_iter ) for i in range(nb)]
167+ # res = [emd2_c(a, b[:, i].copy(), M, numItermax ) for i in range(nb)]
168168
169169 def f (b ):
170- return emd2_c (a , b , M , max_iter )
170+ return emd2_c (a , b , M , numItermax )
171171 res = parmap (f , [b [:, i ] for i in range (nb )], processes )
172172 return np .array (res )
0 commit comments