From ad8344446128742fa05153b2db75e2ad09620426 Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Thu, 30 Oct 2025 10:24:24 +0800 Subject: [PATCH 1/3] Generalize AsymmetricUnifiedFocalLoss for multi-class and align interface Signed-off-by: ytl0623 --- monai/losses/unified_focal_loss.py | 276 +++++++++++++++++------------ 1 file changed, 166 insertions(+), 110 deletions(-) diff --git a/monai/losses/unified_focal_loss.py b/monai/losses/unified_focal_loss.py index 8484eb67ed..dcfb223b4f 100644 --- a/monai/losses/unified_focal_loss.py +++ b/monai/losses/unified_focal_loss.py @@ -14,6 +14,7 @@ import warnings import torch +import torch.nn.functional as F from torch.nn.modules.loss import _Loss from monai.networks import one_hot @@ -23,18 +24,17 @@ class AsymmetricFocalTverskyLoss(_Loss): """ AsymmetricFocalTverskyLoss is a variant of FocalTverskyLoss, which attentions to the foreground class. - - Actually, it's only supported for binary image segmentation now. + It treats the background class (index 0) differently from all foreground classes (indices 1...N). Reimplementation of the Asymmetric Focal Tversky Loss described in: - "Unified Focal Loss: Generalising Dice and Cross Entropy-based Losses to Handle Class Imbalanced Medical Image Segmentation", - Michael Yeung, Computerized Medical Imaging and Graphics + Michael Yeung, Computerized Medical Imaging and Graphics """ def __init__( self, - to_onehot_y: bool = False, + include_background: bool = True, delta: float = 0.7, gamma: float = 0.75, epsilon: float = 1e-7, @@ -42,13 +42,14 @@ def __init__( ) -> None: """ Args: - to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False. - delta : weight of the background. Defaults to 0.7. - gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75. - epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7. + include_background: whether to include loss computation for the background class. Defaults to True. + delta : weight of the background. Defaults to 0.7. (Used to weigh FNs and FPs in Tversky index) + gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75. + epsilon : it defines a very small number each time. similarly smooth value. Defaults to 1e-7. + reduction: specifies the reduction to apply to the output: "none", "mean", "sum". """ super().__init__(reduction=LossReduction(reduction).value) - self.to_onehot_y = to_onehot_y + self.include_background = include_background self.delta = delta self.gamma = gamma self.epsilon = epsilon @@ -56,16 +57,18 @@ def __init__( def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: n_pred_ch = y_pred.shape[1] - if self.to_onehot_y: - if n_pred_ch == 1: - warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") - else: - y_true = one_hot(y_true, num_classes=n_pred_ch) - if y_true.shape != y_pred.shape: raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})") - # clip the prediction to avoid NaN + # Exclude background if needed + if not self.include_background: + if n_pred_ch == 1: + warnings.warn("single channel prediction, `include_background=False` ignored.") + else: + y_pred = y_pred[:, 1:] + y_true = y_true[:, 1:] + + # Clip predictions y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon) axis = list(range(2, len(y_pred.shape))) @@ -74,45 +77,63 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: fn = torch.sum(y_true * (1 - y_pred), dim=axis) fp = torch.sum((1 - y_true) * y_pred, dim=axis) dice_class = (tp + self.epsilon) / (tp + self.delta * fn + (1 - self.delta) * fp + self.epsilon) - - # Calculate losses separately for each class, enhancing both classes - back_dice = 1 - dice_class[:, 0] - fore_dice = (1 - dice_class[:, 1]) * torch.pow(1 - dice_class[:, 1], -self.gamma) - - # Average class scores - loss = torch.mean(torch.stack([back_dice, fore_dice], dim=-1)) - return loss + # dice_class shape is (B, C) + + n_classes = dice_class.shape[1] + + if not self.include_background: + # All classes are foreground, apply foreground logic + loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, C) + elif n_classes == 1: + # Single class, must be foreground (BG was excluded or not provided) + loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, 1) + else: + # Asymmetric logic: class 0 is BG, others are FG + back_dice_loss = (1.0 - dice_class[:, 0]).unsqueeze(1) # (B, 1) + fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 - self.gamma) # (B, C-1) + loss = torch.cat([back_dice_loss, fore_dice_loss], dim=1) # (B, C) + + # Apply reduction + if self.reduction == LossReduction.MEAN.value: + return torch.mean(loss) # mean over batch and classes + if self.reduction == LossReduction.SUM.value: + return torch.sum(loss) # sum over batch and classes + if self.reduction == LossReduction.NONE.value: + return loss # returns (B, C) losses + raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') class AsymmetricFocalLoss(_Loss): """ - AsymmetricFocalLoss is a variant of FocalTverskyLoss, which attentions to the foreground class. - - Actually, it's only supported for binary image segmentation now. + AsymmetricFocalLoss is a variant of FocalLoss, which attentions to the foreground class. + It treats the background class (index 0) differently from all foreground classes (indices 1...N). + Background class (0): applies gamma exponent to (1-p) + Foreground classes (1..N): no gamma exponent Reimplementation of the Asymmetric Focal Loss described in: - "Unified Focal Loss: Generalising Dice and Cross Entropy-based Losses to Handle Class Imbalanced Medical Image Segmentation", - Michael Yeung, Computerized Medical Imaging and Graphics + Michael Yeung, Computerized Medical Imaging and Graphics """ def __init__( self, - to_onehot_y: bool = False, + include_background: bool = True, delta: float = 0.7, - gamma: float = 2, + gamma: float = 2.0, epsilon: float = 1e-7, reduction: LossReduction | str = LossReduction.MEAN, ): """ Args: - to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False. - delta : weight of the background. Defaults to 0.7. - gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75. - epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7. + include_background: whether to include loss computation for the background class. Defaults to True. + delta : weight of the foreground. Defaults to 0.7. (1-delta is weight of background) + gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 2.0. + epsilon : it defines a very small number each time. similarly smooth value. Defaults to 1e-7. + reduction: specifies the reduction to apply to the output: "none", "mean", "sum". """ super().__init__(reduction=LossReduction(reduction).value) - self.to_onehot_y = to_onehot_y + self.include_background = include_background self.delta = delta self.gamma = gamma self.epsilon = epsilon @@ -120,121 +141,156 @@ def __init__( def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: n_pred_ch = y_pred.shape[1] - if self.to_onehot_y: - if n_pred_ch == 1: - warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") - else: - y_true = one_hot(y_true, num_classes=n_pred_ch) - if y_true.shape != y_pred.shape: raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})") - y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon) - cross_entropy = -y_true * torch.log(y_pred) - - back_ce = torch.pow(1 - y_pred[:, 0], self.gamma) * cross_entropy[:, 0] - back_ce = (1 - self.delta) * back_ce - - fore_ce = cross_entropy[:, 1] - fore_ce = self.delta * fore_ce + # Exclude background if needed + if not self.include_background: + if n_pred_ch == 1: + warnings.warn("single channel prediction, `include_background=False` ignored.") + else: + y_pred = y_pred[:, 1:] + y_true = y_true[:, 1:] - loss = torch.mean(torch.sum(torch.stack([back_ce, fore_ce], dim=1), dim=1)) - return loss + y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon) + cross_entropy = -y_true * torch.log(y_pred) # Shape (B, C, H, W, [D]) + + n_classes = y_pred.shape[1] + + if not self.include_background: + # All classes are foreground, apply foreground logic + loss = self.delta * cross_entropy # (B, C, H, W) + elif n_classes == 1: + # Single class, must be foreground + loss = self.delta * cross_entropy # (B, 1, H, W) + else: + # Asymmetric logic: class 0 is BG, others are FG + # (B, H, W) + back_ce = (1.0 - self.delta) * torch.pow(1.0 - y_pred[:, 0], self.gamma) * cross_entropy[:, 0] + # (B, C-1, H, W) + fore_ce = self.delta * cross_entropy[:, 1:] + + loss = torch.cat([back_ce.unsqueeze(1), fore_ce], dim=1) # (B, C, H, W) + + # Apply reduction + if self.reduction == LossReduction.MEAN.value: + return torch.mean(loss) # mean over batch, class, and spatial + if self.reduction == LossReduction.SUM.value: + return torch.sum(loss) + if self.reduction == LossReduction.NONE.value: + return loss # returns (B, C, H, W) + raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') class AsymmetricUnifiedFocalLoss(_Loss): """ - AsymmetricUnifiedFocalLoss is a variant of Focal Loss. - - Actually, it's only supported for binary image segmentation now + AsymmetricUnifiedFocalLoss is a variant of Focal Loss, combining AsymmetricFocalLoss + and AsymmetricFocalTverskyLoss. Reimplementation of the Asymmetric Unified Focal Tversky Loss described in: - "Unified Focal Loss: Generalising Dice and Cross Entropy-based Losses to Handle Class Imbalanced Medical Image Segmentation", - Michael Yeung, Computerized Medical Imaging and Graphics + Michael Yeung, Computerized Medical Imaging and Graphics """ def __init__( self, + include_background: bool = True, to_onehot_y: bool = False, - num_classes: int = 2, - weight: float = 0.5, - gamma: float = 0.5, - delta: float = 0.7, + sigmoid: bool = False, + softmax: bool = False, + lambda_focal: float = 0.5, + focal_loss_gamma: float = 2.0, + focal_loss_delta: float = 0.7, + tversky_loss_gamma: float = 0.75, + tversky_loss_delta: float = 0.7, reduction: LossReduction | str = LossReduction.MEAN, ): """ Args: + include_background: whether to include loss computation for the background class. Defaults to True. to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False. - num_classes : number of classes, it only supports 2 now. Defaults to 2. - delta : weight of the background. Defaults to 0.7. - gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75. - epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7. - weight : weight for each loss function, if it's none it's 0.5. Defaults to None. + sigmoid: if True, apply a sigmoid activation to the input y_pred. + softmax: if True, apply a softmax activation to the input y_pred. + lambda_focal: the weight for AsymmetricFocalLoss (Cross-Entropy based). + The weight for AsymmetricFocalTverskyLoss will be (1 - lambda_focal). Defaults to 0.5. + focal_loss_gamma: gamma parameter for the AsymmetricFocalLoss component. Defaults to 2.0. + focal_loss_delta: delta parameter for the AsymmetricFocalLoss component. Defaults to 0.7. + tversky_loss_gamma: gamma parameter for the AsymmetricFocalTverskyLoss component. Defaults to 0.75. + tversky_loss_delta: delta parameter for the AsymmetricFocalTverskyLoss component. Defaults to 0.7. + reduction: specifies the reduction to apply to the output: "none", "mean", "sum". Example: >>> import torch >>> from monai.losses import AsymmetricUnifiedFocalLoss - >>> pred = torch.ones((1,1,32,32), dtype=torch.float32) - >>> grnd = torch.ones((1,1,32,32), dtype=torch.int64) - >>> fl = AsymmetricUnifiedFocalLoss(to_onehot_y=True) + >>> pred = torch.randn((1, 2, 32, 32), dtype=torch.float32) + >>> grnd = torch.randint(0, 2, (1, 1, 32, 32), dtype=torch.int64) + >>> fl = AsymmetricUnifiedFocalLoss(softmax=True, to_onehot_y=True) >>> fl(pred, grnd) """ super().__init__(reduction=LossReduction(reduction).value) + self.include_background = include_background self.to_onehot_y = to_onehot_y - self.num_classes = num_classes - self.gamma = gamma - self.delta = delta - self.weight: float = weight - self.asy_focal_loss = AsymmetricFocalLoss(gamma=self.gamma, delta=self.delta) - self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(gamma=self.gamma, delta=self.delta) + self.sigmoid = sigmoid + self.softmax = softmax + self.lambda_focal = lambda_focal + + if sigmoid and softmax: + raise ValueError("Both sigmoid and softmax cannot be True.") + + self.asy_focal_loss = AsymmetricFocalLoss( + include_background=self.include_background, + gamma=focal_loss_gamma, + delta=focal_loss_delta, + reduction=self.reduction, + ) + self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss( + include_background=self.include_background, + gamma=tversky_loss_gamma, + delta=tversky_loss_delta, + reduction=self.reduction, + ) - # TODO: Implement this function to support multiple classes segmentation def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: """ Args: - y_pred : the shape should be BNH[WD], where N is the number of classes. - It only supports binary segmentation. - The input should be the original logits since it will be transformed by - a sigmoid in the forward function. - y_true : the shape should be BNH[WD], where N is the number of classes. - It only supports binary segmentation. - - Raises: - ValueError: When input and target are different shape - ValueError: When len(y_pred.shape) != 4 and len(y_pred.shape) != 5 - ValueError: When num_classes - ValueError: When the number of classes entered does not match the expected number + y_pred : the shape should be BNH[WD]. + y_true : the shape should be BNH[WD] or B1H[WD]. """ - if y_pred.shape != y_true.shape: - raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})") - - if len(y_pred.shape) != 4 and len(y_pred.shape) != 5: - raise ValueError(f"input shape must be 4 or 5, but got {y_pred.shape}") - - if y_pred.shape[1] == 1: - y_pred = one_hot(y_pred, num_classes=self.num_classes) - y_true = one_hot(y_true, num_classes=self.num_classes) + n_pred_ch = y_pred.shape[1] - if torch.max(y_true) != self.num_classes - 1: - raise ValueError(f"Please make sure the number of classes is {self.num_classes-1}") + y_pred_act = y_pred + if self.sigmoid: + y_pred_act = torch.sigmoid(y_pred) + elif self.softmax: + if n_pred_ch == 1: + warnings.warn("single channel prediction, softmax=True ignored.") + else: + y_pred_act = torch.softmax(y_pred, dim=1) - n_pred_ch = y_pred.shape[1] if self.to_onehot_y: - if n_pred_ch == 1: + if n_pred_ch == 1 and not self.sigmoid: warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") - else: + elif n_pred_ch > 1 or self.sigmoid: + # Ensure y_true is (B, 1, H, W, [D]) for one-hot conversion + if y_true.shape[1] != 1: + y_true = y_true.unsqueeze(1) y_true = one_hot(y_true, num_classes=n_pred_ch) + + # Ensure y_true has the same shape as y_pred_act + if y_true.shape != y_pred_act.shape: + # This can happen if y_true is (B, H, W) and y_pred is (B, 1, H, W) after sigmoid + if y_true.shape[1] != y_pred_act.shape[1] and y_true.ndim == y_pred_act.ndim - 1: + y_true = y_true.unsqueeze(1) # Add channel dim + + if y_true.shape != y_pred_act.shape: + raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred_act.shape}) " \ + f"after activations/one-hot") - asy_focal_loss = self.asy_focal_loss(y_pred, y_true) - asy_focal_tversky_loss = self.asy_focal_tversky_loss(y_pred, y_true) - loss: torch.Tensor = self.weight * asy_focal_loss + (1 - self.weight) * asy_focal_tversky_loss + f_loss = self.asy_focal_loss(y_pred_act, y_true) + t_loss = self.asy_focal_tversky_loss(y_pred_act, y_true) - if self.reduction == LossReduction.SUM.value: - return torch.sum(loss) # sum over the batch and channel dims - if self.reduction == LossReduction.NONE.value: - return loss # returns [N, num_classes] losses - if self.reduction == LossReduction.MEAN.value: - return torch.mean(loss) - raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') + loss: torch.Tensor = self.lambda_focal * f_loss + (1 - self.lambda_focal) * t_loss + + return loss \ No newline at end of file From b4e0fcc342dbb32c67768583dbdb2cd8d7c3c787 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 30 Oct 2025 02:49:46 +0000 Subject: [PATCH 2/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: ytl0623 --- monai/losses/unified_focal_loss.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/monai/losses/unified_focal_loss.py b/monai/losses/unified_focal_loss.py index dcfb223b4f..70bf93625a 100644 --- a/monai/losses/unified_focal_loss.py +++ b/monai/losses/unified_focal_loss.py @@ -14,7 +14,6 @@ import warnings import torch -import torch.nn.functional as F from torch.nn.modules.loss import _Loss from monai.networks import one_hot @@ -68,8 +67,6 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: y_pred = y_pred[:, 1:] y_true = y_true[:, 1:] - # Clip predictions - y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon) axis = list(range(2, len(y_pred.shape))) # Calculate true positives (tp), false negatives (fn) and false positives (fp) @@ -169,7 +166,7 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: back_ce = (1.0 - self.delta) * torch.pow(1.0 - y_pred[:, 0], self.gamma) * cross_entropy[:, 0] # (B, C-1, H, W) fore_ce = self.delta * cross_entropy[:, 1:] - + loss = torch.cat([back_ce.unsqueeze(1), fore_ce], dim=1) # (B, C, H, W) # Apply reduction @@ -276,21 +273,22 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: if y_true.shape[1] != 1: y_true = y_true.unsqueeze(1) y_true = one_hot(y_true, num_classes=n_pred_ch) - + # Ensure y_true has the same shape as y_pred_act if y_true.shape != y_pred_act.shape: - # This can happen if y_true is (B, H, W) and y_pred is (B, 1, H, W) after sigmoid + # This can happen if y_true is (B, H, W) and y_pred is (B, 1, H, W) after sigmoid if y_true.shape[1] != y_pred_act.shape[1] and y_true.ndim == y_pred_act.ndim - 1: - y_true = y_true.unsqueeze(1) # Add channel dim - - if y_true.shape != y_pred_act.shape: - raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred_act.shape}) " \ - f"after activations/one-hot") + y_true = y_true.unsqueeze(1) # Add channel dim + if y_true.shape != y_pred_act.shape: + raise ValueError( + f"ground truth has different shape ({y_true.shape}) from input ({y_pred_act.shape}) " + f"after activations/one-hot" + ) f_loss = self.asy_focal_loss(y_pred_act, y_true) t_loss = self.asy_focal_tversky_loss(y_pred_act, y_true) loss: torch.Tensor = self.lambda_focal * f_loss + (1 - self.lambda_focal) * t_loss - return loss \ No newline at end of file + return loss From ddf6898eae3be2e8007686c02cafd2b6441902cc Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Wed, 5 Nov 2025 16:39:30 +0800 Subject: [PATCH 3/3] timestep scheduling with np.linspace Signed-off-by: ytl0623 --- monai/networks/schedulers/ddim.py | 2 +- monai/networks/schedulers/ddpm.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/networks/schedulers/ddim.py b/monai/networks/schedulers/ddim.py index 50a680336d..993b826727 100644 --- a/monai/networks/schedulers/ddim.py +++ b/monai/networks/schedulers/ddim.py @@ -127,7 +127,7 @@ def set_timesteps(self, num_inference_steps: int, device: str | torch.device | N # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) + timesteps = np.linspace(self.num_train_timesteps - 1, 0, num_inference_steps).round().astype(np.int64) self.timesteps = torch.from_numpy(timesteps).to(device) self.timesteps += self.steps_offset diff --git a/monai/networks/schedulers/ddpm.py b/monai/networks/schedulers/ddpm.py index e2b7ab55f5..7ef27108bc 100644 --- a/monai/networks/schedulers/ddpm.py +++ b/monai/networks/schedulers/ddpm.py @@ -125,7 +125,7 @@ def set_timesteps(self, num_inference_steps: int, device: str | torch.device | N step_ratio = self.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].astype(np.int64) + timesteps = np.linspace(self.num_train_timesteps - 1, 0, num_inference_steps).round().astype(np.int64) self.timesteps = torch.from_numpy(timesteps).to(device) def _get_mean(self, timestep: int, x_0: torch.Tensor, x_t: torch.Tensor) -> torch.Tensor: