Skip to content

Commit 798494b

Browse files
committed
add scikit-learn compliance to the APLR classes
1 parent 3e71fed commit 798494b

File tree

2 files changed

+161
-18
lines changed

2 files changed

+161
-18
lines changed

python/interpret-core/interpret/glassbox/_aplr.py

Lines changed: 105 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,12 @@ class _SeriesType:
1717
pass
1818

1919

20-
from ..utils._scikit import SKClassifierMixin, SKRegressorMixin
20+
from ..utils._scikit import (
21+
SKBaseEstimator,
22+
SKClassifierMixin,
23+
SKNotFittedError,
24+
SKRegressorMixin,
25+
)
2126
from ..api.base import LocalExplainer, GlobalExplainer
2227
from ..api.templates import FeatureValueExplanation
2328
from ..utils._clean_simple import clean_dimensions
@@ -46,7 +51,11 @@ def __init__(self, *args, **kwargs):
4651

4752

4853
class APLRRegressor(
49-
SKRegressorMixin, LocalExplainer, GlobalExplainer, APLRRegressorNative
54+
SKRegressorMixin,
55+
LocalExplainer,
56+
GlobalExplainer,
57+
SKBaseEstimator,
58+
APLRRegressorNative,
5059
):
5160
"""APLR Regressor."""
5261

@@ -60,13 +69,35 @@ def __init__(self, **kwargs):
6069
# TODO: add feature_names and feature_types to conform to glassbox API
6170
super().__init__(**kwargs)
6271

72+
def get_params(self, deep=True):
73+
return APLRRegressorNative.get_params(self)
74+
75+
def set_params(self, **params):
76+
APLRRegressorNative.set_params(self, **params)
77+
return self
78+
79+
def __sklearn_tags__(self):
80+
tags = super().__sklearn_tags__()
81+
tags.non_deterministic = True
82+
tags.target_tags.required = True
83+
return tags
84+
85+
def predict(self, X):
86+
"""Predicts target values."""
87+
if not hasattr(self, "n_features_in_"):
88+
raise SKNotFittedError(
89+
"This model has not been fitted yet. Call 'fit' first."
90+
)
91+
return super().predict(X)
92+
6393
def fit(self, X, y, **kwargs):
6494
"""Fits model."""
6595
X_names = kwargs.get("X_names")
6696

67-
self.bin_counts, self.bin_edges = calculate_densities(X)
97+
self.bin_counts_, self.bin_edges_ = calculate_densities(X)
6898
self.unique_values_in_ = calculate_unique_values(X)
6999
self.feature_names_in_ = define_feature_names(X, X_names=X_names)
100+
self.n_features_in_ = len(self.feature_names_in_)
70101

71102
super().fit(
72103
X,
@@ -107,8 +138,8 @@ def explain_global(self, name: Optional[str] = None):
107138
is_two_way_interaction: bool = len(predictor_indexes_used) == 2
108139
if is_main_effect:
109140
density_dict = {
110-
"names": self.bin_edges[predictor_indexes_used[0]],
111-
"scores": self.bin_counts[predictor_indexes_used[0]],
141+
"names": self.bin_edges_[predictor_indexes_used[0]],
142+
"scores": self.bin_counts_[predictor_indexes_used[0]],
112143
}
113144
feature_dict = {
114145
"type": "univariate",
@@ -282,7 +313,23 @@ def calculate_densities(X: FloatMatrix) -> Tuple[List[List[int]], List[List[floa
282313

283314

284315
def convert_to_numpy_matrix(X: FloatMatrix) -> np.ndarray:
316+
try:
317+
from scipy import sparse as _sparse
318+
319+
if _sparse.issparse(X):
320+
raise TypeError(
321+
"Sparse input is not supported. Please convert X to a dense array."
322+
)
323+
except ImportError:
324+
pass
325+
285326
if isinstance(X, np.ndarray):
327+
if X.dtype == object:
328+
try:
329+
return X.astype(np.float64)
330+
except (ValueError, TypeError):
331+
msg = "argument must be a float64 convertible type"
332+
raise TypeError(msg)
286333
if not np.issubdtype(X.dtype, np.number):
287334
msg = f"If X is a numpy array, it must contain only numeric values, but got dtype '{X.dtype}'."
288335
raise TypeError(msg)
@@ -341,7 +388,11 @@ def __init__(self, *args, **kwargs):
341388

342389

343390
class APLRClassifier(
344-
SKClassifierMixin, LocalExplainer, GlobalExplainer, APLRClassifierNative
391+
SKClassifierMixin,
392+
LocalExplainer,
393+
GlobalExplainer,
394+
SKBaseEstimator,
395+
APLRClassifierNative,
345396
):
346397
"""APLR Classifier."""
347398

@@ -355,25 +406,63 @@ def __init__(self, **kwargs):
355406
# TODO: add feature_names and feature_types to conform to glassbox API
356407
super().__init__(**kwargs)
357408

409+
def get_params(self, deep=True):
410+
return APLRClassifierNative.get_params(self)
411+
412+
def set_params(self, **params):
413+
APLRClassifierNative.set_params(self, **params)
414+
return self
415+
416+
def __sklearn_tags__(self):
417+
tags = super().__sklearn_tags__()
418+
tags.non_deterministic = True
419+
tags.target_tags.required = True
420+
return tags
421+
422+
def predict(self, X):
423+
"""Predicts class labels."""
424+
if not hasattr(self, "n_features_in_"):
425+
raise SKNotFittedError(
426+
"This model has not been fitted yet. Call 'fit' first."
427+
)
428+
str_preds = super().predict(X)
429+
return np.array(
430+
[self._str_to_label_[s] for s in str_preds], dtype=self.classes_.dtype
431+
)
432+
433+
def predict_proba(self, X):
434+
"""Predicts class probabilities."""
435+
if not hasattr(self, "n_features_in_"):
436+
raise SKNotFittedError(
437+
"This model has not been fitted yet. Call 'fit' first."
438+
)
439+
return self.predict_class_probabilities(X)
440+
358441
def fit(self, X, y, **kwargs):
359442
"""Fits model."""
360443
X_names = kwargs.get("X_names")
361444

362-
self.bin_counts, self.bin_edges = calculate_densities(X)
445+
self.bin_counts_, self.bin_edges_ = calculate_densities(X)
363446
self.unique_values_in_ = calculate_unique_values(X)
364447
self.feature_names_in_ = define_feature_names(X, X_names=X_names)
448+
self.n_features_in_ = len(self.feature_names_in_)
365449

366-
if not all(isinstance(val, str) for val in y):
367-
y = [str(val) for val in y]
368-
if isinstance(y, _SeriesType):
369-
y = y.to_numpy()
450+
y_arr = np.asarray(y)
451+
y_str = [str(val) for val in y_arr]
370452

371453
super().fit(
372454
X,
373-
y,
455+
y_str,
374456
**kwargs,
375457
)
376-
self.classes_ = self.classes_
458+
459+
categories = self.get_categories()
460+
unique_orig = {}
461+
for val, s in zip(y_arr, y_str):
462+
if s not in unique_orig:
463+
unique_orig[s] = val
464+
self.classes_ = np.array([unique_orig[c] for c in categories])
465+
self._str_to_label_ = {c: unique_orig[c] for c in categories}
377466
return self
378467

379468
def explain_global(self, name: Optional[str] = None):
@@ -413,8 +502,8 @@ def explain_global(self, name: Optional[str] = None):
413502
is_two_way_interaction: bool = len(predictor_indexes_used) == 2
414503
if is_main_effect:
415504
density_dict = {
416-
"names": self.bin_edges[predictor_indexes_used[0]],
417-
"scores": self.bin_counts[predictor_indexes_used[0]],
505+
"names": self.bin_edges_[predictor_indexes_used[0]],
506+
"scores": self.bin_counts_[predictor_indexes_used[0]],
418507
}
419508
feature_dict = {
420509
"type": "univariate",
@@ -518,7 +607,7 @@ def explain_local(
518607
for each instance as horizontal bar charts.
519608
"""
520609

521-
pred = self.predict(X)
610+
pred = APLRClassifierNative.predict(self, X)
522611
pred_proba = self.predict_class_probabilities(X)
523612
pred_max_prob = np.max(pred_proba, axis=1)
524613
term_names = self.get_unique_term_affiliations()

python/interpret-core/tests/glassbox/test_aplr.py

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22
# Distributed under the MIT software license
33

44
import numpy as np
5+
import pytest
6+
import warnings
57
from aplr import APLRClassifier as APLRClassifierNative
68
from aplr import APLRRegressor as APLRRegressorNative
79
from interpret.glassbox import APLRClassifier, APLRRegressor
810
from sklearn.datasets import load_breast_cancer, load_diabetes
9-
import warnings
11+
from sklearn.utils import estimator_checks
1012

1113

1214
def test_regression():
@@ -85,7 +87,7 @@ def test_classification():
8587

8688
native_pred = native.predict(X)
8789
our_pred = our_aplr.predict(X)
88-
assert native_pred == our_pred
90+
assert [str(v) for v in our_pred] == list(native_pred)
8991

9092
# With response
9193
local_expl = our_aplr.explain_local(X[:5], y[:5])
@@ -106,3 +108,55 @@ def test_classification():
106108
global_expl = our_aplr.explain_global()
107109
global_viz = global_expl.visualize()
108110
assert global_viz is not None
111+
112+
113+
@pytest.fixture
114+
def skip_sklearn() -> set:
115+
"""Tests which we do not adhere to."""
116+
# TODO: whittle these down to the minimum
117+
return {
118+
"check_do_not_raise_errors_in_init_or_set_params", # native APLR validates params eagerly in __init__/set_params
119+
"check_no_attributes_set_in_init", # native APLR sets attributes in __init__
120+
"check_fit1d", # interpret accepts 1d X for single feature
121+
"check_fit2d_predict1d", # interpret accepts 1d for predict
122+
"check_supervised_y_2d", # interpret deliberately supports y.shape = (nsamples, 1)
123+
"check_classifiers_regression_target", # interpret is more permissive with y values
124+
"check_n_features_in_after_fitting", # interpret uses a different error message format
125+
"check_complex_data", # interpret uses a different error message for complex data
126+
"check_estimators_nan_inf", # interpret treats NaN as missing data, not as NaN/inf validation error
127+
"check_requires_y_none", # interpret uses a different error message for y=None
128+
# native APLR raises RuntimeError instead of ValueError for invalid inputs
129+
"check_regressors_train", # native APLR raises RuntimeError for mismatched X/y lengths
130+
"check_regressor_data_not_an_array", # native APLR raises RuntimeError for mismatched X/y lengths
131+
"check_classifier_data_not_an_array", # native APLR raises RuntimeError for mismatched X/y lengths
132+
"check_classifiers_train", # native APLR raises RuntimeError for mismatched X/y lengths
133+
"check_classifiers_classes", # native APLR raises RuntimeError for mismatched X/y lengths
134+
"check_regressors_no_decision_function", # native APLR raises RuntimeError for mismatched X/y lengths
135+
"check_supervised_y_no_nan", # native APLR raises RuntimeError instead of ValueError for NaN y
136+
"check_estimators_empty_data_messages", # native APLR raises RuntimeError for empty data
137+
"check_fit2d_1sample", # native APLR requires more than 1 sample for CV folds
138+
# native APLR classifier-specific limitations
139+
"check_classifiers_one_label", # native APLR requires at least 2 categories
140+
"check_classifiers_one_label_sample_weights", # native APLR requires at least 2 categories
141+
"check_fit_idempotent", # native APLR classifier fitting twice produces different results
142+
"check_sample_weight_equivalence_on_dense_data", # algorithmic difference
143+
"check_sample_weight_equivalence_on_sparse_data", # algorithmic difference
144+
}
145+
146+
147+
@estimator_checks.parametrize_with_checks(
148+
[
149+
APLRRegressor(cv_folds=2),
150+
APLRClassifier(cv_folds=2),
151+
]
152+
)
153+
def test_sklearn_estimator(estimator, check, skip_sklearn):
154+
if check.func.__name__ in skip_sklearn:
155+
pytest.skip("Deliberate deviation from scikit-learn.")
156+
with warnings.catch_warnings():
157+
warnings.filterwarnings(
158+
"ignore",
159+
"Casting complex values to real discards the imaginary part",
160+
category=np.exceptions.ComplexWarning,
161+
)
162+
check(estimator)

0 commit comments

Comments
 (0)