From ad8344446128742fa05153b2db75e2ad09620426 Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Thu, 30 Oct 2025 10:24:24 +0800 Subject: [PATCH 1/7] Generalize AsymmetricUnifiedFocalLoss for multi-class and align interface Signed-off-by: ytl0623 --- monai/losses/unified_focal_loss.py | 276 +++++++++++++++++------------ 1 file changed, 166 insertions(+), 110 deletions(-) diff --git a/monai/losses/unified_focal_loss.py b/monai/losses/unified_focal_loss.py index 8484eb67ed..dcfb223b4f 100644 --- a/monai/losses/unified_focal_loss.py +++ b/monai/losses/unified_focal_loss.py @@ -14,6 +14,7 @@ import warnings import torch +import torch.nn.functional as F from torch.nn.modules.loss import _Loss from monai.networks import one_hot @@ -23,18 +24,17 @@ class AsymmetricFocalTverskyLoss(_Loss): """ AsymmetricFocalTverskyLoss is a variant of FocalTverskyLoss, which attentions to the foreground class. - - Actually, it's only supported for binary image segmentation now. + It treats the background class (index 0) differently from all foreground classes (indices 1...N). Reimplementation of the Asymmetric Focal Tversky Loss described in: - "Unified Focal Loss: Generalising Dice and Cross Entropy-based Losses to Handle Class Imbalanced Medical Image Segmentation", - Michael Yeung, Computerized Medical Imaging and Graphics + Michael Yeung, Computerized Medical Imaging and Graphics """ def __init__( self, - to_onehot_y: bool = False, + include_background: bool = True, delta: float = 0.7, gamma: float = 0.75, epsilon: float = 1e-7, @@ -42,13 +42,14 @@ def __init__( ) -> None: """ Args: - to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False. - delta : weight of the background. Defaults to 0.7. - gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75. - epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7. + include_background: whether to include loss computation for the background class. Defaults to True. + delta : weight of the background. Defaults to 0.7. (Used to weigh FNs and FPs in Tversky index) + gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75. + epsilon : it defines a very small number each time. similarly smooth value. Defaults to 1e-7. + reduction: specifies the reduction to apply to the output: "none", "mean", "sum". """ super().__init__(reduction=LossReduction(reduction).value) - self.to_onehot_y = to_onehot_y + self.include_background = include_background self.delta = delta self.gamma = gamma self.epsilon = epsilon @@ -56,16 +57,18 @@ def __init__( def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: n_pred_ch = y_pred.shape[1] - if self.to_onehot_y: - if n_pred_ch == 1: - warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") - else: - y_true = one_hot(y_true, num_classes=n_pred_ch) - if y_true.shape != y_pred.shape: raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})") - # clip the prediction to avoid NaN + # Exclude background if needed + if not self.include_background: + if n_pred_ch == 1: + warnings.warn("single channel prediction, `include_background=False` ignored.") + else: + y_pred = y_pred[:, 1:] + y_true = y_true[:, 1:] + + # Clip predictions y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon) axis = list(range(2, len(y_pred.shape))) @@ -74,45 +77,63 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: fn = torch.sum(y_true * (1 - y_pred), dim=axis) fp = torch.sum((1 - y_true) * y_pred, dim=axis) dice_class = (tp + self.epsilon) / (tp + self.delta * fn + (1 - self.delta) * fp + self.epsilon) - - # Calculate losses separately for each class, enhancing both classes - back_dice = 1 - dice_class[:, 0] - fore_dice = (1 - dice_class[:, 1]) * torch.pow(1 - dice_class[:, 1], -self.gamma) - - # Average class scores - loss = torch.mean(torch.stack([back_dice, fore_dice], dim=-1)) - return loss + # dice_class shape is (B, C) + + n_classes = dice_class.shape[1] + + if not self.include_background: + # All classes are foreground, apply foreground logic + loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, C) + elif n_classes == 1: + # Single class, must be foreground (BG was excluded or not provided) + loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, 1) + else: + # Asymmetric logic: class 0 is BG, others are FG + back_dice_loss = (1.0 - dice_class[:, 0]).unsqueeze(1) # (B, 1) + fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 - self.gamma) # (B, C-1) + loss = torch.cat([back_dice_loss, fore_dice_loss], dim=1) # (B, C) + + # Apply reduction + if self.reduction == LossReduction.MEAN.value: + return torch.mean(loss) # mean over batch and classes + if self.reduction == LossReduction.SUM.value: + return torch.sum(loss) # sum over batch and classes + if self.reduction == LossReduction.NONE.value: + return loss # returns (B, C) losses + raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') class AsymmetricFocalLoss(_Loss): """ - AsymmetricFocalLoss is a variant of FocalTverskyLoss, which attentions to the foreground class. - - Actually, it's only supported for binary image segmentation now. + AsymmetricFocalLoss is a variant of FocalLoss, which attentions to the foreground class. + It treats the background class (index 0) differently from all foreground classes (indices 1...N). + Background class (0): applies gamma exponent to (1-p) + Foreground classes (1..N): no gamma exponent Reimplementation of the Asymmetric Focal Loss described in: - "Unified Focal Loss: Generalising Dice and Cross Entropy-based Losses to Handle Class Imbalanced Medical Image Segmentation", - Michael Yeung, Computerized Medical Imaging and Graphics + Michael Yeung, Computerized Medical Imaging and Graphics """ def __init__( self, - to_onehot_y: bool = False, + include_background: bool = True, delta: float = 0.7, - gamma: float = 2, + gamma: float = 2.0, epsilon: float = 1e-7, reduction: LossReduction | str = LossReduction.MEAN, ): """ Args: - to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False. - delta : weight of the background. Defaults to 0.7. - gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75. - epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7. + include_background: whether to include loss computation for the background class. Defaults to True. + delta : weight of the foreground. Defaults to 0.7. (1-delta is weight of background) + gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 2.0. + epsilon : it defines a very small number each time. similarly smooth value. Defaults to 1e-7. + reduction: specifies the reduction to apply to the output: "none", "mean", "sum". """ super().__init__(reduction=LossReduction(reduction).value) - self.to_onehot_y = to_onehot_y + self.include_background = include_background self.delta = delta self.gamma = gamma self.epsilon = epsilon @@ -120,121 +141,156 @@ def __init__( def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: n_pred_ch = y_pred.shape[1] - if self.to_onehot_y: - if n_pred_ch == 1: - warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") - else: - y_true = one_hot(y_true, num_classes=n_pred_ch) - if y_true.shape != y_pred.shape: raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})") - y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon) - cross_entropy = -y_true * torch.log(y_pred) - - back_ce = torch.pow(1 - y_pred[:, 0], self.gamma) * cross_entropy[:, 0] - back_ce = (1 - self.delta) * back_ce - - fore_ce = cross_entropy[:, 1] - fore_ce = self.delta * fore_ce + # Exclude background if needed + if not self.include_background: + if n_pred_ch == 1: + warnings.warn("single channel prediction, `include_background=False` ignored.") + else: + y_pred = y_pred[:, 1:] + y_true = y_true[:, 1:] - loss = torch.mean(torch.sum(torch.stack([back_ce, fore_ce], dim=1), dim=1)) - return loss + y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon) + cross_entropy = -y_true * torch.log(y_pred) # Shape (B, C, H, W, [D]) + + n_classes = y_pred.shape[1] + + if not self.include_background: + # All classes are foreground, apply foreground logic + loss = self.delta * cross_entropy # (B, C, H, W) + elif n_classes == 1: + # Single class, must be foreground + loss = self.delta * cross_entropy # (B, 1, H, W) + else: + # Asymmetric logic: class 0 is BG, others are FG + # (B, H, W) + back_ce = (1.0 - self.delta) * torch.pow(1.0 - y_pred[:, 0], self.gamma) * cross_entropy[:, 0] + # (B, C-1, H, W) + fore_ce = self.delta * cross_entropy[:, 1:] + + loss = torch.cat([back_ce.unsqueeze(1), fore_ce], dim=1) # (B, C, H, W) + + # Apply reduction + if self.reduction == LossReduction.MEAN.value: + return torch.mean(loss) # mean over batch, class, and spatial + if self.reduction == LossReduction.SUM.value: + return torch.sum(loss) + if self.reduction == LossReduction.NONE.value: + return loss # returns (B, C, H, W) + raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') class AsymmetricUnifiedFocalLoss(_Loss): """ - AsymmetricUnifiedFocalLoss is a variant of Focal Loss. - - Actually, it's only supported for binary image segmentation now + AsymmetricUnifiedFocalLoss is a variant of Focal Loss, combining AsymmetricFocalLoss + and AsymmetricFocalTverskyLoss. Reimplementation of the Asymmetric Unified Focal Tversky Loss described in: - "Unified Focal Loss: Generalising Dice and Cross Entropy-based Losses to Handle Class Imbalanced Medical Image Segmentation", - Michael Yeung, Computerized Medical Imaging and Graphics + Michael Yeung, Computerized Medical Imaging and Graphics """ def __init__( self, + include_background: bool = True, to_onehot_y: bool = False, - num_classes: int = 2, - weight: float = 0.5, - gamma: float = 0.5, - delta: float = 0.7, + sigmoid: bool = False, + softmax: bool = False, + lambda_focal: float = 0.5, + focal_loss_gamma: float = 2.0, + focal_loss_delta: float = 0.7, + tversky_loss_gamma: float = 0.75, + tversky_loss_delta: float = 0.7, reduction: LossReduction | str = LossReduction.MEAN, ): """ Args: + include_background: whether to include loss computation for the background class. Defaults to True. to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False. - num_classes : number of classes, it only supports 2 now. Defaults to 2. - delta : weight of the background. Defaults to 0.7. - gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75. - epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7. - weight : weight for each loss function, if it's none it's 0.5. Defaults to None. + sigmoid: if True, apply a sigmoid activation to the input y_pred. + softmax: if True, apply a softmax activation to the input y_pred. + lambda_focal: the weight for AsymmetricFocalLoss (Cross-Entropy based). + The weight for AsymmetricFocalTverskyLoss will be (1 - lambda_focal). Defaults to 0.5. + focal_loss_gamma: gamma parameter for the AsymmetricFocalLoss component. Defaults to 2.0. + focal_loss_delta: delta parameter for the AsymmetricFocalLoss component. Defaults to 0.7. + tversky_loss_gamma: gamma parameter for the AsymmetricFocalTverskyLoss component. Defaults to 0.75. + tversky_loss_delta: delta parameter for the AsymmetricFocalTverskyLoss component. Defaults to 0.7. + reduction: specifies the reduction to apply to the output: "none", "mean", "sum". Example: >>> import torch >>> from monai.losses import AsymmetricUnifiedFocalLoss - >>> pred = torch.ones((1,1,32,32), dtype=torch.float32) - >>> grnd = torch.ones((1,1,32,32), dtype=torch.int64) - >>> fl = AsymmetricUnifiedFocalLoss(to_onehot_y=True) + >>> pred = torch.randn((1, 2, 32, 32), dtype=torch.float32) + >>> grnd = torch.randint(0, 2, (1, 1, 32, 32), dtype=torch.int64) + >>> fl = AsymmetricUnifiedFocalLoss(softmax=True, to_onehot_y=True) >>> fl(pred, grnd) """ super().__init__(reduction=LossReduction(reduction).value) + self.include_background = include_background self.to_onehot_y = to_onehot_y - self.num_classes = num_classes - self.gamma = gamma - self.delta = delta - self.weight: float = weight - self.asy_focal_loss = AsymmetricFocalLoss(gamma=self.gamma, delta=self.delta) - self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(gamma=self.gamma, delta=self.delta) + self.sigmoid = sigmoid + self.softmax = softmax + self.lambda_focal = lambda_focal + + if sigmoid and softmax: + raise ValueError("Both sigmoid and softmax cannot be True.") + + self.asy_focal_loss = AsymmetricFocalLoss( + include_background=self.include_background, + gamma=focal_loss_gamma, + delta=focal_loss_delta, + reduction=self.reduction, + ) + self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss( + include_background=self.include_background, + gamma=tversky_loss_gamma, + delta=tversky_loss_delta, + reduction=self.reduction, + ) - # TODO: Implement this function to support multiple classes segmentation def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: """ Args: - y_pred : the shape should be BNH[WD], where N is the number of classes. - It only supports binary segmentation. - The input should be the original logits since it will be transformed by - a sigmoid in the forward function. - y_true : the shape should be BNH[WD], where N is the number of classes. - It only supports binary segmentation. - - Raises: - ValueError: When input and target are different shape - ValueError: When len(y_pred.shape) != 4 and len(y_pred.shape) != 5 - ValueError: When num_classes - ValueError: When the number of classes entered does not match the expected number + y_pred : the shape should be BNH[WD]. + y_true : the shape should be BNH[WD] or B1H[WD]. """ - if y_pred.shape != y_true.shape: - raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})") - - if len(y_pred.shape) != 4 and len(y_pred.shape) != 5: - raise ValueError(f"input shape must be 4 or 5, but got {y_pred.shape}") - - if y_pred.shape[1] == 1: - y_pred = one_hot(y_pred, num_classes=self.num_classes) - y_true = one_hot(y_true, num_classes=self.num_classes) + n_pred_ch = y_pred.shape[1] - if torch.max(y_true) != self.num_classes - 1: - raise ValueError(f"Please make sure the number of classes is {self.num_classes-1}") + y_pred_act = y_pred + if self.sigmoid: + y_pred_act = torch.sigmoid(y_pred) + elif self.softmax: + if n_pred_ch == 1: + warnings.warn("single channel prediction, softmax=True ignored.") + else: + y_pred_act = torch.softmax(y_pred, dim=1) - n_pred_ch = y_pred.shape[1] if self.to_onehot_y: - if n_pred_ch == 1: + if n_pred_ch == 1 and not self.sigmoid: warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") - else: + elif n_pred_ch > 1 or self.sigmoid: + # Ensure y_true is (B, 1, H, W, [D]) for one-hot conversion + if y_true.shape[1] != 1: + y_true = y_true.unsqueeze(1) y_true = one_hot(y_true, num_classes=n_pred_ch) + + # Ensure y_true has the same shape as y_pred_act + if y_true.shape != y_pred_act.shape: + # This can happen if y_true is (B, H, W) and y_pred is (B, 1, H, W) after sigmoid + if y_true.shape[1] != y_pred_act.shape[1] and y_true.ndim == y_pred_act.ndim - 1: + y_true = y_true.unsqueeze(1) # Add channel dim + + if y_true.shape != y_pred_act.shape: + raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred_act.shape}) " \ + f"after activations/one-hot") - asy_focal_loss = self.asy_focal_loss(y_pred, y_true) - asy_focal_tversky_loss = self.asy_focal_tversky_loss(y_pred, y_true) - loss: torch.Tensor = self.weight * asy_focal_loss + (1 - self.weight) * asy_focal_tversky_loss + f_loss = self.asy_focal_loss(y_pred_act, y_true) + t_loss = self.asy_focal_tversky_loss(y_pred_act, y_true) - if self.reduction == LossReduction.SUM.value: - return torch.sum(loss) # sum over the batch and channel dims - if self.reduction == LossReduction.NONE.value: - return loss # returns [N, num_classes] losses - if self.reduction == LossReduction.MEAN.value: - return torch.mean(loss) - raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') + loss: torch.Tensor = self.lambda_focal * f_loss + (1 - self.lambda_focal) * t_loss + + return loss \ No newline at end of file From 0ba7863e91897d9d257789d718ebaf71627886cc Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 30 Oct 2025 02:49:46 +0000 Subject: [PATCH 2/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: ytl0623 --- monai/losses/unified_focal_loss.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/monai/losses/unified_focal_loss.py b/monai/losses/unified_focal_loss.py index dcfb223b4f..70bf93625a 100644 --- a/monai/losses/unified_focal_loss.py +++ b/monai/losses/unified_focal_loss.py @@ -14,7 +14,6 @@ import warnings import torch -import torch.nn.functional as F from torch.nn.modules.loss import _Loss from monai.networks import one_hot @@ -68,8 +67,6 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: y_pred = y_pred[:, 1:] y_true = y_true[:, 1:] - # Clip predictions - y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon) axis = list(range(2, len(y_pred.shape))) # Calculate true positives (tp), false negatives (fn) and false positives (fp) @@ -169,7 +166,7 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: back_ce = (1.0 - self.delta) * torch.pow(1.0 - y_pred[:, 0], self.gamma) * cross_entropy[:, 0] # (B, C-1, H, W) fore_ce = self.delta * cross_entropy[:, 1:] - + loss = torch.cat([back_ce.unsqueeze(1), fore_ce], dim=1) # (B, C, H, W) # Apply reduction @@ -276,21 +273,22 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: if y_true.shape[1] != 1: y_true = y_true.unsqueeze(1) y_true = one_hot(y_true, num_classes=n_pred_ch) - + # Ensure y_true has the same shape as y_pred_act if y_true.shape != y_pred_act.shape: - # This can happen if y_true is (B, H, W) and y_pred is (B, 1, H, W) after sigmoid + # This can happen if y_true is (B, H, W) and y_pred is (B, 1, H, W) after sigmoid if y_true.shape[1] != y_pred_act.shape[1] and y_true.ndim == y_pred_act.ndim - 1: - y_true = y_true.unsqueeze(1) # Add channel dim - - if y_true.shape != y_pred_act.shape: - raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred_act.shape}) " \ - f"after activations/one-hot") + y_true = y_true.unsqueeze(1) # Add channel dim + if y_true.shape != y_pred_act.shape: + raise ValueError( + f"ground truth has different shape ({y_true.shape}) from input ({y_pred_act.shape}) " + f"after activations/one-hot" + ) f_loss = self.asy_focal_loss(y_pred_act, y_true) t_loss = self.asy_focal_tversky_loss(y_pred_act, y_true) loss: torch.Tensor = self.lambda_focal * f_loss + (1 - self.lambda_focal) * t_loss - return loss \ No newline at end of file + return loss From 30e8263489228c87bd3ecdae3fcb0e88b6ad9923 Mon Sep 17 00:00:00 2001 From: NabJa <32510324+NabJa@users.noreply.github.com> Date: Fri, 31 Oct 2025 12:56:25 +0100 Subject: [PATCH 3/7] 8564 fourier positional encoding (#8570) Fixes #8564 . ### Description Add Fourier feature positional encodings to `PatchEmbeddingBlock`. It has been shown, that Fourier feature positional encodings are better suited for Anistropic images and videos: https://arxiv.org/abs/2509.02488 ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [x] New tests added to cover the changes. - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: NabJa Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Signed-off-by: ytl0623 --- monai/networks/blocks/patchembedding.py | 22 ++++++-- monai/networks/blocks/pos_embed_utils.py | 55 +++++++++++++++++++- tests/networks/blocks/test_patchembedding.py | 39 ++++++++++++++ 3 files changed, 112 insertions(+), 4 deletions(-) diff --git a/monai/networks/blocks/patchembedding.py b/monai/networks/blocks/patchembedding.py index fca566591a..4e8a6a0463 100644 --- a/monai/networks/blocks/patchembedding.py +++ b/monai/networks/blocks/patchembedding.py @@ -12,6 +12,7 @@ from __future__ import annotations from collections.abc import Sequence +from typing import Optional import numpy as np import torch @@ -19,14 +20,14 @@ import torch.nn.functional as F from torch.nn import LayerNorm -from monai.networks.blocks.pos_embed_utils import build_sincos_position_embedding +from monai.networks.blocks.pos_embed_utils import build_fourier_position_embedding, build_sincos_position_embedding from monai.networks.layers import Conv, trunc_normal_ from monai.utils import ensure_tuple_rep, optional_import from monai.utils.module import look_up_option Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange") SUPPORTED_PATCH_EMBEDDING_TYPES = {"conv", "perceptron"} -SUPPORTED_POS_EMBEDDING_TYPES = {"none", "learnable", "sincos"} +SUPPORTED_POS_EMBEDDING_TYPES = {"none", "learnable", "sincos", "fourier"} class PatchEmbeddingBlock(nn.Module): @@ -53,6 +54,7 @@ def __init__( pos_embed_type: str = "learnable", dropout_rate: float = 0.0, spatial_dims: int = 3, + pos_embed_kwargs: Optional[dict] = None, ) -> None: """ Args: @@ -65,6 +67,8 @@ def __init__( pos_embed_type: position embedding layer type. dropout_rate: fraction of the input units to drop. spatial_dims: number of spatial dimensions. + pos_embed_kwargs: additional arguments for position embedding. For `sincos`, it can contain + `temperature` and for fourier it can contain `scales`. """ super().__init__() @@ -105,6 +109,8 @@ def __init__( self.position_embeddings = nn.Parameter(torch.zeros(1, self.n_patches, hidden_size)) self.dropout = nn.Dropout(dropout_rate) + pos_embed_kwargs = {} if pos_embed_kwargs is None else pos_embed_kwargs + if self.pos_embed_type == "none": pass elif self.pos_embed_type == "learnable": @@ -114,7 +120,17 @@ def __init__( for in_size, pa_size in zip(img_size, patch_size): grid_size.append(in_size // pa_size) - self.position_embeddings = build_sincos_position_embedding(grid_size, hidden_size, spatial_dims) + self.position_embeddings = build_sincos_position_embedding( + grid_size, hidden_size, spatial_dims, **pos_embed_kwargs + ) + elif self.pos_embed_type == "fourier": + grid_size = [] + for in_size, pa_size in zip(img_size, patch_size): + grid_size.append(in_size // pa_size) + + self.position_embeddings = build_fourier_position_embedding( + grid_size, hidden_size, spatial_dims, **pos_embed_kwargs + ) else: raise ValueError(f"pos_embed_type {self.pos_embed_type} not supported.") diff --git a/monai/networks/blocks/pos_embed_utils.py b/monai/networks/blocks/pos_embed_utils.py index a9c5176bc2..266be5e28c 100644 --- a/monai/networks/blocks/pos_embed_utils.py +++ b/monai/networks/blocks/pos_embed_utils.py @@ -18,7 +18,7 @@ import torch import torch.nn as nn -__all__ = ["build_sincos_position_embedding"] +__all__ = ["build_fourier_position_embedding", "build_sincos_position_embedding"] # From PyTorch internals @@ -32,6 +32,59 @@ def parse(x): return parse +def build_fourier_position_embedding( + grid_size: Union[int, List[int]], embed_dim: int, spatial_dims: int = 3, scales: Union[float, List[float]] = 1.0 +) -> torch.nn.Parameter: + """ + Builds a (Anistropic) Fourier feature position embedding based on the given grid size, embed dimension, + spatial dimensions, and scales. The scales control the variance of the Fourier features, higher values make distant + points more distinguishable. + Position embedding is made anistropic by allowing setting different scales for each spatial dimension. + Reference: https://arxiv.org/abs/2509.02488 + + Args: + grid_size (int | List[int]): The size of the grid in each spatial dimension. + embed_dim (int): The dimension of the embedding. + spatial_dims (int): The number of spatial dimensions (2 for 2D, 3 for 3D). + scales (float | List[float]): The scale for every spatial dimension. If a single float is provided, + the same scale is used for all dimensions. + + Returns: + pos_embed (nn.Parameter): The Fourier feature position embedding as a fixed parameter. + """ + + to_tuple = _ntuple(spatial_dims) + grid_size_t = to_tuple(grid_size) + if len(grid_size_t) != spatial_dims: + raise ValueError(f"Length of grid_size ({len(grid_size_t)}) must be the same as spatial_dims.") + + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be even for Fourier position embedding") + + # Ensure scales is a tensor of shape (spatial_dims,) + if isinstance(scales, float): + scales_tensor = torch.full((spatial_dims,), scales, dtype=torch.float) + elif isinstance(scales, (list, tuple)): + if len(scales) != spatial_dims: + raise ValueError(f"Length of scales {len(scales)} does not match spatial_dims {spatial_dims}") + scales_tensor = torch.tensor(scales, dtype=torch.float) + else: + raise TypeError(f"scales must be float or list of floats, got {type(scales)}") + + gaussians = torch.randn(embed_dim // 2, spatial_dims, dtype=torch.float32) * scales_tensor + + position_indices = [torch.linspace(0, 1, x, dtype=torch.float32) for x in grid_size_t] + positions = torch.stack(torch.meshgrid(*position_indices, indexing="ij"), dim=-1) + positions = positions.flatten(end_dim=-2) + + x_proj = (2.0 * torch.pi * positions) @ gaussians.T + + pos_emb = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) + pos_emb = nn.Parameter(pos_emb[None, :, :], requires_grad=False) + + return pos_emb + + def build_sincos_position_embedding( grid_size: Union[int, List[int]], embed_dim: int, spatial_dims: int = 3, temperature: float = 10000.0 ) -> torch.nn.Parameter: diff --git a/tests/networks/blocks/test_patchembedding.py b/tests/networks/blocks/test_patchembedding.py index 2945482649..95eba14e6f 100644 --- a/tests/networks/blocks/test_patchembedding.py +++ b/tests/networks/blocks/test_patchembedding.py @@ -87,6 +87,19 @@ def test_sincos_pos_embed(self): self.assertEqual(net.position_embeddings.requires_grad, False) + def test_fourier_pos_embed(self): + net = PatchEmbeddingBlock( + in_channels=1, + img_size=(32, 32, 32), + patch_size=(8, 8, 8), + hidden_size=96, + num_heads=8, + pos_embed_type="fourier", + dropout_rate=0.5, + ) + + self.assertEqual(net.position_embeddings.requires_grad, False) + def test_learnable_pos_embed(self): net = PatchEmbeddingBlock( in_channels=1, @@ -101,6 +114,32 @@ def test_learnable_pos_embed(self): self.assertEqual(net.position_embeddings.requires_grad, True) def test_ill_arg(self): + with self.assertRaises(ValueError): + PatchEmbeddingBlock( + in_channels=1, + img_size=(128, 128, 128), + patch_size=(16, 16, 16), + hidden_size=128, + num_heads=12, + proj_type="conv", + dropout_rate=0.1, + pos_embed_type="fourier", + pos_embed_kwargs=dict(scales=[1.0, 1.0]), + ) + + with self.assertRaises(ValueError): + PatchEmbeddingBlock( + in_channels=1, + img_size=(128, 128), + patch_size=(16, 16), + hidden_size=128, + num_heads=12, + proj_type="conv", + dropout_rate=0.1, + pos_embed_type="fourier", + pos_embed_kwargs=dict(scales=[1.0, 1.0, 1.0]), + ) + with self.assertRaises(ValueError): PatchEmbeddingBlock( in_channels=1, From d4938151b9833f44cfea559bfc2f2b02f906714e Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Fri, 31 Oct 2025 21:48:22 +0800 Subject: [PATCH 4/7] Refactor parameters for UnifiedFocalLoss class Signed-off-by: ytl0623 --- monai/losses/unified_focal_loss.py | 34 +++++++++++++++--------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/monai/losses/unified_focal_loss.py b/monai/losses/unified_focal_loss.py index 70bf93625a..936d789ad4 100644 --- a/monai/losses/unified_focal_loss.py +++ b/monai/losses/unified_focal_loss.py @@ -192,29 +192,29 @@ class AsymmetricUnifiedFocalLoss(_Loss): def __init__( self, - include_background: bool = True, to_onehot_y: bool = False, - sigmoid: bool = False, - softmax: bool = False, + use_sigmoid: bool = False, + use_softmax: bool = False, lambda_focal: float = 0.5, focal_loss_gamma: float = 2.0, focal_loss_delta: float = 0.7, tversky_loss_gamma: float = 0.75, tversky_loss_delta: float = 0.7, + include_background: bool = True, reduction: LossReduction | str = LossReduction.MEAN, ): """ Args: - include_background: whether to include loss computation for the background class. Defaults to True. to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False. - sigmoid: if True, apply a sigmoid activation to the input y_pred. - softmax: if True, apply a softmax activation to the input y_pred. + use_sigmoid: if True, apply a sigmoid activation to the input y_pred. + use_softmax: if True, apply a softmax activation to the input y_pred. lambda_focal: the weight for AsymmetricFocalLoss (Cross-Entropy based). The weight for AsymmetricFocalTverskyLoss will be (1 - lambda_focal). Defaults to 0.5. focal_loss_gamma: gamma parameter for the AsymmetricFocalLoss component. Defaults to 2.0. focal_loss_delta: delta parameter for the AsymmetricFocalLoss component. Defaults to 0.7. tversky_loss_gamma: gamma parameter for the AsymmetricFocalTverskyLoss component. Defaults to 0.75. tversky_loss_delta: delta parameter for the AsymmetricFocalTverskyLoss component. Defaults to 0.7. + include_background: whether to include loss computation for the background class. Defaults to True. reduction: specifies the reduction to apply to the output: "none", "mean", "sum". Example: @@ -222,18 +222,18 @@ def __init__( >>> from monai.losses import AsymmetricUnifiedFocalLoss >>> pred = torch.randn((1, 2, 32, 32), dtype=torch.float32) >>> grnd = torch.randint(0, 2, (1, 1, 32, 32), dtype=torch.int64) - >>> fl = AsymmetricUnifiedFocalLoss(softmax=True, to_onehot_y=True) + >>> fl = AsymmetricUnifiedFocalLoss(use_softmax=True, to_onehot_y=True) >>> fl(pred, grnd) """ super().__init__(reduction=LossReduction(reduction).value) - self.include_background = include_background self.to_onehot_y = to_onehot_y - self.sigmoid = sigmoid - self.softmax = softmax + self.use_sigmoid = use_sigmoid + self.use_softmax = use_softmax self.lambda_focal = lambda_focal + self.include_background = include_background - if sigmoid and softmax: - raise ValueError("Both sigmoid and softmax cannot be True.") + if self.use_sigmoid and self.use_softmax: + raise ValueError("Both use_sigmoid and use_softmax cannot be True.") self.asy_focal_loss = AsymmetricFocalLoss( include_background=self.include_background, @@ -257,18 +257,18 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: n_pred_ch = y_pred.shape[1] y_pred_act = y_pred - if self.sigmoid: + if self.use_sigmoid: y_pred_act = torch.sigmoid(y_pred) - elif self.softmax: + elif self.use_softmax: if n_pred_ch == 1: - warnings.warn("single channel prediction, softmax=True ignored.") + warnings.warn("single channel prediction, use_softmax=True ignored.") else: y_pred_act = torch.softmax(y_pred, dim=1) if self.to_onehot_y: - if n_pred_ch == 1 and not self.sigmoid: + if n_pred_ch == 1 and not self.use_sigmoid: warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") - elif n_pred_ch > 1 or self.sigmoid: + elif n_pred_ch > 1 or self.use_sigmoid: # Ensure y_true is (B, 1, H, W, [D]) for one-hot conversion if y_true.shape[1] != 1: y_true = y_true.unsqueeze(1) From befcfebff8603ff9a6dabc346a48a09729a1ffac Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Fri, 31 Oct 2025 22:43:13 +0800 Subject: [PATCH 5/7] Update unified_focal_loss.py Signed-off-by: ytl0623 --- monai/losses/unified_focal_loss.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/monai/losses/unified_focal_loss.py b/monai/losses/unified_focal_loss.py index 936d789ad4..6dd6ce4100 100644 --- a/monai/losses/unified_focal_loss.py +++ b/monai/losses/unified_focal_loss.py @@ -80,10 +80,10 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: if not self.include_background: # All classes are foreground, apply foreground logic - loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, C) + loss = torch.pow(1.0 - dice_class, 1.0 / self.gamma) # (B, C) elif n_classes == 1: # Single class, must be foreground (BG was excluded or not provided) - loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, 1) + loss = torch.pow(1.0 - dice_class, 1.0 / self.gamma) # (B, 1) else: # Asymmetric logic: class 0 is BG, others are FG back_dice_loss = (1.0 - dice_class[:, 0]).unsqueeze(1) # (B, 1) @@ -276,9 +276,8 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: # Ensure y_true has the same shape as y_pred_act if y_true.shape != y_pred_act.shape: - # This can happen if y_true is (B, H, W) and y_pred is (B, 1, H, W) after sigmoid - if y_true.shape[1] != y_pred_act.shape[1] and y_true.ndim == y_pred_act.ndim - 1: - y_true = y_true.unsqueeze(1) # Add channel dim + if y_true.ndim == y_pred_act.ndim - 1: + y_true = y_true.unsqueeze(1) if y_true.shape != y_pred_act.shape: raise ValueError( From 04a98d4766d45cf621eafe190a85754c4b08a2fe Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Fri, 31 Oct 2025 23:01:35 +0800 Subject: [PATCH 6/7] test Signed-off-by: ytl0623 From ac66250c6b5daa598b0f43c87d32d2474040d7f9 Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Fri, 31 Oct 2025 23:04:12 +0800 Subject: [PATCH 7/7] DCO Remediation Commit for NabJa <32510324+NabJa@users.noreply.github.com> I, NabJa <32510324+NabJa@users.noreply.github.com>, hereby add my Signed-off-by to this commit: 76c4391ca70e9c5f1515beb00b8cf2a24dc6727b Signed-off-by: NabJa <32510324+NabJa@users.noreply.github.com> Signed-off-by: ytl0623