-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Feat: add ignore_index support to DiceLoss #8969
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -67,45 +67,11 @@ def __init__( | |
| batch: bool = False, | ||
| weight: Sequence[float] | float | int | torch.Tensor | None = None, | ||
| soft_label: bool = False, | ||
| ignore_index: int | None = None, | ||
| ) -> None: | ||
| """ | ||
| Args: | ||
| include_background: if False, channel index 0 (background category) is excluded from the calculation. | ||
| if the non-background segmentations are small compared to the total image size they can get overwhelmed | ||
| by the signal from the background so excluding it in such cases helps convergence. | ||
| to_onehot_y: whether to convert the ``target`` into the one-hot format, | ||
| using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False. | ||
| sigmoid: if True, apply a sigmoid function to the prediction. | ||
| softmax: if True, apply a softmax function to the prediction. | ||
| other_act: callable function to execute other activation layers, Defaults to ``None``. for example: | ||
| ``other_act = torch.tanh``. | ||
| squared_pred: use squared versions of targets and predictions in the denominator or not. | ||
| jaccard: compute Jaccard Index (soft IoU) instead of dice or not. | ||
| reduction: {``"none"``, ``"mean"``, ``"sum"``} | ||
| Specifies the reduction to apply to the output. Defaults to ``"mean"``. | ||
|
|
||
| - ``"none"``: no reduction will be applied. | ||
| - ``"mean"``: the sum of the output will be divided by the number of elements in the output. | ||
| - ``"sum"``: the output will be summed. | ||
|
|
||
| smooth_nr: a small constant added to the numerator to avoid zero. | ||
| smooth_dr: a small constant added to the denominator to avoid nan. | ||
| batch: whether to sum the intersection and union areas over the batch dimension before the dividing. | ||
| Defaults to False, a Dice loss value is computed independently from each item in the batch | ||
| before any `reduction`. | ||
| weight: weights to apply to the voxels of each class. If None no weights are applied. | ||
| The input can be a single value (same weight for all classes), a sequence of values (the length | ||
| of the sequence should be the same as the number of classes. If not ``include_background``, | ||
| the number of classes should not include the background category class 0). | ||
| The value/values should be no less than 0. Defaults to None. | ||
| soft_label: whether the target contains non-binary values (soft labels) or not. | ||
| If True a soft label formulation of the loss will be used. | ||
|
|
||
| Raises: | ||
| TypeError: When ``other_act`` is not an ``Optional[Callable]``. | ||
| ValueError: When more than 1 of [``sigmoid=True``, ``softmax=True``, ``other_act is not None``]. | ||
| Incompatible values. | ||
|
|
||
| Args follow standard MONAI DiceLoss with the addition of: | ||
| ignore_index: Specifies a target value that is ignored and does not contribute to the input gradient. | ||
| """ | ||
| super().__init__(reduction=LossReduction(reduction).value) | ||
| if other_act is not None and not callable(other_act): | ||
|
|
@@ -126,29 +92,13 @@ def __init__( | |
| self.register_buffer("class_weight", weight) | ||
| self.class_weight: None | torch.Tensor | ||
| self.soft_label = soft_label | ||
| self.ignore_index = ignore_index | ||
|
|
||
| def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: | ||
| """ | ||
| Args: | ||
| input: the shape should be BNH[WD], where N is the number of classes. | ||
| target: the shape should be BNH[WD] or B1H[WD], where N is the number of classes. | ||
|
|
||
| Raises: | ||
| AssertionError: When input and target (after one hot transform if set) | ||
| have different shapes. | ||
| ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"]. | ||
|
|
||
| Example: | ||
| >>> from monai.losses.dice import * # NOQA | ||
| >>> import torch | ||
| >>> from monai.losses.dice import DiceLoss | ||
| >>> B, C, H, W = 7, 5, 3, 2 | ||
| >>> input = torch.rand(B, C, H, W) | ||
| >>> target_idx = torch.randint(low=0, high=C - 1, size=(B, H, W)).long() | ||
| >>> target = one_hot(target_idx[:, None, ...], num_classes=C) | ||
| >>> self = DiceLoss(reduction='none') | ||
| >>> loss = self(input, target) | ||
| >>> assert np.broadcast_shapes(loss.shape, input.shape) == input.shape | ||
| """ | ||
| if self.sigmoid: | ||
| input = torch.sigmoid(input) | ||
|
|
@@ -163,27 +113,42 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: | |
| if self.other_act is not None: | ||
| input = self.other_act(input) | ||
|
|
||
| # Create valid mask if ignore_index is specified and target is in index format | ||
| valid_mask = None | ||
| if self.ignore_index is not None and target.shape[1] == 1: | ||
| valid_mask = (target != self.ignore_index).to(input.dtype) | ||
|
|
||
| if self.to_onehot_y: | ||
| if n_pred_ch == 1: | ||
| warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2) | ||
| else: | ||
| target = one_hot(target, num_classes=n_pred_ch) | ||
|
|
||
| # Create valid mask if target was already one-hot but ignore_index channel is specified | ||
| if self.ignore_index is not None and valid_mask is None: | ||
| if 0 <= self.ignore_index < target.shape[1]: | ||
| valid_mask = torch.ones_like(target) | ||
| valid_mask[:, self.ignore_index] = 0.0 | ||
|
|
||
| if not self.include_background: | ||
| if n_pred_ch == 1: | ||
| warnings.warn("single channel prediction, `include_background=False` ignored.", stacklevel=2) | ||
| else: | ||
| # if skipping background, removing first channel | ||
| target = target[:, 1:] | ||
| input = input[:, 1:] | ||
| if valid_mask is not None and valid_mask.shape[1] == n_pred_ch: | ||
| valid_mask = valid_mask[:, 1:] | ||
|
|
||
| if target.shape != input.shape: | ||
| raise AssertionError(f"ground truth has different shape ({target.shape}) from input ({input.shape})") | ||
|
|
||
| # reducing only spatial dimensions (not batch nor channels) | ||
| # Apply mask to both predictions and targets to exclude ignored regions | ||
| if valid_mask is not None: | ||
| input = input * valid_mask | ||
| target = target * valid_mask | ||
|
Comment on lines
+127
to
+148
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🎯 Functional Correctness | 🟠 Major | ⚡ Quick win 🧩 Analysis chain🏁 Script executed: #!/bin/bash
set -euo pipefail
printf '%s\n' '--- monai/losses/dice.py outline ---'
ast-grep outline monai/losses/dice.py --view expanded || true
printf '\n%s\n' '--- relevant ignore_index / valid_mask occurrences ---'
rg -n "ignore_index|valid_mask|include_background|one-hot|one hot" monai/losses/dice.py monai -g '!**/*.pyc' || true
printf '\n%s\n' '--- dice.py around the reviewed block ---'
sed -n '90,180p' monai/losses/dice.py
printf '\n%s\n' '--- nearby tests for DiceLoss / ignore_index ---'
rg -n "DiceLoss|ignore_index|include_background" tests monai -g '*test*.py' || trueRepository: Project-MONAI/MONAI Length of output: 50375 🏁 Script executed: #!/bin/bash
set -euo pipefail
printf '%s\n' '--- ignore_index docs and related tests in monai/losses/dice.py ---'
sed -n '60,155p' monai/losses/dice.py
printf '\n%s\n' '--- tests mentioning DiceLoss ignore_index ---'
rg -n "DiceLoss|ignore_index|to_onehot_y" tests -g '*test*.py' | head -n 120
printf '\n%s\n' '--- focused search for one-hot ignore_index behavior ---'
rg -n "one-hot|one_hot|ignore_index" tests monai/losses -g '*test*.py' | head -n 160Repository: Project-MONAI/MONAI Length of output: 26324 Mask ignored pixels, not the whole class channel. In the one-hot path this zeros 🧰 Tools🪛 Ruff (0.15.20)[error] 138-138: Variable (A001) [error] 147-147: Variable (A001) 🤖 Prompt for AI Agents |
||
|
|
||
| reduce_axis: list[int] = torch.arange(2, len(input.shape)).tolist() | ||
| if self.batch: | ||
| # reducing spatial dimensions and batch | ||
| reduce_axis = [0] + reduce_axis | ||
|
|
||
| ord = 2 if self.squared_pred else 1 | ||
|
|
@@ -198,7 +163,6 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: | |
|
|
||
| num_of_classes = target.shape[1] | ||
| if self.class_weight is not None and num_of_classes != 1: | ||
| # make sure the lengths of weights are equal to the number of classes | ||
| if self.class_weight.ndim == 0: | ||
| self.class_weight = torch.as_tensor([self.class_weight] * num_of_classes) | ||
| else: | ||
|
|
@@ -209,16 +173,13 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: | |
| ) | ||
| if self.class_weight.min() < 0: | ||
| raise ValueError("the value/values of the `weight` should be no less than 0.") | ||
| # apply class_weight to loss | ||
| f = f * self.class_weight.to(f) | ||
|
|
||
| if self.reduction == LossReduction.MEAN.value: | ||
| f = torch.mean(f) # the batch and channel average | ||
| f = torch.mean(f) | ||
| elif self.reduction == LossReduction.SUM.value: | ||
| f = torch.sum(f) # sum over the batch and channel dims | ||
| f = torch.sum(f) | ||
| elif self.reduction == LossReduction.NONE.value: | ||
| # If we are not computing voxelwise loss components at least | ||
| # make sure a none reduction maintains a broadcastable shape | ||
| broadcast_shape = list(f.shape[0:2]) + [1] * (len(input.shape) - 2) | ||
| f = f.view(broadcast_shape) | ||
| else: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
📐 Maintainability & Code Quality | 🟡 Minor | ⚡ Quick win
🧩 Analysis chain
🏁 Script executed:
Repository: Project-MONAI/MONAI
Length of output: 616
🏁 Script executed:
Repository: Project-MONAI/MONAI
Length of output: 2353
🏁 Script executed:
Repository: Project-MONAI/MONAI
Length of output: 248
🏁 Script executed:
Repository: Project-MONAI/MONAI
Length of output: 15319
Add
ignore_indexunit tests. CoverB1HWand one-hotBNHWtargets,include_background=False, and out-of-range values; assert ignored regions don’t affect loss or gradients.🧰 Tools
🪛 Ruff (0.15.20)
[error] 138-138: Variable
inputis shadowing a Python builtin(A001)
[error] 147-147: Variable
inputis shadowing a Python builtin(A001)
🤖 Prompt for AI Agents
Source: Path instructions