1010import multiprocessing
1111
1212import numpy as np
13+ from scipy .sparse import coo_matrix
1314
1415from .import cvx
1516
1920from .cvx import barycenter
2021from ..utils import dist
2122
22- __all__ = ['emd' , 'emd2' , 'barycenter' , 'free_support_barycenter' , 'cvx' , 'emd_1d_sorted' ]
23+ __all__ = ['emd' , 'emd2' , 'barycenter' , 'free_support_barycenter' , 'cvx' ,
24+ 'emd_1d' , 'emd2_1d' ]
2325
2426
2527def emd (a , b , M , numItermax = 100000 , log = False ):
@@ -311,33 +313,57 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
311313 return X
312314
313315
314- def emd_1d (a , b , x_a , x_b , metric = 'sqeuclidean' , log = False ):
316+ def emd_1d (a , b , x_a , x_b , metric = 'sqeuclidean' , dense = True , log = False ):
315317 """Solves the Earth Movers distance problem between 1d measures and returns
316318 the OT matrix
317319
318320 """
319- assert x_a .shape [1 ] == x_b .shape [1 ] == 1 , "emd_1d should only be used " + \
320- "with monodimensional data"
321-
322321 a = np .asarray (a , dtype = np .float64 )
323322 b = np .asarray (b , dtype = np .float64 )
323+ x_a = np .asarray (x_a , dtype = np .float64 )
324+ x_b = np .asarray (x_b , dtype = np .float64 )
325+
326+ assert (x_a .ndim == 1 or x_a .ndim == 2 and x_a .shape [1 ] == 1 ), \
327+ "emd_1d should only be used with monodimensional data"
328+ assert (x_b .ndim == 1 or x_b .ndim == 2 and x_b .shape [1 ] == 1 ), \
329+ "emd_1d should only be used with monodimensional data"
324330
325331 # if empty array given then use uniform distributions
326332 if len (a ) == 0 :
327333 a = np .ones ((x_a .shape [0 ],), dtype = np .float64 ) / x_a .shape [0 ]
328334 if len (b ) == 0 :
329335 b = np .ones ((x_b .shape [0 ],), dtype = np .float64 ) / x_b .shape [0 ]
330336
331- perm_a = np .argsort (x_a .reshape ((- 1 , )))
332- perm_b = np .argsort (x_b .reshape ((- 1 , )))
333- inv_perm_a = np .argsort (perm_a )
334- inv_perm_b = np .argsort (perm_b )
335-
336- G_sorted , cost = emd_1d_sorted (a , b , x_a [perm_a ], x_b [perm_b ],
337- metric = metric )
338- G = G_sorted [inv_perm_a , :][:, inv_perm_b ]
337+ x_a_1d = x_a .reshape ((- 1 , ))
338+ x_b_1d = x_b .reshape ((- 1 , ))
339+ perm_a = np .argsort (x_a_1d )
340+ perm_b = np .argsort (x_b_1d )
341+
342+ G_sorted , indices , cost = emd_1d_sorted (a , b ,
343+ x_a_1d [perm_a ], x_b_1d [perm_b ],
344+ metric = metric )
345+ G = coo_matrix ((G_sorted , (perm_a [indices [:, 0 ]], perm_b [indices [:, 1 ]])),
346+ shape = (a .shape [0 ], b .shape [0 ]))
347+ if dense :
348+ G = G .todense ()
339349 if log :
340350 log = {}
341351 log ['cost' ] = cost
342352 return G , log
343- return G
353+ return G
354+
355+
356+ def emd2_1d (a , b , x_a , x_b , metric = 'sqeuclidean' , dense = True , log = False ):
357+ """Solves the Earth Movers distance problem between 1d measures and returns
358+ the loss
359+
360+ """
361+ # If we do not return G (log==False), then we should not to cast it to dense
362+ # (useless overhead)
363+ G , log_emd = emd_1d (a = a , b = b , x_a = x_a , x_b = x_b , metric = metric ,
364+ dense = dense and log , log = True )
365+ cost = log_emd ['cost' ]
366+ if log :
367+ log_emd = {'G' : G }
368+ return cost , log_emd
369+ return cost
0 commit comments