Skip to content

Commit 3954f5c

Browse files
committed
Make 'y' optional in ignore_background
Signed-off-by: ytl0623 <david89062388@gmail.com>
1 parent 57fdd59 commit 3954f5c

File tree

2 files changed

+19
-5
lines changed

2 files changed

+19
-5
lines changed

monai/metrics/active_learning_metrics.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,6 @@ def compute_variance(
130130

131131
if not include_background:
132132
y = y_pred
133-
# TODO If this utils is made to be optional for 'y' it would be nice
134133
y_pred, y = ignore_background(y_pred=y_pred, y=y)
135134

136135
# Set any values below 0 to threshold

monai/metrics/utils.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from collections.abc import Iterable, Sequence
1616
from functools import cache, partial
1717
from types import ModuleType
18-
from typing import Any
18+
from typing import Any, overload
1919

2020
import numpy as np
2121
import torch
@@ -51,20 +51,35 @@
5151
]
5252

5353

54-
def ignore_background(y_pred: NdarrayTensor, y: NdarrayTensor) -> tuple[NdarrayTensor, NdarrayTensor]:
54+
@overload
55+
def ignore_background(y_pred: NdarrayTensor, y: NdarrayTensor) -> tuple[NdarrayTensor, NdarrayTensor]: ...
56+
57+
58+
@overload
59+
def ignore_background(y_pred: NdarrayTensor, y: None = ...) -> tuple[NdarrayTensor, None]: ...
60+
61+
62+
def ignore_background(
63+
y_pred: NdarrayTensor, y: NdarrayTensor | None = None
64+
) -> tuple[NdarrayTensor, NdarrayTensor | None]:
5565
"""
5666
This function is used to remove background (the first channel) for `y_pred` and `y`.
5767
5868
Args:
5969
y_pred: predictions. As for classification tasks,
6070
`y_pred` should has the shape [BN] where N is larger than 1. As for segmentation tasks,
6171
the shape should be [BNHW] or [BNHWD].
62-
y: ground truth, the first dim is batch.
72+
y: ground truth, the first dim is batch. (Optional)
73+
74+
Returns:
75+
Tuple of background-removed predictions and ground truth. `y` is None if not provided.
6376
6477
"""
6578

66-
y = y[:, 1:] if y.shape[1] > 1 else y # type: ignore[assignment]
79+
if y is not None:
80+
y = y[:, 1:] if y.shape[1] > 1 else y # type: ignore[assignment]
6781
y_pred = y_pred[:, 1:] if y_pred.shape[1] > 1 else y_pred # type: ignore[assignment]
82+
6883
return y_pred, y
6984

7085

0 commit comments

Comments
 (0)