File tree Expand file tree Collapse file tree 1 file changed +3
-3
lines changed
Expand file tree Collapse file tree 1 file changed +3
-3
lines changed Original file line number Diff line number Diff line change @@ -293,7 +293,7 @@ def precision_score(
293293 y_true_series , y_pred_series = utils .batch_convert_to_series (y_true , y_pred )
294294
295295 if average is None :
296- return _precision_score_per_class (y_true_series , y_pred_series )
296+ return _precision_score_per_label (y_true_series , y_pred_series )
297297
298298 if average == "binary" :
299299 return _precision_score_binary_pos_only (y_true_series , y_pred_series , pos_label )
@@ -308,7 +308,7 @@ def precision_score(
308308)
309309
310310
311- def _precision_score_per_class (y_true : bpd .Series , y_pred : bpd .Series ) -> pd .Series :
311+ def _precision_score_per_label (y_true : bpd .Series , y_pred : bpd .Series ) -> pd .Series :
312312 is_accurate = y_true == y_pred
313313 unique_labels = (
314314 bpd .concat ([y_true , y_pred ], join = "outer" )
@@ -338,7 +338,7 @@ def _precision_score_binary_pos_only(
338338 "Target is multiclass but average='binary'. Please choose another average setting."
339339 )
340340
341- if pos_label not in unique_labels :
341+ if not ( unique_labels == pos_label ). any () :
342342 raise ValueError (
343343 f"pos_labe={ pos_label } is not a valid label. It should be one of { unique_labels .to_list ()} "
344344 )
You can’t perform that action at this time.
0 commit comments