@@ -17,7 +17,12 @@ class _SeriesType:
1717 pass
1818
1919
20- from ..utils ._scikit import SKClassifierMixin , SKRegressorMixin
20+ from ..utils ._scikit import (
21+ SKBaseEstimator ,
22+ SKClassifierMixin ,
23+ SKNotFittedError ,
24+ SKRegressorMixin ,
25+ )
2126from ..api .base import LocalExplainer , GlobalExplainer
2227from ..api .templates import FeatureValueExplanation
2328from ..utils ._clean_simple import clean_dimensions
@@ -46,7 +51,11 @@ def __init__(self, *args, **kwargs):
4651
4752
4853class APLRRegressor (
49- SKRegressorMixin , LocalExplainer , GlobalExplainer , APLRRegressorNative
54+ SKRegressorMixin ,
55+ LocalExplainer ,
56+ GlobalExplainer ,
57+ SKBaseEstimator ,
58+ APLRRegressorNative ,
5059):
5160 """APLR Regressor."""
5261
@@ -60,13 +69,35 @@ def __init__(self, **kwargs):
6069 # TODO: add feature_names and feature_types to conform to glassbox API
6170 super ().__init__ (** kwargs )
6271
72+ def get_params (self , deep = True ):
73+ return APLRRegressorNative .get_params (self )
74+
75+ def set_params (self , ** params ):
76+ APLRRegressorNative .set_params (self , ** params )
77+ return self
78+
79+ def __sklearn_tags__ (self ):
80+ tags = super ().__sklearn_tags__ ()
81+ tags .non_deterministic = True
82+ tags .target_tags .required = True
83+ return tags
84+
85+ def predict (self , X ):
86+ """Predicts target values."""
87+ if not hasattr (self , "n_features_in_" ):
88+ raise SKNotFittedError (
89+ "This model has not been fitted yet. Call 'fit' first."
90+ )
91+ return super ().predict (X )
92+
6393 def fit (self , X , y , ** kwargs ):
6494 """Fits model."""
6595 X_names = kwargs .get ("X_names" )
6696
67- self .bin_counts , self .bin_edges = calculate_densities (X )
97+ self .bin_counts_ , self .bin_edges_ = calculate_densities (X )
6898 self .unique_values_in_ = calculate_unique_values (X )
6999 self .feature_names_in_ = define_feature_names (X , X_names = X_names )
100+ self .n_features_in_ = len (self .feature_names_in_ )
70101
71102 super ().fit (
72103 X ,
@@ -107,8 +138,8 @@ def explain_global(self, name: Optional[str] = None):
107138 is_two_way_interaction : bool = len (predictor_indexes_used ) == 2
108139 if is_main_effect :
109140 density_dict = {
110- "names" : self .bin_edges [predictor_indexes_used [0 ]],
111- "scores" : self .bin_counts [predictor_indexes_used [0 ]],
141+ "names" : self .bin_edges_ [predictor_indexes_used [0 ]],
142+ "scores" : self .bin_counts_ [predictor_indexes_used [0 ]],
112143 }
113144 feature_dict = {
114145 "type" : "univariate" ,
@@ -282,7 +313,23 @@ def calculate_densities(X: FloatMatrix) -> Tuple[List[List[int]], List[List[floa
282313
283314
284315def convert_to_numpy_matrix (X : FloatMatrix ) -> np .ndarray :
316+ try :
317+ from scipy import sparse as _sparse
318+
319+ if _sparse .issparse (X ):
320+ raise TypeError (
321+ "Sparse input is not supported. Please convert X to a dense array."
322+ )
323+ except ImportError :
324+ pass
325+
285326 if isinstance (X , np .ndarray ):
327+ if X .dtype == object :
328+ try :
329+ return X .astype (np .float64 )
330+ except (ValueError , TypeError ):
331+ msg = "argument must be a float64 convertible type"
332+ raise TypeError (msg )
286333 if not np .issubdtype (X .dtype , np .number ):
287334 msg = f"If X is a numpy array, it must contain only numeric values, but got dtype '{ X .dtype } '."
288335 raise TypeError (msg )
@@ -341,7 +388,11 @@ def __init__(self, *args, **kwargs):
341388
342389
343390class APLRClassifier (
344- SKClassifierMixin , LocalExplainer , GlobalExplainer , APLRClassifierNative
391+ SKClassifierMixin ,
392+ LocalExplainer ,
393+ GlobalExplainer ,
394+ SKBaseEstimator ,
395+ APLRClassifierNative ,
345396):
346397 """APLR Classifier."""
347398
@@ -355,25 +406,63 @@ def __init__(self, **kwargs):
355406 # TODO: add feature_names and feature_types to conform to glassbox API
356407 super ().__init__ (** kwargs )
357408
409+ def get_params (self , deep = True ):
410+ return APLRClassifierNative .get_params (self )
411+
412+ def set_params (self , ** params ):
413+ APLRClassifierNative .set_params (self , ** params )
414+ return self
415+
416+ def __sklearn_tags__ (self ):
417+ tags = super ().__sklearn_tags__ ()
418+ tags .non_deterministic = True
419+ tags .target_tags .required = True
420+ return tags
421+
422+ def predict (self , X ):
423+ """Predicts class labels."""
424+ if not hasattr (self , "n_features_in_" ):
425+ raise SKNotFittedError (
426+ "This model has not been fitted yet. Call 'fit' first."
427+ )
428+ str_preds = super ().predict (X )
429+ return np .array (
430+ [self ._str_to_label_ [s ] for s in str_preds ], dtype = self .classes_ .dtype
431+ )
432+
433+ def predict_proba (self , X ):
434+ """Predicts class probabilities."""
435+ if not hasattr (self , "n_features_in_" ):
436+ raise SKNotFittedError (
437+ "This model has not been fitted yet. Call 'fit' first."
438+ )
439+ return self .predict_class_probabilities (X )
440+
358441 def fit (self , X , y , ** kwargs ):
359442 """Fits model."""
360443 X_names = kwargs .get ("X_names" )
361444
362- self .bin_counts , self .bin_edges = calculate_densities (X )
445+ self .bin_counts_ , self .bin_edges_ = calculate_densities (X )
363446 self .unique_values_in_ = calculate_unique_values (X )
364447 self .feature_names_in_ = define_feature_names (X , X_names = X_names )
448+ self .n_features_in_ = len (self .feature_names_in_ )
365449
366- if not all (isinstance (val , str ) for val in y ):
367- y = [str (val ) for val in y ]
368- if isinstance (y , _SeriesType ):
369- y = y .to_numpy ()
450+ y_arr = np .asarray (y )
451+ y_str = [str (val ) for val in y_arr ]
370452
371453 super ().fit (
372454 X ,
373- y ,
455+ y_str ,
374456 ** kwargs ,
375457 )
376- self .classes_ = self .classes_
458+
459+ categories = self .get_categories ()
460+ unique_orig = {}
461+ for val , s in zip (y_arr , y_str ):
462+ if s not in unique_orig :
463+ unique_orig [s ] = val
464+ self .classes_ = np .array ([unique_orig [c ] for c in categories ])
465+ self ._str_to_label_ = {c : unique_orig [c ] for c in categories }
377466 return self
378467
379468 def explain_global (self , name : Optional [str ] = None ):
@@ -413,8 +502,8 @@ def explain_global(self, name: Optional[str] = None):
413502 is_two_way_interaction : bool = len (predictor_indexes_used ) == 2
414503 if is_main_effect :
415504 density_dict = {
416- "names" : self .bin_edges [predictor_indexes_used [0 ]],
417- "scores" : self .bin_counts [predictor_indexes_used [0 ]],
505+ "names" : self .bin_edges_ [predictor_indexes_used [0 ]],
506+ "scores" : self .bin_counts_ [predictor_indexes_used [0 ]],
418507 }
419508 feature_dict = {
420509 "type" : "univariate" ,
@@ -518,7 +607,7 @@ def explain_local(
518607 for each instance as horizontal bar charts.
519608 """
520609
521- pred = self .predict (X )
610+ pred = APLRClassifierNative .predict (self , X )
522611 pred_proba = self .predict_class_probabilities (X )
523612 pred_max_prob = np .max (pred_proba , axis = 1 )
524613 term_names = self .get_unique_term_affiliations ()
0 commit comments