Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 24 additions & 63 deletions monai/losses/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,45 +67,11 @@ def __init__(
batch: bool = False,
weight: Sequence[float] | float | int | torch.Tensor | None = None,
soft_label: bool = False,
ignore_index: int | None = None,
) -> None:
"""
Args:
include_background: if False, channel index 0 (background category) is excluded from the calculation.
if the non-background segmentations are small compared to the total image size they can get overwhelmed
by the signal from the background so excluding it in such cases helps convergence.
to_onehot_y: whether to convert the ``target`` into the one-hot format,
using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False.
sigmoid: if True, apply a sigmoid function to the prediction.
softmax: if True, apply a softmax function to the prediction.
other_act: callable function to execute other activation layers, Defaults to ``None``. for example:
``other_act = torch.tanh``.
squared_pred: use squared versions of targets and predictions in the denominator or not.
jaccard: compute Jaccard Index (soft IoU) instead of dice or not.
reduction: {``"none"``, ``"mean"``, ``"sum"``}
Specifies the reduction to apply to the output. Defaults to ``"mean"``.

- ``"none"``: no reduction will be applied.
- ``"mean"``: the sum of the output will be divided by the number of elements in the output.
- ``"sum"``: the output will be summed.

smooth_nr: a small constant added to the numerator to avoid zero.
smooth_dr: a small constant added to the denominator to avoid nan.
batch: whether to sum the intersection and union areas over the batch dimension before the dividing.
Defaults to False, a Dice loss value is computed independently from each item in the batch
before any `reduction`.
weight: weights to apply to the voxels of each class. If None no weights are applied.
The input can be a single value (same weight for all classes), a sequence of values (the length
of the sequence should be the same as the number of classes. If not ``include_background``,
the number of classes should not include the background category class 0).
The value/values should be no less than 0. Defaults to None.
soft_label: whether the target contains non-binary values (soft labels) or not.
If True a soft label formulation of the loss will be used.

Raises:
TypeError: When ``other_act`` is not an ``Optional[Callable]``.
ValueError: When more than 1 of [``sigmoid=True``, ``softmax=True``, ``other_act is not None``].
Incompatible values.

Args follow standard MONAI DiceLoss with the addition of:
ignore_index: Specifies a target value that is ignored and does not contribute to the input gradient.
"""
super().__init__(reduction=LossReduction(reduction).value)
if other_act is not None and not callable(other_act):
Expand All @@ -126,29 +92,13 @@ def __init__(
self.register_buffer("class_weight", weight)
self.class_weight: None | torch.Tensor
self.soft_label = soft_label
self.ignore_index = ignore_index

def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Args:
input: the shape should be BNH[WD], where N is the number of classes.
target: the shape should be BNH[WD] or B1H[WD], where N is the number of classes.

Raises:
AssertionError: When input and target (after one hot transform if set)
have different shapes.
ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"].

Example:
>>> from monai.losses.dice import * # NOQA
>>> import torch
>>> from monai.losses.dice import DiceLoss
>>> B, C, H, W = 7, 5, 3, 2
>>> input = torch.rand(B, C, H, W)
>>> target_idx = torch.randint(low=0, high=C - 1, size=(B, H, W)).long()
>>> target = one_hot(target_idx[:, None, ...], num_classes=C)
>>> self = DiceLoss(reduction='none')
>>> loss = self(input, target)
>>> assert np.broadcast_shapes(loss.shape, input.shape) == input.shape
"""
if self.sigmoid:
input = torch.sigmoid(input)
Expand All @@ -163,27 +113,42 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
if self.other_act is not None:
input = self.other_act(input)

# Create valid mask if ignore_index is specified and target is in index format
valid_mask = None
if self.ignore_index is not None and target.shape[1] == 1:
valid_mask = (target != self.ignore_index).to(input.dtype)

if self.to_onehot_y:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2)
else:
target = one_hot(target, num_classes=n_pred_ch)

# Create valid mask if target was already one-hot but ignore_index channel is specified
if self.ignore_index is not None and valid_mask is None:
if 0 <= self.ignore_index < target.shape[1]:
valid_mask = torch.ones_like(target)
valid_mask[:, self.ignore_index] = 0.0

if not self.include_background:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `include_background=False` ignored.", stacklevel=2)
else:
# if skipping background, removing first channel
target = target[:, 1:]
input = input[:, 1:]
if valid_mask is not None and valid_mask.shape[1] == n_pred_ch:
valid_mask = valid_mask[:, 1:]

if target.shape != input.shape:
raise AssertionError(f"ground truth has different shape ({target.shape}) from input ({input.shape})")

# reducing only spatial dimensions (not batch nor channels)
# Apply mask to both predictions and targets to exclude ignored regions
if valid_mask is not None:
input = input * valid_mask
target = target * valid_mask
Comment on lines +116 to +148

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📐 Maintainability & Code Quality | 🟡 Minor | ⚡ Quick win

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Find existing DiceLoss tests and any ignore_index coverage
fd -t f -i 'dice' tests test | head
rg -nP 'ignore_index' -g '*test*' -C2

Repository: Project-MONAI/MONAI

Length of output: 616


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Map relevant files and search for ignore_index coverage precisely.
ast-grep outline monai/losses/dice.py --view expanded || true
printf '\n--- test_dice_loss outline ---\n'
ast-grep outline tests/losses/test_dice_loss.py --view expanded || true
printf '\n--- ignore_index matches in loss tests ---\n'
rg -n "ignore_index" tests/losses
printf '\n--- relevant slices in test_dice_loss.py ---\n'
sed -n '1,260p' tests/losses/test_dice_loss.py

Repository: Project-MONAI/MONAI

Length of output: 2353


🏁 Script executed:

#!/bin/bash
set -euo pipefail

printf '--- monai/losses/dice.py (constructor + forward slice) ---\n'
sed -n '55,190p' monai/losses/dice.py

printf '\n--- tests/losses/test_dice_loss.py (full relevant file) ---\n'
sed -n '1,260p' tests/losses/test_dice_loss.py

Repository: Project-MONAI/MONAI

Length of output: 248


🏁 Script executed:

#!/bin/bash
set -euo pipefail

printf '%s\n' '--- monai/losses/dice.py (constructor + forward slice) ---'
sed -n '55,190p' monai/losses/dice.py

printf '\n%s\n' '--- tests/losses/test_dice_loss.py (full relevant file) ---'
sed -n '1,260p' tests/losses/test_dice_loss.py

Repository: Project-MONAI/MONAI

Length of output: 15319


Add ignore_index unit tests. Cover B1HW and one-hot BNHW targets, include_background=False, and out-of-range values; assert ignored regions don’t affect loss or gradients.

🧰 Tools
🪛 Ruff (0.15.20)

[error] 138-138: Variable input is shadowing a Python builtin

(A001)


[error] 147-147: Variable input is shadowing a Python builtin

(A001)

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@monai/losses/dice.py` around lines 116 - 148, Add unit tests for DiceLoss
ignore_index handling in both index-format B1HW targets and one-hot BNHW
targets, including the include_background=False path and out-of-range
ignore_index values. Extend the existing DiceLoss test coverage to verify
ignored regions do not contribute to the computed loss or backpropagated
gradients, and reference the DiceLoss forward/masking logic when asserting the
expected behavior.

Source: Path instructions

Comment on lines +127 to +148

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎯 Functional Correctness | 🟠 Major | ⚡ Quick win

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

printf '%s\n' '--- monai/losses/dice.py outline ---'
ast-grep outline monai/losses/dice.py --view expanded || true

printf '\n%s\n' '--- relevant ignore_index / valid_mask occurrences ---'
rg -n "ignore_index|valid_mask|include_background|one-hot|one hot" monai/losses/dice.py monai -g '!**/*.pyc' || true

printf '\n%s\n' '--- dice.py around the reviewed block ---'
sed -n '90,180p' monai/losses/dice.py

printf '\n%s\n' '--- nearby tests for DiceLoss / ignore_index ---'
rg -n "DiceLoss|ignore_index|include_background" tests monai -g '*test*.py' || true

Repository: Project-MONAI/MONAI

Length of output: 50375


🏁 Script executed:

#!/bin/bash
set -euo pipefail

printf '%s\n' '--- ignore_index docs and related tests in monai/losses/dice.py ---'
sed -n '60,155p' monai/losses/dice.py

printf '\n%s\n' '--- tests mentioning DiceLoss ignore_index ---'
rg -n "DiceLoss|ignore_index|to_onehot_y" tests -g '*test*.py' | head -n 120

printf '\n%s\n' '--- focused search for one-hot ignore_index behavior ---'
rg -n "one-hot|one_hot|ignore_index" tests monai/losses -g '*test*.py' | head -n 160

Repository: Project-MONAI/MONAI

Length of output: 26324


Mask ignored pixels, not the whole class channel. In the one-hot path this zeros ignore_index everywhere, so ignored pixels still affect the other channels as false positives. Use a spatial mask from that channel so one-hot and index targets behave the same.

🧰 Tools
🪛 Ruff (0.15.20)

[error] 138-138: Variable input is shadowing a Python builtin

(A001)


[error] 147-147: Variable input is shadowing a Python builtin

(A001)

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@monai/losses/dice.py` around lines 127 - 148, The ignore_index handling in
the Dice loss is masking an entire class channel in the one-hot path instead of
only the ignored pixels, so the mask logic in the Dice class should be changed
to produce a spatial valid mask from the ignored channel rather than zeroing the
full channel. Update the masking in the Dice loss flow so `valid_mask` reflects
ignored voxels/pixels per location and then apply it consistently to both
`input` and `target`, keeping one-hot and index targets equivalent while
preserving the existing `include_background` behavior.


reduce_axis: list[int] = torch.arange(2, len(input.shape)).tolist()
if self.batch:
# reducing spatial dimensions and batch
reduce_axis = [0] + reduce_axis

ord = 2 if self.squared_pred else 1
Expand All @@ -198,7 +163,6 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:

num_of_classes = target.shape[1]
if self.class_weight is not None and num_of_classes != 1:
# make sure the lengths of weights are equal to the number of classes
if self.class_weight.ndim == 0:
self.class_weight = torch.as_tensor([self.class_weight] * num_of_classes)
else:
Expand All @@ -209,16 +173,13 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
)
if self.class_weight.min() < 0:
raise ValueError("the value/values of the `weight` should be no less than 0.")
# apply class_weight to loss
f = f * self.class_weight.to(f)

if self.reduction == LossReduction.MEAN.value:
f = torch.mean(f) # the batch and channel average
f = torch.mean(f)
elif self.reduction == LossReduction.SUM.value:
f = torch.sum(f) # sum over the batch and channel dims
f = torch.sum(f)
elif self.reduction == LossReduction.NONE.value:
# If we are not computing voxelwise loss components at least
# make sure a none reduction maintains a broadcastable shape
broadcast_shape = list(f.shape[0:2]) + [1] * (len(input.shape) - 2)
f = f.view(broadcast_shape)
else:
Expand Down
Loading