@@ -921,21 +921,153 @@ def predict(self, x):
921921# proposal
922922##############################################################################
923923
924- from sklearn .base import BaseEstimator
925- from sklearn .metrics import pairwise_distances
924+ # from sklearn.base import BaseEstimator
925+ # from sklearn.metrics import pairwise_distances
926+
927+ ##############################################################################
928+ # adapted from scikit-learn
929+
930+ import warnings
931+ # from .externals.six import string_types, iteritems
926932
927- """
928- - all methods have the same input parameters: Xs, Xt, ys, yt (what order ?)
929- - reg_e: is the entropic reg parameter
930- - reg_cl: is the second reg parameter
931- - gamma_: is the optimal coupling
932- - mapping barycentric for the moment
933-
934- Questions:
935- - Cost matrix estimation: from sklearn or from internal function ?
936- - distribution estimation ? Look at Nathalie's approach
937- - should everything been done into the fit from BaseTransport ?
938- """
933+
934+ class BaseEstimator (object ):
935+ """Base class for all estimators in scikit-learn
936+ Notes
937+ -----
938+ All estimators should specify all the parameters that can be set
939+ at the class level in their ``__init__`` as explicit keyword
940+ arguments (no ``*args`` or ``**kwargs``).
941+ """
942+
943+ @classmethod
944+ def _get_param_names (cls ):
945+ """Get parameter names for the estimator"""
946+ try :
947+ from inspect import signature
948+ except ImportError :
949+ from .externals .funcsigs import signature
950+ # fetch the constructor or the original constructor before
951+ # deprecation wrapping if any
952+ init = getattr (cls .__init__ , 'deprecated_original' , cls .__init__ )
953+ if init is object .__init__ :
954+ # No explicit constructor to introspect
955+ return []
956+
957+ # introspect the constructor arguments to find the model parameters
958+ # to represent
959+ init_signature = signature (init )
960+ # Consider the constructor parameters excluding 'self'
961+ parameters = [p for p in init_signature .parameters .values ()
962+ if p .name != 'self' and p .kind != p .VAR_KEYWORD ]
963+ for p in parameters :
964+ if p .kind == p .VAR_POSITIONAL :
965+ raise RuntimeError ("scikit-learn estimators should always "
966+ "specify their parameters in the signature"
967+ " of their __init__ (no varargs)."
968+ " %s with constructor %s doesn't "
969+ " follow this convention."
970+ % (cls , init_signature ))
971+ # Extract and sort argument names excluding 'self'
972+ return sorted ([p .name for p in parameters ])
973+
974+ def get_params (self , deep = True ):
975+ """Get parameters for this estimator.
976+ Parameters
977+ ----------
978+ deep : boolean, optional
979+ If True, will return the parameters for this estimator and
980+ contained subobjects that are estimators.
981+ Returns
982+ -------
983+ params : mapping of string to any
984+ Parameter names mapped to their values.
985+ """
986+ out = dict ()
987+ for key in self ._get_param_names ():
988+ # We need deprecation warnings to always be on in order to
989+ # catch deprecated param values.
990+ # This is set in utils/__init__.py but it gets overwritten
991+ # when running under python3 somehow.
992+ warnings .simplefilter ("always" , DeprecationWarning )
993+ try :
994+ with warnings .catch_warnings (record = True ) as w :
995+ value = getattr (self , key , None )
996+ if len (w ) and w [0 ].category == DeprecationWarning :
997+ # if the parameter is deprecated, don't show it
998+ continue
999+ finally :
1000+ warnings .filters .pop (0 )
1001+
1002+ # XXX: should we rather test if instance of estimator?
1003+ if deep and hasattr (value , 'get_params' ):
1004+ deep_items = value .get_params ().items ()
1005+ out .update ((key + '__' + k , val ) for k , val in deep_items )
1006+ out [key ] = value
1007+ return out
1008+
1009+ def set_params (self , ** params ):
1010+ """Set the parameters of this estimator.
1011+ The method works on simple estimators as well as on nested objects
1012+ (such as pipelines). The latter have parameters of the form
1013+ ``<component>__<parameter>`` so that it's possible to update each
1014+ component of a nested object.
1015+ Returns
1016+ -------
1017+ self
1018+ """
1019+ if not params :
1020+ # Simple optimisation to gain speed (inspect is slow)
1021+ return self
1022+ valid_params = self .get_params (deep = True )
1023+ # for key, value in iteritems(params):
1024+ for key , value in params .items ():
1025+ split = key .split ('__' , 1 )
1026+ if len (split ) > 1 :
1027+ # nested objects case
1028+ name , sub_name = split
1029+ if name not in valid_params :
1030+ raise ValueError ('Invalid parameter %s for estimator %s. '
1031+ 'Check the list of available parameters '
1032+ 'with `estimator.get_params().keys()`.' %
1033+ (name , self ))
1034+ sub_object = valid_params [name ]
1035+ sub_object .set_params (** {sub_name : value })
1036+ else :
1037+ # simple objects case
1038+ if key not in valid_params :
1039+ raise ValueError ('Invalid parameter %s for estimator %s. '
1040+ 'Check the list of available parameters '
1041+ 'with `estimator.get_params().keys()`.' %
1042+ (key , self .__class__ .__name__ ))
1043+ setattr (self , key , value )
1044+ return self
1045+
1046+ def __repr__ (self ):
1047+ from sklearn .base import _pprint
1048+ class_name = self .__class__ .__name__
1049+ return '%s(%s)' % (class_name , _pprint (self .get_params (deep = False ),
1050+ offset = len (class_name ),),)
1051+
1052+ # __getstate__ and __setstate__ are omitted because they only contain
1053+ # conditionals that are not satisfied by our objects (e.g.,
1054+ # ``if type(self).__module__.startswith('sklearn.')``.
1055+
1056+
1057+ def distribution_estimation_uniform (X ):
1058+ """estimates a uniform distribution from an array of samples X
1059+
1060+ Parameters
1061+ ----------
1062+ X : array-like of shape = [n_samples, n_features]
1063+ The array of samples
1064+ Returns
1065+ -------
1066+ mu : array-like, shape = [n_samples,]
1067+ The uniform distribution estimated from X
1068+ """
1069+
1070+ return np .ones (X .shape [0 ]) / float (X .shape [0 ])
9391071
9401072
9411073class BaseTransport (BaseEstimator ):
@@ -960,18 +1092,19 @@ def fit(self, Xs=None, ys=None, Xt=None, yt=None):
9601092 """
9611093
9621094 # pairwise distance
963- Cost = pairwise_distances (Xs , Xt , metric = self .metric )
1095+ Cost = dist (Xs , Xt , metric = self .metric )
9641096
9651097 if self .mode == "semisupervised" :
9661098 print ("TODO: modify cost matrix accordingly" )
9671099 pass
9681100
9691101 # distribution estimation
970- if self .distribution == "uniform" :
971- mu_s = np .ones (Xs .shape [0 ]) / float (Xs .shape [0 ])
972- mu_t = np .ones (Xt .shape [0 ]) / float (Xt .shape [0 ])
973- else :
974- print ("TODO: implement kernelized approach" )
1102+ mu_s = self .distribution_estimation (Xs )
1103+ mu_t = self .distribution_estimation (Xt )
1104+
1105+ # store arrays of samples
1106+ self .Xs = Xs
1107+ self .Xt = Xt
9751108
9761109 # coupling estimation
9771110 if self .method == "sinkhorn" :
@@ -1024,14 +1157,19 @@ def transform(self, Xs=None, ys=None, Xt=None, yt=None):
10241157 The transport source samples.
10251158 """
10261159
1027- if self .mapping == "barycentric" :
1160+ # TODO: check whether Xs is new or not
1161+ if self .Xs == Xs :
1162+ # perform standard barycentric mapping
10281163 transp = self .gamma_ / np .sum (self .gamma_ , 1 )[:, None ]
10291164
10301165 # set nans to 0
10311166 transp [~ np .isfinite (transp )] = 0
10321167
10331168 # compute transported samples
1034- transp_Xs = np .dot (transp , Xt )
1169+ transp_Xs = np .dot (transp , self .Xt )
1170+ else :
1171+ # perform out of sample mapping
1172+ print ("out of sample mapping not yet implemented" )
10351173
10361174 return transp_Xs
10371175
@@ -1053,16 +1191,19 @@ def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None):
10531191 The transported target samples.
10541192 """
10551193
1056- if self .mapping == "barycentric" :
1194+ # TODO: check whether Xt is new or not
1195+ if self .Xt == Xt :
1196+ # perform standard barycentric mapping
10571197 transp_ = self .gamma_ .T / np .sum (self .gamma_ , 0 )[:, None ]
10581198
10591199 # set nans to 0
10601200 transp_ [~ np .isfinite (transp_ )] = 0
10611201
10621202 # compute transported samples
1063- transp_Xt = np .dot (transp_ , Xs )
1203+ transp_Xt = np .dot (transp_ , self . Xs )
10641204 else :
1065- print ("mapping not yet implemented" )
1205+ # perform out of sample mapping
1206+ print ("out of sample mapping not yet implemented" )
10661207
10671208 return transp_Xt
10681209
@@ -1114,7 +1255,10 @@ class SinkhornTransport(BaseTransport):
11141255
11151256 def __init__ (self , reg_e = 1. , mode = "unsupervised" , max_iter = 1000 ,
11161257 tol = 10e-9 , verbose = False , log = False , mapping = "barycentric" ,
1117- metric = "sqeuclidean" , distribution = "uniform" ):
1258+ metric = "sqeuclidean" ,
1259+ distribution_estimation = distribution_estimation_uniform ,
1260+ out_of_sample_map = 'ferradans' ):
1261+
11181262 self .reg_e = reg_e
11191263 self .mode = mode
11201264 self .max_iter = max_iter
@@ -1123,8 +1267,9 @@ def __init__(self, reg_e=1., mode="unsupervised", max_iter=1000,
11231267 self .log = log
11241268 self .mapping = mapping
11251269 self .metric = metric
1126- self .distribution = distribution
1270+ self .distribution_estimation = distribution_estimation
11271271 self .method = "sinkhorn"
1272+ self .out_of_sample_map = out_of_sample_map
11281273
11291274 def fit (self , Xs = None , ys = None , Xt = None , yt = None ):
11301275 """Build a coupling matrix from source and target sets of samples
0 commit comments