Skip to content

Commit cd3397f

Browse files
committed
integrate AG comments
1 parent 84adadd commit cd3397f

File tree

1 file changed

+120
-82
lines changed

1 file changed

+120
-82
lines changed

ot/da.py

Lines changed: 120 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -940,21 +940,23 @@ def predict(self, x):
940940

941941
class BaseTransport(BaseEstimator):
942942

943-
def fit(self, Xs=None, ys=None, Xt=None, yt=None, method=None):
944-
"""fit: estimates the optimal coupling
945-
946-
Parameters:
947-
-----------
948-
- Xs: source samples, (ns samples, d features) numpy-like array
949-
- ys: source labels
950-
- Xt: target samples (nt samples, d features) numpy-like array
951-
- yt: target labels
952-
- method: algorithm to use to compute optimal coupling
953-
(default: sinkhorn)
954-
955-
Returns:
956-
--------
957-
- self
943+
def fit(self, Xs=None, ys=None, Xt=None, yt=None):
944+
"""Build a coupling matrix from source and target sets of samples
945+
(Xs, ys) and (Xt, yt)
946+
Parameters
947+
----------
948+
Xs : array-like of shape = [n_source_samples, n_features]
949+
The training input samples.
950+
ys : array-like, shape = [n_source_samples]
951+
The class labels
952+
Xt : array-like of shape = [n_target_samples, n_features]
953+
The training input samples.
954+
yt : array-like, shape = [n_labeled_target_samples]
955+
The class labels
956+
Returns
957+
-------
958+
self : object
959+
Returns self.
958960
"""
959961

960962
# pairwise distance
@@ -972,7 +974,7 @@ def fit(self, Xs=None, ys=None, Xt=None, yt=None, method=None):
972974
print("TODO: implement kernelized approach")
973975

974976
# coupling estimation
975-
if method == "sinkhorn":
977+
if self.method == "sinkhorn":
976978
self.gamma_ = sinkhorn(
977979
a=mu_s, b=mu_t, M=Cost, reg=self.reg_e,
978980
numItermax=self.max_iter, stopThr=self.tol,
@@ -983,36 +985,43 @@ def fit(self, Xs=None, ys=None, Xt=None, yt=None, method=None):
983985
return self
984986

985987
def fit_transform(self, Xs=None, ys=None, Xt=None, yt=None):
986-
"""fit_transform
987-
988-
Parameters:
989-
-----------
990-
- Xs: source samples, (ns samples, d features) numpy-like array
991-
- ys: source labels
992-
- Xt: target samples (nt samples, d features) numpy-like array
993-
- yt: target labels
994-
995-
Returns:
996-
--------
997-
- transp_Xt
988+
"""Build a coupling matrix from source and target sets of samples
989+
(Xs, ys) and (Xt, yt) and transports source samples Xs onto target
990+
ones Xt
991+
Parameters
992+
----------
993+
Xs : array-like of shape = [n_source_samples, n_features]
994+
The training input samples.
995+
ys : array-like, shape = [n_source_samples]
996+
The class labels
997+
Xt : array-like of shape = [n_target_samples, n_features]
998+
The training input samples.
999+
yt : array-like, shape = [n_labeled_target_samples]
1000+
The class labels
1001+
Returns
1002+
-------
1003+
transp_Xs : array-like of shape = [n_source_samples, n_features]
1004+
The source samples samples.
9981005
"""
9991006

1000-
return self.fit(Xs, ys, Xt, yt, self.method).transform(Xs, ys, Xt, yt)
1007+
return self.fit(Xs, ys, Xt, yt).transform(Xs, ys, Xt, yt)
10011008

10021009
def transform(self, Xs=None, ys=None, Xt=None, yt=None):
1003-
"""transform: as a convention transports source samples
1004-
onto target samples
1005-
1006-
Parameters:
1007-
-----------
1008-
- Xs: source samples, (ns samples, d features) numpy-like array
1009-
- ys: source labels
1010-
- Xt: target samples (nt samples, d features) numpy-like array
1011-
- yt: target labels
1012-
1013-
Returns:
1014-
--------
1015-
- transp_Xt
1010+
"""Transports source samples Xs onto target ones Xt
1011+
Parameters
1012+
----------
1013+
Xs : array-like of shape = [n_source_samples, n_features]
1014+
The training input samples.
1015+
ys : array-like, shape = [n_source_samples]
1016+
The class labels
1017+
Xt : array-like of shape = [n_target_samples, n_features]
1018+
The training input samples.
1019+
yt : array-like, shape = [n_labeled_target_samples]
1020+
The class labels
1021+
Returns
1022+
-------
1023+
transp_Xs : array-like of shape = [n_source_samples, n_features]
1024+
The transport source samples.
10161025
"""
10171026

10181027
if self.mapping == "barycentric":
@@ -1027,19 +1036,21 @@ def transform(self, Xs=None, ys=None, Xt=None, yt=None):
10271036
return transp_Xs
10281037

10291038
def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None):
1030-
"""inverse_transform: as a convention transports target samples
1031-
onto source samples
1032-
1033-
Parameters:
1034-
-----------
1035-
- Xs: source samples, (ns samples, d features) numpy-like array
1036-
- ys: source labels
1037-
- Xt: target samples (nt samples, d features) numpy-like array
1038-
- yt: target labels
1039-
1040-
Returns:
1041-
--------
1042-
- transp_Xt
1039+
"""Transports target samples Xt onto target samples Xs
1040+
Parameters
1041+
----------
1042+
Xs : array-like of shape = [n_source_samples, n_features]
1043+
The training input samples.
1044+
ys : array-like, shape = [n_source_samples]
1045+
The class labels
1046+
Xt : array-like of shape = [n_target_samples, n_features]
1047+
The training input samples.
1048+
yt : array-like, shape = [n_labeled_target_samples]
1049+
The class labels
1050+
Returns
1051+
-------
1052+
transp_Xt : array-like of shape = [n_source_samples, n_features]
1053+
The transported target samples.
10431054
"""
10441055

10451056
if self.mapping == "barycentric":
@@ -1057,22 +1068,48 @@ def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None):
10571068

10581069

10591070
class SinkhornTransport(BaseTransport):
1060-
"""SinkhornTransport: class wrapper for optimal transport based on
1061-
Sinkhorn's algorithm
1071+
"""Domain Adapatation OT method based on Sinkhorn Algorithm
10621072
10631073
Parameters
10641074
----------
1065-
- reg_e : parameter for entropic regularization
1066-
- mode: unsupervised (default) or semi supervised: controls whether
1067-
labels are taken into accout to construct the optimal coupling
1068-
- max_iter : maximum number of iterations
1069-
- tol : precision
1070-
- verbose : control verbosity
1071-
- log : control log
1072-
1075+
reg_e : float, optional (default=1)
1076+
Entropic regularization parameter
1077+
mode : string, optional (default="unsupervised")
1078+
The DA mode. If "unsupervised" no target labels are taken into account
1079+
to modify the cost matrix. If "semisupervised" the target labels
1080+
are taken into account to set coefficients of the pairwise distance
1081+
matrix to 0 for row and columns indices that correspond to source and
1082+
target samples which share the same labels.
1083+
max_iter : int, float, optional (default=1000)
1084+
The minimum number of iteration before stopping the optimization
1085+
algorithm if no it has not converged
1086+
tol : float, optional (default=10e-9)
1087+
The precision required to stop the optimization algorithm.
1088+
mapping : string, optional (default="barycentric")
1089+
The kind of mapping to apply to transport samples from a domain into
1090+
another one.
1091+
if "barycentric" only the samples used to estimate the coupling can
1092+
be transported from a domain to another one.
1093+
metric : string, optional (default="sqeuclidean")
1094+
The ground metric for the Wasserstein problem
1095+
distribution : string, optional (default="uniform")
1096+
The kind of distribution estimation to employ
1097+
verbose : int, optional (default=0)
1098+
Controls the verbosity of the optimization algorithm
1099+
log : int, optional (default=0)
1100+
Controls the logs of the optimization algorithm
10731101
Attributes
10741102
----------
1075-
- gamma_: optimal coupling estimated by the fit function
1103+
gamma_ : the optimal coupling
1104+
1105+
References
1106+
----------
1107+
.. [1] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy,
1108+
"Optimal Transport for Domain Adaptation," in IEEE Transactions
1109+
on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
1110+
.. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal
1111+
Transport, Advances in Neural Information Processing Systems (NIPS)
1112+
26, 2013
10761113
"""
10771114

10781115
def __init__(self, reg_e=1., mode="unsupervised", max_iter=1000,
@@ -1090,24 +1127,25 @@ def __init__(self, reg_e=1., mode="unsupervised", max_iter=1000,
10901127
self.method = "sinkhorn"
10911128

10921129
def fit(self, Xs=None, ys=None, Xt=None, yt=None):
1093-
"""fit
1094-
1095-
Parameters:
1096-
-----------
1097-
- Xs: source samples, (ns samples, d features) numpy-like array
1098-
- ys: source labels
1099-
- Xt: target samples (nt samples, d features) numpy-like array
1100-
- yt: target labels
1101-
- method: algorithm to use to compute optimal coupling
1102-
(default: sinkhorn)
1103-
1104-
Returns:
1105-
--------
1106-
- self
1130+
"""Build a coupling matrix from source and target sets of samples
1131+
(Xs, ys) and (Xt, yt)
1132+
Parameters
1133+
----------
1134+
Xs : array-like of shape = [n_source_samples, n_features]
1135+
The training input samples.
1136+
ys : array-like, shape = [n_source_samples]
1137+
The class labels
1138+
Xt : array-like of shape = [n_target_samples, n_features]
1139+
The training input samples.
1140+
yt : array-like, shape = [n_labeled_target_samples]
1141+
The class labels
1142+
Returns
1143+
-------
1144+
self : object
1145+
Returns self.
11071146
"""
11081147

1109-
return super(SinkhornTransport, self).fit(
1110-
Xs, ys, Xt, yt, method=self.method)
1148+
return super(SinkhornTransport, self).fit(Xs, ys, Xt, yt)
11111149

11121150

11131151
if __name__ == "__main__":

0 commit comments

Comments
 (0)