You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
- :math:`\Omega_e` is the entropic regularization term :math:`\Omega_e(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
32
+
- :math:`\Omega_g` is the group lasso regulaization term :math:`\Omega_g(\gamma)=\sum_{i,c} \|\gamma_{i,\mathcal{I}_c}\|^{1/2}_1` where :math:`\mathcal{I}_c` are the index of samples from class c in the source domain.
33
+
- a and b are source and target weights (sum to 1)
34
+
35
+
The algorithm used for solving the problem is the generalised conditional gradient as proposed in [5]_ [7]_
36
+
37
+
38
+
Parameters
39
+
----------
40
+
a : np.ndarray (ns,)
41
+
samples weights in the source domain
42
+
labels_a : np.ndarray (ns,)
43
+
labels of samples in the source domain
44
+
b : np.ndarray (nt,)
45
+
samples in the target domain
46
+
M : np.ndarray (ns,nt)
47
+
loss matrix
48
+
reg: float
49
+
Regularization term for entropic regularization >0
50
+
eta: float, optional
51
+
Regularization term for group lasso regularization >0
52
+
numItermax: int, optional
53
+
Max number of iterations
54
+
numInnerItermax: int, optional
55
+
Max number of iterations (inner sinkhorn solver)
56
+
stopInnerThr: float, optional
57
+
Stop threshold on error (inner sinkhorn solver) (>0)
58
+
verbose : bool, optional
59
+
Print information along iterations
60
+
log : bool, optional
61
+
record log if True
62
+
63
+
64
+
Returns
65
+
-------
66
+
gamma: (ns x nt) ndarray
67
+
Optimal transportation matrix for the given parameters
68
+
log: dict
69
+
log dictionary return only if log==True in parameters
70
+
71
+
Examples
72
+
--------
73
+
74
+
>>> a=[.5,.5]
75
+
>>> b=[.5,.5]
76
+
>>> M=[[0.,1.],[1.,0.]]
77
+
>>> ot.sinkhorn(a,b,M,1)
78
+
array([[ 0.36552929, 0.13447071],
79
+
[ 0.13447071, 0.36552929]])
80
+
81
+
82
+
References
83
+
----------
84
+
85
+
.. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, "Optimal Transport for Domain Adaptation," in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
86
+
87
+
.. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). Generalized conditional gradient: analysis of convergence and applications. arXiv preprint arXiv:1510.06567.
88
+
89
+
See Also
90
+
--------
91
+
ot.lp.emd : Unregularized OT
92
+
ot.bregman.sinkhorn : Entropic regularized OT
93
+
ot.optim.cg : General regularized OT
94
+
95
+
"""
13
96
p=0.5
14
97
epsilon=1e-3
15
98
@@ -25,9 +108,9 @@ def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1):
Solvers for the original linear program OT problem
3
4
"""
4
5
6
+
importnumpyasnp
5
7
# import compiled emd
6
8
from .emdimportemd_c
7
-
importnumpyasnp
8
9
9
-
defemd(a,b,M):
10
-
"""
11
-
Solves the Earth Movers distance problem and returns the optimal transport matrix
12
-
13
-
10
+
11
+
defemd(a, b, M):
12
+
"""Solves the Earth Movers distance problem and returns the OT matrix
13
+
14
+
14
15
.. math::
15
-
\gamma = arg\min_\gamma <\gamma,M>_F
16
-
16
+
\gamma = arg\min_\gamma <\gamma,M>_F
17
+
17
18
s.t. \gamma 1 = a
18
-
19
-
\gamma^T 1= b
20
-
19
+
\gamma^T 1= b
21
20
\gamma\geq 0
22
21
where :
23
-
22
+
24
23
- M is the metric cost matrix
25
24
- a and b are the sample weights
26
-
25
+
27
26
Uses the algorithm proposed in [1]_
28
-
27
+
29
28
Parameters
30
29
----------
31
30
a : (ns,) ndarray, float64
32
31
Source histogram (uniform weigth if empty list)
33
32
b : (nt,) ndarray, float64
34
33
Target histogram (uniform weigth if empty list)
35
34
M : (ns,nt) ndarray, float64
36
-
loss matrix
37
-
35
+
loss matrix
36
+
38
37
Returns
39
38
-------
40
39
gamma: (ns x nt) ndarray
41
40
Optimal transportation matrix for the given parameters
42
-
43
-
41
+
42
+
44
43
Examples
45
44
--------
46
-
45
+
47
46
Simple example with obvious solution. The function emd accepts lists and
48
-
perform automatic conversion to numpy arrays
49
-
47
+
perform automatic conversion to numpy arrays
48
+
50
49
>>> a=[.5,.5]
51
50
>>> b=[.5,.5]
52
51
>>> M=[[0.,1.],[1.,0.]]
53
52
>>> ot.emd(a,b,M)
54
53
array([[ 0.5, 0. ],
55
54
[ 0. , 0.5]])
56
-
55
+
57
56
References
58
57
----------
59
-
60
-
.. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W. (2011, December). Displacement interpolation using Lagrangian mass transport. In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. 158). ACM.
61
-
58
+
59
+
.. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W.
60
+
(2011, December). Displacement interpolation using Lagrangian mass
61
+
transport. In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p.
0 commit comments