1414from .import cvx
1515
1616# import compiled emd
17- from .emd_wrap import emd_c , check_result
17+ from .emd_wrap import emd_c , check_result , emd_1d_sorted
1818from ..utils import parmap
1919from .cvx import barycenter
2020from ..utils import dist
2121
22- __all__ = ['emd' , 'emd2' , 'barycenter' , 'free_support_barycenter' , 'cvx' ]
22+ __all__ = ['emd' , 'emd2' , 'barycenter' , 'free_support_barycenter' , 'cvx' , 'emd_1d_sorted' ]
2323
2424
2525def emd (a , b , M , numItermax = 100000 , log = False ):
@@ -94,7 +94,7 @@ def emd(a, b, M, numItermax=100000, log=False):
9494 b = np .asarray (b , dtype = np .float64 )
9595 M = np .asarray (M , dtype = np .float64 )
9696
97- # if empty array given then use unifor distributions
97+ # if empty array given then use uniform distributions
9898 if len (a ) == 0 :
9999 a = np .ones ((M .shape [0 ],), dtype = np .float64 ) / M .shape [0 ]
100100 if len (b ) == 0 :
@@ -187,7 +187,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
187187 b = np .asarray (b , dtype = np .float64 )
188188 M = np .asarray (M , dtype = np .float64 )
189189
190- # if empty array given then use unifor distributions
190+ # if empty array given then use uniform distributions
191191 if len (a ) == 0 :
192192 a = np .ones ((M .shape [0 ],), dtype = np .float64 ) / M .shape [0 ]
193193 if len (b ) == 0 :
@@ -308,4 +308,37 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
308308 log_dict ['displacement_square_norms' ] = displacement_square_norms
309309 return X , log_dict
310310 else :
311- return X
311+ return X
312+
313+
314+ def emd_1d (a , b , x_a , x_b , metric = 'sqeuclidean' , log = False ):
315+ """Solves the Earth Movers distance problem between 1d measures and returns
316+ the OT matrix
317+
318+ """
319+ assert x_a .shape [1 ] == x_b .shape [1 ] == 1 , "emd_1d should only be used " + \
320+ "with monodimensional data"
321+
322+ a = np .asarray (a , dtype = np .float64 )
323+ b = np .asarray (b , dtype = np .float64 )
324+
325+ # if empty array given then use uniform distributions
326+ if len (a ) == 0 :
327+ a = np .ones ((x_a .shape [0 ],), dtype = np .float64 ) / x_a .shape [0 ]
328+ if len (b ) == 0 :
329+ b = np .ones ((x_b .shape [0 ],), dtype = np .float64 ) / x_b .shape [0 ]
330+
331+ perm_a = np .argsort (x_a .reshape ((- 1 , )))
332+ perm_b = np .argsort (x_b .reshape ((- 1 , )))
333+ inv_perm_a = np .argsort (perm_a )
334+ inv_perm_b = np .argsort (perm_b )
335+
336+ M = dist (x_a [perm_a ], x_b [perm_b ], metric = metric )
337+
338+ G_sorted , cost = emd_1d_sorted (a , b , M )
339+ G = G_sorted [inv_perm_a , :][:, inv_perm_b ]
340+ if log :
341+ log = {}
342+ log ['cost' ] = cost
343+ return G , log
344+ return G
0 commit comments