1616from ..utils import parmap
1717
1818
19- def emd (a , b , M , num_iter_max = 100000 , log = False ):
19+ def emd (a , b , M , numItermax = 100000 , log = False ):
2020 """Solves the Earth Movers distance problem and returns the OT matrix
2121
2222
@@ -41,7 +41,7 @@ def emd(a, b, M, num_iter_max=100000, log=False):
4141 Target histogram (uniform weigth if empty list)
4242 M : (ns,nt) ndarray, float64
4343 loss matrix
44- num_iter_max : int, optional (default=100000)
44+ numItermax : int, optional (default=100000)
4545 The maximum number of iterations before stopping the optimization
4646 algorithm if it has not converged.
4747 log: boolean, optional (default=False)
@@ -94,7 +94,7 @@ def emd(a, b, M, num_iter_max=100000, log=False):
9494 if len (b ) == 0 :
9595 b = np .ones ((M .shape [1 ],), dtype = np .float64 ) / M .shape [1 ]
9696
97- G , cost , u , v , result_code = emd_c (a , b , M , num_iter_max )
97+ G , cost , u , v , result_code = emd_c (a , b , M , numItermax )
9898 result_code_string = check_result (result_code )
9999 if log :
100100 log = {}
@@ -107,7 +107,7 @@ def emd(a, b, M, num_iter_max=100000, log=False):
107107 return G
108108
109109
110- def emd2 (a , b , M , processes = multiprocessing .cpu_count (), num_iter_max = 100000 , log = False , return_matrix = False ):
110+ def emd2 (a , b , M , processes = multiprocessing .cpu_count (), numItermax = 100000 , log = False , return_matrix = False ):
111111 """Solves the Earth Movers distance problem and returns the loss
112112
113113 .. math::
@@ -131,7 +131,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), num_iter_max=100000, lo
131131 Target histogram (uniform weigth if empty list)
132132 M : (ns,nt) ndarray, float64
133133 loss matrix
134- num_iter_max : int, optional (default=100000)
134+ numItermax : int, optional (default=100000)
135135 The maximum number of iterations before stopping the optimization
136136 algorithm if it has not converged.
137137 log: boolean, optional (default=False)
@@ -188,7 +188,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), num_iter_max=100000, lo
188188
189189 if log or return_matrix :
190190 def f (b ):
191- G , cost , u , v , resultCode = emd_c (a , b , M , num_iter_max )
191+ G , cost , u , v , resultCode = emd_c (a , b , M , numItermax )
192192 result_code_string = check_result (resultCode )
193193 log = {}
194194 if return_matrix :
@@ -200,7 +200,7 @@ def f(b):
200200 return [cost , log ]
201201 else :
202202 def f (b ):
203- G , cost , u , v , result_code = emd_c (a , b , M , num_iter_max )
203+ G , cost , u , v , result_code = emd_c (a , b , M , numItermax )
204204 check_result (result_code )
205205 return cost
206206
0 commit comments