Skip to content

Commit 553a456

Browse files
committed
remove linewidth error message
1 parent 7638d01 commit 553a456

File tree

1 file changed

+108
-42
lines changed

1 file changed

+108
-42
lines changed

ot/da.py

Lines changed: 108 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,18 @@
1717
from .optim import gcg
1818

1919

20-
def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, numInnerItermax=200, stopInnerThr=1e-9, verbose=False, log=False):
20+
def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
21+
numInnerItermax=200, stopInnerThr=1e-9, verbose=False,
22+
log=False):
2123
"""
22-
Solve the entropic regularization optimal transport problem with nonconvex group lasso regularization
24+
Solve the entropic regularization optimal transport problem with nonconvex
25+
group lasso regularization
2326
2427
The function solves the following optimization problem:
2528
2629
.. math::
27-
\gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega_e(\gamma)+ \eta \Omega_g(\gamma)
30+
\gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega_e(\gamma)
31+
+ \eta \Omega_g(\gamma)
2832
2933
s.t. \gamma 1 = a
3034
@@ -34,11 +38,16 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, numInnerIte
3438
where :
3539
3640
- M is the (ns,nt) metric cost matrix
37-
- :math:`\Omega_e` is the entropic regularization term :math:`\Omega_e(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
38-
- :math:`\Omega_g` is the group lasso regulaization term :math:`\Omega_g(\gamma)=\sum_{i,c} \|\gamma_{i,\mathcal{I}_c}\|^{1/2}_1` where :math:`\mathcal{I}_c` are the index of samples from class c in the source domain.
41+
- :math:`\Omega_e` is the entropic regularization term
42+
:math:`\Omega_e(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
43+
- :math:`\Omega_g` is the group lasso regulaization term
44+
:math:`\Omega_g(\gamma)=\sum_{i,c} \|\gamma_{i,\mathcal{I}_c}\|^{1/2}_1`
45+
where :math:`\mathcal{I}_c` are the index of samples from class c
46+
in the source domain.
3947
- a and b are source and target weights (sum to 1)
4048
41-
The algorithm used for solving the problem is the generalised conditional gradient as proposed in [5]_ [7]_
49+
The algorithm used for solving the problem is the generalised conditional
50+
gradient as proposed in [5]_ [7]_
4251
4352
4453
Parameters
@@ -78,8 +87,13 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, numInnerIte
7887
References
7988
----------
8089
81-
.. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, "Optimal Transport for Domain Adaptation," in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
82-
.. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). Generalized conditional gradient: analysis of convergence and applications. arXiv preprint arXiv:1510.06567.
90+
.. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy,
91+
"Optimal Transport for Domain Adaptation," in IEEE
92+
Transactions on Pattern Analysis and Machine Intelligence ,
93+
vol.PP, no.99, pp.1-1
94+
.. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015).
95+
Generalized conditional gradient: analysis of convergence
96+
and applications. arXiv preprint arXiv:1510.06567.
8397
8498
See Also
8599
--------
@@ -114,14 +128,18 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, numInnerIte
114128
return transp
115129

116130

117-
def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10, numInnerItermax=200, stopInnerThr=1e-9, verbose=False, log=False):
131+
def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
132+
numInnerItermax=200, stopInnerThr=1e-9, verbose=False,
133+
log=False):
118134
"""
119-
Solve the entropic regularization optimal transport problem with group lasso regularization
135+
Solve the entropic regularization optimal transport problem with group
136+
lasso regularization
120137
121138
The function solves the following optimization problem:
122139
123140
.. math::
124-
\gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega_e(\gamma)+ \eta \Omega_g(\gamma)
141+
\gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega_e(\gamma)+
142+
\eta \Omega_g(\gamma)
125143
126144
s.t. \gamma 1 = a
127145
@@ -131,11 +149,16 @@ def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10, numInnerIte
131149
where :
132150
133151
- M is the (ns,nt) metric cost matrix
134-
- :math:`\Omega_e` is the entropic regularization term :math:`\Omega_e(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
135-
- :math:`\Omega_g` is the group lasso regulaization term :math:`\Omega_g(\gamma)=\sum_{i,c} \|\gamma_{i,\mathcal{I}_c}\|^2` where :math:`\mathcal{I}_c` are the index of samples from class c in the source domain.
152+
- :math:`\Omega_e` is the entropic regularization term
153+
:math:`\Omega_e(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
154+
- :math:`\Omega_g` is the group lasso regulaization term
155+
:math:`\Omega_g(\gamma)=\sum_{i,c} \|\gamma_{i,\mathcal{I}_c}\|^2`
156+
where :math:`\mathcal{I}_c` are the index of samples from class
157+
c in the source domain.
136158
- a and b are source and target weights (sum to 1)
137159
138-
The algorithm used for solving the problem is the generalised conditional gradient as proposed in [5]_ [7]_
160+
The algorithm used for solving the problem is the generalised conditional
161+
gradient as proposed in [5]_ [7]_
139162
140163
141164
Parameters
@@ -175,8 +198,12 @@ def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10, numInnerIte
175198
References
176199
----------
177200
178-
.. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, "Optimal Transport for Domain Adaptation," in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
179-
.. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). Generalized conditional gradient: analysis of convergence and applications. arXiv preprint arXiv:1510.06567.
201+
.. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy,
202+
"Optimal Transport for Domain Adaptation," in IEEE Transactions
203+
on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
204+
.. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015).
205+
Generalized conditional gradient: analysis of convergence and
206+
applications. arXiv preprint arXiv:1510.06567.
180207
181208
See Also
182209
--------
@@ -203,16 +230,22 @@ def df(G):
203230
W[labels_a == lab, i] = temp / n
204231
return W
205232

206-
return gcg(a, b, M, reg, eta, f, df, G0=None, numItermax=numItermax, numInnerItermax=numInnerItermax, stopThr=stopInnerThr, verbose=verbose, log=log)
233+
return gcg(a, b, M, reg, eta, f, df, G0=None, numItermax=numItermax,
234+
numInnerItermax=numInnerItermax, stopThr=stopInnerThr,
235+
verbose=verbose, log=log)
207236

208237

209-
def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False, verbose2=False, numItermax=100, numInnerItermax=10, stopInnerThr=1e-6, stopThr=1e-5, log=False, **kwargs):
238+
def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False,
239+
verbose2=False, numItermax=100, numInnerItermax=10,
240+
stopInnerThr=1e-6, stopThr=1e-5, log=False,
241+
**kwargs):
210242
"""Joint OT and linear mapping estimation as proposed in [8]
211243
212244
The function solves the following optimization problem:
213245
214246
.. math::
215-
\min_{\gamma,L}\quad \|L(X_s) -n_s\gamma X_t\|^2_F + \mu<\gamma,M>_F + \eta \|L -I\|^2_F
247+
\min_{\gamma,L}\quad \|L(X_s) -n_s\gamma X_t\|^2_F +
248+
\mu<\gamma,M>_F + \eta \|L -I\|^2_F
216249
217250
s.t. \gamma 1 = a
218251
@@ -221,8 +254,10 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False,
221254
\gamma\geq 0
222255
where :
223256
224-
- M is the (ns,nt) squared euclidean cost matrix between samples in Xs and Xt (scaled by ns)
225-
- :math:`L` is a dxd linear operator that approximates the barycentric mapping
257+
- M is the (ns,nt) squared euclidean cost matrix between samples in
258+
Xs and Xt (scaled by ns)
259+
- :math:`L` is a dxd linear operator that approximates the barycentric
260+
mapping
226261
- :math:`I` is the identity matrix (neutral linear mapping)
227262
- a and b are uniform source and target weights
228263
@@ -277,7 +312,9 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False,
277312
References
278313
----------
279314
280-
.. [8] M. Perrot, N. Courty, R. Flamary, A. Habrard, "Mapping estimation for discrete optimal transport", Neural Information Processing Systems (NIPS), 2016.
315+
.. [8] M. Perrot, N. Courty, R. Flamary, A. Habrard,
316+
"Mapping estimation for discrete optimal transport",
317+
Neural Information Processing Systems (NIPS), 2016.
281318
282319
See Also
283320
--------
@@ -384,13 +421,18 @@ def df(G):
384421
return G, L
385422

386423

387-
def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian', sigma=1, bias=False, verbose=False, verbose2=False, numItermax=100, numInnerItermax=10, stopInnerThr=1e-6, stopThr=1e-5, log=False, **kwargs):
424+
def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian',
425+
sigma=1, bias=False, verbose=False, verbose2=False,
426+
numItermax=100, numInnerItermax=10,
427+
stopInnerThr=1e-6, stopThr=1e-5, log=False,
428+
**kwargs):
388429
"""Joint OT and nonlinear mapping estimation with kernels as proposed in [8]
389430
390431
The function solves the following optimization problem:
391432
392433
.. math::
393-
\min_{\gamma,L\in\mathcal{H}}\quad \|L(X_s) -n_s\gamma X_t\|^2_F + \mu<\gamma,M>_F + \eta \|L\|^2_\mathcal{H}
434+
\min_{\gamma,L\in\mathcal{H}}\quad \|L(X_s) -
435+
n_s\gamma X_t\|^2_F + \mu<\gamma,M>_F + \eta \|L\|^2_\mathcal{H}
394436
395437
s.t. \gamma 1 = a
396438
@@ -399,8 +441,10 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian', sigm
399441
\gamma\geq 0
400442
where :
401443
402-
- M is the (ns,nt) squared euclidean cost matrix between samples in Xs and Xt (scaled by ns)
403-
- :math:`L` is a ns x d linear operator on a kernel matrix that approximates the barycentric mapping
444+
- M is the (ns,nt) squared euclidean cost matrix between samples in
445+
Xs and Xt (scaled by ns)
446+
- :math:`L` is a ns x d linear operator on a kernel matrix that
447+
approximates the barycentric mapping
404448
- a and b are uniform source and target weights
405449
406450
The problem consist in solving jointly an optimal transport matrix
@@ -458,7 +502,9 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian', sigm
458502
References
459503
----------
460504
461-
.. [8] M. Perrot, N. Courty, R. Flamary, A. Habrard, "Mapping estimation for discrete optimal transport", Neural Information Processing Systems (NIPS), 2016.
505+
.. [8] M. Perrot, N. Courty, R. Flamary, A. Habrard,
506+
"Mapping estimation for discrete optimal transport",
507+
Neural Information Processing Systems (NIPS), 2016.
462508
463509
See Also
464510
--------
@@ -593,7 +639,9 @@ class OTDA(object):
593639
References
594640
----------
595641
596-
.. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, "Optimal Transport for Domain Adaptation," in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
642+
.. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy,
643+
"Optimal Transport for Domain Adaptation," in IEEE Transactions on
644+
Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
597645
598646
"""
599647

@@ -606,7 +654,8 @@ def __init__(self, metric='sqeuclidean'):
606654
self.computed = False
607655

608656
def fit(self, xs, xt, ws=None, wt=None, norm=None):
609-
""" Fit domain adaptation between samples is xs and xt (with optional weights)"""
657+
"""Fit domain adaptation between samples is xs and xt
658+
(with optional weights)"""
610659
self.xs = xs
611660
self.xt = xt
612661

@@ -669,7 +718,9 @@ def predict(self, x, direction=1):
669718
References
670719
----------
671720
672-
.. [6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014). Regularized discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882.
721+
.. [6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014).
722+
Regularized discrete optimal transport. SIAM Journal on Imaging
723+
Sciences, 7(3), 1853-1882.
673724
674725
"""
675726
if direction > 0: # >0 then source to target
@@ -708,10 +759,12 @@ def normalizeM(self, norm):
708759

709760
class OTDA_sinkhorn(OTDA):
710761

711-
"""Class for domain adaptation with optimal transport with entropic regularization"""
762+
"""Class for domain adaptation with optimal transport with entropic
763+
regularization"""
712764

713765
def fit(self, xs, xt, reg=1, ws=None, wt=None, norm=None, **kwargs):
714-
""" Fit regularized domain adaptation between samples is xs and xt (with optional weights)"""
766+
"""Fit regularized domain adaptation between samples is xs and xt
767+
(with optional weights)"""
715768
self.xs = xs
716769
self.xt = xt
717770

@@ -731,10 +784,14 @@ def fit(self, xs, xt, reg=1, ws=None, wt=None, norm=None, **kwargs):
731784

732785
class OTDA_lpl1(OTDA):
733786

734-
"""Class for domain adaptation with optimal transport with entropic and group regularization"""
787+
"""Class for domain adaptation with optimal transport with entropic and
788+
group regularization"""
735789

736-
def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, norm=None, **kwargs):
737-
""" Fit regularized domain adaptation between samples is xs and xt (with optional weights), See ot.da.sinkhorn_lpl1_mm for fit parameters"""
790+
def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, norm=None,
791+
**kwargs):
792+
"""Fit regularized domain adaptation between samples is xs and xt
793+
(with optional weights), See ot.da.sinkhorn_lpl1_mm for fit
794+
parameters"""
738795
self.xs = xs
739796
self.xt = xt
740797

@@ -754,10 +811,14 @@ def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, norm=None, **kwargs):
754811

755812
class OTDA_l1l2(OTDA):
756813

757-
"""Class for domain adaptation with optimal transport with entropic and group lasso regularization"""
814+
"""Class for domain adaptation with optimal transport with entropic
815+
and group lasso regularization"""
758816

759-
def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, norm=None, **kwargs):
760-
""" Fit regularized domain adaptation between samples is xs and xt (with optional weights), See ot.da.sinkhorn_lpl1_gl for fit parameters"""
817+
def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, norm=None,
818+
**kwargs):
819+
"""Fit regularized domain adaptation between samples is xs and xt
820+
(with optional weights), See ot.da.sinkhorn_lpl1_gl for fit
821+
parameters"""
761822
self.xs = xs
762823
self.xt = xt
763824

@@ -777,7 +838,9 @@ def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, norm=None, **kwargs):
777838

778839
class OTDA_mapping_linear(OTDA):
779840

780-
"""Class for optimal transport with joint linear mapping estimation as in [8]"""
841+
"""Class for optimal transport with joint linear mapping estimation as in
842+
[8]
843+
"""
781844

782845
def __init__(self):
783846
""" Class initialization"""
@@ -820,9 +883,11 @@ def predict(self, x):
820883

821884
class OTDA_mapping_kernel(OTDA_mapping_linear):
822885

823-
"""Class for optimal transport with joint nonlinear mapping estimation as in [8]"""
886+
"""Class for optimal transport with joint nonlinear mapping
887+
estimation as in [8]"""
824888

825-
def fit(self, xs, xt, mu=1, eta=1, bias=False, kerneltype='gaussian', sigma=1, **kwargs):
889+
def fit(self, xs, xt, mu=1, eta=1, bias=False, kerneltype='gaussian',
890+
sigma=1, **kwargs):
826891
""" Fit domain adaptation between samples is xs and xt """
827892
self.xs = xs
828893
self.xt = xt
@@ -843,7 +908,8 @@ def predict(self, x):
843908

844909
if self.computed:
845910
K = kernel(
846-
x, self.xs, method=self.kernel, sigma=self.sigma, **self.kwargs)
911+
x, self.xs, method=self.kernel, sigma=self.sigma,
912+
**self.kwargs)
847913
if self.bias:
848914
K = np.hstack((K, np.ones((x.shape[0], 1))))
849915
return K.dot(self.L)

0 commit comments

Comments
 (0)