1515from .utils import unif , dist , kernel
1616from .optim import cg
1717from .optim import gcg
18+ import warnings
1819
1920
2021def sinkhorn_lpl1_mm (a , labels_a , b , M , reg , eta = 0.1 , numItermax = 10 ,
@@ -921,15 +922,8 @@ def predict(self, x):
921922# proposal
922923##############################################################################
923924
924- # from sklearn.base import BaseEstimator
925- # from sklearn.metrics import pairwise_distances
926-
927- ##############################################################################
928- # adapted from scikit-learn
929-
930- import warnings
931- # from .externals.six import string_types, iteritems
932925
926+ # adapted from sklearn
933927
934928class BaseEstimator (object ):
935929 """Base class for all estimators in scikit-learn
@@ -1067,7 +1061,7 @@ def distribution_estimation_uniform(X):
10671061 The uniform distribution estimated from X
10681062 """
10691063
1070- return np . ones ( X . shape [ 0 ]) / float (X .shape [0 ])
1064+ return unif (X .shape [0 ])
10711065
10721066
10731067class BaseTransport (BaseEstimator ):
@@ -1092,29 +1086,20 @@ def fit(self, Xs=None, ys=None, Xt=None, yt=None):
10921086 """
10931087
10941088 # pairwise distance
1095- Cost = dist (Xs , Xt , metric = self .metric )
1089+ self . Cost = dist (Xs , Xt , metric = self .metric )
10961090
10971091 if self .mode == "semisupervised" :
10981092 print ("TODO: modify cost matrix accordingly" )
10991093 pass
11001094
11011095 # distribution estimation
1102- mu_s = self .distribution_estimation (Xs )
1103- mu_t = self .distribution_estimation (Xt )
1096+ self . mu_s = self .distribution_estimation (Xs )
1097+ self . mu_t = self .distribution_estimation (Xt )
11041098
11051099 # store arrays of samples
11061100 self .Xs = Xs
11071101 self .Xt = Xt
11081102
1109- # coupling estimation
1110- if self .method == "sinkhorn" :
1111- self .gamma_ = sinkhorn (
1112- a = mu_s , b = mu_t , M = Cost , reg = self .reg_e ,
1113- numItermax = self .max_iter , stopThr = self .tol ,
1114- verbose = self .verbose , log = self .log )
1115- else :
1116- print ("TODO: implement the other methods" )
1117-
11181103 return self
11191104
11201105 def fit_transform (self , Xs = None , ys = None , Xt = None , yt = None ):
@@ -1157,8 +1142,7 @@ def transform(self, Xs=None, ys=None, Xt=None, yt=None):
11571142 The transport source samples.
11581143 """
11591144
1160- # TODO: check whether Xs is new or not
1161- if self .Xs == Xs :
1145+ if np .array_equal (self .Xs , Xs ):
11621146 # perform standard barycentric mapping
11631147 transp = self .gamma_ / np .sum (self .gamma_ , 1 )[:, None ]
11641148
@@ -1169,7 +1153,9 @@ def transform(self, Xs=None, ys=None, Xt=None, yt=None):
11691153 transp_Xs = np .dot (transp , self .Xt )
11701154 else :
11711155 # perform out of sample mapping
1172- print ("out of sample mapping not yet implemented" )
1156+ print ("Warning: out of sample mapping not yet implemented" )
1157+ print ("input data will be returned" )
1158+ transp_Xs = Xs
11731159
11741160 return transp_Xs
11751161
@@ -1191,8 +1177,7 @@ def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None):
11911177 The transported target samples.
11921178 """
11931179
1194- # TODO: check whether Xt is new or not
1195- if self .Xt == Xt :
1180+ if np .array_equal (self .Xt , Xt ):
11961181 # perform standard barycentric mapping
11971182 transp_ = self .gamma_ .T / np .sum (self .gamma_ , 0 )[:, None ]
11981183
@@ -1203,7 +1188,9 @@ def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None):
12031188 transp_Xt = np .dot (transp_ , self .Xs )
12041189 else :
12051190 # perform out of sample mapping
1206- print ("out of sample mapping not yet implemented" )
1191+ print ("Warning: out of sample mapping not yet implemented" )
1192+ print ("input data will be returned" )
1193+ transp_Xt = Xt
12071194
12081195 return transp_Xt
12091196
@@ -1254,7 +1241,7 @@ class SinkhornTransport(BaseTransport):
12541241 """
12551242
12561243 def __init__ (self , reg_e = 1. , mode = "unsupervised" , max_iter = 1000 ,
1257- tol = 10e-9 , verbose = False , log = False , mapping = "barycentric" ,
1244+ tol = 10e-9 , verbose = False , log = False ,
12581245 metric = "sqeuclidean" ,
12591246 distribution_estimation = distribution_estimation_uniform ,
12601247 out_of_sample_map = 'ferradans' ):
@@ -1265,7 +1252,6 @@ def __init__(self, reg_e=1., mode="unsupervised", max_iter=1000,
12651252 self .tol = tol
12661253 self .verbose = verbose
12671254 self .log = log
1268- self .mapping = mapping
12691255 self .metric = metric
12701256 self .distribution_estimation = distribution_estimation
12711257 self .method = "sinkhorn"
@@ -1290,10 +1276,10 @@ def fit(self, Xs=None, ys=None, Xt=None, yt=None):
12901276 Returns self.
12911277 """
12921278
1293- return super (SinkhornTransport , self ).fit (Xs , ys , Xt , yt )
1294-
1279+ self = super (SinkhornTransport , self ).fit (Xs , ys , Xt , yt )
12951280
1296- if __name__ == "__main__" :
1297- print ("Small test" )
1298-
1299- st = SinkhornTransport ()
1281+ # coupling estimation
1282+ self .gamma_ = sinkhorn (
1283+ a = self .mu_s , b = self .mu_t , M = self .Cost , reg = self .reg_e ,
1284+ numItermax = self .max_iter , stopThr = self .tol ,
1285+ verbose = self .verbose , log = self .log )
0 commit comments