Skip to content

Commit e1c032b

Browse files
committed
fix test
1 parent 96758ff commit e1c032b

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

bigframes/ml/metrics/_metrics.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ def precision_score(
293293
y_true_series, y_pred_series = utils.batch_convert_to_series(y_true, y_pred)
294294

295295
if average is None:
296-
return _precision_score_per_class(y_true_series, y_pred_series)
296+
return _precision_score_per_label(y_true_series, y_pred_series)
297297

298298
if average == "binary":
299299
return _precision_score_binary_pos_only(y_true_series, y_pred_series, pos_label)
@@ -308,7 +308,7 @@ def precision_score(
308308
)
309309

310310

311-
def _precision_score_per_class(y_true: bpd.Series, y_pred: bpd.Series) -> pd.Series:
311+
def _precision_score_per_label(y_true: bpd.Series, y_pred: bpd.Series) -> pd.Series:
312312
is_accurate = y_true == y_pred
313313
unique_labels = (
314314
bpd.concat([y_true, y_pred], join="outer")
@@ -338,7 +338,7 @@ def _precision_score_binary_pos_only(
338338
"Target is multiclass but average='binary'. Please choose another average setting."
339339
)
340340

341-
if pos_label not in unique_labels:
341+
if not (unique_labels == pos_label).any():
342342
raise ValueError(
343343
f"pos_labe={pos_label} is not a valid label. It should be one of {unique_labels.to_list()}"
344344
)

0 commit comments

Comments
 (0)