Skip to content

Commit ca9c9d6

Browse files
committed
first proposal for OT wrappers
1 parent 553a456 commit ca9c9d6

File tree

1 file changed

+179
-0
lines changed

1 file changed

+179
-0
lines changed

ot/da.py

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -916,3 +916,182 @@ def predict(self, x):
916916
else:
917917
print("Warning, model not fitted yet, returning None")
918918
return None
919+
920+
##############################################################################
921+
# proposal
922+
##############################################################################
923+
924+
from sklearn.base import BaseEstimator
925+
from sklearn.metrics import pairwise_distances
926+
927+
"""
928+
- all methods have the same input parameters: Xs, Xt, ys, yt (what order ?)
929+
- ref: is the entropic reg parameter
930+
- eta: is the second reg parameter
931+
- gamma_: is the optimal coupling
932+
- mapping barycentric for the moment
933+
934+
Questions:
935+
- Cost matrix estimation: from sklearn or from internal function ?
936+
- distribution estimation ? Look at Nathalie's approach
937+
- should everything been done into the fit from BaseTransport ?
938+
"""
939+
940+
941+
class BaseTransport(BaseEstimator):
942+
943+
def fit(self, Xs=None, ys=None, Xt=None, yt=None, method="sinkhorn"):
944+
"""fit: estimates the optimal coupling
945+
946+
Parameters:
947+
-----------
948+
- Xs: source samples, (ns samples, d features) numpy-like array
949+
- ys: source labels
950+
- Xt: target samples (nt samples, d features) numpy-like array
951+
- yt: target labels
952+
- method: algorithm to use to compute optimal coupling
953+
(default: sinkhorn)
954+
955+
Returns:
956+
--------
957+
- self
958+
"""
959+
960+
# pairwise distance
961+
Cost = pairwise_distances(Xs, Xt, metric=self.metric)
962+
963+
if self.mode == "semisupervised":
964+
print("TODO: modify cost matrix accordingly")
965+
pass
966+
967+
# distribution estimation: should we change it ?
968+
mu_s = np.ones(Xs.shape[0]) / float(Xs.shape[0])
969+
mu_t = np.ones(Xt.shape[0]) / float(Xt.shape[0])
970+
971+
if method == "sinkhorn":
972+
self.gamma_ = sinkhorn(
973+
a=mu_s, b=mu_t, M=Cost, reg=self.reg,
974+
numItermax=self.max_iter, stopThr=self.tol,
975+
verbose=self.verbose, log=self.log)
976+
else:
977+
print("TODO: implement the other methods")
978+
979+
return self
980+
981+
def fit_transform(self, Xs=None, ys=None, Xt=None, yt=None):
982+
"""fit_transform
983+
984+
Parameters:
985+
-----------
986+
- Xs: source samples, (ns samples, d features) numpy-like array
987+
- ys: source labels
988+
- Xt: target samples (nt samples, d features) numpy-like array
989+
- yt: target labels
990+
991+
Returns:
992+
--------
993+
- transp_Xt
994+
"""
995+
996+
return self.fit(Xs, ys, Xt, yt, self.method).transform(Xs, ys, Xt, yt)
997+
998+
def transform(self, Xs=None, ys=None, Xt=None, yt=None):
999+
"""transform: as a convention transports source samples
1000+
onto target samples
1001+
1002+
Parameters:
1003+
-----------
1004+
- Xs: source samples, (ns samples, d features) numpy-like array
1005+
- ys: source labels
1006+
- Xt: target samples (nt samples, d features) numpy-like array
1007+
- yt: target labels
1008+
1009+
Returns:
1010+
--------
1011+
- transp_Xt
1012+
"""
1013+
1014+
if self.mapping == "barycentric":
1015+
transp = self.gamma_ / np.sum(self.gamma_, 1)[:, None]
1016+
1017+
# set nans to 0
1018+
transp[~ np.isfinite(transp)] = 0
1019+
1020+
# compute transported samples
1021+
transp_Xs = np.dot(transp, Xt)
1022+
1023+
return transp_Xs
1024+
1025+
def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None):
1026+
"""inverse_transform: as a convention transports target samples
1027+
onto source samples
1028+
1029+
Parameters:
1030+
-----------
1031+
- Xs: source samples, (ns samples, d features) numpy-like array
1032+
- ys: source labels
1033+
- Xt: target samples (nt samples, d features) numpy-like array
1034+
- yt: target labels
1035+
1036+
Returns:
1037+
--------
1038+
- transp_Xt
1039+
"""
1040+
1041+
if self.mapping == "barycentric":
1042+
transp_ = self.gamma_.T / np.sum(self.gamma_, 0)[:, None]
1043+
1044+
# set nans to 0
1045+
transp_[~ np.isfinite(transp_)] = 0
1046+
1047+
# compute transported samples
1048+
transp_Xt = np.dot(transp_, Xs)
1049+
else:
1050+
print("mapping not yet implemented")
1051+
1052+
return transp_Xt
1053+
1054+
1055+
class SinkhornTransport(BaseTransport):
1056+
"""SinkhornTransport: class wrapper for optimal transport based on
1057+
Sinkhorn's algorithm
1058+
1059+
Parameters
1060+
----------
1061+
- reg : parameter for entropic regularization
1062+
- mode: unsupervised (default) or semi supervised: controls whether
1063+
labels are taken into accout to construct the optimal coupling
1064+
- max_iter : maximum number of iterations
1065+
- tol : precision
1066+
- verbose : control verbosity
1067+
- log : control log
1068+
1069+
Attributes
1070+
----------
1071+
- gamma_: optimal coupling estimated by the fit function
1072+
"""
1073+
1074+
def __init__(self, reg=1., mode="unsupervised", max_iter=1000,
1075+
tol=10e-9, verbose=False, log=False, mapping="barycentric",
1076+
metric="sqeuclidean"):
1077+
self.reg = reg
1078+
self.mode = mode
1079+
self.max_iter = max_iter
1080+
self.tol = tol
1081+
self.verbose = verbose
1082+
self.log = log
1083+
self.mapping = mapping
1084+
self.metric = metric
1085+
self.method = "sinkhorn"
1086+
1087+
def fit(self, Xs=None, ys=None, Xt=None, yt=None):
1088+
"""_fit
1089+
"""
1090+
return super(SinkhornTransport, self).fit(
1091+
Xs, ys, Xt, yt, method=self.method)
1092+
1093+
1094+
if __name__ == "__main__":
1095+
print("Small test")
1096+
1097+
st = SinkhornTransport()

0 commit comments

Comments
 (0)