|
15 | 15 | from collections.abc import Iterable, Sequence |
16 | 16 | from functools import cache, partial |
17 | 17 | from types import ModuleType |
18 | | -from typing import Any |
| 18 | +from typing import Any, overload |
19 | 19 |
|
20 | 20 | import numpy as np |
21 | 21 | import torch |
|
51 | 51 | ] |
52 | 52 |
|
53 | 53 |
|
54 | | -def ignore_background(y_pred: NdarrayTensor, y: NdarrayTensor) -> tuple[NdarrayTensor, NdarrayTensor]: |
| 54 | +@overload |
| 55 | +def ignore_background(y_pred: NdarrayTensor, y: NdarrayTensor) -> tuple[NdarrayTensor, NdarrayTensor]: ... |
| 56 | + |
| 57 | + |
| 58 | +@overload |
| 59 | +def ignore_background(y_pred: NdarrayTensor, y: None = ...) -> tuple[NdarrayTensor, None]: ... |
| 60 | + |
| 61 | + |
| 62 | +def ignore_background( |
| 63 | + y_pred: NdarrayTensor, y: NdarrayTensor | None = None |
| 64 | +) -> tuple[NdarrayTensor, NdarrayTensor | None]: |
55 | 65 | """ |
56 | 66 | This function is used to remove background (the first channel) for `y_pred` and `y`. |
57 | 67 |
|
58 | 68 | Args: |
59 | 69 | y_pred: predictions. As for classification tasks, |
60 | 70 | `y_pred` should has the shape [BN] where N is larger than 1. As for segmentation tasks, |
61 | 71 | the shape should be [BNHW] or [BNHWD]. |
62 | | - y: ground truth, the first dim is batch. |
| 72 | + y: ground truth, the first dim is batch. (Optional) |
| 73 | +
|
| 74 | + Returns: |
| 75 | + Tuple of background-removed predictions and ground truth. `y` is None if not provided. |
63 | 76 |
|
64 | 77 | """ |
65 | 78 |
|
66 | | - y = y[:, 1:] if y.shape[1] > 1 else y # type: ignore[assignment] |
| 79 | + if y is not None: |
| 80 | + y = y[:, 1:] if y.shape[1] > 1 else y # type: ignore[assignment] |
67 | 81 | y_pred = y_pred[:, 1:] if y_pred.shape[1] > 1 else y_pred # type: ignore[assignment] |
| 82 | + |
68 | 83 | return y_pred, y |
69 | 84 |
|
70 | 85 |
|
|
0 commit comments