diff --git a/metrics/f1/f1.py b/metrics/f1/f1.py index 05b0baad..34a31f01 100644 --- a/metrics/f1/f1.py +++ b/metrics/f1/f1.py @@ -39,6 +39,11 @@ - 'weighted': Calculate metrics for each label, and find their average weighted by support (the number of true instances for each label). This alters `'macro'` to account for label imbalance. This option can result in an F-score that is not between precision and recall. - 'samples': Calculate metrics for each instance, and find their average (only meaningful for multilabel classification). sample_weight (`list` of `float`): Sample weights Defaults to None. + zero_division (`int` or `"warn"`, optional): Passed directly to sklearn's `f1_score`. Controls behavior when a label has no predicted or true samples. Use `0`, `1`, or `"warn"` (default sklearn behavior). + + - 0: Returns 0 when there is a zero division. + - 1: Returns 1 when there is a zero division. + - `'warn'`: Raises a warning and then returns 0 when there is a zero division. Returns: f1 (`float` or `array` of `float`): F1 score or list of f1 scores, depending on the value passed to `average`. Minimum possible value is 0. Maximum possible value is 1. Higher f1 scores are better. @@ -84,6 +89,13 @@ >>> results = f1_metric.compute(predictions=[[0, 1, 1], [1, 1, 0]], references=[[0, 1, 1], [0, 1, 0]], average="macro") >>> print(round(results['f1'], 2)) 0.67 + + Example 6-The same multiclass example as in Example 4, but with `zero_division` set to `1` for labels with no predicted or true samples. + >>> predictions = [0, 0, 0, 0, 0] + >>> references = [0, 1, 0, 1, 2] + >>> results = f1_metric.compute(predictions=predictions, references=references, average=None, labels=[0, 1, 2, 3], zero_division=1) + >>> print([round(res, 2) for res in results['f1']]) + [0.57, 0.0, 0.0, 1.0] """ @@ -123,8 +135,17 @@ def _info(self): reference_urls=["https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html"], ) - def _compute(self, predictions, references, labels=None, pos_label=1, average="binary", sample_weight=None): + def _compute( + self, + predictions, + references, + labels=None, + **kwargs, + ): score = f1_score( - references, predictions, labels=labels, pos_label=pos_label, average=average, sample_weight=sample_weight + references, + predictions, + labels=labels, + **kwargs, ) return {"f1": score if getattr(score, "size", 1) > 1 else float(score)}