Skip to content

Commit 06392d2

Browse files
committed
use unique(keep_order=False) to count unique items
1 parent 95a005c commit 06392d2

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

bigframes/ml/metrics/_metrics.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -331,13 +331,17 @@ def _precision_score_per_class(y_true: bpd.Series, y_pred: bpd.Series) -> pd.Ser
331331
def _precision_score_binary_pos_only(
332332
y_true: bpd.Series, y_pred: bpd.Series, pos_label: int | float | bool | str
333333
) -> float:
334-
if y_true.drop_duplicates().count() != 2 or y_pred.drop_duplicates().count() != 2:
334+
if (
335+
y_true.unique(keep_order=False).count() != 2
336+
or y_pred.unique(keep_order=False).count() != 2
337+
):
335338
raise ValueError(
336339
"Target is multiclass but average='binary'. Please choose another average setting."
337340
)
338341

339342
total_labels = set(
340-
y_true.drop_duplicates().to_list() + y_pred.drop_duplicates().to_list()
343+
y_true.unique(keep_order=False).to_list()
344+
+ y_pred.unique(keep_order=False).to_list()
341345
)
342346

343347
if len(total_labels) != 2:

0 commit comments

Comments
 (0)