Skip to content

Commit 9bde913

Browse files
committed
fix activation order and inheritance
Signed-off-by: ytl0623 <david89062388@gmail.com>
1 parent adff162 commit 9bde913

File tree

1 file changed

+46
-20
lines changed

1 file changed

+46
-20
lines changed

monai/losses/dice.py

Lines changed: 46 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -239,27 +239,53 @@ class MaskedDiceLoss(DiceLoss):
239239
240240
"""
241241

242-
def __init__(self, *args: Any, **kwargs: Any) -> None:
242+
def __init__(
243+
self,
244+
include_background: bool = True,
245+
to_onehot_y: bool = False,
246+
sigmoid: bool = False,
247+
softmax: bool = False,
248+
other_act: Callable | None = None,
249+
squared_pred: bool = False,
250+
jaccard: bool = False,
251+
reduction: LossReduction | str = LossReduction.MEAN,
252+
smooth_nr: float = 1e-5,
253+
smooth_dr: float = 1e-5,
254+
batch: bool = False,
255+
weight: Sequence[float] | float | int | torch.Tensor | None = None,
256+
soft_label: bool = False,
257+
) -> None:
243258
"""
244259
Args follow :py:class:`monai.losses.DiceLoss`.
245260
"""
246-
super().__init__(*args, **kwargs)
247-
self.dice = DiceLoss(
248-
include_background=self.include_background,
249-
to_onehot_y=self.to_onehot_y,
261+
if other_act is not None and not callable(other_act):
262+
raise TypeError(f"other_act must be None or callable but is {type(other_act).__name__}.")
263+
if sigmoid and softmax:
264+
raise ValueError("Incompatible values: sigmoid=True and softmax=True.")
265+
if other_act is not None and (sigmoid or softmax):
266+
raise ValueError("Incompatible values: other_act is not None and sigmoid=True or softmax=True.")
267+
268+
self.pre_sigmoid = sigmoid
269+
self.pre_softmax = softmax
270+
self.pre_other_act = other_act
271+
272+
super().__init__(
273+
include_background=include_background,
274+
to_onehot_y=to_onehot_y,
250275
sigmoid=False,
251276
softmax=False,
252277
other_act=None,
253-
squared_pred=self.squared_pred,
254-
jaccard=self.jaccard,
255-
reduction=self.reduction,
256-
smooth_nr=self.smooth_nr,
257-
smooth_dr=self.smooth_dr,
258-
batch=self.batch,
259-
weight=self.class_weight,
260-
soft_label=self.soft_label,
278+
squared_pred=squared_pred,
279+
jaccard=jaccard,
280+
reduction=reduction,
281+
smooth_nr=smooth_nr,
282+
smooth_dr=smooth_dr,
283+
batch=batch,
284+
weight=weight,
285+
soft_label=soft_label,
261286
)
262-
self.spatial_weighted = MaskedLoss(loss=self.dice.forward)
287+
288+
self.spatial_weighted = MaskedLoss(loss=super().forward)
263289

264290
def forward(self, input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
265291
"""
@@ -269,19 +295,19 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor
269295
mask: the shape should B1H[WD] or 11H[WD].
270296
"""
271297

272-
if self.sigmoid:
298+
if self.pre_sigmoid:
273299
input = torch.sigmoid(input)
274300

275301
n_pred_ch = input.shape[1]
276-
if self.softmax:
302+
if self.pre_softmax:
277303
if n_pred_ch == 1:
278-
warnings.warn("single channel prediction, `softmax=True` ignored.")
304+
warnings.warn("single channel prediction, `softmax=True` ignored.", stacklevel=2)
279305
else:
280306
input = torch.softmax(input, 1)
281307

282-
if self.other_act is not None:
283-
input = self.other_act(input)
284-
return self.spatial_weighted(input=input, target=target, mask=mask) # type: ignore[no-any-return]
308+
if self.pre_other_act is not None:
309+
input = self.pre_other_act(input)
310+
return self.spatial_weighted(input=input, target=target, mask=mask) # type: ignore[no-any-return]
285311

286312

287313
class GeneralizedDiceLoss(_Loss):

0 commit comments

Comments
 (0)