Skip to content

Commit fc2f9cd

Browse files
j-moranoJosé Morano
authored andcommitted
Added focal_loss_with_probs and focal activation options in Dice+FL. (Fixes #8242)
Signed-off-by: j-morano <research.msj@gmail.com>
1 parent d388d1c commit fc2f9cd

File tree

4 files changed

+102
-42
lines changed

4 files changed

+102
-42
lines changed

monai/losses/dice.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -823,8 +823,10 @@ def __init__(
823823
self,
824824
include_background: bool = True,
825825
to_onehot_y: bool = False,
826-
sigmoid: bool = False,
827-
softmax: bool = False,
826+
sigmoid_dice: bool = False,
827+
softmax_dice: bool = False,
828+
sigmoid_focal: bool = True,
829+
softmax_focal: bool = False,
828830
other_act: Callable | None = None,
829831
squared_pred: bool = False,
830832
jaccard: bool = False,
@@ -843,10 +845,10 @@ def __init__(
843845
include_background: if False channel index 0 (background category) is excluded from the calculation.
844846
to_onehot_y: whether to convert the ``target`` into the one-hot format,
845847
using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False.
846-
sigmoid: if True, apply a sigmoid function to the prediction, only used by the `DiceLoss`,
847-
don't need to specify activation function for `FocalLoss`.
848-
softmax: if True, apply a softmax function to the prediction, only used by the `DiceLoss`,
849-
don't need to specify activation function for `FocalLoss`.
848+
sigmoid_dice: if True, apply a sigmoid function to the prediction for the `DiceLoss`.
849+
softmax_dice: if True, apply a softmax function to the prediction for the `DiceLoss`.
850+
sigmoid_focal: if True, apply a sigmoid function to the prediction for `FocalLoss`.
851+
softmax_focal: if True, apply a softmax function to the prediction for `FocalLoss`.
850852
other_act: callable function to execute other activation layers, Defaults to ``None``.
851853
for example: `other_act = torch.tanh`. only used by the `DiceLoss`, not for `FocalLoss`.
852854
squared_pred: use squared versions of targets and predictions in the denominator or not.
@@ -878,8 +880,8 @@ def __init__(
878880
self.dice = DiceLoss(
879881
include_background=include_background,
880882
to_onehot_y=False,
881-
sigmoid=sigmoid,
882-
softmax=softmax,
883+
sigmoid=sigmoid_dice,
884+
softmax=softmax_dice,
883885
other_act=other_act,
884886
squared_pred=squared_pred,
885887
jaccard=jaccard,
@@ -896,6 +898,8 @@ def __init__(
896898
weight=weight,
897899
alpha=alpha,
898900
reduction=reduction,
901+
use_sigmoid=sigmoid_focal,
902+
use_softmax=softmax_focal,
899903
)
900904
if lambda_dice < 0.0:
901905
raise ValueError("lambda_dice should be no less than 0.0.")
@@ -953,8 +957,14 @@ class GeneralizedDiceFocalLoss(_Loss):
953957
Defaults to True.
954958
to_onehot_y: whether to convert the ``target`` into the one-hot format,
955959
using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False.
956-
sigmoid (bool, optional): if True, apply a sigmoid function to the prediction. Defaults to False.
957-
softmax (bool, optional): if True, apply a softmax function to the prediction. Defaults to False.
960+
sigmoid_dice (bool, optional): if True, apply a sigmoid function to the prediction for `GeneralizedDiceLoss`.
961+
Defaults to False.
962+
softmax_dice (bool, optional): if True, apply a softmax function to the prediction for `GeneralizedDiceLoss`.
963+
Defaults to False.
964+
sigmoid_focal (bool, optional): if True, apply a sigmoid function to the prediction for `FocalLoss`.
965+
Defaults to True.
966+
softmax_focal (bool, optional): if True, apply a softmax function to the prediction for `FocalLoss`.
967+
Defaults to False.
958968
other_act (Optional[Callable], optional): callable function to execute other activation layers,
959969
Defaults to ``None``. for example: `other_act = torch.tanh`.
960970
only used by the `GeneralizedDiceLoss`, not for the `FocalLoss`.
@@ -987,8 +997,10 @@ def __init__(
987997
self,
988998
include_background: bool = True,
989999
to_onehot_y: bool = False,
990-
sigmoid: bool = False,
991-
softmax: bool = False,
1000+
sigmoid_dice: bool = False,
1001+
softmax_dice: bool = False,
1002+
sigmoid_focal: bool = True,
1003+
softmax_focal: bool = False,
9921004
other_act: Callable | None = None,
9931005
w_type: Weight | str = Weight.SQUARE,
9941006
reduction: LossReduction | str = LossReduction.MEAN,
@@ -1004,8 +1016,8 @@ def __init__(
10041016
self.generalized_dice = GeneralizedDiceLoss(
10051017
include_background=include_background,
10061018
to_onehot_y=to_onehot_y,
1007-
sigmoid=sigmoid,
1008-
softmax=softmax,
1019+
sigmoid=sigmoid_dice,
1020+
softmax=softmax_dice,
10091021
other_act=other_act,
10101022
w_type=w_type,
10111023
reduction=reduction,
@@ -1019,6 +1031,8 @@ def __init__(
10191031
gamma=gamma,
10201032
weight=weight,
10211033
reduction=reduction,
1034+
use_sigmoid=sigmoid_focal,
1035+
use_softmax=softmax_focal,
10221036
)
10231037
if lambda_gdl < 0.0:
10241038
raise ValueError("lambda_gdl should be no less than 0.0.")

monai/losses/focal_loss.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def __init__(
7474
weight: Sequence[float] | float | int | torch.Tensor | None = None,
7575
reduction: LossReduction | str = LossReduction.MEAN,
7676
use_softmax: bool = False,
77+
use_sigmoid: bool = True,
7778
) -> None:
7879
"""
7980
Args:
@@ -96,7 +97,9 @@ def __init__(
9697
- ``"sum"``: the output will be summed.
9798
9899
use_softmax: whether to use softmax to transform the original logits into probabilities.
99-
If True, softmax is used. If False, sigmoid is used. Defaults to False.
100+
If True, softmax is used. Defaults to False.
101+
use_sigmoid: whether to use sigmoid to transform the original logits into probabilities.
102+
If True, sigmoid is used. Defaults to True.
100103
101104
Example:
102105
>>> import torch
@@ -113,6 +116,7 @@ def __init__(
113116
self.alpha = alpha
114117
self.weight = weight
115118
self.use_softmax = use_softmax
119+
self.use_sigmoid = use_sigmoid
116120
weight = torch.as_tensor(weight) if weight is not None else None
117121
self.register_buffer("class_weight", weight)
118122
self.class_weight: None | torch.Tensor
@@ -161,8 +165,10 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
161165
self.alpha = None
162166
warnings.warn("`include_background=False`, `alpha` ignored when using softmax.")
163167
loss = softmax_focal_loss(input, target, self.gamma, self.alpha)
164-
else:
168+
elif self.use_sigmoid:
165169
loss = sigmoid_focal_loss(input, target, self.gamma, self.alpha)
170+
else:
171+
loss = focal_loss_with_probs(input, target, self.gamma, self.alpha)
166172

167173
num_of_classes = target.shape[1]
168174
if self.class_weight is not None and num_of_classes != 1:
@@ -253,3 +259,28 @@ def sigmoid_focal_loss(
253259
loss = alpha_factor * loss
254260

255261
return loss
262+
263+
264+
def focal_loss_with_probs(
265+
input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: Optional[float] = None
266+
) -> torch.Tensor:
267+
"""
268+
FL(pt) = -alpha * (1 - pt)**gamma * log(pt)
269+
270+
where p = x, pt = p if label is 1 or 1 - p if label is 0
271+
"""
272+
# Compute pt (probability of true class)
273+
pt = torch.where(target == 1, input, 1 - input)
274+
275+
# Compute focal loss components
276+
log_pt = torch.log(torch.clamp(pt, min=1e-8)) # Avoid log(0)
277+
focal_factor = (1 - pt).pow(gamma) # (1 - pt)**gamma
278+
279+
loss = -focal_factor * log_pt
280+
281+
if alpha is not None:
282+
# alpha if t==1; (1-alpha) if t==0
283+
alpha_factor = torch.where(target == 1, alpha, 1 - alpha)
284+
loss = alpha_factor * loss
285+
286+
return loss

tests/losses/test_dice_focal_loss.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,14 +53,12 @@ def test_result_no_onehot_no_bg(self, size, onehot):
5353
for lambda_focal in [0.5, 1.0, 1.5]:
5454
common_params = {
5555
"include_background": False,
56-
"softmax": True,
5756
"to_onehot_y": onehot,
5857
"reduction": reduction,
5958
"weight": weight,
6059
}
61-
dice_focal = DiceFocalLoss(lambda_focal=lambda_focal, **common_params)
62-
dice = DiceLoss(**common_params)
63-
common_params.pop("softmax", None)
60+
dice_focal = DiceFocalLoss(lambda_focal=lambda_focal, softmax_dice=True, **common_params)
61+
dice = DiceLoss(softmax=True, **common_params)
6462
focal = FocalLoss(**common_params)
6563
result = dice_focal(pred, label)
6664
expected_val = dice(pred, label) + lambda_focal * focal(pred, label)

tests/losses/test_focal_loss.py

Lines changed: 39 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from __future__ import annotations
1313

1414
import unittest
15+
from itertools import product
1516

1617
import numpy as np
1718
import torch
@@ -205,59 +206,71 @@ def test_consistency_with_cross_entropy_classification_01(self):
205206
self.assertNotAlmostEqual(max_error, 0.0, places=3)
206207

207208
def test_bin_seg_2d(self):
208-
for use_softmax in [True, False]:
209+
for use_softmax, use_sigmoid in product([True, False], repeat=2):
209210
# define 2d examples
210211
target = torch.tensor([[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]])
211212
# add another dimension corresponding to the batch (batch size = 1 here)
212213
target = target.unsqueeze(0) # shape (1, H, W)
213-
pred_very_good = 100 * F.one_hot(target, num_classes=2).permute(0, 3, 1, 2).float() - 50.0
214+
if not use_sigmoid and not use_softmax:
215+
# The prediction here are probabilities, not logits.
216+
pred_very_good = F.one_hot(target, num_classes=2).permute(0, 3, 1, 2).float()
217+
else:
218+
pred_very_good = 100 * F.one_hot(target, num_classes=2).permute(0, 3, 1, 2).float() - 50.0
214219

215220
# initialize the mean dice loss
216-
loss = FocalLoss(to_onehot_y=True, use_softmax=use_softmax)
221+
loss = FocalLoss(to_onehot_y=True, use_softmax=use_softmax, use_sigmoid=use_sigmoid)
217222

218223
# focal loss for pred_very_good should be close to 0
219224
target = target.unsqueeze(1) # shape (1, 1, H, W)
220225
focal_loss_good = float(loss(pred_very_good, target).cpu())
221226
self.assertAlmostEqual(focal_loss_good, 0.0, places=3)
222227

223228
# with alpha
224-
loss = FocalLoss(to_onehot_y=True, alpha=0.5, use_softmax=use_softmax)
229+
loss = FocalLoss(to_onehot_y=True, alpha=0.5, use_softmax=use_softmax, use_sigmoid=use_sigmoid)
225230
focal_loss_good = float(loss(pred_very_good, target).cpu())
226231
self.assertAlmostEqual(focal_loss_good, 0.0, places=3)
227232

228233
def test_empty_class_2d(self):
229-
for use_softmax in [True, False]:
234+
for use_softmax, use_sigmoid in product([True, False], repeat=2):
230235
num_classes = 2
231236
# define 2d examples
232237
target = torch.tensor([[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]])
233238
# add another dimension corresponding to the batch (batch size = 1 here)
234239
target = target.unsqueeze(0) # shape (1, H, W)
235-
pred_very_good = 1000 * F.one_hot(target, num_classes=num_classes).permute(0, 3, 1, 2).float() - 500.0
240+
if not use_sigmoid and not use_softmax:
241+
# The prediction here are probabilities, not logits.
242+
pred_very_good = F.one_hot(target, num_classes=num_classes).permute(0, 3, 1, 2).float()
243+
else:
244+
pred_very_good = 1000 * F.one_hot(target, num_classes=num_classes).permute(0, 3, 1, 2).float() - 500.0
236245

237246
# initialize the mean dice loss
238-
loss = FocalLoss(to_onehot_y=True, use_softmax=use_softmax)
247+
loss = FocalLoss(to_onehot_y=True, use_softmax=use_softmax, use_sigmoid=use_sigmoid)
239248

240249
# focal loss for pred_very_good should be close to 0
241250
target = target.unsqueeze(1) # shape (1, 1, H, W)
242251
focal_loss_good = float(loss(pred_very_good, target).cpu())
243252
self.assertAlmostEqual(focal_loss_good, 0.0, places=3)
244253

245254
# with alpha
246-
loss = FocalLoss(to_onehot_y=True, alpha=0.5, use_softmax=use_softmax)
255+
loss = FocalLoss(to_onehot_y=True, alpha=0.5, use_softmax=use_softmax, use_sigmoid=use_sigmoid)
247256
focal_loss_good = float(loss(pred_very_good, target).cpu())
248257
self.assertAlmostEqual(focal_loss_good, 0.0, places=3)
249258

250259
def test_multi_class_seg_2d(self):
251-
for use_softmax in [True, False]:
260+
for use_softmax, use_sigmoid in product([True, False], repeat=2):
252261
num_classes = 6 # labels 0 to 5
253262
# define 2d examples
254263
target = torch.tensor([[0, 0, 0, 0], [0, 1, 2, 0], [0, 3, 4, 0], [0, 0, 0, 0]])
255264
# add another dimension corresponding to the batch (batch size = 1 here)
256265
target = target.unsqueeze(0) # shape (1, H, W)
257-
pred_very_good = 1000 * F.one_hot(target, num_classes=num_classes).permute(0, 3, 1, 2).float() - 500.0
266+
if not use_sigmoid and not use_softmax:
267+
# The prediction here are probabilities, not logits.
268+
pred_very_good = F.one_hot(target, num_classes=num_classes).permute(0, 3, 1, 2).float()
269+
else:
270+
pred_very_good = 1000 * F.one_hot(target, num_classes=num_classes).permute(0, 3, 1, 2).float() - 500.0
258271
# initialize the mean dice loss
259-
loss = FocalLoss(to_onehot_y=True, use_softmax=use_softmax)
260-
loss_onehot = FocalLoss(to_onehot_y=False, use_softmax=use_softmax)
272+
loss = FocalLoss(to_onehot_y=True, use_softmax=use_softmax, use_sigmoid=use_sigmoid)
273+
loss_onehot = FocalLoss(to_onehot_y=False, use_softmax=use_softmax, use_sigmoid=use_sigmoid)
261274

262275
# focal loss for pred_very_good should be close to 0
263276
target_one_hot = F.one_hot(target, num_classes=num_classes).permute(0, 3, 1, 2) # test one hot
@@ -270,15 +283,15 @@ def test_multi_class_seg_2d(self):
270283
self.assertAlmostEqual(focal_loss_good, 0.0, places=3)
271284

272285
# with alpha
273-
loss = FocalLoss(to_onehot_y=True, alpha=0.5, use_softmax=use_softmax)
286+
loss = FocalLoss(to_onehot_y=True, alpha=0.5, use_softmax=use_softmax, use_sigmoid=use_sigmoid)
274287
focal_loss_good = float(loss(pred_very_good, target).cpu())
275288
self.assertAlmostEqual(focal_loss_good, 0.0, places=3)
276-
loss_onehot = FocalLoss(to_onehot_y=False, alpha=0.5, use_softmax=use_softmax)
289+
loss_onehot = FocalLoss(to_onehot_y=False, alpha=0.5, use_softmax=use_softmax, use_sigmoid=use_sigmoid)
277290
focal_loss_good = float(loss_onehot(pred_very_good, target_one_hot).cpu())
278291
self.assertAlmostEqual(focal_loss_good, 0.0, places=3)
279292

280293
def test_bin_seg_3d(self):
281-
for use_softmax in [True, False]:
294+
for use_softmax, use_sigmoid in product([True, False], repeat=2):
282295
num_classes = 2 # labels 0, 1
283296
# define 3d examples
284297
target = torch.tensor(
@@ -294,11 +307,15 @@ def test_bin_seg_3d(self):
294307
# add another dimension corresponding to the batch (batch size = 1 here)
295308
target = target.unsqueeze(0) # shape (1, H, W, D)
296309
target_one_hot = F.one_hot(target, num_classes=num_classes).permute(0, 4, 1, 2, 3) # test one hot
297-
pred_very_good = 1000 * F.one_hot(target, num_classes=num_classes).permute(0, 4, 1, 2, 3).float() - 500.0
310+
if not use_sigmoid and not use_softmax:
311+
# The prediction here are probabilities, not logits.
312+
pred_very_good = target_one_hot.clone().float()
313+
else:
314+
pred_very_good = 1000 * F.one_hot(target, num_classes=num_classes).permute(0, 4, 1, 2, 3).float() - 500.0
298315

299316
# initialize the mean dice loss
300-
loss = FocalLoss(to_onehot_y=True, use_softmax=use_softmax)
301-
loss_onehot = FocalLoss(to_onehot_y=False, use_softmax=use_softmax)
317+
loss = FocalLoss(to_onehot_y=True, use_softmax=use_softmax, use_sigmoid=use_sigmoid)
318+
loss_onehot = FocalLoss(to_onehot_y=False, use_softmax=use_softmax, use_sigmoid=use_sigmoid)
302319

303320
# focal loss for pred_very_good should be close to 0
304321
target = target.unsqueeze(1) # shape (1, 1, H, W)
@@ -309,10 +326,10 @@ def test_bin_seg_3d(self):
309326
self.assertAlmostEqual(focal_loss_good, 0.0, places=3)
310327

311328
# with alpha
312-
loss = FocalLoss(to_onehot_y=True, alpha=0.5, use_softmax=use_softmax)
329+
loss = FocalLoss(to_onehot_y=True, alpha=0.5, use_softmax=use_softmax, use_sigmoid=use_sigmoid)
313330
focal_loss_good = float(loss(pred_very_good, target).cpu())
314331
self.assertAlmostEqual(focal_loss_good, 0.0, places=3)
315-
loss_onehot = FocalLoss(to_onehot_y=False, alpha=0.5, use_softmax=use_softmax)
332+
loss_onehot = FocalLoss(to_onehot_y=False, alpha=0.5, use_softmax=use_softmax, use_sigmoid=use_sigmoid)
316333
focal_loss_good = float(loss_onehot(pred_very_good, target_one_hot).cpu())
317334
self.assertAlmostEqual(focal_loss_good, 0.0, places=3)
318335

@@ -369,8 +386,8 @@ def test_warnings(self):
369386
loss(chn_input, chn_target)
370387

371388
def test_script(self):
372-
for use_softmax in [True, False]:
373-
loss = FocalLoss(use_softmax=use_softmax)
389+
for use_softmax, use_sigmoid in product([True, False], repeat=2):
390+
loss = FocalLoss(use_softmax=use_softmax, use_sigmoid=use_sigmoid)
374391
test_input = torch.ones(2, 2, 8, 8)
375392
test_script_save(loss, test_input, test_input)
376393

0 commit comments

Comments
 (0)