Skip to content

Commit 635b467

Browse files
fix CC with categorical labels
1 parent 9a406f4 commit 635b467

3 files changed

Lines changed: 38 additions & 10 deletions

File tree

mlquantify/adjust_counting/_adjustment.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,8 @@ def _adjust(self, predictions, train_y_scores, train_y_values):
105105
thresholds, tprs, fprs = evaluate_thresholds(train_y_values, positive_scores)
106106
threshold, tpr, fpr = self.get_best_threshold(thresholds, tprs, fprs)
107107

108-
cc_predictions = CC(threshold=threshold).aggregate(predictions, train_y_values)[1]
108+
cc_predictions = CC(threshold=threshold).aggregate(predictions, train_y_values)
109+
cc_predictions = list(cc_predictions.values())[1]
109110

110111
if tpr - fpr == 0:
111112
prevalence = cc_predictions
@@ -609,7 +610,7 @@ def _adjust(self, predictions, train_y_scores, train_y_values):
609610
prevs = []
610611
for thr, tpr, fpr in zip(thresholds, tprs, fprs):
611612
cc_predictions = CC(threshold=thr).aggregate(predictions, train_y_values)
612-
cc_predictions = cc_predictions[1]
613+
cc_predictions = list(cc_predictions.values())[1]
613614

614615
if tpr - fpr == 0:
615616
prevalence = cc_predictions

mlquantify/adjust_counting/_counting.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,14 +76,15 @@ def __init__(self, learner=None, threshold=0.5):
7676
self.threshold = threshold
7777

7878
def aggregate(self, predictions, train_y_values=None):
79-
predictions = validate_predictions(self, predictions, self.threshold)
79+
predictions = validate_predictions(self, predictions, self.threshold, train_y_values)
8080

8181
if train_y_values is None:
8282
train_y_values = np.unique(predictions)
83+
8384
self.classes_ = check_classes_attribute(self, np.unique(train_y_values))
8485
class_counts = np.array([np.count_nonzero(predictions == _class) for _class in self.classes_])
8586
prevalences = class_counts / len(predictions)
86-
87+
8788
prevalences = validate_prevalences(self, prevalences, self.classes_)
8889
return prevalences
8990

mlquantify/utils/_validation.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -96,23 +96,49 @@ def validate_y(quantifier: Any, y: np.ndarray) -> None:
9696

9797
def _get_valid_crisp_predictions(predictions, threshold=0.5):
9898
predictions = np.asarray(predictions)
99-
10099
dimensions = predictions.ndim
101100

101+
if train_y_values is not None:
102+
classes = np.unique(train_y_values)
103+
else:
104+
classes = None
105+
102106
if dimensions > 2:
103-
predictions = np.argmax(predictions, axis=1)
107+
# Assuming the last dimension contains class probabilities
108+
crisp_indices = np.argmax(predictions, axis=-1)
109+
if classes is not None:
110+
predictions = classes[crisp_indices]
111+
else:
112+
predictions = crisp_indices
104113
elif dimensions == 2:
105-
predictions = (predictions[:, 1] >= threshold).astype(int)
114+
# Binary or multi-class probabilities (N, C)
115+
if classes is not None and len(classes) == 2:
116+
# Binary case with explicit classes
117+
predictions = np.where(predictions[:, 1] >= threshold, classes[1], classes[0])
118+
elif classes is not None and len(classes) > 2:
119+
# Multi-class case with explicit classes
120+
crisp_indices = np.argmax(predictions, axis=1)
121+
predictions = classes[crisp_indices]
122+
else:
123+
# Default binary (0 or 1) or multi-class (0 to C-1)
124+
if predictions.shape[1] == 2:
125+
predictions = (predictions[:, 1] >= threshold).astype(int)
126+
else:
127+
predictions = np.argmax(predictions, axis=1)
106128
elif dimensions == 1:
129+
# 1D probabilities (e.g., probability of positive class)
107130
if np.issubdtype(predictions.dtype, np.floating):
108-
predictions = (predictions >= threshold).astype(int)
131+
if classes is not None and len(classes) == 2:
132+
predictions = np.where(predictions >= threshold, classes[1], classes[0])
133+
else:
134+
predictions = (predictions >= threshold).astype(int)
109135
else:
110136
raise ValueError(f"Predictions array has an invalid number of dimensions. Expected 1 or more dimensions, got {predictions.ndim}.")
111137

112138
return predictions
113139

114140

115-
def validate_predictions(quantifier: Any, predictions: np.ndarray, threshold: float = 0.5) -> np.ndarray:
141+
def validate_predictions(quantifier: Any, predictions: np.ndarray, threshold: float = 0.5, train_y_values=None) -> np.ndarray:
116142
"""
117143
Validate predictions using the quantifier's declared output tags.
118144
Raises InputValidationError if inconsistent with tags.
@@ -132,7 +158,7 @@ def validate_predictions(quantifier: Any, predictions: np.ndarray, threshold: fl
132158
f"Soft predictions for {quantifier.__class__.__name__} must be float, got dtype {predictions.dtype}."
133159
)
134160
elif estimator_type == "crisp" and np.issubdtype(predictions.dtype, np.floating):
135-
predictions = _get_valid_crisp_predictions(predictions, threshold)
161+
predictions = _get_valid_crisp_predictions(predictions, train_y_values, threshold)
136162
return predictions
137163

138164

0 commit comments

Comments
 (0)