@@ -202,7 +202,7 @@ def estimate_dual_null_weights(alpha0, beta0, a, b, M):
202202 return center_ot_dual (alpha , beta , a , b )
203203
204204
205- def emd (a , b , M , numItermax = 100000 , log = False , center_dual = True , numThreads = 1 ):
205+ def emd (a , b , M , numItermax = 100000 , log = False , center_dual = True , numThreads = 1 , check_marginals = True ):
206206 r"""Solves the Earth Movers distance problem and returns the OT matrix
207207
208208
@@ -259,6 +259,10 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1):
259259 numThreads: int or "max", optional (default=1, i.e. OpenMP is not used)
260260 If compiled with OpenMP, chooses the number of threads to parallelize.
261261 "max" selects the highest number possible.
262+ check_marginals: bool, optional (default=True)
263+ If True, checks that the marginals mass are equal. If False, skips the
264+ check.
265+
262266
263267 Returns
264268 -------
@@ -328,9 +332,10 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1):
328332 "Dimension mismatch, check dimensions of M with a and b"
329333
330334 # ensure that same mass
331- np .testing .assert_almost_equal (a .sum (0 ),
332- b .sum (0 ), err_msg = 'a and b vector must have the same sum' ,
333- decimal = 6 )
335+ if check_marginals :
336+ np .testing .assert_almost_equal (a .sum (0 ),
337+ b .sum (0 ), err_msg = 'a and b vector must have the same sum' ,
338+ decimal = 6 )
334339 b = b * a .sum () / b .sum ()
335340
336341 asel = a != 0
@@ -368,7 +373,7 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1):
368373
369374def emd2 (a , b , M , processes = 1 ,
370375 numItermax = 100000 , log = False , return_matrix = False ,
371- center_dual = True , numThreads = 1 ):
376+ center_dual = True , numThreads = 1 , check_marginals = True ):
372377 r"""Solves the Earth Movers distance problem and returns the loss
373378
374379 .. math::
@@ -425,7 +430,11 @@ def emd2(a, b, M, processes=1,
425430 numThreads: int or "max", optional (default=1, i.e. OpenMP is not used)
426431 If compiled with OpenMP, chooses the number of threads to parallelize.
427432 "max" selects the highest number possible.
428-
433+ check_marginals: bool, optional (default=True)
434+ If True, checks that the marginals mass are equal. If False, skips the
435+ check.
436+
437+
429438 Returns
430439 -------
431440 W: float, array-like
@@ -492,8 +501,10 @@ def emd2(a, b, M, processes=1,
492501 "Dimension mismatch, check dimensions of M with a and b"
493502
494503 # ensure that same mass
495- np .testing .assert_almost_equal (a .sum (0 ),
496- b .sum (0 ,keepdims = True ), err_msg = 'a and b vector must have the same sum' )
504+ if check_marginals :
505+ np .testing .assert_almost_equal (a .sum (0 ),
506+ b .sum (0 ,keepdims = True ), err_msg = 'a and b vector must have the same sum' ,
507+ decimal = 6 )
497508 b = b * a .sum (0 ) / b .sum (0 ,keepdims = True )
498509
499510 asel = a != 0
0 commit comments