Skip to content

Commit bd7c7d2

Browse files
committed
own BaseEstimator class written + rflamary comments addressed
1 parent cd3397f commit bd7c7d2

File tree

1 file changed

+172
-27
lines changed

1 file changed

+172
-27
lines changed

ot/da.py

Lines changed: 172 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -921,21 +921,153 @@ def predict(self, x):
921921
# proposal
922922
##############################################################################
923923

924-
from sklearn.base import BaseEstimator
925-
from sklearn.metrics import pairwise_distances
924+
# from sklearn.base import BaseEstimator
925+
# from sklearn.metrics import pairwise_distances
926+
927+
##############################################################################
928+
# adapted from scikit-learn
929+
930+
import warnings
931+
# from .externals.six import string_types, iteritems
926932

927-
"""
928-
- all methods have the same input parameters: Xs, Xt, ys, yt (what order ?)
929-
- reg_e: is the entropic reg parameter
930-
- reg_cl: is the second reg parameter
931-
- gamma_: is the optimal coupling
932-
- mapping barycentric for the moment
933-
934-
Questions:
935-
- Cost matrix estimation: from sklearn or from internal function ?
936-
- distribution estimation ? Look at Nathalie's approach
937-
- should everything been done into the fit from BaseTransport ?
938-
"""
933+
934+
class BaseEstimator(object):
935+
"""Base class for all estimators in scikit-learn
936+
Notes
937+
-----
938+
All estimators should specify all the parameters that can be set
939+
at the class level in their ``__init__`` as explicit keyword
940+
arguments (no ``*args`` or ``**kwargs``).
941+
"""
942+
943+
@classmethod
944+
def _get_param_names(cls):
945+
"""Get parameter names for the estimator"""
946+
try:
947+
from inspect import signature
948+
except ImportError:
949+
from .externals.funcsigs import signature
950+
# fetch the constructor or the original constructor before
951+
# deprecation wrapping if any
952+
init = getattr(cls.__init__, 'deprecated_original', cls.__init__)
953+
if init is object.__init__:
954+
# No explicit constructor to introspect
955+
return []
956+
957+
# introspect the constructor arguments to find the model parameters
958+
# to represent
959+
init_signature = signature(init)
960+
# Consider the constructor parameters excluding 'self'
961+
parameters = [p for p in init_signature.parameters.values()
962+
if p.name != 'self' and p.kind != p.VAR_KEYWORD]
963+
for p in parameters:
964+
if p.kind == p.VAR_POSITIONAL:
965+
raise RuntimeError("scikit-learn estimators should always "
966+
"specify their parameters in the signature"
967+
" of their __init__ (no varargs)."
968+
" %s with constructor %s doesn't "
969+
" follow this convention."
970+
% (cls, init_signature))
971+
# Extract and sort argument names excluding 'self'
972+
return sorted([p.name for p in parameters])
973+
974+
def get_params(self, deep=True):
975+
"""Get parameters for this estimator.
976+
Parameters
977+
----------
978+
deep : boolean, optional
979+
If True, will return the parameters for this estimator and
980+
contained subobjects that are estimators.
981+
Returns
982+
-------
983+
params : mapping of string to any
984+
Parameter names mapped to their values.
985+
"""
986+
out = dict()
987+
for key in self._get_param_names():
988+
# We need deprecation warnings to always be on in order to
989+
# catch deprecated param values.
990+
# This is set in utils/__init__.py but it gets overwritten
991+
# when running under python3 somehow.
992+
warnings.simplefilter("always", DeprecationWarning)
993+
try:
994+
with warnings.catch_warnings(record=True) as w:
995+
value = getattr(self, key, None)
996+
if len(w) and w[0].category == DeprecationWarning:
997+
# if the parameter is deprecated, don't show it
998+
continue
999+
finally:
1000+
warnings.filters.pop(0)
1001+
1002+
# XXX: should we rather test if instance of estimator?
1003+
if deep and hasattr(value, 'get_params'):
1004+
deep_items = value.get_params().items()
1005+
out.update((key + '__' + k, val) for k, val in deep_items)
1006+
out[key] = value
1007+
return out
1008+
1009+
def set_params(self, **params):
1010+
"""Set the parameters of this estimator.
1011+
The method works on simple estimators as well as on nested objects
1012+
(such as pipelines). The latter have parameters of the form
1013+
``<component>__<parameter>`` so that it's possible to update each
1014+
component of a nested object.
1015+
Returns
1016+
-------
1017+
self
1018+
"""
1019+
if not params:
1020+
# Simple optimisation to gain speed (inspect is slow)
1021+
return self
1022+
valid_params = self.get_params(deep=True)
1023+
# for key, value in iteritems(params):
1024+
for key, value in params.items():
1025+
split = key.split('__', 1)
1026+
if len(split) > 1:
1027+
# nested objects case
1028+
name, sub_name = split
1029+
if name not in valid_params:
1030+
raise ValueError('Invalid parameter %s for estimator %s. '
1031+
'Check the list of available parameters '
1032+
'with `estimator.get_params().keys()`.' %
1033+
(name, self))
1034+
sub_object = valid_params[name]
1035+
sub_object.set_params(**{sub_name: value})
1036+
else:
1037+
# simple objects case
1038+
if key not in valid_params:
1039+
raise ValueError('Invalid parameter %s for estimator %s. '
1040+
'Check the list of available parameters '
1041+
'with `estimator.get_params().keys()`.' %
1042+
(key, self.__class__.__name__))
1043+
setattr(self, key, value)
1044+
return self
1045+
1046+
def __repr__(self):
1047+
from sklearn.base import _pprint
1048+
class_name = self.__class__.__name__
1049+
return '%s(%s)' % (class_name, _pprint(self.get_params(deep=False),
1050+
offset=len(class_name),),)
1051+
1052+
# __getstate__ and __setstate__ are omitted because they only contain
1053+
# conditionals that are not satisfied by our objects (e.g.,
1054+
# ``if type(self).__module__.startswith('sklearn.')``.
1055+
1056+
1057+
def distribution_estimation_uniform(X):
1058+
"""estimates a uniform distribution from an array of samples X
1059+
1060+
Parameters
1061+
----------
1062+
X : array-like of shape = [n_samples, n_features]
1063+
The array of samples
1064+
Returns
1065+
-------
1066+
mu : array-like, shape = [n_samples,]
1067+
The uniform distribution estimated from X
1068+
"""
1069+
1070+
return np.ones(X.shape[0]) / float(X.shape[0])
9391071

9401072

9411073
class BaseTransport(BaseEstimator):
@@ -960,18 +1092,19 @@ def fit(self, Xs=None, ys=None, Xt=None, yt=None):
9601092
"""
9611093

9621094
# pairwise distance
963-
Cost = pairwise_distances(Xs, Xt, metric=self.metric)
1095+
Cost = dist(Xs, Xt, metric=self.metric)
9641096

9651097
if self.mode == "semisupervised":
9661098
print("TODO: modify cost matrix accordingly")
9671099
pass
9681100

9691101
# distribution estimation
970-
if self.distribution == "uniform":
971-
mu_s = np.ones(Xs.shape[0]) / float(Xs.shape[0])
972-
mu_t = np.ones(Xt.shape[0]) / float(Xt.shape[0])
973-
else:
974-
print("TODO: implement kernelized approach")
1102+
mu_s = self.distribution_estimation(Xs)
1103+
mu_t = self.distribution_estimation(Xt)
1104+
1105+
# store arrays of samples
1106+
self.Xs = Xs
1107+
self.Xt = Xt
9751108

9761109
# coupling estimation
9771110
if self.method == "sinkhorn":
@@ -1024,14 +1157,19 @@ def transform(self, Xs=None, ys=None, Xt=None, yt=None):
10241157
The transport source samples.
10251158
"""
10261159

1027-
if self.mapping == "barycentric":
1160+
# TODO: check whether Xs is new or not
1161+
if self.Xs == Xs:
1162+
# perform standard barycentric mapping
10281163
transp = self.gamma_ / np.sum(self.gamma_, 1)[:, None]
10291164

10301165
# set nans to 0
10311166
transp[~ np.isfinite(transp)] = 0
10321167

10331168
# compute transported samples
1034-
transp_Xs = np.dot(transp, Xt)
1169+
transp_Xs = np.dot(transp, self.Xt)
1170+
else:
1171+
# perform out of sample mapping
1172+
print("out of sample mapping not yet implemented")
10351173

10361174
return transp_Xs
10371175

@@ -1053,16 +1191,19 @@ def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None):
10531191
The transported target samples.
10541192
"""
10551193

1056-
if self.mapping == "barycentric":
1194+
# TODO: check whether Xt is new or not
1195+
if self.Xt == Xt:
1196+
# perform standard barycentric mapping
10571197
transp_ = self.gamma_.T / np.sum(self.gamma_, 0)[:, None]
10581198

10591199
# set nans to 0
10601200
transp_[~ np.isfinite(transp_)] = 0
10611201

10621202
# compute transported samples
1063-
transp_Xt = np.dot(transp_, Xs)
1203+
transp_Xt = np.dot(transp_, self.Xs)
10641204
else:
1065-
print("mapping not yet implemented")
1205+
# perform out of sample mapping
1206+
print("out of sample mapping not yet implemented")
10661207

10671208
return transp_Xt
10681209

@@ -1114,7 +1255,10 @@ class SinkhornTransport(BaseTransport):
11141255

11151256
def __init__(self, reg_e=1., mode="unsupervised", max_iter=1000,
11161257
tol=10e-9, verbose=False, log=False, mapping="barycentric",
1117-
metric="sqeuclidean", distribution="uniform"):
1258+
metric="sqeuclidean",
1259+
distribution_estimation=distribution_estimation_uniform,
1260+
out_of_sample_map='ferradans'):
1261+
11181262
self.reg_e = reg_e
11191263
self.mode = mode
11201264
self.max_iter = max_iter
@@ -1123,8 +1267,9 @@ def __init__(self, reg_e=1., mode="unsupervised", max_iter=1000,
11231267
self.log = log
11241268
self.mapping = mapping
11251269
self.metric = metric
1126-
self.distribution = distribution
1270+
self.distribution_estimation = distribution_estimation
11271271
self.method = "sinkhorn"
1272+
self.out_of_sample_map = out_of_sample_map
11281273

11291274
def fit(self, Xs=None, ys=None, Xt=None, yt=None):
11301275
"""Build a coupling matrix from source and target sets of samples

0 commit comments

Comments
 (0)