Skip to content

Commit 4efdda7

Browse files
committed
add documentation
1 parent 57330c5 commit 4efdda7

File tree

2 files changed

+89
-3
lines changed

2 files changed

+89
-3
lines changed

ot/da.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -670,10 +670,16 @@ def predict(self,x,direction=1):
670670
return xf[idx,:]+x-x0[idx,:] # aply the delta to the interpolation
671671

672672
def normalizeM(self, norm):
673+
""" Apply normalization to the loss matrix
674+
675+
676+
Parameters
677+
----------
678+
norm : str
679+
type of normalization from 'median','max','log','loglog'
680+
673681
"""
674-
It may help to normalize the cost matrix self.M if there are numerical
675-
errors during the sinkhorn based algorithms.
676-
"""
682+
677683
if norm == "median":
678684
self.M /= float(np.median(self.M))
679685
elif norm == "max":
@@ -682,6 +688,7 @@ def normalizeM(self, norm):
682688
self.M = np.log(1 + self.M)
683689
elif norm == "loglog":
684690
self.M = np.log(1 + np.log(1 + self.M))
691+
685692

686693

687694
class OTDA_sinkhorn(OTDA):

ot/gpu/da.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,76 @@ def pairwiseEuclideanGPU(a, b, returnAsGPU=False, squared=False):
7070
def sinkhorn_lpl1_mm(a, labels_a, b, M_GPU, reg, eta=0.1, numItermax=10,
7171
numInnerItermax=200, stopInnerThr=1e-9,
7272
verbose=False, log=False):
73+
"""
74+
Solve the entropic regularization optimal transport problem with nonconvex group lasso regularization
75+
76+
The function solves the following optimization problem:
77+
78+
.. math::
79+
\gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega_e(\gamma)+ \eta \Omega_g(\gamma)
80+
81+
s.t. \gamma 1 = a
82+
83+
\gamma^T 1= b
84+
85+
\gamma\geq 0
86+
where :
87+
88+
- M is the (ns,nt) metric cost matrix
89+
- :math:`\Omega_e` is the entropic regularization term :math:`\Omega_e(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
90+
- :math:`\Omega_g` is the group lasso regulaization term :math:`\Omega_g(\gamma)=\sum_{i,c} \|\gamma_{i,\mathcal{I}_c}\|^{1/2}_1` where :math:`\mathcal{I}_c` are the index of samples from class c in the source domain.
91+
- a and b are source and target weights (sum to 1)
92+
93+
The algorithm used for solving the problem is the generalised conditional gradient as proposed in [5]_ [7]_
94+
95+
96+
Parameters
97+
----------
98+
a : np.ndarray (ns,)
99+
samples weights in the source domain
100+
labels_a : np.ndarray (ns,)
101+
labels of samples in the source domain
102+
b : np.ndarray (nt,)
103+
samples weights in the target domain
104+
M_GPU : cudamat.CUDAMatrix (ns,nt)
105+
loss matrix
106+
reg : float
107+
Regularization term for entropic regularization >0
108+
eta : float, optional
109+
Regularization term for group lasso regularization >0
110+
numItermax : int, optional
111+
Max number of iterations
112+
numInnerItermax : int, optional
113+
Max number of iterations (inner sinkhorn solver)
114+
stopInnerThr : float, optional
115+
Stop threshold on error (inner sinkhorn solver) (>0)
116+
verbose : bool, optional
117+
Print information along iterations
118+
log : bool, optional
119+
record log if True
120+
121+
122+
Returns
123+
-------
124+
gamma : (ns x nt) ndarray
125+
Optimal transportation matrix for the given parameters
126+
log : dict
127+
log dictionary return only if log==True in parameters
128+
129+
130+
References
131+
----------
132+
133+
.. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, "Optimal Transport for Domain Adaptation," in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
134+
.. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). Generalized conditional gradient: analysis of convergence and applications. arXiv preprint arXiv:1510.06567.
135+
136+
See Also
137+
--------
138+
ot.lp.emd : Unregularized OT
139+
ot.bregman.sinkhorn : Entropic regularized OT
140+
ot.optim.cg : General regularized OT
141+
142+
"""
73143
p = 0.5
74144
epsilon = 1e-3
75145
Nfin = len(b)
@@ -111,6 +181,15 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M_GPU, reg, eta=0.1, numItermax=10,
111181

112182
class OTDA_GPU(OTDA):
113183
def normalizeM(self, norm):
184+
""" Apply normalization to the loss matrix
185+
186+
187+
Parameters
188+
----------
189+
norm : str
190+
type of normalization from 'median','max','log','loglog'
191+
192+
"""
114193
if norm == "median":
115194
self.M_GPU.divide(float(np.median(self.M_GPU.asarray())))
116195
elif norm == "max":

0 commit comments

Comments
 (0)