Skip to content

Commit 3e71fed

Browse files
committed
rename the scikit-learn imports with SK prefix to make them public when scikit-learn not installed
1 parent 02610a5 commit 3e71fed

8 files changed

Lines changed: 76 additions & 76 deletions

File tree

python/interpret-core/interpret/glassbox/_aplr.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class _SeriesType:
1717
pass
1818

1919

20-
from ..utils._scikit import _ClassifierMixin, _RegressorMixin
20+
from ..utils._scikit import SKClassifierMixin, SKRegressorMixin
2121
from ..api.base import LocalExplainer, GlobalExplainer
2222
from ..api.templates import FeatureValueExplanation
2323
from ..utils._clean_simple import clean_dimensions
@@ -46,7 +46,7 @@ def __init__(self, *args, **kwargs):
4646

4747

4848
class APLRRegressor(
49-
_RegressorMixin, LocalExplainer, GlobalExplainer, APLRRegressorNative
49+
SKRegressorMixin, LocalExplainer, GlobalExplainer, APLRRegressorNative
5050
):
5151
"""APLR Regressor."""
5252

@@ -341,7 +341,7 @@ def __init__(self, *args, **kwargs):
341341

342342

343343
class APLRClassifier(
344-
_ClassifierMixin, LocalExplainer, GlobalExplainer, APLRClassifierNative
344+
SKClassifierMixin, LocalExplainer, GlobalExplainer, APLRClassifierNative
345345
):
346346
"""APLR Classifier."""
347347

python/interpret-core/interpret/glassbox/_decisiontree.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77

88
import numpy as np
99
from ..utils._scikit import (
10-
_ClassifierMixin,
11-
_RegressorMixin,
12-
_BaseEstimator,
13-
_NotFittedError,
10+
SKClassifierMixin,
11+
SKRegressorMixin,
12+
SKBaseEstimator,
13+
SKNotFittedError,
1414
_is_classifier,
1515
)
1616
from ..api.base import LocalExplainer, GlobalExplainer, BaseExplanation
@@ -222,7 +222,7 @@ def _weight_nodes_feature(self, nodes, feature_name):
222222
return new_nodes
223223

224224

225-
class BaseShallowDecisionTree(LocalExplainer, GlobalExplainer, _BaseEstimator):
225+
class BaseShallowDecisionTree(LocalExplainer, GlobalExplainer, SKBaseEstimator):
226226
"""Shallow Decision Tree (low depth).
227227
228228
Currently wrapper around DecisionTreeClassifier or DecisionTreeRegressor in scikit-learn.
@@ -318,7 +318,7 @@ def predict(self, X):
318318
"""
319319

320320
if not hasattr(self, "n_features_in_"):
321-
raise _NotFittedError(
321+
raise SKNotFittedError(
322322
"This model has not been fitted yet. Call 'fit' first."
323323
)
324324

@@ -347,7 +347,7 @@ def explain_global(self, name=None):
347347
"""
348348

349349
if not hasattr(self, "n_features_in_"):
350-
raise _NotFittedError(
350+
raise SKNotFittedError(
351351
"This model has not been fitted yet. Call 'fit' first."
352352
)
353353

@@ -398,7 +398,7 @@ def explain_local(self, X, y=None, name=None):
398398
"""
399399

400400
if not hasattr(self, "n_features_in_"):
401-
raise _NotFittedError(
401+
raise SKNotFittedError(
402402
"This model has not been fitted yet. Call 'fit' first."
403403
)
404404

@@ -570,7 +570,7 @@ def __sklearn_tags__(self):
570570
return tags
571571

572572

573-
class RegressionTree(_RegressorMixin, BaseShallowDecisionTree):
573+
class RegressionTree(SKRegressorMixin, BaseShallowDecisionTree):
574574
"""Regression tree with shallow depth."""
575575

576576
def __init__(self, feature_names=None, feature_types=None, max_depth=3, **kwargs):
@@ -619,7 +619,7 @@ def fit(self, X, y, sample_weight=None, check_input=True):
619619
)
620620

621621

622-
class ClassificationTree(_ClassifierMixin, BaseShallowDecisionTree):
622+
class ClassificationTree(SKClassifierMixin, BaseShallowDecisionTree):
623623
"""Classification tree with shallow depth."""
624624

625625
def __init__(self, feature_names=None, feature_types=None, max_depth=3, **kwargs):
@@ -678,7 +678,7 @@ def predict_proba(self, X):
678678
"""
679679

680680
if not hasattr(self, "n_features_in_"):
681-
raise _NotFittedError(
681+
raise SKNotFittedError(
682682
"This model has not been fitted yet. Call 'fit' first."
683683
)
684684

python/interpret-core/interpret/glassbox/_ebm/_ebm.py

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@
1818
import numpy as np
1919
from joblib import Parallel, delayed
2020
from ...utils._scikit import (
21-
_BaseEstimator,
22-
_ClassifierMixin,
23-
_RegressorMixin,
24-
_NotFittedError,
21+
SKBaseEstimator,
22+
SKClassifierMixin,
23+
SKRegressorMixin,
24+
SKNotFittedError,
2525
_is_classifier,
2626
_is_regressor,
2727
)
@@ -283,7 +283,7 @@ def clean_interactions(interactions, n_features_in):
283283
return interactions
284284

285285

286-
class BaseEBM(LocalExplainer, GlobalExplainer, _BaseEstimator):
286+
class BaseEBM(LocalExplainer, GlobalExplainer, SKBaseEstimator):
287287
"""Base class for all EBMs. Do not instantiate directly."""
288288

289289
n_features_in_: int
@@ -1713,7 +1713,7 @@ def to_jsonable(self, detail="all"):
17131713
17141714
"""
17151715
if not hasattr(self, "bins_"):
1716-
raise _NotFittedError(
1716+
raise SKNotFittedError(
17171717
"This model has not been fitted yet. Call 'fit' first."
17181718
)
17191719

@@ -1735,7 +1735,7 @@ def to_json(self, file, detail="all", indent=2):
17351735
17361736
"""
17371737
if not hasattr(self, "bins_"):
1738-
raise _NotFittedError(
1738+
raise SKNotFittedError(
17391739
"This model has not been fitted yet. Call 'fit' first."
17401740
)
17411741

@@ -1796,7 +1796,7 @@ def to_excel_exportable(self, file):
17961796
"""
17971797

17981798
if not hasattr(self, "bins_"):
1799-
raise _NotFittedError(
1799+
raise SKNotFittedError(
18001800
"This model has not been fitted yet. Call 'fit' first."
18011801
)
18021802

@@ -1817,7 +1817,7 @@ def to_excel(self, file):
18171817
"""
18181818

18191819
if not hasattr(self, "bins_"):
1820-
raise _NotFittedError(
1820+
raise SKNotFittedError(
18211821
"This model has not been fitted yet. Call 'fit' first."
18221822
)
18231823

@@ -1837,7 +1837,7 @@ def _predict_score(self, X, init_score=None):
18371837
18381838
"""
18391839
if not hasattr(self, "bins_"):
1840-
raise _NotFittedError(
1840+
raise SKNotFittedError(
18411841
"This model has not been fitted yet. Call 'fit' first."
18421842
)
18431843

@@ -1872,7 +1872,7 @@ def eval_terms(self, X):
18721872
18731873
"""
18741874
if not hasattr(self, "bins_"):
1875-
raise _NotFittedError(
1875+
raise SKNotFittedError(
18761876
"This model has not been fitted yet. Call 'fit' first."
18771877
)
18781878

@@ -1901,7 +1901,7 @@ def explain_global(self, name=None):
19011901
name = gen_name_from_class(self)
19021902

19031903
if not hasattr(self, "bins_"):
1904-
raise _NotFittedError(
1904+
raise SKNotFittedError(
19051905
"This model has not been fitted yet. Call 'fit' first."
19061906
)
19071907

@@ -2187,7 +2187,7 @@ def explain_local(self, X, y=None, name=None, init_score=None):
21872187
# Values are the model graph score per respective term.
21882188

21892189
if not hasattr(self, "bins_"):
2190-
raise _NotFittedError(
2190+
raise SKNotFittedError(
21912191
"This model has not been fitted yet. Call 'fit' first."
21922192
)
21932193

@@ -2329,7 +2329,7 @@ def term_importances(self, importance_type="avg_weight"):
23292329
23302330
"""
23312331
if not hasattr(self, "bins_"):
2332-
raise _NotFittedError(
2332+
raise SKNotFittedError(
23332333
"This model has not been fitted yet. Call 'fit' first."
23342334
)
23352335

@@ -2385,7 +2385,7 @@ def monotonize(self, term, increasing="auto", passthrough=0.0):
23852385
23862386
"""
23872387
if not hasattr(self, "bins_"):
2388-
raise _NotFittedError(
2388+
raise SKNotFittedError(
23892389
"This model has not been fitted yet. Call 'fit' first."
23902390
)
23912391

@@ -2486,7 +2486,7 @@ def remove_terms(self, terms):
24862486
24872487
"""
24882488
if not hasattr(self, "bins_"):
2489-
raise _NotFittedError(
2489+
raise SKNotFittedError(
24902490
"This model has not been fitted yet. Call 'fit' first."
24912491
)
24922492

@@ -2540,7 +2540,7 @@ def remove_features(self, features):
25402540
25412541
"""
25422542
if not hasattr(self, "bins_"):
2543-
raise _NotFittedError(
2543+
raise SKNotFittedError(
25442544
"This model has not been fitted yet. Call 'fit' first."
25452545
)
25462546

@@ -2597,7 +2597,7 @@ def sweep(self, terms=True, bins=True, features=False):
25972597
25982598
"""
25992599
if not hasattr(self, "bins_"):
2600-
raise _NotFittedError(
2600+
raise SKNotFittedError(
26012601
"This model has not been fitted yet. Call 'fit' first."
26022602
)
26032603

@@ -2652,7 +2652,7 @@ def scale(self, term, factor):
26522652
26532653
"""
26542654
if not hasattr(self, "bins_"):
2655-
raise _NotFittedError(
2655+
raise SKNotFittedError(
26562656
"This model has not been fitted yet. Call 'fit' first."
26572657
)
26582658

@@ -2688,7 +2688,7 @@ def predict_with_uncertainty(self, X, init_score=None):
26882688
"""
26892689

26902690
if not hasattr(self, "bins_"):
2691-
raise _NotFittedError(
2691+
raise SKNotFittedError(
26922692
"This model has not been fitted yet. Call 'fit' first."
26932693
)
26942694

@@ -2723,7 +2723,7 @@ def predict_with_uncertainty(self, X, init_score=None):
27232723

27242724
def _multinomialize(self, passthrough=0.0):
27252725
if not hasattr(self, "bins_"):
2726-
raise _NotFittedError(
2726+
raise SKNotFittedError(
27272727
"This model has not been fitted yet. Call 'fit' first."
27282728
)
27292729

@@ -2776,7 +2776,7 @@ def _multinomialize(self, passthrough=0.0):
27762776

27772777
def _ovrize(self, passthrough=0.0):
27782778
if not hasattr(self, "bins_"):
2779-
raise _NotFittedError(
2779+
raise SKNotFittedError(
27802780
"This model has not been fitted yet. Call 'fit' first."
27812781
)
27822782

@@ -2829,7 +2829,7 @@ def _ovrize(self, passthrough=0.0):
28292829

28302830
def _binarize(self, passthrough=0.0):
28312831
if not hasattr(self, "bins_"):
2832-
raise _NotFittedError(
2832+
raise SKNotFittedError(
28332833
"This model has not been fitted yet. Call 'fit' first."
28342834
)
28352835

@@ -2878,7 +2878,7 @@ def __sklearn_tags__(self):
28782878
return tags
28792879

28802880

2881-
class EBMClassifierMixin(_ClassifierMixin):
2881+
class EBMClassifierMixin(SKClassifierMixin):
28822882
"""Mixin class for EBM classifiers.
28832883
28842884
Provides predict, predict_proba, decision_function, and reorder_classes methods.
@@ -2902,7 +2902,7 @@ def predict_proba(self, X, init_score=None):
29022902
"""
29032903

29042904
if not hasattr(self, "bins_"):
2905-
raise _NotFittedError(
2905+
raise SKNotFittedError(
29062906
"This model has not been fitted yet. Call 'fit' first."
29072907
)
29082908

@@ -2940,7 +2940,7 @@ def decision_function(self, X, init_score=None):
29402940
29412941
"""
29422942
if not hasattr(self, "bins_"):
2943-
raise _NotFittedError(
2943+
raise SKNotFittedError(
29442944
"This model has not been fitted yet. Call 'fit' first."
29452945
)
29462946

@@ -2974,7 +2974,7 @@ def predict(self, X, init_score=None):
29742974
29752975
"""
29762976
if not hasattr(self, "bins_"):
2977-
raise _NotFittedError(
2977+
raise SKNotFittedError(
29782978
"This model has not been fitted yet. Call 'fit' first."
29792979
)
29802980

@@ -3013,7 +3013,7 @@ def reorder_classes(self, classes):
30133013
30143014
"""
30153015
if not hasattr(self, "bins_"):
3016-
raise _NotFittedError(
3016+
raise SKNotFittedError(
30173017
"This model has not been fitted yet. Call 'fit' first."
30183018
)
30193019

@@ -3063,7 +3063,7 @@ def reorder_classes(self, classes):
30633063
return self
30643064

30653065

3066-
class EBMRegressorMixin(_RegressorMixin):
3066+
class EBMRegressorMixin(SKRegressorMixin):
30673067
"""Mixin class for EBM regressors.
30683068
30693069
Provides the regression predict method.
@@ -3088,7 +3088,7 @@ def predict(self, X, init_score=None):
30883088
"""
30893089

30903090
if not hasattr(self, "bins_"):
3091-
raise _NotFittedError(
3091+
raise SKNotFittedError(
30923092
"This model has not been fitted yet. Call 'fit' first."
30933093
)
30943094

python/interpret-core/interpret/glassbox/_ebm/_research/_group_importance.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import numpy as np
1010
import plotly.express as px
11-
from ....utils._scikit import _NotFittedError
11+
from ....utils._scikit import SKNotFittedError
1212

1313

1414
def compute_group_importance(term_list, ebm, X, contributions=None):
@@ -24,7 +24,7 @@ def compute_group_importance(term_list, ebm, X, contributions=None):
2424
float: term_list's group importance
2525
"""
2626
if not hasattr(ebm, "bins_"):
27-
raise _NotFittedError("This model has not been fitted yet. Call 'fit' first.")
27+
raise SKNotFittedError("This model has not been fitted yet. Call 'fit' first.")
2828

2929
if contributions is None:
3030
contributions = ebm.eval_terms(X)
@@ -116,7 +116,7 @@ def append_group_importance(
116116
EBMExplanation: A global explanation with the group importance appended to it
117117
"""
118118
if not hasattr(ebm, "bins_"):
119-
raise _NotFittedError("This model has not been fitted yet. Call 'fit' first.")
119+
raise SKNotFittedError("This model has not been fitted yet. Call 'fit' first.")
120120

121121
if global_exp is not None:
122122
if global_exp.explanation_type != "global":

0 commit comments

Comments
 (0)