Skip to content

Commit 3cecc18

Browse files
committed
Changes to LP solver:
- Allow to modify the maximal number of iterations - Display an error message in the python console if the solver encountered an issue
1 parent a2ec6e5 commit 3cecc18

File tree

5 files changed

+41
-18
lines changed

5 files changed

+41
-18
lines changed

ot/da.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -658,7 +658,7 @@ def __init__(self, metric='sqeuclidean'):
658658
self.metric = metric
659659
self.computed = False
660660

661-
def fit(self, xs, xt, ws=None, wt=None, norm=None):
661+
def fit(self, xs, xt, ws=None, wt=None, norm=None, numItermax=10000):
662662
"""Fit domain adaptation between samples is xs and xt
663663
(with optional weights)"""
664664
self.xs = xs
@@ -674,7 +674,7 @@ def fit(self, xs, xt, ws=None, wt=None, norm=None):
674674

675675
self.M = dist(xs, xt, metric=self.metric)
676676
self.normalizeM(norm)
677-
self.G = emd(ws, wt, self.M)
677+
self.G = emd(ws, wt, self.M, numItermax)
678678
self.computed = True
679679

680680
def interp(self, direction=1):

ot/lp/EMD.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,12 @@
2323
using namespace lemon;
2424
typedef unsigned int node_id_type;
2525

26+
enum ProblemType {
27+
INFEASIBLE,
28+
OPTIMAL,
29+
UNBOUNDED
30+
};
2631

27-
void EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *cost);
32+
int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *cost, int numItermax);
2833

2934
#endif

ot/lp/EMD_wrapper.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,10 @@
1515
#include "EMD.h"
1616

1717

18-
void EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *cost) {
18+
int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *cost, int numItermax) {
1919
// beware M and C anre strored in row major C style!!!
2020
int n, m, i,cur;
2121
double max;
22-
int max_iter=10000;
2322

2423
typedef FullBipartiteDigraph Digraph;
2524
DIGRAPH_TYPEDEFS(FullBipartiteDigraph);
@@ -46,7 +45,7 @@ void EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *
4645
std::vector<int> indI(n), indJ(m);
4746
std::vector<double> weights1(n), weights2(m);
4847
Digraph di(n, m);
49-
NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, n*m,max_iter);
48+
NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, n*m, numItermax);
5049

5150
// Set supply and demand, don't account for 0 values (faster)
5251

@@ -116,5 +115,5 @@ void EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *
116115
};
117116

118117

119-
118+
return ret;
120119
}

ot/lp/__init__.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616

1717

18-
def emd(a, b, M):
18+
def emd(a, b, M, numItermax=10000):
1919
"""Solves the Earth Movers distance problem and returns the OT matrix
2020
2121
@@ -40,6 +40,8 @@ def emd(a, b, M):
4040
Target histogram (uniform weigth if empty list)
4141
M : (ns,nt) ndarray, float64
4242
loss matrix
43+
numItermax : int
44+
Maximum number of iterations made by the LP solver.
4345
4446
Returns
4547
-------
@@ -84,9 +86,9 @@ def emd(a, b, M):
8486
if len(b) == 0:
8587
b = np.ones((M.shape[1], ), dtype=np.float64)/M.shape[1]
8688

87-
return emd_c(a, b, M)
89+
return emd_c(a, b, M, numItermax)
8890

89-
def emd2(a, b, M,processes=multiprocessing.cpu_count()):
91+
def emd2(a, b, M, numItermax=10000, processes=multiprocessing.cpu_count()):
9092
"""Solves the Earth Movers distance problem and returns the loss
9193
9294
.. math::
@@ -110,6 +112,8 @@ def emd2(a, b, M,processes=multiprocessing.cpu_count()):
110112
Target histogram (uniform weigth if empty list)
111113
M : (ns,nt) ndarray, float64
112114
loss matrix
115+
numItermax : int
116+
Maximum number of iterations made by the LP solver.
113117
114118
Returns
115119
-------
@@ -155,12 +159,12 @@ def emd2(a, b, M,processes=multiprocessing.cpu_count()):
155159
b = np.ones((M.shape[1], ), dtype=np.float64)/M.shape[1]
156160

157161
if len(b.shape)==1:
158-
return emd2_c(a, b, M)
162+
return emd2_c(a, b, M, numItermax)
159163
else:
160164
nb=b.shape[1]
161-
#res=[emd2_c(a,b[:,i].copy(),M) for i in range(nb)]
165+
#res=[emd2_c(a,b[:,i].copy(),M, numItermax) for i in range(nb)]
162166
def f(b):
163-
return emd2_c(a,b,M)
167+
return emd2_c(a,b,M, numItermax)
164168
res= parmap(f, [b[:,i] for i in range(nb)],processes)
165169
return np.array(res)
166170

ot/lp/emd_wrap.pyx

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,14 @@ cimport cython
1515

1616

1717
cdef extern from "EMD.h":
18-
void EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *cost)
18+
int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *cost, int numItermax)
19+
cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED
1920

2021

2122

2223
@cython.boundscheck(False)
2324
@cython.wraparound(False)
24-
def emd_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M):
25+
def emd_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M, int numItermax):
2526
"""
2627
Solves the Earth Movers distance problem and returns the optimal transport matrix
2728
@@ -48,6 +49,8 @@ def emd_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mod
4849
target histogram
4950
M : (ns,nt) ndarray, float64
5051
loss matrix
52+
numItermax : int
53+
Maximum number of iterations made by the LP solver.
5154
5255
5356
Returns
@@ -69,13 +72,18 @@ def emd_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mod
6972
b=np.ones((n2,))/n2
7073

7174
# calling the function
72-
EMD_wrap(n1,n2,<double*> a.data,<double*> b.data,<double*> M.data,<double*> G.data,<double*> &cost)
75+
cdef int resultSolver = EMD_wrap(n1,n2,<double*> a.data,<double*> b.data,<double*> M.data,<double*> G.data,<double*> &cost, numItermax)
76+
if resultSolver != OPTIMAL:
77+
if resultSolver == INFEASIBLE:
78+
print("Problem infeasible. Try to inscrease numItermax.")
79+
elif resultSolver == UNBOUNDED:
80+
print("Problem unbounded")
7381

7482
return G
7583

7684
@cython.boundscheck(False)
7785
@cython.wraparound(False)
78-
def emd2_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M):
86+
def emd2_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M, int numItermax):
7987
"""
8088
Solves the Earth Movers distance problem and returns the optimal transport loss
8189
@@ -102,6 +110,8 @@ def emd2_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mo
102110
target histogram
103111
M : (ns,nt) ndarray, float64
104112
loss matrix
113+
numItermax : int
114+
Maximum number of iterations made by the LP solver.
105115
106116
107117
Returns
@@ -123,7 +133,12 @@ def emd2_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mo
123133
b=np.ones((n2,))/n2
124134

125135
# calling the function
126-
EMD_wrap(n1,n2,<double*> a.data,<double*> b.data,<double*> M.data,<double*> G.data,<double*> &cost)
136+
cdef int resultSolver = EMD_wrap(n1,n2,<double*> a.data,<double*> b.data,<double*> M.data,<double*> G.data,<double*> &cost, numItermax)
137+
if resultSolver != OPTIMAL:
138+
if resultSolver == INFEASIBLE:
139+
print("Problem infeasible. Try to inscrease numItermax.")
140+
elif resultSolver == UNBOUNDED:
141+
print("Problem unbounded")
127142

128143
cost=0
129144
for i in range(n1):

0 commit comments

Comments
 (0)