@@ -10,6 +10,8 @@ Cython linker with C solver
1010import numpy as np
1111cimport numpy as np
1212
13+ from ..utils import dist
14+
1315cimport cython
1416
1517import warnings
@@ -99,7 +101,9 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod
99101@ cython.wraparound (False )
100102def emd_1d_sorted (np.ndarray[double , ndim = 1 , mode = " c" ] u_weights,
101103 np.ndarray[double , ndim = 1 , mode = " c" ] v_weights,
102- np.ndarray[double , ndim = 2 , mode = " c" ] M):
104+ np.ndarray[double , ndim = 2 , mode = " c" ] u,
105+ np.ndarray[double , ndim = 2 , mode = " c" ] v,
106+ str metric = ' sqeuclidean' ):
103107 r """
104108 Roro's stuff
105109 """
@@ -112,17 +116,21 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights,
112116 cdef int j = 0
113117 cdef double w_j = v_weights[0 ]
114118
119+ cdef double m_ij = 0.
120+
115121 cdef np.ndarray[double , ndim= 2 , mode= " c" ] G = np.zeros((n, m),
116122 dtype = np.float64)
117123 while i < n and j < m:
124+ m_ij = dist(u[i].reshape((1 , 1 )), v[j].reshape((1 , 1 )),
125+ metric = metric)[0 , 0 ]
118126 if w_i < w_j or j == m - 1 :
119- cost += M[i, j] * w_i
127+ cost += m_ij * w_i
120128 G[i, j] = w_i
121129 i += 1
122130 w_j -= w_i
123131 w_i = u_weights[i]
124132 else :
125- cost += M[i, j] * w_j
133+ cost += m_ij * w_j
126134 G[i, j] = w_j
127135 j += 1
128136 w_i -= w_j
0 commit comments