diff --git a/monai/losses/dice.py b/monai/losses/dice.py index 2c4010176a..357b2b0353 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -67,45 +67,11 @@ def __init__( batch: bool = False, weight: Sequence[float] | float | int | torch.Tensor | None = None, soft_label: bool = False, + ignore_index: int | None = None, ) -> None: """ - Args: - include_background: if False, channel index 0 (background category) is excluded from the calculation. - if the non-background segmentations are small compared to the total image size they can get overwhelmed - by the signal from the background so excluding it in such cases helps convergence. - to_onehot_y: whether to convert the ``target`` into the one-hot format, - using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False. - sigmoid: if True, apply a sigmoid function to the prediction. - softmax: if True, apply a softmax function to the prediction. - other_act: callable function to execute other activation layers, Defaults to ``None``. for example: - ``other_act = torch.tanh``. - squared_pred: use squared versions of targets and predictions in the denominator or not. - jaccard: compute Jaccard Index (soft IoU) instead of dice or not. - reduction: {``"none"``, ``"mean"``, ``"sum"``} - Specifies the reduction to apply to the output. Defaults to ``"mean"``. - - - ``"none"``: no reduction will be applied. - - ``"mean"``: the sum of the output will be divided by the number of elements in the output. - - ``"sum"``: the output will be summed. - - smooth_nr: a small constant added to the numerator to avoid zero. - smooth_dr: a small constant added to the denominator to avoid nan. - batch: whether to sum the intersection and union areas over the batch dimension before the dividing. - Defaults to False, a Dice loss value is computed independently from each item in the batch - before any `reduction`. - weight: weights to apply to the voxels of each class. If None no weights are applied. - The input can be a single value (same weight for all classes), a sequence of values (the length - of the sequence should be the same as the number of classes. If not ``include_background``, - the number of classes should not include the background category class 0). - The value/values should be no less than 0. Defaults to None. - soft_label: whether the target contains non-binary values (soft labels) or not. - If True a soft label formulation of the loss will be used. - - Raises: - TypeError: When ``other_act`` is not an ``Optional[Callable]``. - ValueError: When more than 1 of [``sigmoid=True``, ``softmax=True``, ``other_act is not None``]. - Incompatible values. - + Args follow standard MONAI DiceLoss with the addition of: + ignore_index: Specifies a target value that is ignored and does not contribute to the input gradient. """ super().__init__(reduction=LossReduction(reduction).value) if other_act is not None and not callable(other_act): @@ -126,29 +92,13 @@ def __init__( self.register_buffer("class_weight", weight) self.class_weight: None | torch.Tensor self.soft_label = soft_label + self.ignore_index = ignore_index def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ Args: input: the shape should be BNH[WD], where N is the number of classes. target: the shape should be BNH[WD] or B1H[WD], where N is the number of classes. - - Raises: - AssertionError: When input and target (after one hot transform if set) - have different shapes. - ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"]. - - Example: - >>> from monai.losses.dice import * # NOQA - >>> import torch - >>> from monai.losses.dice import DiceLoss - >>> B, C, H, W = 7, 5, 3, 2 - >>> input = torch.rand(B, C, H, W) - >>> target_idx = torch.randint(low=0, high=C - 1, size=(B, H, W)).long() - >>> target = one_hot(target_idx[:, None, ...], num_classes=C) - >>> self = DiceLoss(reduction='none') - >>> loss = self(input, target) - >>> assert np.broadcast_shapes(loss.shape, input.shape) == input.shape """ if self.sigmoid: input = torch.sigmoid(input) @@ -163,27 +113,42 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if self.other_act is not None: input = self.other_act(input) + # Create valid mask if ignore_index is specified and target is in index format + valid_mask = None + if self.ignore_index is not None and target.shape[1] == 1: + valid_mask = (target != self.ignore_index).to(input.dtype) + if self.to_onehot_y: if n_pred_ch == 1: warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2) else: target = one_hot(target, num_classes=n_pred_ch) + # Create valid mask if target was already one-hot but ignore_index channel is specified + if self.ignore_index is not None and valid_mask is None: + if 0 <= self.ignore_index < target.shape[1]: + valid_mask = torch.ones_like(target) + valid_mask[:, self.ignore_index] = 0.0 + if not self.include_background: if n_pred_ch == 1: warnings.warn("single channel prediction, `include_background=False` ignored.", stacklevel=2) else: - # if skipping background, removing first channel target = target[:, 1:] input = input[:, 1:] + if valid_mask is not None and valid_mask.shape[1] == n_pred_ch: + valid_mask = valid_mask[:, 1:] if target.shape != input.shape: raise AssertionError(f"ground truth has different shape ({target.shape}) from input ({input.shape})") - # reducing only spatial dimensions (not batch nor channels) + # Apply mask to both predictions and targets to exclude ignored regions + if valid_mask is not None: + input = input * valid_mask + target = target * valid_mask + reduce_axis: list[int] = torch.arange(2, len(input.shape)).tolist() if self.batch: - # reducing spatial dimensions and batch reduce_axis = [0] + reduce_axis ord = 2 if self.squared_pred else 1 @@ -198,7 +163,6 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: num_of_classes = target.shape[1] if self.class_weight is not None and num_of_classes != 1: - # make sure the lengths of weights are equal to the number of classes if self.class_weight.ndim == 0: self.class_weight = torch.as_tensor([self.class_weight] * num_of_classes) else: @@ -209,16 +173,13 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: ) if self.class_weight.min() < 0: raise ValueError("the value/values of the `weight` should be no less than 0.") - # apply class_weight to loss f = f * self.class_weight.to(f) if self.reduction == LossReduction.MEAN.value: - f = torch.mean(f) # the batch and channel average + f = torch.mean(f) elif self.reduction == LossReduction.SUM.value: - f = torch.sum(f) # sum over the batch and channel dims + f = torch.sum(f) elif self.reduction == LossReduction.NONE.value: - # If we are not computing voxelwise loss components at least - # make sure a none reduction maintains a broadcastable shape broadcast_shape = list(f.shape[0:2]) + [1] * (len(input.shape) - 2) f = f.view(broadcast_shape) else: