Skip to content

Commit 96758ff

Browse files
committed
use concat before checking unique labels
1 parent a9943bd commit 96758ff

File tree

1 file changed

+4
-12
lines changed

1 file changed

+4
-12
lines changed

bigframes/ml/metrics/_metrics.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -331,24 +331,16 @@ def _precision_score_per_class(y_true: bpd.Series, y_pred: bpd.Series) -> pd.Ser
331331
def _precision_score_binary_pos_only(
332332
y_true: bpd.Series, y_pred: bpd.Series, pos_label: int | float | bool | str
333333
) -> float:
334-
y_true_classes = y_true.unique(keep_order=False)
335-
y_pred_classes = y_pred.unique(keep_order=False)
334+
unique_labels = bpd.concat([y_true, y_pred]).unique(keep_order=False)
336335

337-
if y_true_classes.count() != 2 or y_pred_classes.count() != 2:
336+
if unique_labels.count() != 2:
338337
raise ValueError(
339338
"Target is multiclass but average='binary'. Please choose another average setting."
340339
)
341340

342-
total_labels = set(y_true_classes.to_list() + y_pred_classes.to_list())
343-
344-
if len(total_labels) != 2:
345-
raise ValueError(
346-
"Target is multiclass but average='binary'. Please choose another average setting."
347-
)
348-
349-
if pos_label not in total_labels:
341+
if pos_label not in unique_labels:
350342
raise ValueError(
351-
f"pos_labe={pos_label} is not a valid label. It should be one of {list(total_labels)}"
343+
f"pos_labe={pos_label} is not a valid label. It should be one of {unique_labels.to_list()}"
352344
)
353345

354346
target_elem_idx = y_pred == pos_label

0 commit comments

Comments
 (0)