Skip to content

Commit 522802f

Browse files
ytl0623Rusheel86
authored andcommitted
Weights in alpha for FocalLoss (Project-MONAI#8665)
Fixes Project-MONAI#8601 Support alpha as a list, tuple, or tensor of floats, in addition to the existing scalar support. <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: ytl0623 <david89062388@gmail.com>
1 parent e27c567 commit 522802f

2 files changed

Lines changed: 143 additions & 16 deletions

File tree

monai/losses/focal_loss.py

Lines changed: 99 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -69,25 +69,67 @@ def __init__(
6969
include_background: bool = True,
7070
to_onehot_y: bool = False,
7171
gamma: float = 2.0,
72-
alpha: float | None = None,
72+
alpha: float | Sequence[float] | None = None,
7373
weight: Sequence[float] | float | int | torch.Tensor | None = None,
7474
reduction: LossReduction | str = LossReduction.MEAN,
7575
use_softmax: bool = False,
7676
ignore_index: int | None = None,
7777
) -> None:
7878
"""
7979
Args:
80+
<<<<<<< HEAD
8081
# ... (other args)
8182
ignore_index: index of the class to ignore during calculation.
83+
=======
84+
include_background: if False, channel index 0 (background category) is excluded from the loss calculation.
85+
If False, `alpha` is invalid when using softmax unless `alpha` is a sequence (explicit class weights).
86+
to_onehot_y: whether to convert the label `y` into the one-hot format. Defaults to False.
87+
gamma: value of the exponent gamma in the definition of the Focal loss. Defaults to 2.
88+
alpha: value of the alpha in the definition of the alpha-balanced Focal loss.
89+
The value should be in [0, 1].
90+
If a sequence is provided, its length must match the number of classes
91+
(excluding the background class if `include_background=False`).
92+
Defaults to None.
93+
weight: weights to apply to the voxels of each class. If None no weights are applied.
94+
The input can be a single value (same weight for all classes), a sequence of values (the length
95+
of the sequence should be the same as the number of classes. If not ``include_background``,
96+
the number of classes should not include the background category class 0).
97+
The value/values should be no less than 0. Defaults to None.
98+
reduction: {``"none"``, ``"mean"``, ``"sum"``}
99+
Specifies the reduction to apply to the output. Defaults to ``"mean"``.
100+
101+
- ``"none"``: no reduction will be applied.
102+
- ``"mean"``: the sum of the output will be divided by the number of elements in the output.
103+
- ``"sum"``: the output will be summed.
104+
105+
use_softmax: whether to use softmax to transform the original logits into probabilities.
106+
If True, softmax is used. If False, sigmoid is used. Defaults to False.
107+
108+
Example:
109+
>>> import torch
110+
>>> from monai.losses import FocalLoss
111+
>>> pred = torch.tensor([[1, 0], [0, 1], [1, 0]], dtype=torch.float32)
112+
>>> grnd = torch.tensor([[0], [1], [0]], dtype=torch.int64)
113+
>>> fl = FocalLoss(to_onehot_y=True)
114+
>>> fl(pred, grnd)
115+
>>>>>>> 40df2f61 (Weights in alpha for FocalLoss (#8665))
82116
"""
83117
super().__init__(reduction=LossReduction(reduction).value)
84118
self.include_background = include_background
85119
self.to_onehot_y = to_onehot_y
86120
self.gamma = gamma
87-
self.alpha = alpha
88121
self.weight = weight
89122
self.use_softmax = use_softmax
123+
self.use_softmax = use_softmax
90124
self.ignore_index = ignore_index
125+
126+
self.alpha: float | torch.Tensor | None
127+
if alpha is None:
128+
self.alpha = None
129+
elif isinstance(alpha, (float, int)):
130+
self.alpha = float(alpha)
131+
else:
132+
self.alpha = torch.as_tensor(alpha)
91133
weight = torch.as_tensor(weight) if weight is not None else None
92134
self.register_buffer("class_weight", weight)
93135
self.class_weight: None | torch.Tensor
@@ -125,14 +167,24 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
125167

126168
input = input.float()
127169
target = target.float()
128-
129-
if self.use_softmax and input.shape[1] > 1:
170+
alpha_arg = self.alpha
171+
if self.use_softmax:
130172
if not self.include_background and self.alpha is not None:
131-
self.alpha = None
132-
warnings.warn("`include_background=False`, `alpha` ignored when using softmax.")
133-
loss = softmax_focal_loss(input, target, self.gamma, self.alpha)
173+
if isinstance(self.alpha, (float, int)):
174+
alpha_arg = None
175+
warnings.warn(
176+
"`include_background=False`, scalar `alpha` ignored when using softmax.", stacklevel=2
177+
)
178+
loss = softmax_focal_loss(input, target, self.gamma, alpha_arg)
134179
else:
135-
loss = sigmoid_focal_loss(input, target, self.gamma, self.alpha)
180+
loss = sigmoid_focal_loss(input, target, self.gamma, alpha_arg)
181+
if not self.include_background and self.alpha is not None:
182+
if isinstance(self.alpha, (float, int)):
183+
alpha_arg = None
184+
warnings.warn(
185+
"`include_background=False`, scalar `alpha` ignored when using softmax.", stacklevel=2
186+
)
187+
loss = softmax_focal_loss(input, target, self.gamma, alpha_arg)
136188

137189
if mask is not None:
138190
loss = loss * mask
@@ -167,7 +219,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
167219

168220

169221
def softmax_focal_loss(
170-
input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: float | None = None
222+
input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: float | torch.Tensor | None = None
171223
) -> torch.Tensor:
172224
"""
173225
FL(pt) = -alpha * (1 - pt)**gamma * log(pt)
@@ -179,8 +231,22 @@ def softmax_focal_loss(
179231
loss: torch.Tensor = -(1 - input_ls.exp()).pow(gamma) * input_ls * target
180232

181233
if alpha is not None:
182-
# (1-alpha) for the background class and alpha for the other classes
183-
alpha_fac = torch.tensor([1 - alpha] + [alpha] * (target.shape[1] - 1)).to(loss)
234+
if isinstance(alpha, torch.Tensor):
235+
alpha_t = alpha.to(device=input.device, dtype=input.dtype)
236+
else:
237+
alpha_t = torch.tensor(alpha, device=input.device, dtype=input.dtype)
238+
239+
if alpha_t.ndim == 0: # scalar
240+
alpha_val = alpha_t.item()
241+
# (1-alpha) for the background class and alpha for the other classes
242+
alpha_fac = torch.tensor([1 - alpha_val] + [alpha_val] * (target.shape[1] - 1)).to(loss)
243+
else: # tensor (sequence)
244+
if alpha_t.shape[0] != target.shape[1]:
245+
raise ValueError(
246+
f"The length of alpha ({alpha_t.shape[0]}) must match the number of classes ({target.shape[1]})."
247+
)
248+
alpha_fac = alpha_t
249+
184250
broadcast_dims = [-1] + [1] * len(target.shape[2:])
185251
alpha_fac = alpha_fac.view(broadcast_dims)
186252
loss = alpha_fac * loss
@@ -189,7 +255,7 @@ def softmax_focal_loss(
189255

190256

191257
def sigmoid_focal_loss(
192-
input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: float | None = None
258+
input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: float | torch.Tensor | None = None
193259
) -> torch.Tensor:
194260
"""
195261
FL(pt) = -alpha * (1 - pt)**gamma * log(pt)
@@ -212,8 +278,27 @@ def sigmoid_focal_loss(
212278
loss = (invprobs * gamma).exp() * loss
213279

214280
if alpha is not None:
215-
# alpha if t==1; (1-alpha) if t==0
216-
alpha_factor = target * alpha + (1 - target) * (1 - alpha)
281+
if isinstance(alpha, torch.Tensor):
282+
alpha_t = alpha.to(device=input.device, dtype=input.dtype)
283+
else:
284+
alpha_t = torch.tensor(alpha, device=input.device, dtype=input.dtype)
285+
286+
if alpha_t.ndim == 0: # scalar
287+
# alpha if t==1; (1-alpha) if t==0
288+
alpha_factor = target * alpha_t + (1 - target) * (1 - alpha_t)
289+
else: # tensor (sequence)
290+
if alpha_t.shape[0] != target.shape[1]:
291+
raise ValueError(
292+
f"The length of alpha ({alpha_t.shape[0]}) must match the number of classes ({target.shape[1]})."
293+
)
294+
# Reshape alpha for broadcasting: (1, C, 1, 1...)
295+
broadcast_dims = [-1] + [1] * len(target.shape[2:])
296+
alpha_t = alpha_t.view(broadcast_dims)
297+
# Apply per-class weight only to positive samples
298+
# For positive samples (target==1): multiply by alpha[c]
299+
# For negative samples (target==0): keep weight as 1.0
300+
alpha_factor = torch.where(target == 1, alpha_t, torch.ones_like(alpha_t))
301+
217302
loss = alpha_factor * loss
218303

219304
return loss

tests/losses/test_focal_loss.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,11 @@
2121

2222
from monai.losses import FocalLoss
2323
from monai.networks import one_hot
24-
from tests.test_utils import test_script_save
24+
from tests.test_utils import TEST_DEVICES, test_script_save
2525

2626
TEST_CASES = []
27-
for device in ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]:
27+
for case in TEST_DEVICES:
28+
device = case[0]
2829
input_data = {
2930
"input": torch.tensor(
3031
[[[[1.0, 1.0], [0.5, 0.0]], [[1.0, 1.0], [0.5, 0.0]], [[1.0, 1.0], [0.5, 0.0]]]], device=device
@@ -77,6 +78,13 @@
7778
TEST_CASES.append([{"to_onehot_y": True, "use_softmax": True}, input_data, 0.16276])
7879
TEST_CASES.append([{"to_onehot_y": True, "alpha": 0.8, "use_softmax": True}, input_data, 0.08138])
7980

81+
TEST_ALPHA_BROADCASTING = []
82+
for case in TEST_DEVICES:
83+
device = case[0]
84+
for include_background in [True, False]:
85+
for use_softmax in [True, False]:
86+
TEST_ALPHA_BROADCASTING.append([device, include_background, use_softmax])
87+
8088

8189
class TestFocalLoss(unittest.TestCase):
8290
@parameterized.expand(TEST_CASES)
@@ -374,6 +382,40 @@ def test_script(self):
374382
test_input = torch.ones(2, 2, 8, 8)
375383
test_script_save(loss, test_input, test_input)
376384

385+
@parameterized.expand(TEST_ALPHA_BROADCASTING)
386+
def test_alpha_sequence_broadcasting(self, device, include_background, use_softmax):
387+
"""
388+
Test FocalLoss with alpha as a sequence for proper broadcasting.
389+
"""
390+
num_classes = 3
391+
batch_size = 2
392+
spatial_dims = (4, 4)
393+
394+
logits = torch.randn(batch_size, num_classes, *spatial_dims, device=device)
395+
target = torch.randint(0, num_classes, (batch_size, 1, *spatial_dims), device=device)
396+
397+
if include_background:
398+
alpha_seq = [0.1, 0.5, 2.0]
399+
else:
400+
alpha_seq = [0.5, 2.0]
401+
402+
loss_func = FocalLoss(
403+
to_onehot_y=True,
404+
gamma=2.0,
405+
alpha=alpha_seq,
406+
include_background=include_background,
407+
use_softmax=use_softmax,
408+
reduction="mean",
409+
)
410+
411+
result = loss_func(logits, target)
412+
413+
self.assertTrue(torch.is_tensor(result))
414+
self.assertEqual(result.ndim, 0)
415+
self.assertTrue(
416+
result > 0, f"Loss should be positive. params: dev={device}, bg={include_background}, softmax={use_softmax}"
417+
)
418+
377419

378420
if __name__ == "__main__":
379421
unittest.main()

0 commit comments

Comments
 (0)