@@ -239,27 +239,53 @@ class MaskedDiceLoss(DiceLoss):
239239
240240 """
241241
242- def __init__ (self , * args : Any , ** kwargs : Any ) -> None :
242+ def __init__ (
243+ self ,
244+ include_background : bool = True ,
245+ to_onehot_y : bool = False ,
246+ sigmoid : bool = False ,
247+ softmax : bool = False ,
248+ other_act : Callable | None = None ,
249+ squared_pred : bool = False ,
250+ jaccard : bool = False ,
251+ reduction : LossReduction | str = LossReduction .MEAN ,
252+ smooth_nr : float = 1e-5 ,
253+ smooth_dr : float = 1e-5 ,
254+ batch : bool = False ,
255+ weight : Sequence [float ] | float | int | torch .Tensor | None = None ,
256+ soft_label : bool = False ,
257+ ) -> None :
243258 """
244259 Args follow :py:class:`monai.losses.DiceLoss`.
245260 """
246- super ().__init__ (* args , ** kwargs )
247- self .dice = DiceLoss (
248- include_background = self .include_background ,
249- to_onehot_y = self .to_onehot_y ,
261+ if other_act is not None and not callable (other_act ):
262+ raise TypeError (f"other_act must be None or callable but is { type (other_act ).__name__ } ." )
263+ if sigmoid and softmax :
264+ raise ValueError ("Incompatible values: sigmoid=True and softmax=True." )
265+ if other_act is not None and (sigmoid or softmax ):
266+ raise ValueError ("Incompatible values: other_act is not None and sigmoid=True or softmax=True." )
267+
268+ self .pre_sigmoid = sigmoid
269+ self .pre_softmax = softmax
270+ self .pre_other_act = other_act
271+
272+ super ().__init__ (
273+ include_background = include_background ,
274+ to_onehot_y = to_onehot_y ,
250275 sigmoid = False ,
251276 softmax = False ,
252277 other_act = None ,
253- squared_pred = self . squared_pred ,
254- jaccard = self . jaccard ,
255- reduction = self . reduction ,
256- smooth_nr = self . smooth_nr ,
257- smooth_dr = self . smooth_dr ,
258- batch = self . batch ,
259- weight = self . class_weight ,
260- soft_label = self . soft_label ,
278+ squared_pred = squared_pred ,
279+ jaccard = jaccard ,
280+ reduction = reduction ,
281+ smooth_nr = smooth_nr ,
282+ smooth_dr = smooth_dr ,
283+ batch = batch ,
284+ weight = weight ,
285+ soft_label = soft_label ,
261286 )
262- self .spatial_weighted = MaskedLoss (loss = self .dice .forward )
287+
288+ self .spatial_weighted = MaskedLoss (loss = super ().forward )
263289
264290 def forward (self , input : torch .Tensor , target : torch .Tensor , mask : torch .Tensor | None = None ) -> torch .Tensor :
265291 """
@@ -269,19 +295,19 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor
269295 mask: the shape should B1H[WD] or 11H[WD].
270296 """
271297
272- if self .sigmoid :
298+ if self .pre_sigmoid :
273299 input = torch .sigmoid (input )
274300
275301 n_pred_ch = input .shape [1 ]
276- if self .softmax :
302+ if self .pre_softmax :
277303 if n_pred_ch == 1 :
278- warnings .warn ("single channel prediction, `softmax=True` ignored." )
304+ warnings .warn ("single channel prediction, `softmax=True` ignored." , stacklevel = 2 )
279305 else :
280306 input = torch .softmax (input , 1 )
281307
282- if self .other_act is not None :
283- input = self .other_act (input )
284- return self .spatial_weighted (input = input , target = target , mask = mask ) # type: ignore[no-any-return]
308+ if self .pre_other_act is not None :
309+ input = self .pre_other_act (input )
310+ return self .spatial_weighted (input = input , target = target , mask = mask ) # type: ignore[no-any-return]
285311
286312
287313class GeneralizedDiceLoss (_Loss ):
0 commit comments