|
20 | 20 | from monai.utils import LossReduction |
21 | 21 |
|
22 | 22 |
|
23 | | -class AsymmetricFocalTverskyLoss(_Loss): |
| 23 | +class AsymmetricUnifiedFocalLoss(_Loss): |
24 | 24 | """ |
25 | | - AsymmetricFocalTverskyLoss is a variant of FocalTverskyLoss, which attentions to the foreground class. |
| 25 | + AsymmetricUnifiedFocalLoss is a variant of Focal Loss that combines Asymmetric Focal Loss |
| 26 | + and Asymmetric Focal Tversky Loss to handle imbalanced medical image segmentation. |
26 | 27 |
|
27 | | - Actually, it's only supported for binary image segmentation now. |
| 28 | + It supports multi-class segmentation by treating channel 0 as background and |
| 29 | + channels 1..N as foreground, applying asymmetric weighting controlled by `delta`. |
28 | 30 |
|
29 | | - Reimplementation of the Asymmetric Focal Tversky Loss described in: |
| 31 | + Reimplementation of the Asymmetric Unified Focal Loss described in: |
30 | 32 |
|
31 | 33 | - "Unified Focal Loss: Generalising Dice and Cross Entropy-based Losses to Handle Class Imbalanced Medical Image Segmentation", |
32 | 34 | Michael Yeung, Computerized Medical Imaging and Graphics |
| 35 | + |
| 36 | + Example: |
| 37 | + >>> import torch |
| 38 | + >>> from monai.losses import AsymmetricUnifiedFocalLoss |
| 39 | + >>> # B, C, H, W = 1, 3, 32, 32 |
| 40 | + >>> pred_logits = torch.randn(1, 3, 32, 32) |
| 41 | + >>> # Ground truth indices (B, 1, H, W) |
| 42 | + >>> grnd = torch.randint(0, 3, (1, 1, 32, 32)) |
| 43 | + >>> # Use softmax=True if input is logits |
| 44 | + >>> loss_func = AsymmetricUnifiedFocalLoss(to_onehot_y=True, use_softmax=True) |
| 45 | + >>> loss = loss_func(pred_logits, grnd) |
33 | 46 | """ |
34 | 47 |
|
35 | 48 | def __init__( |
36 | 49 | self, |
| 50 | + weight: float = 0.5, |
| 51 | + delta: float = 0.6, |
| 52 | + gamma: float = 0.5, |
| 53 | + include_background: bool = True, |
37 | 54 | to_onehot_y: bool = False, |
38 | | - delta: float = 0.7, |
39 | | - gamma: float = 0.75, |
40 | | - epsilon: float = 1e-7, |
41 | 55 | reduction: LossReduction | str = LossReduction.MEAN, |
| 56 | + use_softmax: bool = False, |
| 57 | + epsilon: float = 1e-7, |
42 | 58 | ) -> None: |
43 | 59 | """ |
44 | 60 | Args: |
45 | | - to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False. |
46 | | - delta : weight of the background. Defaults to 0.7. |
47 | | - gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75. |
48 | | - epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7. |
| 61 | + weight: The weighting factor between Asymmetric Focal Loss and Asymmetric Focal Tversky Loss. |
| 62 | + Final Loss = weight * AFL + (1 - weight) * AFTL. Defaults to 0.5. |
| 63 | + delta: The balancing factor controls the weight of background vs foreground classes. |
| 64 | + Values > 0.5 give more weight to foreground (False Negatives). Defaults to 0.6. |
| 65 | + gamma: The focal exponent. Higher values focus more on hard examples. Defaults to 0.5. |
| 66 | + include_background: If False, channel index 0 (background category) is excluded from the loss calculation. |
| 67 | + Defaults to True. |
| 68 | + to_onehot_y: Whether to convert the label `target` into the one-hot format. Defaults to False. |
| 69 | + reduction: {``"none"``, ``"mean"``, ``"sum"``} |
| 70 | + Specifies the reduction to apply to the output. Defaults to ``"mean"``. |
| 71 | + use_softmax: Whether to use softmax to transform the original logits into probabilities. |
| 72 | + If True, softmax is used. If False, assumes input is already probabilities. Defaults to False. |
| 73 | + epsilon: Small value to prevent division by zero or log(0). Defaults to 1e-7. |
49 | 74 | """ |
50 | 75 | super().__init__(reduction=LossReduction(reduction).value) |
51 | | - self.to_onehot_y = to_onehot_y |
| 76 | + self.weight = weight |
52 | 77 | self.delta = delta |
53 | 78 | self.gamma = gamma |
| 79 | + self.include_background = include_background |
| 80 | + self.to_onehot_y = to_onehot_y |
| 81 | + self.use_softmax = use_softmax |
54 | 82 | self.epsilon = epsilon |
55 | 83 |
|
56 | | - def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: |
57 | | - n_pred_ch = y_pred.shape[1] |
58 | | - |
59 | | - if self.to_onehot_y: |
60 | | - if n_pred_ch == 1: |
61 | | - warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") |
62 | | - else: |
63 | | - y_true = one_hot(y_true, num_classes=n_pred_ch) |
64 | | - |
65 | | - if y_true.shape != y_pred.shape: |
66 | | - raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})") |
67 | | - |
68 | | - # clip the prediction to avoid NaN |
69 | | - y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon) |
70 | | - axis = list(range(2, len(y_pred.shape))) |
71 | | - |
72 | | - # Calculate true positives (tp), false negatives (fn) and false positives (fp) |
73 | | - tp = torch.sum(y_true * y_pred, dim=axis) |
74 | | - fn = torch.sum(y_true * (1 - y_pred), dim=axis) |
75 | | - fp = torch.sum((1 - y_true) * y_pred, dim=axis) |
76 | | - dice_class = (tp + self.epsilon) / (tp + self.delta * fn + (1 - self.delta) * fp + self.epsilon) |
77 | | - |
78 | | - # Calculate losses separately for each class, enhancing both classes |
79 | | - back_dice = 1 - dice_class[:, 0] |
80 | | - fore_dice = (1 - dice_class[:, 1]) * torch.pow(1 - dice_class[:, 1], -self.gamma) |
81 | | - |
82 | | - # Average class scores |
83 | | - loss = torch.mean(torch.stack([back_dice, fore_dice], dim=-1)) |
84 | | - return loss |
85 | | - |
86 | | - |
87 | | -class AsymmetricFocalLoss(_Loss): |
88 | | - """ |
89 | | - AsymmetricFocalLoss is a variant of FocalTverskyLoss, which attentions to the foreground class. |
90 | | -
|
91 | | - Actually, it's only supported for binary image segmentation now. |
92 | | -
|
93 | | - Reimplementation of the Asymmetric Focal Loss described in: |
94 | | -
|
95 | | - - "Unified Focal Loss: Generalising Dice and Cross Entropy-based Losses to Handle Class Imbalanced Medical Image Segmentation", |
96 | | - Michael Yeung, Computerized Medical Imaging and Graphics |
97 | | - """ |
98 | | - |
99 | | - def __init__( |
100 | | - self, |
101 | | - to_onehot_y: bool = False, |
102 | | - delta: float = 0.7, |
103 | | - gamma: float = 2, |
104 | | - epsilon: float = 1e-7, |
105 | | - reduction: LossReduction | str = LossReduction.MEAN, |
106 | | - ): |
| 84 | + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: |
107 | 85 | """ |
108 | 86 | Args: |
109 | | - to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False. |
110 | | - delta : weight of the background. Defaults to 0.7. |
111 | | - gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75. |
112 | | - epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7. |
| 87 | + input: the shape should be BNH[WD], where N is the number of classes. |
| 88 | + target: the shape should be BNH[WD] or B1H[WD]. |
| 89 | +
|
| 90 | + Raises: |
| 91 | + ValueError: When input and target have incompatible shapes. |
113 | 92 | """ |
114 | | - super().__init__(reduction=LossReduction(reduction).value) |
115 | | - self.to_onehot_y = to_onehot_y |
116 | | - self.delta = delta |
117 | | - self.gamma = gamma |
118 | | - self.epsilon = epsilon |
| 93 | + if self.use_softmax: |
| 94 | + input = torch.nn.functional.softmax(input, dim=1) |
119 | 95 |
|
120 | | - def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: |
121 | | - n_pred_ch = y_pred.shape[1] |
| 96 | + n_pred_ch = input.shape[1] |
122 | 97 |
|
123 | 98 | if self.to_onehot_y: |
124 | 99 | if n_pred_ch == 1: |
125 | 100 | warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") |
126 | 101 | else: |
127 | | - y_true = one_hot(y_true, num_classes=n_pred_ch) |
128 | | - |
129 | | - if y_true.shape != y_pred.shape: |
130 | | - raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})") |
131 | | - |
132 | | - y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon) |
133 | | - cross_entropy = -y_true * torch.log(y_pred) |
| 102 | + if target.shape[1] == 1: |
| 103 | + target = one_hot(target, num_classes=n_pred_ch) |
134 | 104 |
|
135 | | - back_ce = torch.pow(1 - y_pred[:, 0], self.gamma) * cross_entropy[:, 0] |
136 | | - back_ce = (1 - self.delta) * back_ce |
| 105 | + if target.shape != input.shape: |
| 106 | + raise ValueError(f"ground truth has different shape ({target.shape}) from input ({input.shape})") |
137 | 107 |
|
138 | | - fore_ce = cross_entropy[:, 1] |
139 | | - fore_ce = self.delta * fore_ce |
140 | | - |
141 | | - loss = torch.mean(torch.sum(torch.stack([back_ce, fore_ce], dim=1), dim=1)) |
142 | | - return loss |
143 | | - |
144 | | - |
145 | | -class AsymmetricUnifiedFocalLoss(_Loss): |
146 | | - """ |
147 | | - AsymmetricUnifiedFocalLoss is a variant of Focal Loss. |
| 108 | + # Clip values for numerical stability |
| 109 | + input = torch.clamp(input, self.epsilon, 1.0 - self.epsilon) |
148 | 110 |
|
149 | | - Actually, it's only supported for binary image segmentation now |
| 111 | + # Part A: Asymmetric Focal Loss |
| 112 | + # Cross Entropy: -target * log(input) |
| 113 | + cross_entropy = -target * torch.log(input) |
150 | 114 |
|
151 | | - Reimplementation of the Asymmetric Unified Focal Tversky Loss described in: |
| 115 | + # Background (Channel 0): (1 - delta) * (1 - p)^gamma * CE |
| 116 | + back_ce = (1 - self.delta) * torch.pow(1 - input[:, 0:1], self.gamma) * cross_entropy[:, 0:1] |
152 | 117 |
|
153 | | - - "Unified Focal Loss: Generalising Dice and Cross Entropy-based Losses to Handle Class Imbalanced Medical Image Segmentation", |
154 | | - Michael Yeung, Computerized Medical Imaging and Graphics |
155 | | - """ |
| 118 | + # Foreground (Channel 1..N): delta * CE |
| 119 | + fore_ce = self.delta * cross_entropy[:, 1:] |
156 | 120 |
|
157 | | - def __init__( |
158 | | - self, |
159 | | - to_onehot_y: bool = False, |
160 | | - num_classes: int = 2, |
161 | | - weight: float = 0.5, |
162 | | - gamma: float = 0.5, |
163 | | - delta: float = 0.7, |
164 | | - reduction: LossReduction | str = LossReduction.MEAN, |
165 | | - ): |
166 | | - """ |
167 | | - Args: |
168 | | - to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False. |
169 | | - num_classes : number of classes, it only supports 2 now. Defaults to 2. |
170 | | - delta : weight of the background. Defaults to 0.7. |
171 | | - gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75. |
172 | | - epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7. |
173 | | - weight : weight for each loss function, if it's none it's 0.5. Defaults to None. |
174 | | -
|
175 | | - Example: |
176 | | - >>> import torch |
177 | | - >>> from monai.losses import AsymmetricUnifiedFocalLoss |
178 | | - >>> pred = torch.ones((1,1,32,32), dtype=torch.float32) |
179 | | - >>> grnd = torch.ones((1,1,32,32), dtype=torch.int64) |
180 | | - >>> fl = AsymmetricUnifiedFocalLoss(to_onehot_y=True) |
181 | | - >>> fl(pred, grnd) |
182 | | - """ |
183 | | - super().__init__(reduction=LossReduction(reduction).value) |
184 | | - self.to_onehot_y = to_onehot_y |
185 | | - self.num_classes = num_classes |
186 | | - self.gamma = gamma |
187 | | - self.delta = delta |
188 | | - self.weight: float = weight |
189 | | - self.asy_focal_loss = AsymmetricFocalLoss(gamma=self.gamma, delta=self.delta) |
190 | | - self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(gamma=self.gamma, delta=self.delta) |
| 121 | + # Combine |
| 122 | + if self.include_background: |
| 123 | + asy_focal_loss = torch.cat([back_ce, fore_ce], dim=1) |
| 124 | + else: |
| 125 | + asy_focal_loss = fore_ce |
191 | 126 |
|
192 | | - # TODO: Implement this function to support multiple classes segmentation |
193 | | - def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: |
194 | | - """ |
195 | | - Args: |
196 | | - y_pred : the shape should be BNH[WD], where N is the number of classes. |
197 | | - It only supports binary segmentation. |
198 | | - The input should be the original logits since it will be transformed by |
199 | | - a sigmoid in the forward function. |
200 | | - y_true : the shape should be BNH[WD], where N is the number of classes. |
201 | | - It only supports binary segmentation. |
| 127 | + # Part B: Asymmetric Focal Tversky Loss |
| 128 | + # Sum over spatial dimensions (Batch and Channel dims are preserved) |
| 129 | + reduce_axis = list(range(2, input.dim())) |
202 | 130 |
|
203 | | - Raises: |
204 | | - ValueError: When input and target are different shape |
205 | | - ValueError: When len(y_pred.shape) != 4 and len(y_pred.shape) != 5 |
206 | | - ValueError: When num_classes |
207 | | - ValueError: When the number of classes entered does not match the expected number |
208 | | - """ |
209 | | - if y_pred.shape != y_true.shape: |
210 | | - raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})") |
| 131 | + tp = torch.sum(target * input, dim=reduce_axis) |
| 132 | + fn = torch.sum(target * (1 - input), dim=reduce_axis) |
| 133 | + fp = torch.sum((1 - target) * input, dim=reduce_axis) |
211 | 134 |
|
212 | | - if len(y_pred.shape) != 4 and len(y_pred.shape) != 5: |
213 | | - raise ValueError(f"input shape must be 4 or 5, but got {y_pred.shape}") |
| 135 | + # Tversky Index |
| 136 | + dice_class = (tp + self.epsilon) / (tp + self.delta * fn + (1 - self.delta) * fp + self.epsilon) |
214 | 137 |
|
215 | | - if y_pred.shape[1] == 1: |
216 | | - y_pred = one_hot(y_pred, num_classes=self.num_classes) |
217 | | - y_true = one_hot(y_true, num_classes=self.num_classes) |
| 138 | + # Background: 1 - Dice |
| 139 | + back_dice_loss = 1 - dice_class[:, 0:1] |
218 | 140 |
|
219 | | - if torch.max(y_true) != self.num_classes - 1: |
220 | | - raise ValueError(f"Please make sure the number of classes is {self.num_classes - 1}") |
| 141 | + # Foreground: (1 - Dice)^(1 - gamma) |
| 142 | + fore_dice_loss = (1 - dice_class[:, 1:]) * torch.pow(1 - dice_class[:, 1:], -self.gamma) |
221 | 143 |
|
222 | | - n_pred_ch = y_pred.shape[1] |
223 | | - if self.to_onehot_y: |
224 | | - if n_pred_ch == 1: |
225 | | - warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") |
226 | | - else: |
227 | | - y_true = one_hot(y_true, num_classes=n_pred_ch) |
| 144 | + # Combine |
| 145 | + if self.include_background: |
| 146 | + asy_focal_tversky_loss = torch.cat([back_dice_loss, fore_dice_loss], dim=1) |
| 147 | + else: |
| 148 | + asy_focal_tversky_loss = fore_dice_loss |
228 | 149 |
|
229 | | - asy_focal_loss = self.asy_focal_loss(y_pred, y_true) |
230 | | - asy_focal_tversky_loss = self.asy_focal_tversky_loss(y_pred, y_true) |
| 150 | + # Part C: Unified Combination & Reduction |
| 151 | + # Aggregate Focal Loss spatial dimensions to match Tversky Loss shape (B, C) |
| 152 | + if asy_focal_loss.dim() > 2: |
| 153 | + asy_focal_loss = torch.mean(asy_focal_loss, dim=reduce_axis) |
231 | 154 |
|
232 | | - loss: torch.Tensor = self.weight * asy_focal_loss + (1 - self.weight) * asy_focal_tversky_loss |
| 155 | + # Weighted sum |
| 156 | + total_loss = self.weight * asy_focal_loss + (1 - self.weight) * asy_focal_tversky_loss |
233 | 157 |
|
234 | 158 | if self.reduction == LossReduction.SUM.value: |
235 | | - return torch.sum(loss) # sum over the batch and channel dims |
| 159 | + return torch.sum(total_loss) |
236 | 160 | if self.reduction == LossReduction.NONE.value: |
237 | | - return loss # returns [N, num_classes] losses |
| 161 | + return total_loss |
238 | 162 | if self.reduction == LossReduction.MEAN.value: |
239 | | - return torch.mean(loss) |
| 163 | + return torch.mean(total_loss) |
| 164 | + |
240 | 165 | raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') |
0 commit comments