@@ -15,13 +15,14 @@ cimport cython
1515
1616
1717cdef extern from " EMD.h" :
18- void EMD_wrap(int n1,int n2, double * X, double * Y,double * D, double * G, double * cost)
18+ int EMD_wrap(int n1,int n2, double * X, double * Y,double * D, double * G, double * cost, int numItermax)
19+ cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED
1920
2021
2122
2223@ cython.boundscheck (False )
2324@ cython.wraparound (False )
24- def emd_c ( np.ndarray[double , ndim = 1 , mode = " c" ] a,np.ndarray[double , ndim = 1 , mode = " c" ] b,np.ndarray[double , ndim = 2 , mode = " c" ] M):
25+ def emd_c ( np.ndarray[double , ndim = 1 , mode = " c" ] a,np.ndarray[double , ndim = 1 , mode = " c" ] b,np.ndarray[double , ndim = 2 , mode = " c" ] M, int numItermax ):
2526 """
2627 Solves the Earth Movers distance problem and returns the optimal transport matrix
2728
@@ -48,6 +49,8 @@ def emd_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mod
4849 target histogram
4950 M : (ns,nt) ndarray, float64
5051 loss matrix
52+ numItermax : int
53+ Maximum number of iterations made by the LP solver.
5154
5255
5356 Returns
@@ -69,13 +72,18 @@ def emd_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mod
6972 b= np.ones((n2,))/ n2
7073
7174 # calling the function
72- EMD_wrap(n1,n2,< double * > a.data,< double * > b.data,< double * > M.data,< double * > G.data,< double * > & cost)
75+ cdef int resultSolver = EMD_wrap(n1,n2,< double * > a.data,< double * > b.data,< double * > M.data,< double * > G.data,< double * > & cost, numItermax)
76+ if resultSolver != OPTIMAL:
77+ if resultSolver == INFEASIBLE:
78+ print (" Problem infeasible. Try to inscrease numItermax." )
79+ elif resultSolver == UNBOUNDED:
80+ print (" Problem unbounded" )
7381
7482 return G
7583
7684@ cython.boundscheck (False )
7785@ cython.wraparound (False )
78- def emd2_c ( np.ndarray[double , ndim = 1 , mode = " c" ] a,np.ndarray[double , ndim = 1 , mode = " c" ] b,np.ndarray[double , ndim = 2 , mode = " c" ] M):
86+ def emd2_c ( np.ndarray[double , ndim = 1 , mode = " c" ] a,np.ndarray[double , ndim = 1 , mode = " c" ] b,np.ndarray[double , ndim = 2 , mode = " c" ] M, int numItermax ):
7987 """
8088 Solves the Earth Movers distance problem and returns the optimal transport loss
8189
@@ -102,6 +110,8 @@ def emd2_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mo
102110 target histogram
103111 M : (ns,nt) ndarray, float64
104112 loss matrix
113+ numItermax : int
114+ Maximum number of iterations made by the LP solver.
105115
106116
107117 Returns
@@ -123,7 +133,12 @@ def emd2_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mo
123133 b= np.ones((n2,))/ n2
124134
125135 # calling the function
126- EMD_wrap(n1,n2,< double * > a.data,< double * > b.data,< double * > M.data,< double * > G.data,< double * > & cost)
136+ cdef int resultSolver = EMD_wrap(n1,n2,< double * > a.data,< double * > b.data,< double * > M.data,< double * > G.data,< double * > & cost, numItermax)
137+ if resultSolver != OPTIMAL:
138+ if resultSolver == INFEASIBLE:
139+ print (" Problem infeasible. Try to inscrease numItermax." )
140+ elif resultSolver == UNBOUNDED:
141+ print (" Problem unbounded" )
127142
128143 cost= 0
129144 for i in range (n1):
0 commit comments