@@ -313,7 +313,7 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
313313 return X
314314
315315
316- def emd_1d (x_a , x_b , a , b , metric = 'sqeuclidean' , dense = True , log = False ):
316+ def emd_1d (x_a , x_b , a = None , b = None , metric = 'sqeuclidean' , dense = True , log = False ):
317317 """Solves the Earth Movers distance problem between 1d measures and returns
318318 the OT matrix
319319
@@ -338,10 +338,10 @@ def emd_1d(x_a, x_b, a, b, metric='sqeuclidean', dense=True, log=False):
338338 Source dirac locations (on the real line)
339339 x_b : (nt,) or (ns, 1) ndarray, float64
340340 Target dirac locations (on the real line)
341- a : (ns,) ndarray, float64
342- Source histogram (uniform weight if empty list )
343- b : (nt,) ndarray, float64
344- Target histogram (uniform weight if empty list )
341+ a : (ns,) ndarray, float64, optional
342+ Source histogram (default is uniform weight )
343+ b : (nt,) ndarray, float64, optional
344+ Target histogram (default is uniform weight )
345345 metric: str, optional (default='sqeuclidean')
346346 Metric to be used. Only strings listed in :func:`ot.dist` are accepted.
347347 Due to implementation details, this function runs faster when
@@ -375,6 +375,9 @@ def emd_1d(x_a, x_b, a, b, metric='sqeuclidean', dense=True, log=False):
375375 >>> x_a = [2., 0.]
376376 >>> x_b = [0., 3.]
377377 >>> ot.emd_1d(x_a, x_b, a, b)
378+ array([[0. , 0.5],
379+ [0.5, 0. ]])
380+ >>> ot.emd_1d(x_a, x_b)
378381 array([[0. , 0.5],
379382 [0.5, 0. ]])
380383
@@ -401,9 +404,9 @@ def emd_1d(x_a, x_b, a, b, metric='sqeuclidean', dense=True, log=False):
401404 "emd_1d should only be used with monodimensional data"
402405
403406 # if empty array given then use uniform distributions
404- if len (a ) == 0 :
407+ if a . ndim == 0 or len (a ) == 0 :
405408 a = np .ones ((x_a .shape [0 ],), dtype = np .float64 ) / x_a .shape [0 ]
406- if len (b ) == 0 :
409+ if b . ndim == 0 or len (b ) == 0 :
407410 b = np .ones ((x_b .shape [0 ],), dtype = np .float64 ) / x_b .shape [0 ]
408411
409412 x_a_1d = x_a .reshape ((- 1 , ))
@@ -424,7 +427,7 @@ def emd_1d(x_a, x_b, a, b, metric='sqeuclidean', dense=True, log=False):
424427 return G
425428
426429
427- def emd2_1d (x_a , x_b , a , b , metric = 'sqeuclidean' , dense = True , log = False ):
430+ def emd2_1d (x_a , x_b , a = None , b = None , metric = 'sqeuclidean' , dense = True , log = False ):
428431 """Solves the Earth Movers distance problem between 1d measures and returns
429432 the loss
430433
@@ -449,10 +452,10 @@ def emd2_1d(x_a, x_b, a, b, metric='sqeuclidean', dense=True, log=False):
449452 Source dirac locations (on the real line)
450453 x_b : (nt,) or (ns, 1) ndarray, float64
451454 Target dirac locations (on the real line)
452- a : (ns,) ndarray, float64
453- Source histogram (uniform weight if empty list )
454- b : (nt,) ndarray, float64
455- Target histogram (uniform weight if empty list )
455+ a : (ns,) ndarray, float64, optional
456+ Source histogram (default is uniform weight )
457+ b : (nt,) ndarray, float64, optional
458+ Target histogram (default is uniform weight )
456459 metric: str, optional (default='sqeuclidean')
457460 Metric to be used. Only strings listed in :func:`ot.dist` are accepted.
458461 Due to implementation details, this function runs faster when
@@ -488,6 +491,8 @@ def emd2_1d(x_a, x_b, a, b, metric='sqeuclidean', dense=True, log=False):
488491 >>> x_b = [0., 3.]
489492 >>> ot.emd2_1d(x_a, x_b, a, b)
490493 0.5
494+ >>> ot.emd2_1d(x_a, x_b)
495+ 0.5
491496
492497 References
493498 ----------
0 commit comments