Feat: add ignore_index support to DiceLoss#8969
Conversation
Signed-off-by: qwepablo12 <grazava6@gmail.com>
📝 WalkthroughWalkthroughDiceLoss gains a new Changes
Estimated code review effort: 3 (Moderate) | ~25 minutes Sequence Diagram(s)sequenceDiagram
participant Caller
participant DiceLoss
Caller->>DiceLoss: forward(input, target)
alt ignore_index set
DiceLoss->>DiceLoss: build valid_mask from target
DiceLoss->>DiceLoss: apply mask to input and target
end
DiceLoss->>DiceLoss: to_onehot_y conversion
DiceLoss->>DiceLoss: include_background handling
DiceLoss->>DiceLoss: compute per-class dice loss
alt class_weight set
DiceLoss->>DiceLoss: apply class_weight to loss
end
DiceLoss->>DiceLoss: apply reduction
DiceLoss-->>Caller: return loss
Related Issues: Not specified in the provided information. Related PRs: Not specified in the provided information. Suggested labels: enhancement, losses Suggested reviewers: Not specified in the provided information. 🐰 A dice loss learned to look away, 🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (2)
monai/losses/dice.py (2)
98-102: 📐 Maintainability & Code Quality | 🔵 Trivial | 💤 Low valueDocument raised exceptions in
forward.
forwardraisesAssertionError(Line 143) andValueError(Line 186) but the trimmed docstring only hasArgs. Add aRaises:section.As per path instructions, docstrings should describe each raised exception in the appropriate Google-style section.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@monai/losses/dice.py` around lines 98 - 102, The forward docstring in DiceLoss currently documents only Args, but the method can raise AssertionError and ValueError. Update the docstring in DiceLoss.forward to add a Google-style Raises section that names both exceptions and briefly states the conditions that trigger them, keeping the docs aligned with the existing forward validation logic.Source: Path instructions
72-75: 📐 Maintainability & Code Quality | 🔵 Trivial | 💤 Low valueConstructor docstring isn't Google-style.
Only the new
ignore_indexis documented; the rest defers to "standard MONAI DiceLoss". Per Google style each arg should be listed underArgs:.As per path instructions, docstrings should describe each variable in the appropriate Google-style section.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@monai/losses/dice.py` around lines 72 - 75, The constructor docstring for DiceLoss is not in Google style and only mentions ignore_index, so update the docstring in the DiceLoss initializer to use an Args: section that explicitly documents every parameter instead of referring to “standard MONAI DiceLoss.” Make sure each argument used by the DiceLoss constructor is listed with a brief description, and include ignore_index in the same section alongside the existing parameters.Source: Path instructions
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@monai/losses/dice.py`:
- Around line 127-148: The ignore_index handling in the Dice loss is masking an
entire class channel in the one-hot path instead of only the ignored pixels, so
the mask logic in the Dice class should be changed to produce a spatial valid
mask from the ignored channel rather than zeroing the full channel. Update the
masking in the Dice loss flow so `valid_mask` reflects ignored voxels/pixels per
location and then apply it consistently to both `input` and `target`, keeping
one-hot and index targets equivalent while preserving the existing
`include_background` behavior.
- Around line 116-148: Add unit tests for DiceLoss ignore_index handling in both
index-format B1HW targets and one-hot BNHW targets, including the
include_background=False path and out-of-range ignore_index values. Extend the
existing DiceLoss test coverage to verify ignored regions do not contribute to
the computed loss or backpropagated gradients, and reference the DiceLoss
forward/masking logic when asserting the expected behavior.
---
Nitpick comments:
In `@monai/losses/dice.py`:
- Around line 98-102: The forward docstring in DiceLoss currently documents only
Args, but the method can raise AssertionError and ValueError. Update the
docstring in DiceLoss.forward to add a Google-style Raises section that names
both exceptions and briefly states the conditions that trigger them, keeping the
docs aligned with the existing forward validation logic.
- Around line 72-75: The constructor docstring for DiceLoss is not in Google
style and only mentions ignore_index, so update the docstring in the DiceLoss
initializer to use an Args: section that explicitly documents every parameter
instead of referring to “standard MONAI DiceLoss.” Make sure each argument used
by the DiceLoss constructor is listed with a brief description, and include
ignore_index in the same section alongside the existing parameters.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: ed022536-cc87-49a6-861c-d8f8312c39e6
📒 Files selected for processing (1)
monai/losses/dice.py
| # Create valid mask if ignore_index is specified and target is in index format | ||
| valid_mask = None | ||
| if self.ignore_index is not None and target.shape[1] == 1: | ||
| valid_mask = (target != self.ignore_index).to(input.dtype) | ||
|
|
||
| if self.to_onehot_y: | ||
| if n_pred_ch == 1: | ||
| warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2) | ||
| else: | ||
| target = one_hot(target, num_classes=n_pred_ch) | ||
|
|
||
| # Create valid mask if target was already one-hot but ignore_index channel is specified | ||
| if self.ignore_index is not None and valid_mask is None: | ||
| if 0 <= self.ignore_index < target.shape[1]: | ||
| valid_mask = torch.ones_like(target) | ||
| valid_mask[:, self.ignore_index] = 0.0 | ||
|
|
||
| if not self.include_background: | ||
| if n_pred_ch == 1: | ||
| warnings.warn("single channel prediction, `include_background=False` ignored.", stacklevel=2) | ||
| else: | ||
| # if skipping background, removing first channel | ||
| target = target[:, 1:] | ||
| input = input[:, 1:] | ||
| if valid_mask is not None and valid_mask.shape[1] == n_pred_ch: | ||
| valid_mask = valid_mask[:, 1:] | ||
|
|
||
| if target.shape != input.shape: | ||
| raise AssertionError(f"ground truth has different shape ({target.shape}) from input ({input.shape})") | ||
|
|
||
| # reducing only spatial dimensions (not batch nor channels) | ||
| # Apply mask to both predictions and targets to exclude ignored regions | ||
| if valid_mask is not None: | ||
| input = input * valid_mask | ||
| target = target * valid_mask |
There was a problem hiding this comment.
📐 Maintainability & Code Quality | 🟡 Minor | ⚡ Quick win
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Find existing DiceLoss tests and any ignore_index coverage
fd -t f -i 'dice' tests test | head
rg -nP 'ignore_index' -g '*test*' -C2Repository: Project-MONAI/MONAI
Length of output: 616
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Map relevant files and search for ignore_index coverage precisely.
ast-grep outline monai/losses/dice.py --view expanded || true
printf '\n--- test_dice_loss outline ---\n'
ast-grep outline tests/losses/test_dice_loss.py --view expanded || true
printf '\n--- ignore_index matches in loss tests ---\n'
rg -n "ignore_index" tests/losses
printf '\n--- relevant slices in test_dice_loss.py ---\n'
sed -n '1,260p' tests/losses/test_dice_loss.pyRepository: Project-MONAI/MONAI
Length of output: 2353
🏁 Script executed:
#!/bin/bash
set -euo pipefail
printf '--- monai/losses/dice.py (constructor + forward slice) ---\n'
sed -n '55,190p' monai/losses/dice.py
printf '\n--- tests/losses/test_dice_loss.py (full relevant file) ---\n'
sed -n '1,260p' tests/losses/test_dice_loss.pyRepository: Project-MONAI/MONAI
Length of output: 248
🏁 Script executed:
#!/bin/bash
set -euo pipefail
printf '%s\n' '--- monai/losses/dice.py (constructor + forward slice) ---'
sed -n '55,190p' monai/losses/dice.py
printf '\n%s\n' '--- tests/losses/test_dice_loss.py (full relevant file) ---'
sed -n '1,260p' tests/losses/test_dice_loss.pyRepository: Project-MONAI/MONAI
Length of output: 15319
Add ignore_index unit tests. Cover B1HW and one-hot BNHW targets, include_background=False, and out-of-range values; assert ignored regions don’t affect loss or gradients.
🧰 Tools
🪛 Ruff (0.15.20)
[error] 138-138: Variable input is shadowing a Python builtin
(A001)
[error] 147-147: Variable input is shadowing a Python builtin
(A001)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@monai/losses/dice.py` around lines 116 - 148, Add unit tests for DiceLoss
ignore_index handling in both index-format B1HW targets and one-hot BNHW
targets, including the include_background=False path and out-of-range
ignore_index values. Extend the existing DiceLoss test coverage to verify
ignored regions do not contribute to the computed loss or backpropagated
gradients, and reference the DiceLoss forward/masking logic when asserting the
expected behavior.
Source: Path instructions
| # Create valid mask if target was already one-hot but ignore_index channel is specified | ||
| if self.ignore_index is not None and valid_mask is None: | ||
| if 0 <= self.ignore_index < target.shape[1]: | ||
| valid_mask = torch.ones_like(target) | ||
| valid_mask[:, self.ignore_index] = 0.0 | ||
|
|
||
| if not self.include_background: | ||
| if n_pred_ch == 1: | ||
| warnings.warn("single channel prediction, `include_background=False` ignored.", stacklevel=2) | ||
| else: | ||
| # if skipping background, removing first channel | ||
| target = target[:, 1:] | ||
| input = input[:, 1:] | ||
| if valid_mask is not None and valid_mask.shape[1] == n_pred_ch: | ||
| valid_mask = valid_mask[:, 1:] | ||
|
|
||
| if target.shape != input.shape: | ||
| raise AssertionError(f"ground truth has different shape ({target.shape}) from input ({input.shape})") | ||
|
|
||
| # reducing only spatial dimensions (not batch nor channels) | ||
| # Apply mask to both predictions and targets to exclude ignored regions | ||
| if valid_mask is not None: | ||
| input = input * valid_mask | ||
| target = target * valid_mask |
There was a problem hiding this comment.
🎯 Functional Correctness | 🟠 Major | ⚡ Quick win
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
set -euo pipefail
printf '%s\n' '--- monai/losses/dice.py outline ---'
ast-grep outline monai/losses/dice.py --view expanded || true
printf '\n%s\n' '--- relevant ignore_index / valid_mask occurrences ---'
rg -n "ignore_index|valid_mask|include_background|one-hot|one hot" monai/losses/dice.py monai -g '!**/*.pyc' || true
printf '\n%s\n' '--- dice.py around the reviewed block ---'
sed -n '90,180p' monai/losses/dice.py
printf '\n%s\n' '--- nearby tests for DiceLoss / ignore_index ---'
rg -n "DiceLoss|ignore_index|include_background" tests monai -g '*test*.py' || trueRepository: Project-MONAI/MONAI
Length of output: 50375
🏁 Script executed:
#!/bin/bash
set -euo pipefail
printf '%s\n' '--- ignore_index docs and related tests in monai/losses/dice.py ---'
sed -n '60,155p' monai/losses/dice.py
printf '\n%s\n' '--- tests mentioning DiceLoss ignore_index ---'
rg -n "DiceLoss|ignore_index|to_onehot_y" tests -g '*test*.py' | head -n 120
printf '\n%s\n' '--- focused search for one-hot ignore_index behavior ---'
rg -n "one-hot|one_hot|ignore_index" tests monai/losses -g '*test*.py' | head -n 160Repository: Project-MONAI/MONAI
Length of output: 26324
Mask ignored pixels, not the whole class channel. In the one-hot path this zeros ignore_index everywhere, so ignored pixels still affect the other channels as false positives. Use a spatial mask from that channel so one-hot and index targets behave the same.
🧰 Tools
🪛 Ruff (0.15.20)
[error] 138-138: Variable input is shadowing a Python builtin
(A001)
[error] 147-147: Variable input is shadowing a Python builtin
(A001)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@monai/losses/dice.py` around lines 127 - 148, The ignore_index handling in
the Dice loss is masking an entire class channel in the one-hot path instead of
only the ignored pixels, so the mask logic in the Dice class should be changed
to produce a spatial valid mask from the ignored channel rather than zeroing the
full channel. Update the masking in the Dice loss flow so `valid_mask` reflects
ignored voxels/pixels per location and then apply it consistently to both
`input` and `target`, keeping one-hot and index targets equivalent while
preserving the existing `include_background` behavior.
Description
Addresses #8734.
This PR introduces
ignore_indexsupport forDiceLoss. Whenignore_indexis specified, the target regions matching this index are masked out and completely excluded from the Dice coefficient computation (both numerator and denominator). This ensures that ignored regions/classes do not affect the loss values or gradients during training.Types of changes
./runtests.sh -f -u --net --coverage../runtests.sh --quick --unittests --disttests.make htmlcommand in thedocs/folder.