@@ -69,25 +69,67 @@ def __init__(
6969 include_background : bool = True ,
7070 to_onehot_y : bool = False ,
7171 gamma : float = 2.0 ,
72- alpha : float | None = None ,
72+ alpha : float | Sequence [ float ] | None = None ,
7373 weight : Sequence [float ] | float | int | torch .Tensor | None = None ,
7474 reduction : LossReduction | str = LossReduction .MEAN ,
7575 use_softmax : bool = False ,
7676 ignore_index : int | None = None ,
7777 ) -> None :
7878 """
7979 Args:
80+ <<<<<<< HEAD
8081 # ... (other args)
8182 ignore_index: index of the class to ignore during calculation.
83+ =======
84+ include_background: if False, channel index 0 (background category) is excluded from the loss calculation.
85+ If False, `alpha` is invalid when using softmax unless `alpha` is a sequence (explicit class weights).
86+ to_onehot_y: whether to convert the label `y` into the one-hot format. Defaults to False.
87+ gamma: value of the exponent gamma in the definition of the Focal loss. Defaults to 2.
88+ alpha: value of the alpha in the definition of the alpha-balanced Focal loss.
89+ The value should be in [0, 1].
90+ If a sequence is provided, its length must match the number of classes
91+ (excluding the background class if `include_background=False`).
92+ Defaults to None.
93+ weight: weights to apply to the voxels of each class. If None no weights are applied.
94+ The input can be a single value (same weight for all classes), a sequence of values (the length
95+ of the sequence should be the same as the number of classes. If not ``include_background``,
96+ the number of classes should not include the background category class 0).
97+ The value/values should be no less than 0. Defaults to None.
98+ reduction: {``"none"``, ``"mean"``, ``"sum"``}
99+ Specifies the reduction to apply to the output. Defaults to ``"mean"``.
100+
101+ - ``"none"``: no reduction will be applied.
102+ - ``"mean"``: the sum of the output will be divided by the number of elements in the output.
103+ - ``"sum"``: the output will be summed.
104+
105+ use_softmax: whether to use softmax to transform the original logits into probabilities.
106+ If True, softmax is used. If False, sigmoid is used. Defaults to False.
107+
108+ Example:
109+ >>> import torch
110+ >>> from monai.losses import FocalLoss
111+ >>> pred = torch.tensor([[1, 0], [0, 1], [1, 0]], dtype=torch.float32)
112+ >>> grnd = torch.tensor([[0], [1], [0]], dtype=torch.int64)
113+ >>> fl = FocalLoss(to_onehot_y=True)
114+ >>> fl(pred, grnd)
115+ >>>>>>> 40df2f61 (Weights in alpha for FocalLoss (#8665))
82116 """
83117 super ().__init__ (reduction = LossReduction (reduction ).value )
84118 self .include_background = include_background
85119 self .to_onehot_y = to_onehot_y
86120 self .gamma = gamma
87- self .alpha = alpha
88121 self .weight = weight
89122 self .use_softmax = use_softmax
123+ self .use_softmax = use_softmax
90124 self .ignore_index = ignore_index
125+
126+ self .alpha : float | torch .Tensor | None
127+ if alpha is None :
128+ self .alpha = None
129+ elif isinstance (alpha , (float , int )):
130+ self .alpha = float (alpha )
131+ else :
132+ self .alpha = torch .as_tensor (alpha )
91133 weight = torch .as_tensor (weight ) if weight is not None else None
92134 self .register_buffer ("class_weight" , weight )
93135 self .class_weight : None | torch .Tensor
@@ -125,14 +167,24 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
125167
126168 input = input .float ()
127169 target = target .float ()
128-
129- if self .use_softmax and input . shape [ 1 ] > 1 :
170+ alpha_arg = self . alpha
171+ if self .use_softmax :
130172 if not self .include_background and self .alpha is not None :
131- self .alpha = None
132- warnings .warn ("`include_background=False`, `alpha` ignored when using softmax." )
133- loss = softmax_focal_loss (input , target , self .gamma , self .alpha )
173+ if isinstance (self .alpha , (float , int )):
174+ alpha_arg = None
175+ warnings .warn (
176+ "`include_background=False`, scalar `alpha` ignored when using softmax." , stacklevel = 2
177+ )
178+ loss = softmax_focal_loss (input , target , self .gamma , alpha_arg )
134179 else :
135- loss = sigmoid_focal_loss (input , target , self .gamma , self .alpha )
180+ loss = sigmoid_focal_loss (input , target , self .gamma , alpha_arg )
181+ if not self .include_background and self .alpha is not None :
182+ if isinstance (self .alpha , (float , int )):
183+ alpha_arg = None
184+ warnings .warn (
185+ "`include_background=False`, scalar `alpha` ignored when using softmax." , stacklevel = 2
186+ )
187+ loss = softmax_focal_loss (input , target , self .gamma , alpha_arg )
136188
137189 if mask is not None :
138190 loss = loss * mask
@@ -167,7 +219,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
167219
168220
169221def softmax_focal_loss (
170- input : torch .Tensor , target : torch .Tensor , gamma : float = 2.0 , alpha : float | None = None
222+ input : torch .Tensor , target : torch .Tensor , gamma : float = 2.0 , alpha : float | torch . Tensor | None = None
171223) -> torch .Tensor :
172224 """
173225 FL(pt) = -alpha * (1 - pt)**gamma * log(pt)
@@ -179,8 +231,22 @@ def softmax_focal_loss(
179231 loss : torch .Tensor = - (1 - input_ls .exp ()).pow (gamma ) * input_ls * target
180232
181233 if alpha is not None :
182- # (1-alpha) for the background class and alpha for the other classes
183- alpha_fac = torch .tensor ([1 - alpha ] + [alpha ] * (target .shape [1 ] - 1 )).to (loss )
234+ if isinstance (alpha , torch .Tensor ):
235+ alpha_t = alpha .to (device = input .device , dtype = input .dtype )
236+ else :
237+ alpha_t = torch .tensor (alpha , device = input .device , dtype = input .dtype )
238+
239+ if alpha_t .ndim == 0 : # scalar
240+ alpha_val = alpha_t .item ()
241+ # (1-alpha) for the background class and alpha for the other classes
242+ alpha_fac = torch .tensor ([1 - alpha_val ] + [alpha_val ] * (target .shape [1 ] - 1 )).to (loss )
243+ else : # tensor (sequence)
244+ if alpha_t .shape [0 ] != target .shape [1 ]:
245+ raise ValueError (
246+ f"The length of alpha ({ alpha_t .shape [0 ]} ) must match the number of classes ({ target .shape [1 ]} )."
247+ )
248+ alpha_fac = alpha_t
249+
184250 broadcast_dims = [- 1 ] + [1 ] * len (target .shape [2 :])
185251 alpha_fac = alpha_fac .view (broadcast_dims )
186252 loss = alpha_fac * loss
@@ -189,7 +255,7 @@ def softmax_focal_loss(
189255
190256
191257def sigmoid_focal_loss (
192- input : torch .Tensor , target : torch .Tensor , gamma : float = 2.0 , alpha : float | None = None
258+ input : torch .Tensor , target : torch .Tensor , gamma : float = 2.0 , alpha : float | torch . Tensor | None = None
193259) -> torch .Tensor :
194260 """
195261 FL(pt) = -alpha * (1 - pt)**gamma * log(pt)
@@ -212,8 +278,27 @@ def sigmoid_focal_loss(
212278 loss = (invprobs * gamma ).exp () * loss
213279
214280 if alpha is not None :
215- # alpha if t==1; (1-alpha) if t==0
216- alpha_factor = target * alpha + (1 - target ) * (1 - alpha )
281+ if isinstance (alpha , torch .Tensor ):
282+ alpha_t = alpha .to (device = input .device , dtype = input .dtype )
283+ else :
284+ alpha_t = torch .tensor (alpha , device = input .device , dtype = input .dtype )
285+
286+ if alpha_t .ndim == 0 : # scalar
287+ # alpha if t==1; (1-alpha) if t==0
288+ alpha_factor = target * alpha_t + (1 - target ) * (1 - alpha_t )
289+ else : # tensor (sequence)
290+ if alpha_t .shape [0 ] != target .shape [1 ]:
291+ raise ValueError (
292+ f"The length of alpha ({ alpha_t .shape [0 ]} ) must match the number of classes ({ target .shape [1 ]} )."
293+ )
294+ # Reshape alpha for broadcasting: (1, C, 1, 1...)
295+ broadcast_dims = [- 1 ] + [1 ] * len (target .shape [2 :])
296+ alpha_t = alpha_t .view (broadcast_dims )
297+ # Apply per-class weight only to positive samples
298+ # For positive samples (target==1): multiply by alpha[c]
299+ # For negative samples (target==0): keep weight as 1.0
300+ alpha_factor = torch .where (target == 1 , alpha_t , torch .ones_like (alpha_t ))
301+
217302 loss = alpha_factor * loss
218303
219304 return loss
0 commit comments