@@ -96,23 +96,49 @@ def validate_y(quantifier: Any, y: np.ndarray) -> None:
9696
9797def _get_valid_crisp_predictions (predictions , threshold = 0.5 ):
9898 predictions = np .asarray (predictions )
99-
10099 dimensions = predictions .ndim
101100
101+ if train_y_values is not None :
102+ classes = np .unique (train_y_values )
103+ else :
104+ classes = None
105+
102106 if dimensions > 2 :
103- predictions = np .argmax (predictions , axis = 1 )
107+ # Assuming the last dimension contains class probabilities
108+ crisp_indices = np .argmax (predictions , axis = - 1 )
109+ if classes is not None :
110+ predictions = classes [crisp_indices ]
111+ else :
112+ predictions = crisp_indices
104113 elif dimensions == 2 :
105- predictions = (predictions [:, 1 ] >= threshold ).astype (int )
114+ # Binary or multi-class probabilities (N, C)
115+ if classes is not None and len (classes ) == 2 :
116+ # Binary case with explicit classes
117+ predictions = np .where (predictions [:, 1 ] >= threshold , classes [1 ], classes [0 ])
118+ elif classes is not None and len (classes ) > 2 :
119+ # Multi-class case with explicit classes
120+ crisp_indices = np .argmax (predictions , axis = 1 )
121+ predictions = classes [crisp_indices ]
122+ else :
123+ # Default binary (0 or 1) or multi-class (0 to C-1)
124+ if predictions .shape [1 ] == 2 :
125+ predictions = (predictions [:, 1 ] >= threshold ).astype (int )
126+ else :
127+ predictions = np .argmax (predictions , axis = 1 )
106128 elif dimensions == 1 :
129+ # 1D probabilities (e.g., probability of positive class)
107130 if np .issubdtype (predictions .dtype , np .floating ):
108- predictions = (predictions >= threshold ).astype (int )
131+ if classes is not None and len (classes ) == 2 :
132+ predictions = np .where (predictions >= threshold , classes [1 ], classes [0 ])
133+ else :
134+ predictions = (predictions >= threshold ).astype (int )
109135 else :
110136 raise ValueError (f"Predictions array has an invalid number of dimensions. Expected 1 or more dimensions, got { predictions .ndim } ." )
111137
112138 return predictions
113139
114140
115- def validate_predictions (quantifier : Any , predictions : np .ndarray , threshold : float = 0.5 ) -> np .ndarray :
141+ def validate_predictions (quantifier : Any , predictions : np .ndarray , threshold : float = 0.5 , train_y_values = None ) -> np .ndarray :
116142 """
117143 Validate predictions using the quantifier's declared output tags.
118144 Raises InputValidationError if inconsistent with tags.
@@ -132,7 +158,7 @@ def validate_predictions(quantifier: Any, predictions: np.ndarray, threshold: fl
132158 f"Soft predictions for { quantifier .__class__ .__name__ } must be float, got dtype { predictions .dtype } ."
133159 )
134160 elif estimator_type == "crisp" and np .issubdtype (predictions .dtype , np .floating ):
135- predictions = _get_valid_crisp_predictions (predictions , threshold )
161+ predictions = _get_valid_crisp_predictions (predictions , train_y_values , threshold )
136162 return predictions
137163
138164
0 commit comments