Skip to content

Commit b547410

Browse files
committed
fix bug: scalar alpha still passed despite ignored warning
Signed-off-by: ytl0623 <david89062388@gmail.com>
1 parent 9b5d011 commit b547410

2 files changed

Lines changed: 6 additions & 5 deletions

File tree

monai/losses/focal_loss.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -164,16 +164,17 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
164164
loss: Optional[torch.Tensor] = None
165165
input = input.float()
166166
target = target.float()
167-
167+
alpha_arg = self.alpha
168168
if self.use_softmax:
169169
if not self.include_background and self.alpha is not None:
170170
if isinstance(self.alpha, (float, int)):
171+
alpha_arg = None
171172
warnings.warn(
172173
"`include_background=False`, scalar `alpha` ignored when using softmax.", stacklevel=2
173174
)
174-
loss = softmax_focal_loss(input, target, self.gamma, self.alpha)
175+
loss = softmax_focal_loss(input, target, self.gamma, alpha_arg)
175176
else:
176-
loss = sigmoid_focal_loss(input, target, self.gamma, self.alpha)
177+
loss = sigmoid_focal_loss(input, target, self.gamma, alpha_arg)
177178

178179
num_of_classes = target.shape[1]
179180
if self.class_weight is not None and num_of_classes != 1:

tests/losses/test_focal_loss.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,10 @@
7979

8080
TEST_ALPHA_BROADCASTING = []
8181
for case in TEST_DEVICES:
82-
device = case[0]
82+
dev = case[0]
8383
for include_background in [True, False]:
8484
for use_softmax in [True, False]:
85-
TEST_ALPHA_BROADCASTING.append([device, include_background, use_softmax])
85+
TEST_ALPHA_BROADCASTING.append([dev, include_background, use_softmax])
8686

8787

8888
class TestFocalLoss(unittest.TestCase):

0 commit comments

Comments
 (0)