@@ -940,21 +940,23 @@ def predict(self, x):
940940
941941class BaseTransport (BaseEstimator ):
942942
943- def fit (self , Xs = None , ys = None , Xt = None , yt = None , method = None ):
944- """fit: estimates the optimal coupling
945-
946- Parameters:
947- -----------
948- - Xs: source samples, (ns samples, d features) numpy-like array
949- - ys: source labels
950- - Xt: target samples (nt samples, d features) numpy-like array
951- - yt: target labels
952- - method: algorithm to use to compute optimal coupling
953- (default: sinkhorn)
954-
955- Returns:
956- --------
957- - self
943+ def fit (self , Xs = None , ys = None , Xt = None , yt = None ):
944+ """Build a coupling matrix from source and target sets of samples
945+ (Xs, ys) and (Xt, yt)
946+ Parameters
947+ ----------
948+ Xs : array-like of shape = [n_source_samples, n_features]
949+ The training input samples.
950+ ys : array-like, shape = [n_source_samples]
951+ The class labels
952+ Xt : array-like of shape = [n_target_samples, n_features]
953+ The training input samples.
954+ yt : array-like, shape = [n_labeled_target_samples]
955+ The class labels
956+ Returns
957+ -------
958+ self : object
959+ Returns self.
958960 """
959961
960962 # pairwise distance
@@ -972,7 +974,7 @@ def fit(self, Xs=None, ys=None, Xt=None, yt=None, method=None):
972974 print ("TODO: implement kernelized approach" )
973975
974976 # coupling estimation
975- if method == "sinkhorn" :
977+ if self . method == "sinkhorn" :
976978 self .gamma_ = sinkhorn (
977979 a = mu_s , b = mu_t , M = Cost , reg = self .reg_e ,
978980 numItermax = self .max_iter , stopThr = self .tol ,
@@ -983,36 +985,43 @@ def fit(self, Xs=None, ys=None, Xt=None, yt=None, method=None):
983985 return self
984986
985987 def fit_transform (self , Xs = None , ys = None , Xt = None , yt = None ):
986- """fit_transform
987-
988- Parameters:
989- -----------
990- - Xs: source samples, (ns samples, d features) numpy-like array
991- - ys: source labels
992- - Xt: target samples (nt samples, d features) numpy-like array
993- - yt: target labels
994-
995- Returns:
996- --------
997- - transp_Xt
988+ """Build a coupling matrix from source and target sets of samples
989+ (Xs, ys) and (Xt, yt) and transports source samples Xs onto target
990+ ones Xt
991+ Parameters
992+ ----------
993+ Xs : array-like of shape = [n_source_samples, n_features]
994+ The training input samples.
995+ ys : array-like, shape = [n_source_samples]
996+ The class labels
997+ Xt : array-like of shape = [n_target_samples, n_features]
998+ The training input samples.
999+ yt : array-like, shape = [n_labeled_target_samples]
1000+ The class labels
1001+ Returns
1002+ -------
1003+ transp_Xs : array-like of shape = [n_source_samples, n_features]
1004+ The source samples samples.
9981005 """
9991006
1000- return self .fit (Xs , ys , Xt , yt , self . method ).transform (Xs , ys , Xt , yt )
1007+ return self .fit (Xs , ys , Xt , yt ).transform (Xs , ys , Xt , yt )
10011008
10021009 def transform (self , Xs = None , ys = None , Xt = None , yt = None ):
1003- """transform: as a convention transports source samples
1004- onto target samples
1005-
1006- Parameters:
1007- -----------
1008- - Xs: source samples, (ns samples, d features) numpy-like array
1009- - ys: source labels
1010- - Xt: target samples (nt samples, d features) numpy-like array
1011- - yt: target labels
1012-
1013- Returns:
1014- --------
1015- - transp_Xt
1010+ """Transports source samples Xs onto target ones Xt
1011+ Parameters
1012+ ----------
1013+ Xs : array-like of shape = [n_source_samples, n_features]
1014+ The training input samples.
1015+ ys : array-like, shape = [n_source_samples]
1016+ The class labels
1017+ Xt : array-like of shape = [n_target_samples, n_features]
1018+ The training input samples.
1019+ yt : array-like, shape = [n_labeled_target_samples]
1020+ The class labels
1021+ Returns
1022+ -------
1023+ transp_Xs : array-like of shape = [n_source_samples, n_features]
1024+ The transport source samples.
10161025 """
10171026
10181027 if self .mapping == "barycentric" :
@@ -1027,19 +1036,21 @@ def transform(self, Xs=None, ys=None, Xt=None, yt=None):
10271036 return transp_Xs
10281037
10291038 def inverse_transform (self , Xs = None , ys = None , Xt = None , yt = None ):
1030- """inverse_transform: as a convention transports target samples
1031- onto source samples
1032-
1033- Parameters:
1034- -----------
1035- - Xs: source samples, (ns samples, d features) numpy-like array
1036- - ys: source labels
1037- - Xt: target samples (nt samples, d features) numpy-like array
1038- - yt: target labels
1039-
1040- Returns:
1041- --------
1042- - transp_Xt
1039+ """Transports target samples Xt onto target samples Xs
1040+ Parameters
1041+ ----------
1042+ Xs : array-like of shape = [n_source_samples, n_features]
1043+ The training input samples.
1044+ ys : array-like, shape = [n_source_samples]
1045+ The class labels
1046+ Xt : array-like of shape = [n_target_samples, n_features]
1047+ The training input samples.
1048+ yt : array-like, shape = [n_labeled_target_samples]
1049+ The class labels
1050+ Returns
1051+ -------
1052+ transp_Xt : array-like of shape = [n_source_samples, n_features]
1053+ The transported target samples.
10431054 """
10441055
10451056 if self .mapping == "barycentric" :
@@ -1057,22 +1068,48 @@ def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None):
10571068
10581069
10591070class SinkhornTransport (BaseTransport ):
1060- """SinkhornTransport: class wrapper for optimal transport based on
1061- Sinkhorn's algorithm
1071+ """Domain Adapatation OT method based on Sinkhorn Algorithm
10621072
10631073 Parameters
10641074 ----------
1065- - reg_e : parameter for entropic regularization
1066- - mode: unsupervised (default) or semi supervised: controls whether
1067- labels are taken into accout to construct the optimal coupling
1068- - max_iter : maximum number of iterations
1069- - tol : precision
1070- - verbose : control verbosity
1071- - log : control log
1072-
1075+ reg_e : float, optional (default=1)
1076+ Entropic regularization parameter
1077+ mode : string, optional (default="unsupervised")
1078+ The DA mode. If "unsupervised" no target labels are taken into account
1079+ to modify the cost matrix. If "semisupervised" the target labels
1080+ are taken into account to set coefficients of the pairwise distance
1081+ matrix to 0 for row and columns indices that correspond to source and
1082+ target samples which share the same labels.
1083+ max_iter : int, float, optional (default=1000)
1084+ The minimum number of iteration before stopping the optimization
1085+ algorithm if no it has not converged
1086+ tol : float, optional (default=10e-9)
1087+ The precision required to stop the optimization algorithm.
1088+ mapping : string, optional (default="barycentric")
1089+ The kind of mapping to apply to transport samples from a domain into
1090+ another one.
1091+ if "barycentric" only the samples used to estimate the coupling can
1092+ be transported from a domain to another one.
1093+ metric : string, optional (default="sqeuclidean")
1094+ The ground metric for the Wasserstein problem
1095+ distribution : string, optional (default="uniform")
1096+ The kind of distribution estimation to employ
1097+ verbose : int, optional (default=0)
1098+ Controls the verbosity of the optimization algorithm
1099+ log : int, optional (default=0)
1100+ Controls the logs of the optimization algorithm
10731101 Attributes
10741102 ----------
1075- - gamma_: optimal coupling estimated by the fit function
1103+ gamma_ : the optimal coupling
1104+
1105+ References
1106+ ----------
1107+ .. [1] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy,
1108+ "Optimal Transport for Domain Adaptation," in IEEE Transactions
1109+ on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
1110+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal
1111+ Transport, Advances in Neural Information Processing Systems (NIPS)
1112+ 26, 2013
10761113 """
10771114
10781115 def __init__ (self , reg_e = 1. , mode = "unsupervised" , max_iter = 1000 ,
@@ -1090,24 +1127,25 @@ def __init__(self, reg_e=1., mode="unsupervised", max_iter=1000,
10901127 self .method = "sinkhorn"
10911128
10921129 def fit (self , Xs = None , ys = None , Xt = None , yt = None ):
1093- """fit
1094-
1095- Parameters:
1096- -----------
1097- - Xs: source samples, (ns samples, d features) numpy-like array
1098- - ys: source labels
1099- - Xt: target samples (nt samples, d features) numpy-like array
1100- - yt: target labels
1101- - method: algorithm to use to compute optimal coupling
1102- (default: sinkhorn)
1103-
1104- Returns:
1105- --------
1106- - self
1130+ """Build a coupling matrix from source and target sets of samples
1131+ (Xs, ys) and (Xt, yt)
1132+ Parameters
1133+ ----------
1134+ Xs : array-like of shape = [n_source_samples, n_features]
1135+ The training input samples.
1136+ ys : array-like, shape = [n_source_samples]
1137+ The class labels
1138+ Xt : array-like of shape = [n_target_samples, n_features]
1139+ The training input samples.
1140+ yt : array-like, shape = [n_labeled_target_samples]
1141+ The class labels
1142+ Returns
1143+ -------
1144+ self : object
1145+ Returns self.
11071146 """
11081147
1109- return super (SinkhornTransport , self ).fit (
1110- Xs , ys , Xt , yt , method = self .method )
1148+ return super (SinkhornTransport , self ).fit (Xs , ys , Xt , yt )
11111149
11121150
11131151if __name__ == "__main__" :
0 commit comments