Skip to content

Commit c9566c6

Browse files
committed
Add sigmoid/softmax interface for AsymmetricUnifiedFocalLoss
Signed-off-by: ytl0623 <david89062388@gmail.com>
1 parent 57fdd59 commit c9566c6

File tree

2 files changed

+208
-191
lines changed

2 files changed

+208
-191
lines changed

monai/losses/unified_focal_loss.py

Lines changed: 94 additions & 169 deletions
Original file line numberDiff line numberDiff line change
@@ -20,221 +20,146 @@
2020
from monai.utils import LossReduction
2121

2222

23-
class AsymmetricFocalTverskyLoss(_Loss):
23+
class AsymmetricUnifiedFocalLoss(_Loss):
2424
"""
25-
AsymmetricFocalTverskyLoss is a variant of FocalTverskyLoss, which attentions to the foreground class.
25+
AsymmetricUnifiedFocalLoss is a variant of Focal Loss that combines Asymmetric Focal Loss
26+
and Asymmetric Focal Tversky Loss to handle imbalanced medical image segmentation.
2627
27-
Actually, it's only supported for binary image segmentation now.
28+
It supports multi-class segmentation by treating channel 0 as background and
29+
channels 1..N as foreground, applying asymmetric weighting controlled by `delta`.
2830
29-
Reimplementation of the Asymmetric Focal Tversky Loss described in:
31+
Reimplementation of the Asymmetric Unified Focal Loss described in:
3032
3133
- "Unified Focal Loss: Generalising Dice and Cross Entropy-based Losses to Handle Class Imbalanced Medical Image Segmentation",
3234
Michael Yeung, Computerized Medical Imaging and Graphics
35+
36+
Example:
37+
>>> import torch
38+
>>> from monai.losses import AsymmetricUnifiedFocalLoss
39+
>>> # B, C, H, W = 1, 3, 32, 32
40+
>>> pred_logits = torch.randn(1, 3, 32, 32)
41+
>>> # Ground truth indices (B, 1, H, W)
42+
>>> grnd = torch.randint(0, 3, (1, 1, 32, 32))
43+
>>> # Use softmax=True if input is logits
44+
>>> loss_func = AsymmetricUnifiedFocalLoss(to_onehot_y=True, use_softmax=True)
45+
>>> loss = loss_func(pred_logits, grnd)
3346
"""
3447

3548
def __init__(
3649
self,
50+
weight: float = 0.5,
51+
delta: float = 0.6,
52+
gamma: float = 0.5,
53+
include_background: bool = True,
3754
to_onehot_y: bool = False,
38-
delta: float = 0.7,
39-
gamma: float = 0.75,
40-
epsilon: float = 1e-7,
4155
reduction: LossReduction | str = LossReduction.MEAN,
56+
use_softmax: bool = False,
57+
epsilon: float = 1e-7,
4258
) -> None:
4359
"""
4460
Args:
45-
to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
46-
delta : weight of the background. Defaults to 0.7.
47-
gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75.
48-
epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
61+
weight: The weighting factor between Asymmetric Focal Loss and Asymmetric Focal Tversky Loss.
62+
Final Loss = weight * AFL + (1 - weight) * AFTL. Defaults to 0.5.
63+
delta: The balancing factor controls the weight of background vs foreground classes.
64+
Values > 0.5 give more weight to foreground (False Negatives). Defaults to 0.6.
65+
gamma: The focal exponent. Higher values focus more on hard examples. Defaults to 0.5.
66+
include_background: If False, channel index 0 (background category) is excluded from the loss calculation.
67+
Defaults to True.
68+
to_onehot_y: Whether to convert the label `target` into the one-hot format. Defaults to False.
69+
reduction: {``"none"``, ``"mean"``, ``"sum"``}
70+
Specifies the reduction to apply to the output. Defaults to ``"mean"``.
71+
use_softmax: Whether to use softmax to transform the original logits into probabilities.
72+
If True, softmax is used. If False, assumes input is already probabilities. Defaults to False.
73+
epsilon: Small value to prevent division by zero or log(0). Defaults to 1e-7.
4974
"""
5075
super().__init__(reduction=LossReduction(reduction).value)
51-
self.to_onehot_y = to_onehot_y
76+
self.weight = weight
5277
self.delta = delta
5378
self.gamma = gamma
79+
self.include_background = include_background
80+
self.to_onehot_y = to_onehot_y
81+
self.use_softmax = use_softmax
5482
self.epsilon = epsilon
5583

56-
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
57-
n_pred_ch = y_pred.shape[1]
58-
59-
if self.to_onehot_y:
60-
if n_pred_ch == 1:
61-
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
62-
else:
63-
y_true = one_hot(y_true, num_classes=n_pred_ch)
64-
65-
if y_true.shape != y_pred.shape:
66-
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")
67-
68-
# clip the prediction to avoid NaN
69-
y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon)
70-
axis = list(range(2, len(y_pred.shape)))
71-
72-
# Calculate true positives (tp), false negatives (fn) and false positives (fp)
73-
tp = torch.sum(y_true * y_pred, dim=axis)
74-
fn = torch.sum(y_true * (1 - y_pred), dim=axis)
75-
fp = torch.sum((1 - y_true) * y_pred, dim=axis)
76-
dice_class = (tp + self.epsilon) / (tp + self.delta * fn + (1 - self.delta) * fp + self.epsilon)
77-
78-
# Calculate losses separately for each class, enhancing both classes
79-
back_dice = 1 - dice_class[:, 0]
80-
fore_dice = (1 - dice_class[:, 1]) * torch.pow(1 - dice_class[:, 1], -self.gamma)
81-
82-
# Average class scores
83-
loss = torch.mean(torch.stack([back_dice, fore_dice], dim=-1))
84-
return loss
85-
86-
87-
class AsymmetricFocalLoss(_Loss):
88-
"""
89-
AsymmetricFocalLoss is a variant of FocalTverskyLoss, which attentions to the foreground class.
90-
91-
Actually, it's only supported for binary image segmentation now.
92-
93-
Reimplementation of the Asymmetric Focal Loss described in:
94-
95-
- "Unified Focal Loss: Generalising Dice and Cross Entropy-based Losses to Handle Class Imbalanced Medical Image Segmentation",
96-
Michael Yeung, Computerized Medical Imaging and Graphics
97-
"""
98-
99-
def __init__(
100-
self,
101-
to_onehot_y: bool = False,
102-
delta: float = 0.7,
103-
gamma: float = 2,
104-
epsilon: float = 1e-7,
105-
reduction: LossReduction | str = LossReduction.MEAN,
106-
):
84+
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
10785
"""
10886
Args:
109-
to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False.
110-
delta : weight of the background. Defaults to 0.7.
111-
gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75.
112-
epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
87+
input: the shape should be BNH[WD], where N is the number of classes.
88+
target: the shape should be BNH[WD] or B1H[WD].
89+
90+
Raises:
91+
ValueError: When input and target have incompatible shapes.
11392
"""
114-
super().__init__(reduction=LossReduction(reduction).value)
115-
self.to_onehot_y = to_onehot_y
116-
self.delta = delta
117-
self.gamma = gamma
118-
self.epsilon = epsilon
93+
if self.use_softmax:
94+
input = torch.nn.functional.softmax(input, dim=1)
11995

120-
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
121-
n_pred_ch = y_pred.shape[1]
96+
n_pred_ch = input.shape[1]
12297

12398
if self.to_onehot_y:
12499
if n_pred_ch == 1:
125100
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
126101
else:
127-
y_true = one_hot(y_true, num_classes=n_pred_ch)
128-
129-
if y_true.shape != y_pred.shape:
130-
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")
131-
132-
y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon)
133-
cross_entropy = -y_true * torch.log(y_pred)
102+
if target.shape[1] == 1:
103+
target = one_hot(target, num_classes=n_pred_ch)
134104

135-
back_ce = torch.pow(1 - y_pred[:, 0], self.gamma) * cross_entropy[:, 0]
136-
back_ce = (1 - self.delta) * back_ce
105+
if target.shape != input.shape:
106+
raise ValueError(f"ground truth has different shape ({target.shape}) from input ({input.shape})")
137107

138-
fore_ce = cross_entropy[:, 1]
139-
fore_ce = self.delta * fore_ce
140-
141-
loss = torch.mean(torch.sum(torch.stack([back_ce, fore_ce], dim=1), dim=1))
142-
return loss
143-
144-
145-
class AsymmetricUnifiedFocalLoss(_Loss):
146-
"""
147-
AsymmetricUnifiedFocalLoss is a variant of Focal Loss.
108+
# Clip values for numerical stability
109+
input = torch.clamp(input, self.epsilon, 1.0 - self.epsilon)
148110

149-
Actually, it's only supported for binary image segmentation now
111+
# Part A: Asymmetric Focal Loss
112+
# Cross Entropy: -target * log(input)
113+
cross_entropy = -target * torch.log(input)
150114

151-
Reimplementation of the Asymmetric Unified Focal Tversky Loss described in:
115+
# Background (Channel 0): (1 - delta) * (1 - p)^gamma * CE
116+
back_ce = (1 - self.delta) * torch.pow(1 - input[:, 0:1], self.gamma) * cross_entropy[:, 0:1]
152117

153-
- "Unified Focal Loss: Generalising Dice and Cross Entropy-based Losses to Handle Class Imbalanced Medical Image Segmentation",
154-
Michael Yeung, Computerized Medical Imaging and Graphics
155-
"""
118+
# Foreground (Channel 1..N): delta * CE
119+
fore_ce = self.delta * cross_entropy[:, 1:]
156120

157-
def __init__(
158-
self,
159-
to_onehot_y: bool = False,
160-
num_classes: int = 2,
161-
weight: float = 0.5,
162-
gamma: float = 0.5,
163-
delta: float = 0.7,
164-
reduction: LossReduction | str = LossReduction.MEAN,
165-
):
166-
"""
167-
Args:
168-
to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False.
169-
num_classes : number of classes, it only supports 2 now. Defaults to 2.
170-
delta : weight of the background. Defaults to 0.7.
171-
gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75.
172-
epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
173-
weight : weight for each loss function, if it's none it's 0.5. Defaults to None.
174-
175-
Example:
176-
>>> import torch
177-
>>> from monai.losses import AsymmetricUnifiedFocalLoss
178-
>>> pred = torch.ones((1,1,32,32), dtype=torch.float32)
179-
>>> grnd = torch.ones((1,1,32,32), dtype=torch.int64)
180-
>>> fl = AsymmetricUnifiedFocalLoss(to_onehot_y=True)
181-
>>> fl(pred, grnd)
182-
"""
183-
super().__init__(reduction=LossReduction(reduction).value)
184-
self.to_onehot_y = to_onehot_y
185-
self.num_classes = num_classes
186-
self.gamma = gamma
187-
self.delta = delta
188-
self.weight: float = weight
189-
self.asy_focal_loss = AsymmetricFocalLoss(gamma=self.gamma, delta=self.delta)
190-
self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(gamma=self.gamma, delta=self.delta)
121+
# Combine
122+
if self.include_background:
123+
asy_focal_loss = torch.cat([back_ce, fore_ce], dim=1)
124+
else:
125+
asy_focal_loss = fore_ce
191126

192-
# TODO: Implement this function to support multiple classes segmentation
193-
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
194-
"""
195-
Args:
196-
y_pred : the shape should be BNH[WD], where N is the number of classes.
197-
It only supports binary segmentation.
198-
The input should be the original logits since it will be transformed by
199-
a sigmoid in the forward function.
200-
y_true : the shape should be BNH[WD], where N is the number of classes.
201-
It only supports binary segmentation.
127+
# Part B: Asymmetric Focal Tversky Loss
128+
# Sum over spatial dimensions (Batch and Channel dims are preserved)
129+
reduce_axis = list(range(2, input.dim()))
202130

203-
Raises:
204-
ValueError: When input and target are different shape
205-
ValueError: When len(y_pred.shape) != 4 and len(y_pred.shape) != 5
206-
ValueError: When num_classes
207-
ValueError: When the number of classes entered does not match the expected number
208-
"""
209-
if y_pred.shape != y_true.shape:
210-
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")
131+
tp = torch.sum(target * input, dim=reduce_axis)
132+
fn = torch.sum(target * (1 - input), dim=reduce_axis)
133+
fp = torch.sum((1 - target) * input, dim=reduce_axis)
211134

212-
if len(y_pred.shape) != 4 and len(y_pred.shape) != 5:
213-
raise ValueError(f"input shape must be 4 or 5, but got {y_pred.shape}")
135+
# Tversky Index
136+
dice_class = (tp + self.epsilon) / (tp + self.delta * fn + (1 - self.delta) * fp + self.epsilon)
214137

215-
if y_pred.shape[1] == 1:
216-
y_pred = one_hot(y_pred, num_classes=self.num_classes)
217-
y_true = one_hot(y_true, num_classes=self.num_classes)
138+
# Background: 1 - Dice
139+
back_dice_loss = 1 - dice_class[:, 0:1]
218140

219-
if torch.max(y_true) != self.num_classes - 1:
220-
raise ValueError(f"Please make sure the number of classes is {self.num_classes - 1}")
141+
# Foreground: (1 - Dice)^(1 - gamma)
142+
fore_dice_loss = (1 - dice_class[:, 1:]) * torch.pow(1 - dice_class[:, 1:], -self.gamma)
221143

222-
n_pred_ch = y_pred.shape[1]
223-
if self.to_onehot_y:
224-
if n_pred_ch == 1:
225-
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
226-
else:
227-
y_true = one_hot(y_true, num_classes=n_pred_ch)
144+
# Combine
145+
if self.include_background:
146+
asy_focal_tversky_loss = torch.cat([back_dice_loss, fore_dice_loss], dim=1)
147+
else:
148+
asy_focal_tversky_loss = fore_dice_loss
228149

229-
asy_focal_loss = self.asy_focal_loss(y_pred, y_true)
230-
asy_focal_tversky_loss = self.asy_focal_tversky_loss(y_pred, y_true)
150+
# Part C: Unified Combination & Reduction
151+
# Aggregate Focal Loss spatial dimensions to match Tversky Loss shape (B, C)
152+
if asy_focal_loss.dim() > 2:
153+
asy_focal_loss = torch.mean(asy_focal_loss, dim=reduce_axis)
231154

232-
loss: torch.Tensor = self.weight * asy_focal_loss + (1 - self.weight) * asy_focal_tversky_loss
155+
# Weighted sum
156+
total_loss = self.weight * asy_focal_loss + (1 - self.weight) * asy_focal_tversky_loss
233157

234158
if self.reduction == LossReduction.SUM.value:
235-
return torch.sum(loss) # sum over the batch and channel dims
159+
return torch.sum(total_loss)
236160
if self.reduction == LossReduction.NONE.value:
237-
return loss # returns [N, num_classes] losses
161+
return total_loss
238162
if self.reduction == LossReduction.MEAN.value:
239-
return torch.mean(loss)
163+
return torch.mean(total_loss)
164+
240165
raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].')

0 commit comments

Comments
 (0)