Skip to content
Open
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
8e5fe51
feat: implement ignore_index support in metrics and losses with dedic…
Rusheel86 Feb 28, 2026
a1b0a4f
test: run compute tests, format and lint loss/metric updates
Rusheel86 Mar 1, 2026
f2caaf8
feat: implement ignore_index support for losses and metrics
Rusheel86 Mar 9, 2026
d075009
chore: trigger CI rerun
Rusheel86 Mar 9, 2026
941a73b
chore: trigger CI rerun
Rusheel86 Mar 9, 2026
0f6e05a
DCO Remediation Commit for Rusheel Sharma <rusheelhere@gmail.com>
Rusheel86 Mar 9, 2026
a1f6ef4
fix: revert GWDL reduction handling and apply black formatting
Rusheel86 Mar 9, 2026
f01cbc4
fix: resolve shape issues and CI fails
Rusheel86 Mar 10, 2026
af83422
style: reformat with black 25.11.0
Rusheel86 Mar 10, 2026
187da14
fix: resolve mypy type error in utils.py
Rusheel86 Mar 11, 2026
780b567
fix: complete ignore_index implementation with proper one-hot masking
Rusheel86 Mar 11, 2026
91bb2e5
fix: resolve mypy union-attr error in unified_focal_loss
Rusheel86 Mar 12, 2026
57c2f78
chore: trigger CI with fresh runner
Rusheel86 Mar 12, 2026
9863a93
chore: retrigger CI (previous runs had disk space errors)
Rusheel86 Mar 12, 2026
170f34a
fix: address CodeRabbit minor issues
Rusheel86 Mar 12, 2026
c80eeeb
fix: address CodeRabbit critical and major issues
Rusheel86 Mar 12, 2026
1114907
fix: resolve all mypy and CodeRabbit issues
Rusheel86 Mar 13, 2026
3bd76e7
fix:CodeRabbit issues
Rusheel86 Mar 13, 2026
610756d
refactor: centralize ignore_index masking into create_ignore_mask helper
Rusheel86 Mar 19, 2026
754407b
Fix docstring indentation in create_ignore_mask
Rusheel86 Mar 19, 2026
b6d2362
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 19, 2026
edc9e4e
Fix NoneType error in AsymmetricFocalTverskyLoss for None ignore_index
Rusheel86 Mar 19, 2026
4457e37
Merge branch 'feat-ignore-index-support' of https://github.com/Rushee…
Rusheel86 Mar 19, 2026
c2612ea
style: fix import sorting with isort
Rusheel86 Mar 19, 2026
cfc54ec
fix: CI errors
Rusheel86 Mar 19, 2026
eeda3c7
chore: format and lint code
Rusheel86 Mar 19, 2026
64421e1
Fix : lint and format
Rusheel86 Mar 19, 2026
df0833b
chore: trigger CI
Rusheel86 Mar 19, 2026
9cb6592
style: format with black --skip-magic-trailing-comma for Python 3.9 c…
Rusheel86 Mar 20, 2026
03e5e9b
fix: add type ignore comment for mypy no-any-return in utils.py
Rusheel86 Mar 20, 2026
5fb4d4f
style: reformat utils.py after adding type ignore comment
Rusheel86 Mar 20, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions monai/losses/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def __init__(
batch: bool = False,
weight: Sequence[float] | float | int | torch.Tensor | None = None,
soft_label: bool = False,
ignore_index: int | None = None,
) -> None:
"""
Args:
Expand Down Expand Up @@ -100,6 +101,7 @@ def __init__(
The value/values should be no less than 0. Defaults to None.
soft_label: whether the target contains non-binary values (soft labels) or not.
If True a soft label formulation of the loss will be used.
ignore_index: class index to ignore from the loss computation.

Raises:
TypeError: When ``other_act`` is not an ``Optional[Callable]``.
Expand All @@ -122,6 +124,7 @@ def __init__(
self.smooth_nr = float(smooth_nr)
self.smooth_dr = float(smooth_dr)
self.batch = batch
self.ignore_index = ignore_index
weight = torch.as_tensor(weight) if weight is not None else None
self.register_buffer("class_weight", weight)
self.class_weight: None | torch.Tensor
Expand Down Expand Up @@ -163,6 +166,10 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
if self.other_act is not None:
input = self.other_act(input)

mask: torch.Tensor | None = None
if self.ignore_index is not None:
mask = (target != self.ignore_index).float()

if self.to_onehot_y:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
Expand All @@ -180,6 +187,10 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
if target.shape != input.shape:
raise AssertionError(f"ground truth has different shape ({target.shape}) from input ({input.shape})")

if mask is not None:
input = input * mask
target = target * mask

# reducing only spatial dimensions (not batch nor channels)
reduce_axis: list[int] = torch.arange(2, len(input.shape)).tolist()
if self.batch:
Expand Down
8 changes: 8 additions & 0 deletions monai/losses/focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def __init__(
weight: Sequence[float] | float | int | torch.Tensor | None = None,
reduction: LossReduction | str = LossReduction.MEAN,
use_softmax: bool = False,
ignore_index: int | None = None,
) -> None:
"""
Args:
Expand All @@ -99,6 +100,7 @@ def __init__(

use_softmax: whether to use softmax to transform the original logits into probabilities.
If True, softmax is used. If False, sigmoid is used. Defaults to False.
ignore_index: class index to ignore from the loss computation.

Example:
>>> import torch
Expand All @@ -124,6 +126,7 @@ def __init__(
weight = torch.as_tensor(weight) if weight is not None else None
self.register_buffer("class_weight", weight)
self.class_weight: None | torch.Tensor
self.ignore_index = ignore_index

def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Expand Down Expand Up @@ -161,6 +164,11 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
if target.shape != input.shape:
raise ValueError(f"ground truth has different shape ({target.shape}) from input ({input.shape})")

if self.ignore_index is not None:
mask = (target != self.ignore_index).float()
input = input * mask
target = target * mask

Comment thread
Rusheel86 marked this conversation as resolved.
Outdated
loss: torch.Tensor | None = None
input = input.float()
target = target.float()
Expand Down
16 changes: 16 additions & 0 deletions monai/losses/tversky.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(
smooth_dr: float = 1e-5,
batch: bool = False,
soft_label: bool = False,
ignore_index: int | None = None,
) -> None:
"""
Args:
Expand All @@ -77,6 +78,7 @@ def __init__(
before any `reduction`.
soft_label: whether the target contains non-binary values (soft labels) or not.
If True a soft label formulation of the loss will be used.
ignore_index: index of the class to ignore during calculation.

Raises:
TypeError: When ``other_act`` is not an ``Optional[Callable]``.
Expand All @@ -101,6 +103,7 @@ def __init__(
self.smooth_dr = float(smooth_dr)
self.batch = batch
self.soft_label = soft_label
self.ignore_index = ignore_index

def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Expand Down Expand Up @@ -129,8 +132,21 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
else:
original_target = target
target = one_hot(target, num_classes=n_pred_ch)
Comment thread
Rusheel86 marked this conversation as resolved.

if self.ignore_index is not None:
mask_src = original_target if self.to_onehot_y and n_pred_ch > 1 else target

if mask_src.shape[1] == 1:
mask = (mask_src != self.ignore_index).to(input.dtype)
else:
# Fallback for cases where target is already one-hot
mask = (1.0 - mask_src[:, self.ignore_index : self.ignore_index + 1]).to(input.dtype)

input = input * mask
target = target * mask

if not self.include_background:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `include_background=False` ignored.")
Expand Down
125 changes: 102 additions & 23 deletions monai/losses/unified_focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,48 +39,76 @@ def __init__(
gamma: float = 0.75,
epsilon: float = 1e-7,
reduction: LossReduction | str = LossReduction.MEAN,
ignore_index: int | None = None,
) -> None:
"""
Args:
to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
delta : weight of the background. Defaults to 0.7.
gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75.
epsilon : it defines a very small number each time. similarly smooth value. Defaults to 1e-7.
ignore_index: class index to ignore from the loss computation.
"""
super().__init__(reduction=LossReduction(reduction).value)
self.to_onehot_y = to_onehot_y
self.delta = delta
self.gamma = gamma
self.epsilon = epsilon
self.ignore_index = ignore_index

def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
n_pred_ch = y_pred.shape[1]

# Save original for masking
original_y_true = y_true if self.ignore_index is not None else None

if self.to_onehot_y:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
else:
if self.ignore_index is not None:
# Replace ignore_index with valid class before one_hot
y_true = torch.where(y_true == self.ignore_index, torch.tensor(0, device=y_true.device), y_true)
y_true = one_hot(y_true, num_classes=n_pred_ch)

if y_true.shape != y_pred.shape:
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")

# clip the prediction to avoid NaN
# Build mask after one_hot conversion
mask = torch.ones_like(y_true)
if self.ignore_index is not None:
if original_y_true is not None and self.to_onehot_y:
# Use original labels to build spatial mask
spatial_mask = (original_y_true != self.ignore_index).float()
elif self.ignore_index < y_true.shape[1]:
# For already one-hot: use ignored class channel
spatial_mask = 1.0 - y_true[:, self.ignore_index : self.ignore_index + 1]
else:
# For sentinel values: any valid channel
spatial_mask = (y_true.sum(dim=1, keepdim=True) > 0).float()
mask = spatial_mask.expand_as(y_true)
y_pred = y_pred * mask
y_true = y_true * mask

y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon)
axis = list(range(2, len(y_pred.shape)))

# Calculate true positives (tp), false negatives (fn) and false positives (fp)
tp = torch.sum(y_true * y_pred, dim=axis)
fn = torch.sum(y_true * (1 - y_pred), dim=axis)
fp = torch.sum((1 - y_true) * y_pred, dim=axis)
fn = torch.sum(y_true * (1 - y_pred) * mask, dim=axis)
fp = torch.sum((1 - y_true) * y_pred * mask, dim=axis)
dice_class = (tp + self.epsilon) / (tp + self.delta * fn + (1 - self.delta) * fp + self.epsilon)

# Calculate losses separately for each class, enhancing both classes
back_dice = 1 - dice_class[:, 0]
fore_dice = (1 - dice_class[:, 1]) * torch.pow(1 - dice_class[:, 1], -self.gamma)
fore_dice = torch.pow(1 - dice_class[:, 1], 1 - self.gamma)

# Average class scores
loss = torch.mean(torch.stack([back_dice, fore_dice], dim=-1))
loss = torch.stack([back_dice, fore_dice], dim=-1)
if self.reduction == LossReduction.MEAN.value:
return torch.mean(loss)
if self.reduction == LossReduction.SUM.value:
return torch.sum(loss)
return loss


Expand All @@ -103,27 +131,36 @@ def __init__(
gamma: float = 2,
epsilon: float = 1e-7,
reduction: LossReduction | str = LossReduction.MEAN,
ignore_index: int | None = None,
):
"""
Args:
to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False.
delta : weight of the background. Defaults to 0.7.
gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 2.
epsilon : it defines a very small number each time. similarly smooth value. Defaults to 1e-7.
ignore_index: class index to ignore from the loss computation.
"""
super().__init__(reduction=LossReduction(reduction).value)
self.to_onehot_y = to_onehot_y
self.delta = delta
self.gamma = gamma
self.epsilon = epsilon
self.ignore_index = ignore_index

def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
n_pred_ch = y_pred.shape[1]

# Save original for masking
original_y_true = y_true if self.ignore_index is not None else None

if self.to_onehot_y:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
else:
if self.ignore_index is not None:
# Replace ignore_index with valid class before one_hot
y_true = torch.where(y_true == self.ignore_index, torch.tensor(0, device=y_true.device), y_true)
y_true = one_hot(y_true, num_classes=n_pred_ch)

if y_true.shape != y_pred.shape:
Expand All @@ -132,13 +169,36 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon)
cross_entropy = -y_true * torch.log(y_pred)

# Build mask from original labels if available
spatial_mask: torch.Tensor | None = None
if self.ignore_index is not None:
if original_y_true is not None and self.to_onehot_y:
spatial_mask = (original_y_true != self.ignore_index).float()
elif self.ignore_index < y_true.shape[1]:
spatial_mask = 1.0 - y_true[:, self.ignore_index : self.ignore_index + 1]
else:
spatial_mask = (y_true.sum(dim=1, keepdim=True) > 0).float()

if spatial_mask is not None:
cross_entropy = cross_entropy * spatial_mask.expand_as(cross_entropy)

back_ce = torch.pow(1 - y_pred[:, 0], self.gamma) * cross_entropy[:, 0]
back_ce = (1 - self.delta) * back_ce

fore_ce = cross_entropy[:, 1]
fore_ce = self.delta * fore_ce

loss = torch.mean(torch.sum(torch.stack([back_ce, fore_ce], dim=1), dim=1))
loss = torch.stack([back_ce, fore_ce], dim=1) # [B, 2, H, W]

if self.reduction == LossReduction.MEAN.value:
if self.ignore_index is not None and spatial_mask is not None:
# Apply mask to loss, then average over valid elements only
# loss has shape [B, 2, H, W], spatial_mask has shape [B, 1, H, W]
masked_loss = loss * spatial_mask.expand_as(loss)
return masked_loss.sum() / (spatial_mask.expand_as(loss).sum().clamp(min=1e-5))
return loss.mean()
Comment thread
coderabbitai[bot] marked this conversation as resolved.
if self.reduction == LossReduction.SUM.value:
return loss.sum()
return loss


Expand All @@ -162,6 +222,7 @@ def __init__(
gamma: float = 0.5,
delta: float = 0.7,
reduction: LossReduction | str = LossReduction.MEAN,
ignore_index: int | None = None,
):
"""
Args:
Expand All @@ -170,8 +231,7 @@ def __init__(
weight : weight for each loss function. Defaults to 0.5.
gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.5.
delta : weight of the background. Defaults to 0.7.


ignore_index: class index to ignore from the loss computation.

Example:
>>> import torch
Expand All @@ -187,10 +247,12 @@ def __init__(
self.gamma = gamma
self.delta = delta
self.weight: float = weight
self.asy_focal_loss = AsymmetricFocalLoss(gamma=self.gamma, delta=self.delta)
self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(gamma=self.gamma, delta=self.delta)
self.asy_focal_loss = AsymmetricFocalLoss(gamma=self.gamma, delta=self.delta, ignore_index=ignore_index)
self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(
gamma=self.gamma, delta=self.delta, ignore_index=ignore_index
)
self.ignore_index = ignore_index

# TODO: Implement this function to support multiple classes segmentation
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
"""
Args:
Expand All @@ -207,25 +269,42 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
ValueError: When num_classes
ValueError: When the number of classes entered does not match the expected number
"""
if y_pred.shape != y_true.shape:
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")

if len(y_pred.shape) != 4 and len(y_pred.shape) != 5:
raise ValueError(f"input shape must be 4 or 5, but got {y_pred.shape}")

# Transform binary inputs to 2-channel space
if y_pred.shape[1] == 1:
y_pred = one_hot(y_pred, num_classes=self.num_classes)
y_true = one_hot(y_true, num_classes=self.num_classes)
y_pred = torch.cat([1 - y_pred, y_pred], dim=1)

if torch.max(y_true) != self.num_classes - 1:
raise ValueError(f"Please make sure the number of classes is {self.num_classes - 1}")

n_pred_ch = y_pred.shape[1]
# Move one_hot conversion OUTSIDE the if y_pred.shape[1] == 1 block
if self.to_onehot_y:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
if self.ignore_index is not None:
mask = (y_true != self.ignore_index).float()
y_true_clean = torch.where(y_true == self.ignore_index, 0, y_true)
y_true = one_hot(y_true_clean, num_classes=self.num_classes)
# Keep the channel-wise mask
y_true = y_true * mask
else:
y_true = one_hot(y_true, num_classes=n_pred_ch)
y_true = one_hot(y_true, num_classes=self.num_classes)

# Check if shapes match
if y_true.shape[1] == 1 and y_pred.shape[1] == 2:
if self.ignore_index is not None:
# Create mask for valid pixels
mask = (y_true != self.ignore_index).float()
# Set ignore_index values to 0 before conversion
y_true_clean = y_true * mask
# Convert to 2-channel
y_true = torch.cat([1 - y_true_clean, y_true_clean], dim=1)
# Apply mask to both channels so ignored pixels are all zeros
y_true = y_true * mask
else:
y_true = torch.cat([1 - y_true, y_true], dim=1)

if y_true.shape != y_pred.shape:
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")
if self.ignore_index is None and torch.max(y_true) > self.num_classes - 1:
raise ValueError(f"Invalid class index found. Maximum class should be {self.num_classes - 1}")

asy_focal_loss = self.asy_focal_loss(y_pred, y_true)
asy_focal_tversky_loss = self.asy_focal_tversky_loss(y_pred, y_true)
Expand Down
Loading
Loading