From 22d5211f7090d4e4dc7e636989fbc0c77e16076f Mon Sep 17 00:00:00 2001 From: Shubham Chandravanshi Date: Mon, 26 Jan 2026 01:14:26 +0530 Subject: [PATCH 1/7] Add AUC-Margin loss for AUROC optimization (#4609) Signed-off-by: Shubham Chandravanshi --- monai/losses/__init__.py | 1 + monai/losses/aucm_loss.py | 137 +++++++++++++++++++++++++++++++++ tests/losses/test_aucm_loss.py | 78 +++++++++++++++++++ 3 files changed, 216 insertions(+) create mode 100644 monai/losses/aucm_loss.py create mode 100644 tests/losses/test_aucm_loss.py diff --git a/monai/losses/__init__.py b/monai/losses/__init__.py index 41935be204..853c355831 100644 --- a/monai/losses/__init__.py +++ b/monai/losses/__init__.py @@ -12,6 +12,7 @@ from __future__ import annotations from .adversarial_loss import PatchAdversarialLoss +from .aucm_loss import AUCMLoss from .barlow_twins import BarlowTwinsLoss from .cldice import SoftclDiceLoss, SoftDiceclDiceLoss from .contrastive import ContrastiveLoss diff --git a/monai/losses/aucm_loss.py b/monai/losses/aucm_loss.py new file mode 100644 index 0000000000..6f3250b4d8 --- /dev/null +++ b/monai/losses/aucm_loss.py @@ -0,0 +1,137 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import warnings + +import torch +import torch.nn as nn +from torch.nn.modules.loss import _Loss + +from monai.utils import LossReduction + + +class AUCMLoss(_Loss): + """ + AUC-Margin loss with squared-hinge surrogate loss for optimizing AUROC. + + The loss optimizes the Area Under the ROC Curve (AUROC) by using margin-based constraints + on positive and negative predictions. It supports two versions: 'v1' includes class prior + information, while 'v2' removes this dependency for better generalization. + + Reference: + Yuan, Zhuoning, Yan, Yan, Sonka, Milan, and Yang, Tianbao. + "Large-scale robust deep auc maximization: A new surrogate loss and empirical studies on medical image classification." + Proceedings of the IEEE/CVF International Conference on Computer Vision. 2021. + https://arxiv.org/abs/2012.03173 + + Implementation based on: https://github.com/Optimization-AI/LibAUC/blob/1.4.0/libauc/losses/auc.py + + Example: + >>> import torch + >>> from monai.losses import AUCMLoss + >>> loss_fn = AUCMLoss() + >>> input = torch.randn(32, 1, requires_grad=True) + >>> target = torch.randint(0, 2, (32, 1)).float() + >>> loss = loss_fn(input, target) + """ + + def __init__( + self, + margin: float = 1.0, + imratio: float | None = None, + version: str = "v1", + reduction: LossReduction | str = LossReduction.MEAN, + ) -> None: + """ + Args: + margin: margin for squared-hinge surrogate loss (default: ``1.0``). + imratio: the ratio of the number of positive samples to the number of total samples in the training dataset. + If this value is not given, it will be automatically calculated with mini-batch samples. + This value is ignored when ``version`` is set to ``'v2'``. + version: whether to include prior class information in the objective function (default: ``'v1'``). + 'v1' includes class prior, 'v2' removes this dependency. + reduction: {``"none"``, ``"mean"``, ``"sum"``} + Specifies the reduction to apply to the output. Defaults to ``"mean"``. + + - ``"none"``: no reduction will be applied. + - ``"mean"``: the sum of the output will be divided by the number of elements in the output. + - ``"sum"``: the output will be summed. + + Raises: + ValueError: When ``version`` is not one of ["v1", "v2"]. + + Example: + >>> import torch + >>> from monai.losses import AUCMLoss + >>> loss_fn = AUCMLoss(version='v2') + >>> input = torch.randn(32, 1, requires_grad=True) + >>> target = torch.randint(0, 2, (32, 1)).float() + >>> loss = loss_fn(input, target) + """ + super().__init__(reduction=LossReduction(reduction).value) + if version not in ["v1", "v2"]: + raise ValueError(f"version should be 'v1' or 'v2', got {version}") + self.margin = margin + self.imratio = imratio + self.version = version + self.a = nn.Parameter(torch.tensor(0.0)) + self.b = nn.Parameter(torch.tensor(0.0)) + self.alpha = nn.Parameter(torch.tensor(0.0)) + + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Args: + input: the shape should be B1HW[D], where the channel dimension is 1 for binary classification. + target: the shape should be B1HW[D], with values 0 or 1. + + Raises: + ValueError: When input or target have incorrect shapes. + """ + if input.shape[1] != 1: + raise ValueError(f"Input should have 1 channel for binary classification, got {input.shape[1]}") + if target.shape[1] != 1: + raise ValueError(f"Target should have 1 channel, got {target.shape[1]}") + if input.shape != target.shape: + raise ValueError(f"Input and target shapes do not match: {input.shape} vs {target.shape}") + + input = input.flatten() + target = target.flatten() + + pos_mask = (target == 1).float() + neg_mask = (target == 0).float() + + if self.version == "v1": + p = self.imratio if self.imratio is not None else pos_mask.mean() + loss = ( + (1 - p) * self._safe_mean((input - self.a) ** 2 * pos_mask) + + p * self._safe_mean((input - self.b) ** 2 * neg_mask) + + 2 + * self.alpha + * (p * (1 - p) * self.margin + self._safe_mean(p * input * neg_mask - (1 - p) * input * pos_mask)) + - p * (1 - p) * self.alpha**2 + ) + else: + loss = ( + self._safe_mean((input - self.a) ** 2 * pos_mask) + + self._safe_mean((input - self.b) ** 2 * neg_mask) + + 2 * self.alpha * (self.margin + self._safe_mean(input * neg_mask) - self._safe_mean(input * pos_mask)) + - self.alpha**2 + ) + + return loss + + def _safe_mean(self, tensor: torch.Tensor) -> torch.Tensor: + """Compute mean safely, returning 0 if tensor is empty.""" + if tensor.numel() == 0 or tensor.count_nonzero() == 0: + return torch.tensor(0.0, device=tensor.device, dtype=tensor.dtype) + return tensor.sum() / tensor.count_nonzero() diff --git a/tests/losses/test_aucm_loss.py b/tests/losses/test_aucm_loss.py new file mode 100644 index 0000000000..684a2b737c --- /dev/null +++ b/tests/losses/test_aucm_loss.py @@ -0,0 +1,78 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch + +from monai.losses import AUCMLoss +from tests.test_utils import test_script_save + + +class TestAUCMLoss(unittest.TestCase): + def test_v1(self): + loss_fn = AUCMLoss(version="v1") + input = torch.randn(32, 1, requires_grad=True) + target = torch.randint(0, 2, (32, 1)).float() + loss = loss_fn(input, target) + self.assertIsInstance(loss, torch.Tensor) + self.assertEqual(loss.ndim, 0) + + def test_v2(self): + loss_fn = AUCMLoss(version="v2") + input = torch.randn(32, 1, requires_grad=True) + target = torch.randint(0, 2, (32, 1)).float() + loss = loss_fn(input, target) + self.assertIsInstance(loss, torch.Tensor) + self.assertEqual(loss.ndim, 0) + + def test_invalid_version(self): + with self.assertRaises(ValueError): + AUCMLoss(version="invalid") + + def test_invalid_input_shape(self): + loss_fn = AUCMLoss() + input = torch.randn(32, 2) # Wrong channel + target = torch.randint(0, 2, (32, 1)).float() + with self.assertRaises(ValueError): + loss_fn(input, target) + + def test_invalid_target_shape(self): + loss_fn = AUCMLoss() + input = torch.randn(32, 1) + target = torch.randint(0, 2, (32, 2)).float() # Wrong channel + with self.assertRaises(ValueError): + loss_fn(input, target) + + def test_shape_mismatch(self): + loss_fn = AUCMLoss() + input = torch.randn(32, 1) + target = torch.randint(0, 2, (16, 1)).float() + with self.assertRaises(ValueError): + loss_fn(input, target) + + def test_backward(self): + loss_fn = AUCMLoss() + input = torch.randn(32, 1, requires_grad=True) + target = torch.randint(0, 2, (32, 1)).float() + loss = loss_fn(input, target) + loss.backward() + self.assertIsNotNone(input.grad) + + def test_script_save(self): + loss_fn = AUCMLoss() + test_script_save(loss_fn, torch.randn(32, 1), torch.randint(0, 2, (32, 1)).float()) + + +if __name__ == "__main__": + unittest.main() From f692165c3fef40f0fe07147d086f29d421cfcd08 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 25 Jan 2026 20:00:53 +0000 Subject: [PATCH 2/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/losses/aucm_loss.py | 1 - 1 file changed, 1 deletion(-) diff --git a/monai/losses/aucm_loss.py b/monai/losses/aucm_loss.py index 6f3250b4d8..11c84eafd4 100644 --- a/monai/losses/aucm_loss.py +++ b/monai/losses/aucm_loss.py @@ -11,7 +11,6 @@ from __future__ import annotations -import warnings import torch import torch.nn as nn From 96c9c369c47dc8e1487486253cff29fe5ceefe86 Mon Sep 17 00:00:00 2001 From: Shubham Chandravanshi Date: Mon, 26 Jan 2026 02:04:55 +0530 Subject: [PATCH 3/7] Correct masked mean computation in AUCMLoss and update docstrings Signed-off-by: Shubham Chandravanshi --- monai/losses/aucm_loss.py | 28 +++++++++++++++++----------- tests/losses/test_aucm_loss.py | 10 ++++++++++ 2 files changed, 27 insertions(+), 11 deletions(-) diff --git a/monai/losses/aucm_loss.py b/monai/losses/aucm_loss.py index 11c84eafd4..75bb32b98f 100644 --- a/monai/losses/aucm_loss.py +++ b/monai/losses/aucm_loss.py @@ -11,7 +11,6 @@ from __future__ import annotations - import torch import torch.nn as nn from torch.nn.modules.loss import _Loss @@ -93,6 +92,9 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: input: the shape should be B1HW[D], where the channel dimension is 1 for binary classification. target: the shape should be B1HW[D], with values 0 or 1. + Returns: + torch.Tensor: scalar AUCM loss. + Raises: ValueError: When input or target have incorrect shapes. """ @@ -112,25 +114,29 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if self.version == "v1": p = self.imratio if self.imratio is not None else pos_mask.mean() loss = ( - (1 - p) * self._safe_mean((input - self.a) ** 2 * pos_mask) - + p * self._safe_mean((input - self.b) ** 2 * neg_mask) + (1 - p) * self._safe_mean((input - self.a) ** 2, pos_mask) + + p * self._safe_mean((input - self.b) ** 2, neg_mask) + 2 * self.alpha - * (p * (1 - p) * self.margin + self._safe_mean(p * input * neg_mask - (1 - p) * input * pos_mask)) + * ( + p * (1 - p) * self.margin + + self._safe_mean(p * input * neg_mask - (1 - p) * input * pos_mask, pos_mask + neg_mask) + ) - p * (1 - p) * self.alpha**2 ) else: loss = ( - self._safe_mean((input - self.a) ** 2 * pos_mask) - + self._safe_mean((input - self.b) ** 2 * neg_mask) - + 2 * self.alpha * (self.margin + self._safe_mean(input * neg_mask) - self._safe_mean(input * pos_mask)) + self._safe_mean((input - self.a) ** 2, pos_mask) + + self._safe_mean((input - self.b) ** 2, neg_mask) + + 2 * self.alpha * (self.margin + self._safe_mean(input, neg_mask) - self._safe_mean(input, pos_mask)) - self.alpha**2 ) return loss - def _safe_mean(self, tensor: torch.Tensor) -> torch.Tensor: - """Compute mean safely, returning 0 if tensor is empty.""" - if tensor.numel() == 0 or tensor.count_nonzero() == 0: + def _safe_mean(self, tensor: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """Compute mean safely over masked elements.""" + denom = mask.sum() + if denom == 0: return torch.tensor(0.0, device=tensor.device, dtype=tensor.dtype) - return tensor.sum() / tensor.count_nonzero() + return (tensor * mask).sum() / denom diff --git a/tests/losses/test_aucm_loss.py b/tests/losses/test_aucm_loss.py index 684a2b737c..609d6ff51b 100644 --- a/tests/losses/test_aucm_loss.py +++ b/tests/losses/test_aucm_loss.py @@ -20,7 +20,10 @@ class TestAUCMLoss(unittest.TestCase): + """Test cases for AUCMLoss.""" + def test_v1(self): + """Test AUCMLoss with version 'v1'.""" loss_fn = AUCMLoss(version="v1") input = torch.randn(32, 1, requires_grad=True) target = torch.randint(0, 2, (32, 1)).float() @@ -29,6 +32,7 @@ def test_v1(self): self.assertEqual(loss.ndim, 0) def test_v2(self): + """Test AUCMLoss with version 'v2'.""" loss_fn = AUCMLoss(version="v2") input = torch.randn(32, 1, requires_grad=True) target = torch.randint(0, 2, (32, 1)).float() @@ -37,10 +41,12 @@ def test_v2(self): self.assertEqual(loss.ndim, 0) def test_invalid_version(self): + """Test that invalid version raises ValueError.""" with self.assertRaises(ValueError): AUCMLoss(version="invalid") def test_invalid_input_shape(self): + """Test that invalid input shape raises ValueError.""" loss_fn = AUCMLoss() input = torch.randn(32, 2) # Wrong channel target = torch.randint(0, 2, (32, 1)).float() @@ -48,6 +54,7 @@ def test_invalid_input_shape(self): loss_fn(input, target) def test_invalid_target_shape(self): + """Test that invalid target shape raises ValueError.""" loss_fn = AUCMLoss() input = torch.randn(32, 1) target = torch.randint(0, 2, (32, 2)).float() # Wrong channel @@ -55,6 +62,7 @@ def test_invalid_target_shape(self): loss_fn(input, target) def test_shape_mismatch(self): + """Test that mismatched shapes raise ValueError.""" loss_fn = AUCMLoss() input = torch.randn(32, 1) target = torch.randint(0, 2, (16, 1)).float() @@ -62,6 +70,7 @@ def test_shape_mismatch(self): loss_fn(input, target) def test_backward(self): + """Test that gradients can be computed.""" loss_fn = AUCMLoss() input = torch.randn(32, 1, requires_grad=True) target = torch.randint(0, 2, (32, 1)).float() @@ -70,6 +79,7 @@ def test_backward(self): self.assertIsNotNone(input.grad) def test_script_save(self): + """Test that the loss can be saved as TorchScript.""" loss_fn = AUCMLoss() test_script_save(loss_fn, torch.randn(32, 1), torch.randint(0, 2, (32, 1)).float()) From 38d81695e355b1c8bc40618c05700d85a14b073e Mon Sep 17 00:00:00 2001 From: Shubham Chandravanshi Date: Mon, 26 Jan 2026 02:57:45 +0530 Subject: [PATCH 4/7] Validate binary targets, clarify reduction, and fix AUCM typing Signed-off-by: Shubham Chandravanshi --- monai/losses/aucm_loss.py | 12 +++++++----- tests/losses/test_aucm_loss.py | 8 ++++++++ 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/monai/losses/aucm_loss.py b/monai/losses/aucm_loss.py index 75bb32b98f..41cb88d433 100644 --- a/monai/losses/aucm_loss.py +++ b/monai/losses/aucm_loss.py @@ -60,10 +60,8 @@ def __init__( 'v1' includes class prior, 'v2' removes this dependency. reduction: {``"none"``, ``"mean"``, ``"sum"``} Specifies the reduction to apply to the output. Defaults to ``"mean"``. - - - ``"none"``: no reduction will be applied. - - ``"mean"``: the sum of the output will be divided by the number of elements in the output. - - ``"sum"``: the output will be summed. + Note: This loss is computed at the batch level and always returns a scalar. + The reduction parameter is accepted for API consistency but has no effect. Raises: ValueError: When ``version`` is not one of ["v1", "v2"]. @@ -97,6 +95,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: Raises: ValueError: When input or target have incorrect shapes. + ValueError: When target contains non-binary values. """ if input.shape[1] != 1: raise ValueError(f"Input should have 1 channel for binary classification, got {input.shape[1]}") @@ -108,11 +107,14 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: input = input.flatten() target = target.flatten() + if not torch.all((target == 0) | (target == 1)): + raise ValueError("Target must contain only binary values (0 or 1)") + pos_mask = (target == 1).float() neg_mask = (target == 0).float() if self.version == "v1": - p = self.imratio if self.imratio is not None else pos_mask.mean() + p = float(self.imratio) if self.imratio is not None else float(pos_mask.mean().item()) loss = ( (1 - p) * self._safe_mean((input - self.a) ** 2, pos_mask) + p * self._safe_mean((input - self.b) ** 2, neg_mask) diff --git a/tests/losses/test_aucm_loss.py b/tests/losses/test_aucm_loss.py index 609d6ff51b..9a0803ca11 100644 --- a/tests/losses/test_aucm_loss.py +++ b/tests/losses/test_aucm_loss.py @@ -69,6 +69,14 @@ def test_shape_mismatch(self): with self.assertRaises(ValueError): loss_fn(input, target) + def test_non_binary_target(self): + """Test that non-binary target values raise ValueError.""" + loss_fn = AUCMLoss() + input = torch.randn(32, 1) + target = torch.tensor([[0.5], [1.0], [2.0]] * 10 + [[0.0]]) # Contains non-binary values + with self.assertRaises(ValueError): + loss_fn(input, target) + def test_backward(self): """Test that gradients can be computed.""" loss_fn = AUCMLoss() From 3623f9684f8f9f8b757b04dbfff8b4c14c2226ea Mon Sep 17 00:00:00 2001 From: Shubham Chandravanshi Date: Mon, 26 Jan 2026 03:13:49 +0530 Subject: [PATCH 5/7] Validate imratio and input shape, added test cases for it and fix non-binary target test Signed-off-by: Shubham Chandravanshi --- monai/losses/aucm_loss.py | 6 ++++++ tests/losses/test_aucm_loss.py | 17 ++++++++++++++++- 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/monai/losses/aucm_loss.py b/monai/losses/aucm_loss.py index 41cb88d433..799a13bd10 100644 --- a/monai/losses/aucm_loss.py +++ b/monai/losses/aucm_loss.py @@ -65,6 +65,7 @@ def __init__( Raises: ValueError: When ``version`` is not one of ["v1", "v2"]. + ValueError: When ``imratio`` is not in [0, 1]. Example: >>> import torch @@ -77,6 +78,8 @@ def __init__( super().__init__(reduction=LossReduction(reduction).value) if version not in ["v1", "v2"]: raise ValueError(f"version should be 'v1' or 'v2', got {version}") + if imratio is not None and not (0.0 <= imratio <= 1.0): + raise ValueError(f"imratio must be in [0, 1], got {imratio}") self.margin = margin self.imratio = imratio self.version = version @@ -95,8 +98,11 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: Raises: ValueError: When input or target have incorrect shapes. + ValueError: When input or target have fewer than 2 dimensions. ValueError: When target contains non-binary values. """ + if input.ndim < 2 or target.ndim < 2: + raise ValueError("Input and target must have at least 2 dimensions (B, C, ...)") if input.shape[1] != 1: raise ValueError(f"Input should have 1 channel for binary classification, got {input.shape[1]}") if target.shape[1] != 1: diff --git a/tests/losses/test_aucm_loss.py b/tests/losses/test_aucm_loss.py index 9a0803ca11..e4010f90ec 100644 --- a/tests/losses/test_aucm_loss.py +++ b/tests/losses/test_aucm_loss.py @@ -45,6 +45,13 @@ def test_invalid_version(self): with self.assertRaises(ValueError): AUCMLoss(version="invalid") + def test_invalid_imratio(self): + """Test that invalid imratio raises ValueError.""" + with self.assertRaises(ValueError): + AUCMLoss(imratio=1.5) + with self.assertRaises(ValueError): + AUCMLoss(imratio=-0.1) + def test_invalid_input_shape(self): """Test that invalid input shape raises ValueError.""" loss_fn = AUCMLoss() @@ -61,6 +68,14 @@ def test_invalid_target_shape(self): with self.assertRaises(ValueError): loss_fn(input, target) + def test_insufficient_dimensions(self): + """Test that tensors with insufficient dimensions raise ValueError.""" + loss_fn = AUCMLoss() + input = torch.randn(32) # 1D tensor + target = torch.randint(0, 2, (32, 1)).float() + with self.assertRaises(ValueError): + loss_fn(input, target) + def test_shape_mismatch(self): """Test that mismatched shapes raise ValueError.""" loss_fn = AUCMLoss() @@ -73,7 +88,7 @@ def test_non_binary_target(self): """Test that non-binary target values raise ValueError.""" loss_fn = AUCMLoss() input = torch.randn(32, 1) - target = torch.tensor([[0.5], [1.0], [2.0]] * 10 + [[0.0]]) # Contains non-binary values + target = torch.tensor([[0.5], [1.0], [2.0], [0.0]] * 8) # 32x1, still non-binary with self.assertRaises(ValueError): loss_fn(input, target) From 25c2702596ea42b20a5b17ac1af5de596f7e1b8e Mon Sep 17 00:00:00 2001 From: Shubham Chandravanshi Date: Sun, 29 Mar 2026 21:48:40 +0530 Subject: [PATCH 6/7] Refactor AUCMLoss implementation and improve tests Signed-off-by: Shubham Chandravanshi --- monai/losses/aucm_loss.py | 87 +++++++++++++++++++++++++++------- tests/losses/test_aucm_loss.py | 27 ++++++----- 2 files changed, 85 insertions(+), 29 deletions(-) diff --git a/monai/losses/aucm_loss.py b/monai/losses/aucm_loss.py index 799a13bd10..324e9adb43 100644 --- a/monai/losses/aucm_loss.py +++ b/monai/losses/aucm_loss.py @@ -119,32 +119,83 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: pos_mask = (target == 1).float() neg_mask = (target == 0).float() + mean_pos_sq = (input - self.a) ** 2 + mean_neg_sq = (input - self.b) ** 2 + + # Note: + # v1 uses global expectations (normalized by total number of samples), + # following the original LibAUC implementation. + # v2 uses class-conditional expectations (normalized by number of samples + # in each class), implemented via non-zero averaging. + # These behaviors differ and should not be unified. if self.version == "v1": p = float(self.imratio) if self.imratio is not None else float(pos_mask.mean().item()) + p1 = 1.0 - p + + mean_pos = self._global_mean(mean_pos_sq, pos_mask) + mean_neg = self._global_mean(mean_neg_sq, neg_mask) + + interaction = self._global_mean(p * input * neg_mask - p1 * input * pos_mask, pos_mask + neg_mask) + loss = ( - (1 - p) * self._safe_mean((input - self.a) ** 2, pos_mask) - + p * self._safe_mean((input - self.b) ** 2, neg_mask) - + 2 - * self.alpha - * ( - p * (1 - p) * self.margin - + self._safe_mean(p * input * neg_mask - (1 - p) * input * pos_mask, pos_mask + neg_mask) - ) - - p * (1 - p) * self.alpha**2 + p1 * mean_pos + + p * mean_neg + + 2 * self.alpha * (p * p1 * self.margin + interaction) + - p * p1 * self.alpha**2 ) - else: + + else: # v2 + mean_pos = self._class_mean(mean_pos_sq, pos_mask) + mean_neg = self._class_mean(mean_neg_sq, neg_mask) + + mean_input_pos = self._class_mean(input, pos_mask) + mean_input_neg = self._class_mean(input, neg_mask) + loss = ( - self._safe_mean((input - self.a) ** 2, pos_mask) - + self._safe_mean((input - self.b) ** 2, neg_mask) - + 2 * self.alpha * (self.margin + self._safe_mean(input, neg_mask) - self._safe_mean(input, pos_mask)) - - self.alpha**2 + mean_pos + mean_neg + 2 * self.alpha * (self.margin + mean_input_neg - mean_input_pos) - self.alpha**2 ) return loss - def _safe_mean(self, tensor: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: - """Compute mean safely over masked elements.""" + def _global_mean(self, tensor: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """ + Compute the global mean of a masked tensor. + + This computes the mean over all elements, where values outside the mask + are zeroed out. The result is normalized by the total number of elements, + not by the number of masked elements. + + This corresponds to a global expectation: + E[mask * tensor] + + Args: + tensor: Input tensor. + mask: Binary mask tensor of the same shape as ``tensor``. + + Returns: + Scalar tensor representing the global mean. + """ + return (tensor * mask).mean() + + def _class_mean(self, tensor: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """ + Compute the class-conditional mean of a masked tensor. + + This computes the mean over only the masked (non-zero) elements, i.e., + the result is normalized by the number of masked elements. + + This corresponds to a class-conditional expectation: + E[tensor | mask] + + Args: + tensor: Input tensor. + mask: Binary mask tensor of the same shape as ``tensor``. + + Returns: + Scalar tensor representing the class-conditional mean. + Returns 0 if no elements are selected by the mask. + """ denom = mask.sum() - if denom == 0: - return torch.tensor(0.0, device=tensor.device, dtype=tensor.dtype) + if denom.item() == 0: + return torch.zeros((), dtype=tensor.dtype, device=tensor.device) return (tensor * mask).sum() / denom diff --git a/tests/losses/test_aucm_loss.py b/tests/losses/test_aucm_loss.py index e4010f90ec..f234a3c74e 100644 --- a/tests/losses/test_aucm_loss.py +++ b/tests/losses/test_aucm_loss.py @@ -14,31 +14,36 @@ import unittest import torch +from parameterized import parameterized from monai.losses import AUCMLoss from tests.test_utils import test_script_save +FIXED_INPUT = torch.tensor([[1.0], [2.0]]) +FIXED_TARGET = torch.tensor([[1.0], [0.0]]) + +EXPECTED_V1 = 1.25 +EXPECTED_V2 = 5.0 + class TestAUCMLoss(unittest.TestCase): """Test cases for AUCMLoss.""" - def test_v1(self): - """Test AUCMLoss with version 'v1'.""" - loss_fn = AUCMLoss(version="v1") + @parameterized.expand([("v1",), ("v2",)]) + def test_versions(self, version): + """Test AUCMLoss with different versions.""" + loss_fn = AUCMLoss(version=version) input = torch.randn(32, 1, requires_grad=True) target = torch.randint(0, 2, (32, 1)).float() loss = loss_fn(input, target) self.assertIsInstance(loss, torch.Tensor) self.assertEqual(loss.ndim, 0) - def test_v2(self): - """Test AUCMLoss with version 'v2'.""" - loss_fn = AUCMLoss(version="v2") - input = torch.randn(32, 1, requires_grad=True) - target = torch.randint(0, 2, (32, 1)).float() - loss = loss_fn(input, target) - self.assertIsInstance(loss, torch.Tensor) - self.assertEqual(loss.ndim, 0) + @parameterized.expand([("v1", EXPECTED_V1), ("v2", EXPECTED_V2)]) + def test_known_values(self, version, expected): + """Test AUCMLoss against fixed manually computed values.""" + loss = AUCMLoss(version=version)(FIXED_INPUT, FIXED_TARGET) + self.assertAlmostEqual(loss.item(), expected, places=5) def test_invalid_version(self): """Test that invalid version raises ValueError.""" From a2f51013326b53d754fe2767308b648e24f247b4 Mon Sep 17 00:00:00 2001 From: Shubham Chandravanshi Date: Sun, 29 Mar 2026 23:55:41 +0530 Subject: [PATCH 7/7] Address coderabbitai review comments for AUCMLoss and tests Signed-off-by: Shubham Chandravanshi --- monai/losses/aucm_loss.py | 8 +++- tests/losses/test_aucm_loss.py | 80 ++++++++++++++++++++++------------ 2 files changed, 60 insertions(+), 28 deletions(-) diff --git a/monai/losses/aucm_loss.py b/monai/losses/aucm_loss.py index 324e9adb43..c6fa24e0cb 100644 --- a/monai/losses/aucm_loss.py +++ b/monai/losses/aucm_loss.py @@ -113,6 +113,9 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: input = input.flatten() target = target.flatten() + if input.numel() == 0: + raise ValueError("Input and target must contain at least one element.") + if not torch.all((target == 0) | (target == 1)): raise ValueError("Target must contain only binary values (0 or 1)") @@ -175,7 +178,10 @@ def _global_mean(self, tensor: torch.Tensor, mask: torch.Tensor) -> torch.Tensor Returns: Scalar tensor representing the global mean. """ - return (tensor * mask).mean() + masked = tensor * mask + if masked.numel() == 0: + return torch.zeros((), dtype=tensor.dtype, device=tensor.device) + return masked.mean() def _class_mean(self, tensor: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: """ diff --git a/tests/losses/test_aucm_loss.py b/tests/losses/test_aucm_loss.py index f234a3c74e..dfa10dff1b 100644 --- a/tests/losses/test_aucm_loss.py +++ b/tests/losses/test_aucm_loss.py @@ -13,37 +13,63 @@ import unittest +import numpy as np import torch from parameterized import parameterized from monai.losses import AUCMLoss from tests.test_utils import test_script_save -FIXED_INPUT = torch.tensor([[1.0], [2.0]]) -FIXED_TARGET = torch.tensor([[1.0], [0.0]]) - -EXPECTED_V1 = 1.25 -EXPECTED_V2 = 5.0 +TEST_CASES = [ + # small deterministic cases (with expected values) + ("v1", torch.tensor([[1.0], [2.0]]), torch.tensor([[1.0], [0.0]]), 1.25), + ("v2", torch.tensor([[1.0], [2.0]]), torch.tensor([[1.0], [0.0]]), 5.0), +] class TestAUCMLoss(unittest.TestCase): - """Test cases for AUCMLoss.""" + """Unit tests for AUCMLoss covering correctness, edge cases, and scriptability.""" @parameterized.expand([("v1",), ("v2",)]) def test_versions(self, version): """Test AUCMLoss with different versions.""" loss_fn = AUCMLoss(version=version) - input = torch.randn(32, 1, requires_grad=True) + pred = torch.randn(32, 1, requires_grad=True) target = torch.randint(0, 2, (32, 1)).float() - loss = loss_fn(input, target) + loss = loss_fn(pred, target) self.assertIsInstance(loss, torch.Tensor) self.assertEqual(loss.ndim, 0) - @parameterized.expand([("v1", EXPECTED_V1), ("v2", EXPECTED_V2)]) - def test_known_values(self, version, expected): + @parameterized.expand(TEST_CASES) + def test_known_values(self, version, pred, target, expected): """Test AUCMLoss against fixed manually computed values.""" - loss = AUCMLoss(version=version)(FIXED_INPUT, FIXED_TARGET) - self.assertAlmostEqual(loss.item(), expected, places=5) + loss = AUCMLoss(version=version)(pred, target) + np.testing.assert_allclose(loss.detach().cpu().numpy(), expected, atol=1e-5, rtol=1e-5) + + @parameterized.expand([("v1",), ("v2",)]) + def test_high_dimensional(self, version): + """Test AUCMLoss with higher dimensional preds (e.g., segmentation).""" + loss_fn = AUCMLoss(version=version) + + pred = torch.randn(2, 1, 8, 8, requires_grad=True) + target = torch.randint(0, 2, (2, 1, 8, 8)).float() + + loss = loss_fn(pred, target) + + self.assertIsInstance(loss, torch.Tensor) + self.assertEqual(loss.ndim, 0) + + def test_imbalanced(self): + """Test AUCMLoss with highly imbalanced targets.""" + loss_fn = AUCMLoss(version="v1") + + pred = torch.randn(32, 1) + target = torch.zeros(32, 1) + target[0] = 1.0 # only one positive + + loss = loss_fn(pred, target) + + self.assertIsInstance(loss, torch.Tensor) def test_invalid_version(self): """Test that invalid version raises ValueError.""" @@ -57,54 +83,54 @@ def test_invalid_imratio(self): with self.assertRaises(ValueError): AUCMLoss(imratio=-0.1) - def test_invalid_input_shape(self): - """Test that invalid input shape raises ValueError.""" + def test_invalid_pred_shape(self): + """Test that invalid pred shape raises ValueError.""" loss_fn = AUCMLoss() - input = torch.randn(32, 2) # Wrong channel + pred = torch.randn(32, 2) # Wrong channel target = torch.randint(0, 2, (32, 1)).float() with self.assertRaises(ValueError): - loss_fn(input, target) + loss_fn(pred, target) def test_invalid_target_shape(self): """Test that invalid target shape raises ValueError.""" loss_fn = AUCMLoss() - input = torch.randn(32, 1) + pred = torch.randn(32, 1) target = torch.randint(0, 2, (32, 2)).float() # Wrong channel with self.assertRaises(ValueError): - loss_fn(input, target) + loss_fn(pred, target) def test_insufficient_dimensions(self): """Test that tensors with insufficient dimensions raise ValueError.""" loss_fn = AUCMLoss() - input = torch.randn(32) # 1D tensor + pred = torch.randn(32) # 1D tensor target = torch.randint(0, 2, (32, 1)).float() with self.assertRaises(ValueError): - loss_fn(input, target) + loss_fn(pred, target) def test_shape_mismatch(self): """Test that mismatched shapes raise ValueError.""" loss_fn = AUCMLoss() - input = torch.randn(32, 1) + pred = torch.randn(32, 1) target = torch.randint(0, 2, (16, 1)).float() with self.assertRaises(ValueError): - loss_fn(input, target) + loss_fn(pred, target) def test_non_binary_target(self): """Test that non-binary target values raise ValueError.""" loss_fn = AUCMLoss() - input = torch.randn(32, 1) + pred = torch.randn(32, 1) target = torch.tensor([[0.5], [1.0], [2.0], [0.0]] * 8) # 32x1, still non-binary with self.assertRaises(ValueError): - loss_fn(input, target) + loss_fn(pred, target) def test_backward(self): """Test that gradients can be computed.""" loss_fn = AUCMLoss() - input = torch.randn(32, 1, requires_grad=True) + pred = torch.randn(32, 1, requires_grad=True) target = torch.randint(0, 2, (32, 1)).float() - loss = loss_fn(input, target) + loss = loss_fn(pred, target) loss.backward() - self.assertIsNotNone(input.grad) + self.assertIsNotNone(pred.grad) def test_script_save(self): """Test that the loss can be saved as TorchScript."""