|
10 | 10 | # License: MIT License |
11 | 11 |
|
12 | 12 | import numpy as np |
13 | | -import warnings |
14 | 13 |
|
15 | 14 | from .bregman import sinkhorn |
16 | 15 | from .lp import emd |
17 | 16 | from .utils import unif, dist, kernel |
| 17 | +from .utils import deprecated, BaseEstimator |
18 | 18 | from .optim import cg |
19 | 19 | from .optim import gcg |
20 | | -from .deprecation import deprecated |
21 | 20 |
|
22 | 21 |
|
23 | 22 | def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, |
@@ -936,139 +935,6 @@ def predict(self, x): |
936 | 935 | print("Warning, model not fitted yet, returning None") |
937 | 936 | return None |
938 | 937 |
|
939 | | -############################################################################## |
940 | | -# proposal |
941 | | -############################################################################## |
942 | | - |
943 | | - |
944 | | -# adapted from sklearn |
945 | | - |
946 | | -class BaseEstimator(object): |
947 | | - """Base class for all estimators in scikit-learn |
948 | | - Notes |
949 | | - ----- |
950 | | - All estimators should specify all the parameters that can be set |
951 | | - at the class level in their ``__init__`` as explicit keyword |
952 | | - arguments (no ``*args`` or ``**kwargs``). |
953 | | - """ |
954 | | - |
955 | | - @classmethod |
956 | | - def _get_param_names(cls): |
957 | | - """Get parameter names for the estimator""" |
958 | | - try: |
959 | | - from inspect import signature |
960 | | - except ImportError: |
961 | | - from .externals.funcsigs import signature |
962 | | - # fetch the constructor or the original constructor before |
963 | | - # deprecation wrapping if any |
964 | | - init = getattr(cls.__init__, 'deprecated_original', cls.__init__) |
965 | | - if init is object.__init__: |
966 | | - # No explicit constructor to introspect |
967 | | - return [] |
968 | | - |
969 | | - # introspect the constructor arguments to find the model parameters |
970 | | - # to represent |
971 | | - init_signature = signature(init) |
972 | | - # Consider the constructor parameters excluding 'self' |
973 | | - parameters = [p for p in init_signature.parameters.values() |
974 | | - if p.name != 'self' and p.kind != p.VAR_KEYWORD] |
975 | | - for p in parameters: |
976 | | - if p.kind == p.VAR_POSITIONAL: |
977 | | - raise RuntimeError("scikit-learn estimators should always " |
978 | | - "specify their parameters in the signature" |
979 | | - " of their __init__ (no varargs)." |
980 | | - " %s with constructor %s doesn't " |
981 | | - " follow this convention." |
982 | | - % (cls, init_signature)) |
983 | | - # Extract and sort argument names excluding 'self' |
984 | | - return sorted([p.name for p in parameters]) |
985 | | - |
986 | | - def get_params(self, deep=True): |
987 | | - """Get parameters for this estimator. |
988 | | -
|
989 | | - Parameters |
990 | | - ---------- |
991 | | - deep : boolean, optional |
992 | | - If True, will return the parameters for this estimator and |
993 | | - contained subobjects that are estimators. |
994 | | -
|
995 | | - Returns |
996 | | - ------- |
997 | | - params : mapping of string to any |
998 | | - Parameter names mapped to their values. |
999 | | - """ |
1000 | | - out = dict() |
1001 | | - for key in self._get_param_names(): |
1002 | | - # We need deprecation warnings to always be on in order to |
1003 | | - # catch deprecated param values. |
1004 | | - # This is set in utils/__init__.py but it gets overwritten |
1005 | | - # when running under python3 somehow. |
1006 | | - warnings.simplefilter("always", DeprecationWarning) |
1007 | | - try: |
1008 | | - with warnings.catch_warnings(record=True) as w: |
1009 | | - value = getattr(self, key, None) |
1010 | | - if len(w) and w[0].category == DeprecationWarning: |
1011 | | - # if the parameter is deprecated, don't show it |
1012 | | - continue |
1013 | | - finally: |
1014 | | - warnings.filters.pop(0) |
1015 | | - |
1016 | | - # XXX: should we rather test if instance of estimator? |
1017 | | - if deep and hasattr(value, 'get_params'): |
1018 | | - deep_items = value.get_params().items() |
1019 | | - out.update((key + '__' + k, val) for k, val in deep_items) |
1020 | | - out[key] = value |
1021 | | - return out |
1022 | | - |
1023 | | - def set_params(self, **params): |
1024 | | - """Set the parameters of this estimator. |
1025 | | -
|
1026 | | - The method works on simple estimators as well as on nested objects |
1027 | | - (such as pipelines). The latter have parameters of the form |
1028 | | - ``<component>__<parameter>`` so that it's possible to update each |
1029 | | - component of a nested object. |
1030 | | -
|
1031 | | - Returns |
1032 | | - ------- |
1033 | | - self |
1034 | | - """ |
1035 | | - if not params: |
1036 | | - # Simple optimisation to gain speed (inspect is slow) |
1037 | | - return self |
1038 | | - valid_params = self.get_params(deep=True) |
1039 | | - # for key, value in iteritems(params): |
1040 | | - for key, value in params.items(): |
1041 | | - split = key.split('__', 1) |
1042 | | - if len(split) > 1: |
1043 | | - # nested objects case |
1044 | | - name, sub_name = split |
1045 | | - if name not in valid_params: |
1046 | | - raise ValueError('Invalid parameter %s for estimator %s. ' |
1047 | | - 'Check the list of available parameters ' |
1048 | | - 'with `estimator.get_params().keys()`.' % |
1049 | | - (name, self)) |
1050 | | - sub_object = valid_params[name] |
1051 | | - sub_object.set_params(**{sub_name: value}) |
1052 | | - else: |
1053 | | - # simple objects case |
1054 | | - if key not in valid_params: |
1055 | | - raise ValueError('Invalid parameter %s for estimator %s. ' |
1056 | | - 'Check the list of available parameters ' |
1057 | | - 'with `estimator.get_params().keys()`.' % |
1058 | | - (key, self.__class__.__name__)) |
1059 | | - setattr(self, key, value) |
1060 | | - return self |
1061 | | - |
1062 | | - def __repr__(self): |
1063 | | - from sklearn.base import _pprint |
1064 | | - class_name = self.__class__.__name__ |
1065 | | - return '%s(%s)' % (class_name, _pprint(self.get_params(deep=False), |
1066 | | - offset=len(class_name),),) |
1067 | | - |
1068 | | - # __getstate__ and __setstate__ are omitted because they only contain |
1069 | | - # conditionals that are not satisfied by our objects (e.g., |
1070 | | - # ``if type(self).__module__.startswith('sklearn.')``. |
1071 | | - |
1072 | 938 |
|
1073 | 939 | def distribution_estimation_uniform(X): |
1074 | 940 | """estimates a uniform distribution from an array of samples X |
|
0 commit comments