Skip to content

Commit 84adadd

Browse files
committed
small modifs according to NG proposals
1 parent ca9c9d6 commit 84adadd

File tree

1 file changed

+31
-12
lines changed

1 file changed

+31
-12
lines changed

ot/da.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -926,8 +926,8 @@ def predict(self, x):
926926

927927
"""
928928
- all methods have the same input parameters: Xs, Xt, ys, yt (what order ?)
929-
- ref: is the entropic reg parameter
930-
- eta: is the second reg parameter
929+
- reg_e: is the entropic reg parameter
930+
- reg_cl: is the second reg parameter
931931
- gamma_: is the optimal coupling
932932
- mapping barycentric for the moment
933933
@@ -940,7 +940,7 @@ def predict(self, x):
940940

941941
class BaseTransport(BaseEstimator):
942942

943-
def fit(self, Xs=None, ys=None, Xt=None, yt=None, method="sinkhorn"):
943+
def fit(self, Xs=None, ys=None, Xt=None, yt=None, method=None):
944944
"""fit: estimates the optimal coupling
945945
946946
Parameters:
@@ -964,13 +964,17 @@ def fit(self, Xs=None, ys=None, Xt=None, yt=None, method="sinkhorn"):
964964
print("TODO: modify cost matrix accordingly")
965965
pass
966966

967-
# distribution estimation: should we change it ?
968-
mu_s = np.ones(Xs.shape[0]) / float(Xs.shape[0])
969-
mu_t = np.ones(Xt.shape[0]) / float(Xt.shape[0])
967+
# distribution estimation
968+
if self.distribution == "uniform":
969+
mu_s = np.ones(Xs.shape[0]) / float(Xs.shape[0])
970+
mu_t = np.ones(Xt.shape[0]) / float(Xt.shape[0])
971+
else:
972+
print("TODO: implement kernelized approach")
970973

974+
# coupling estimation
971975
if method == "sinkhorn":
972976
self.gamma_ = sinkhorn(
973-
a=mu_s, b=mu_t, M=Cost, reg=self.reg,
977+
a=mu_s, b=mu_t, M=Cost, reg=self.reg_e,
974978
numItermax=self.max_iter, stopThr=self.tol,
975979
verbose=self.verbose, log=self.log)
976980
else:
@@ -1058,7 +1062,7 @@ class SinkhornTransport(BaseTransport):
10581062
10591063
Parameters
10601064
----------
1061-
- reg : parameter for entropic regularization
1065+
- reg_e : parameter for entropic regularization
10621066
- mode: unsupervised (default) or semi supervised: controls whether
10631067
labels are taken into accout to construct the optimal coupling
10641068
- max_iter : maximum number of iterations
@@ -1071,22 +1075,37 @@ class SinkhornTransport(BaseTransport):
10711075
- gamma_: optimal coupling estimated by the fit function
10721076
"""
10731077

1074-
def __init__(self, reg=1., mode="unsupervised", max_iter=1000,
1078+
def __init__(self, reg_e=1., mode="unsupervised", max_iter=1000,
10751079
tol=10e-9, verbose=False, log=False, mapping="barycentric",
1076-
metric="sqeuclidean"):
1077-
self.reg = reg
1080+
metric="sqeuclidean", distribution="uniform"):
1081+
self.reg_e = reg_e
10781082
self.mode = mode
10791083
self.max_iter = max_iter
10801084
self.tol = tol
10811085
self.verbose = verbose
10821086
self.log = log
10831087
self.mapping = mapping
10841088
self.metric = metric
1089+
self.distribution = distribution
10851090
self.method = "sinkhorn"
10861091

10871092
def fit(self, Xs=None, ys=None, Xt=None, yt=None):
1088-
"""_fit
1093+
"""fit
1094+
1095+
Parameters:
1096+
-----------
1097+
- Xs: source samples, (ns samples, d features) numpy-like array
1098+
- ys: source labels
1099+
- Xt: target samples (nt samples, d features) numpy-like array
1100+
- yt: target labels
1101+
- method: algorithm to use to compute optimal coupling
1102+
(default: sinkhorn)
1103+
1104+
Returns:
1105+
--------
1106+
- self
10891107
"""
1108+
10901109
return super(SinkhornTransport, self).fit(
10911110
Xs, ys, Xt, yt, method=self.method)
10921111

0 commit comments

Comments
 (0)