Skip to content
1 change: 1 addition & 0 deletions monai/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from __future__ import annotations

from .adversarial_loss import PatchAdversarialLoss
from .aucm_loss import AUCMLoss
from .barlow_twins import BarlowTwinsLoss
from .cldice import SoftclDiceLoss, SoftDiceclDiceLoss
from .contrastive import ContrastiveLoss
Expand Down
207 changes: 207 additions & 0 deletions monai/losses/aucm_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import torch
import torch.nn as nn
from torch.nn.modules.loss import _Loss

from monai.utils import LossReduction


class AUCMLoss(_Loss):
"""
AUC-Margin loss with squared-hinge surrogate loss for optimizing AUROC.

The loss optimizes the Area Under the ROC Curve (AUROC) by using margin-based constraints
on positive and negative predictions. It supports two versions: 'v1' includes class prior
information, while 'v2' removes this dependency for better generalization.

Reference:
Yuan, Zhuoning, Yan, Yan, Sonka, Milan, and Yang, Tianbao.
"Large-scale robust deep auc maximization: A new surrogate loss and empirical studies on medical image classification."
Proceedings of the IEEE/CVF International Conference on Computer Vision. 2021.
https://arxiv.org/abs/2012.03173

Implementation based on: https://github.com/Optimization-AI/LibAUC/blob/1.4.0/libauc/losses/auc.py

Example:
>>> import torch
>>> from monai.losses import AUCMLoss
>>> loss_fn = AUCMLoss()
>>> input = torch.randn(32, 1, requires_grad=True)
>>> target = torch.randint(0, 2, (32, 1)).float()
>>> loss = loss_fn(input, target)
"""

def __init__(
self,
margin: float = 1.0,
imratio: float | None = None,
version: str = "v1",
reduction: LossReduction | str = LossReduction.MEAN,
) -> None:
"""
Args:
margin: margin for squared-hinge surrogate loss (default: ``1.0``).
imratio: the ratio of the number of positive samples to the number of total samples in the training dataset.
If this value is not given, it will be automatically calculated with mini-batch samples.
This value is ignored when ``version`` is set to ``'v2'``.
version: whether to include prior class information in the objective function (default: ``'v1'``).
'v1' includes class prior, 'v2' removes this dependency.
reduction: {``"none"``, ``"mean"``, ``"sum"``}
Specifies the reduction to apply to the output. Defaults to ``"mean"``.
Note: This loss is computed at the batch level and always returns a scalar.
The reduction parameter is accepted for API consistency but has no effect.

Raises:
ValueError: When ``version`` is not one of ["v1", "v2"].
ValueError: When ``imratio`` is not in [0, 1].

Example:
>>> import torch
>>> from monai.losses import AUCMLoss
>>> loss_fn = AUCMLoss(version='v2')
>>> input = torch.randn(32, 1, requires_grad=True)
>>> target = torch.randint(0, 2, (32, 1)).float()
>>> loss = loss_fn(input, target)
"""
super().__init__(reduction=LossReduction(reduction).value)
if version not in ["v1", "v2"]:
raise ValueError(f"version should be 'v1' or 'v2', got {version}")
if imratio is not None and not (0.0 <= imratio <= 1.0):
raise ValueError(f"imratio must be in [0, 1], got {imratio}")
self.margin = margin
self.imratio = imratio
self.version = version
self.a = nn.Parameter(torch.tensor(0.0))
self.b = nn.Parameter(torch.tensor(0.0))
self.alpha = nn.Parameter(torch.tensor(0.0))

def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Args:
input: the shape should be B1HW[D], where the channel dimension is 1 for binary classification.
target: the shape should be B1HW[D], with values 0 or 1.

Returns:
torch.Tensor: scalar AUCM loss.

Raises:
ValueError: When input or target have incorrect shapes.
ValueError: When input or target have fewer than 2 dimensions.
ValueError: When target contains non-binary values.
"""
if input.ndim < 2 or target.ndim < 2:
raise ValueError("Input and target must have at least 2 dimensions (B, C, ...)")
if input.shape[1] != 1:
raise ValueError(f"Input should have 1 channel for binary classification, got {input.shape[1]}")
if target.shape[1] != 1:
raise ValueError(f"Target should have 1 channel, got {target.shape[1]}")
if input.shape != target.shape:
raise ValueError(f"Input and target shapes do not match: {input.shape} vs {target.shape}")

input = input.flatten()
target = target.flatten()

if input.numel() == 0:
raise ValueError("Input and target must contain at least one element.")

if not torch.all((target == 0) | (target == 1)):
raise ValueError("Target must contain only binary values (0 or 1)")

pos_mask = (target == 1).float()
neg_mask = (target == 0).float()

mean_pos_sq = (input - self.a) ** 2
mean_neg_sq = (input - self.b) ** 2

# Note:
# v1 uses global expectations (normalized by total number of samples),
# following the original LibAUC implementation.
# v2 uses class-conditional expectations (normalized by number of samples
# in each class), implemented via non-zero averaging.
# These behaviors differ and should not be unified.
if self.version == "v1":
p = float(self.imratio) if self.imratio is not None else float(pos_mask.mean().item())
p1 = 1.0 - p

mean_pos = self._global_mean(mean_pos_sq, pos_mask)
mean_neg = self._global_mean(mean_neg_sq, neg_mask)

interaction = self._global_mean(p * input * neg_mask - p1 * input * pos_mask, pos_mask + neg_mask)

loss = (
p1 * mean_pos
+ p * mean_neg
+ 2 * self.alpha * (p * p1 * self.margin + interaction)
- p * p1 * self.alpha**2
)

else: # v2
mean_pos = self._class_mean(mean_pos_sq, pos_mask)
mean_neg = self._class_mean(mean_neg_sq, neg_mask)

mean_input_pos = self._class_mean(input, pos_mask)
mean_input_neg = self._class_mean(input, neg_mask)

loss = (
mean_pos + mean_neg + 2 * self.alpha * (self.margin + mean_input_neg - mean_input_pos) - self.alpha**2
)

return loss

def _global_mean(self, tensor: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
"""
Compute the global mean of a masked tensor.

This computes the mean over all elements, where values outside the mask
are zeroed out. The result is normalized by the total number of elements,
not by the number of masked elements.

This corresponds to a global expectation:
E[mask * tensor]

Args:
tensor: Input tensor.
mask: Binary mask tensor of the same shape as ``tensor``.

Returns:
Scalar tensor representing the global mean.
"""
masked = tensor * mask
if masked.numel() == 0:
return torch.zeros((), dtype=tensor.dtype, device=tensor.device)
return masked.mean()

def _class_mean(self, tensor: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
"""
Compute the class-conditional mean of a masked tensor.

This computes the mean over only the masked (non-zero) elements, i.e.,
the result is normalized by the number of masked elements.

This corresponds to a class-conditional expectation:
E[tensor | mask]

Args:
tensor: Input tensor.
mask: Binary mask tensor of the same shape as ``tensor``.

Returns:
Scalar tensor representing the class-conditional mean.
Returns 0 if no elements are selected by the mask.
"""
denom = mask.sum()
if denom.item() == 0:
return torch.zeros((), dtype=tensor.dtype, device=tensor.device)
return (tensor * mask).sum() / denom
142 changes: 142 additions & 0 deletions tests/losses/test_aucm_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import unittest

import numpy as np
import torch
from parameterized import parameterized

from monai.losses import AUCMLoss
from tests.test_utils import test_script_save

TEST_CASES = [
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these are in the right direction but there are other cases, such as totally blank inputs for both versions. Tests should also cover a range of constructor argument value combinations, so different positive tests for values of imratio or margin. You can look at how dice is tested or other files to get an idea of how to do this, you can pre-compute results from your or other implementations of this loss to store here.

# small deterministic cases (with expected values)
("v1", torch.tensor([[1.0], [2.0]]), torch.tensor([[1.0], [0.0]]), 1.25),
("v2", torch.tensor([[1.0], [2.0]]), torch.tensor([[1.0], [0.0]]), 5.0),
]


class TestAUCMLoss(unittest.TestCase):
"""Unit tests for AUCMLoss covering correctness, edge cases, and scriptability."""

@parameterized.expand([("v1",), ("v2",)])
def test_versions(self, version):
"""Test AUCMLoss with different versions."""
loss_fn = AUCMLoss(version=version)
pred = torch.randn(32, 1, requires_grad=True)
target = torch.randint(0, 2, (32, 1)).float()
loss = loss_fn(pred, target)
self.assertIsInstance(loss, torch.Tensor)
self.assertEqual(loss.ndim, 0)

@parameterized.expand(TEST_CASES)
def test_known_values(self, version, pred, target, expected):
"""Test AUCMLoss against fixed manually computed values."""
loss = AUCMLoss(version=version)(pred, target)
np.testing.assert_allclose(loss.detach().cpu().numpy(), expected, atol=1e-5, rtol=1e-5)

@parameterized.expand([("v1",), ("v2",)])
def test_high_dimensional(self, version):
"""Test AUCMLoss with higher dimensional preds (e.g., segmentation)."""
loss_fn = AUCMLoss(version=version)

pred = torch.randn(2, 1, 8, 8, requires_grad=True)
target = torch.randint(0, 2, (2, 1, 8, 8)).float()

loss = loss_fn(pred, target)

self.assertIsInstance(loss, torch.Tensor)
self.assertEqual(loss.ndim, 0)

def test_imbalanced(self):
"""Test AUCMLoss with highly imbalanced targets."""
loss_fn = AUCMLoss(version="v1")

pred = torch.randn(32, 1)
target = torch.zeros(32, 1)
target[0] = 1.0 # only one positive

loss = loss_fn(pred, target)

self.assertIsInstance(loss, torch.Tensor)

def test_invalid_version(self):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With these tests what I mean to say with parameterized is to have a test like:

@parameterized(BAD_ARGS):
def test_bad_args(self, kwargs):
    with self.assertRaises(ValueError):
        AUCMLoss(**kwargs)

The BAD_ARGS list would contain dictionaries for all the bad combinations of arguments you want to check raise an exception.

"""Test that invalid version raises ValueError."""
with self.assertRaises(ValueError):
AUCMLoss(version="invalid")

def test_invalid_imratio(self):
"""Test that invalid imratio raises ValueError."""
with self.assertRaises(ValueError):
AUCMLoss(imratio=1.5)
with self.assertRaises(ValueError):
AUCMLoss(imratio=-0.1)

def test_invalid_pred_shape(self):
"""Test that invalid pred shape raises ValueError."""
loss_fn = AUCMLoss()
pred = torch.randn(32, 2) # Wrong channel
target = torch.randint(0, 2, (32, 1)).float()
with self.assertRaises(ValueError):
loss_fn(pred, target)

def test_invalid_target_shape(self):
"""Test that invalid target shape raises ValueError."""
loss_fn = AUCMLoss()
pred = torch.randn(32, 1)
target = torch.randint(0, 2, (32, 2)).float() # Wrong channel
with self.assertRaises(ValueError):
loss_fn(pred, target)

def test_insufficient_dimensions(self):
"""Test that tensors with insufficient dimensions raise ValueError."""
loss_fn = AUCMLoss()
pred = torch.randn(32) # 1D tensor
target = torch.randint(0, 2, (32, 1)).float()
with self.assertRaises(ValueError):
loss_fn(pred, target)

def test_shape_mismatch(self):
"""Test that mismatched shapes raise ValueError."""
loss_fn = AUCMLoss()
pred = torch.randn(32, 1)
target = torch.randint(0, 2, (16, 1)).float()
with self.assertRaises(ValueError):
loss_fn(pred, target)

def test_non_binary_target(self):
"""Test that non-binary target values raise ValueError."""
loss_fn = AUCMLoss()
pred = torch.randn(32, 1)
target = torch.tensor([[0.5], [1.0], [2.0], [0.0]] * 8) # 32x1, still non-binary
with self.assertRaises(ValueError):
loss_fn(pred, target)

def test_backward(self):
"""Test that gradients can be computed."""
loss_fn = AUCMLoss()
pred = torch.randn(32, 1, requires_grad=True)
target = torch.randint(0, 2, (32, 1)).float()
loss = loss_fn(pred, target)
loss.backward()
self.assertIsNotNone(pred.grad)

def test_script_save(self):
"""Test that the loss can be saved as TorchScript."""
loss_fn = AUCMLoss()
test_script_save(loss_fn, torch.randn(32, 1), torch.randint(0, 2, (32, 1)).float())


if __name__ == "__main__":
unittest.main()
Loading