1212from __future__ import annotations
1313
1414import unittest
15+ from itertools import product
1516
1617import numpy as np
1718import torch
@@ -205,59 +206,71 @@ def test_consistency_with_cross_entropy_classification_01(self):
205206 self .assertNotAlmostEqual (max_error , 0.0 , places = 3 )
206207
207208 def test_bin_seg_2d (self ):
208- for use_softmax in [True , False ]:
209+ for use_softmax , use_sigmoid in product ( [True , False ], repeat = 2 ) :
209210 # define 2d examples
210211 target = torch .tensor ([[0 , 0 , 0 , 0 ], [0 , 1 , 1 , 0 ], [0 , 1 , 1 , 0 ], [0 , 0 , 0 , 0 ]])
211212 # add another dimension corresponding to the batch (batch size = 1 here)
212213 target = target .unsqueeze (0 ) # shape (1, H, W)
213- pred_very_good = 100 * F .one_hot (target , num_classes = 2 ).permute (0 , 3 , 1 , 2 ).float () - 50.0
214+ if not use_sigmoid and not use_softmax :
215+ # The prediction here are probabilities, not logits.
216+ pred_very_good = F .one_hot (target , num_classes = 2 ).permute (0 , 3 , 1 , 2 ).float ()
217+ else :
218+ pred_very_good = 100 * F .one_hot (target , num_classes = 2 ).permute (0 , 3 , 1 , 2 ).float () - 50.0
214219
215220 # initialize the mean dice loss
216- loss = FocalLoss (to_onehot_y = True , use_softmax = use_softmax )
221+ loss = FocalLoss (to_onehot_y = True , use_softmax = use_softmax , use_sigmoid = use_sigmoid )
217222
218223 # focal loss for pred_very_good should be close to 0
219224 target = target .unsqueeze (1 ) # shape (1, 1, H, W)
220225 focal_loss_good = float (loss (pred_very_good , target ).cpu ())
221226 self .assertAlmostEqual (focal_loss_good , 0.0 , places = 3 )
222227
223228 # with alpha
224- loss = FocalLoss (to_onehot_y = True , alpha = 0.5 , use_softmax = use_softmax )
229+ loss = FocalLoss (to_onehot_y = True , alpha = 0.5 , use_softmax = use_softmax , use_sigmoid = use_sigmoid )
225230 focal_loss_good = float (loss (pred_very_good , target ).cpu ())
226231 self .assertAlmostEqual (focal_loss_good , 0.0 , places = 3 )
227232
228233 def test_empty_class_2d (self ):
229- for use_softmax in [True , False ]:
234+ for use_softmax , use_sigmoid in product ( [True , False ], repeat = 2 ) :
230235 num_classes = 2
231236 # define 2d examples
232237 target = torch .tensor ([[0 , 0 , 0 , 0 ], [0 , 0 , 0 , 0 ], [0 , 0 , 0 , 0 ], [0 , 0 , 0 , 0 ]])
233238 # add another dimension corresponding to the batch (batch size = 1 here)
234239 target = target .unsqueeze (0 ) # shape (1, H, W)
235- pred_very_good = 1000 * F .one_hot (target , num_classes = num_classes ).permute (0 , 3 , 1 , 2 ).float () - 500.0
240+ if not use_sigmoid and not use_softmax :
241+ # The prediction here are probabilities, not logits.
242+ pred_very_good = F .one_hot (target , num_classes = num_classes ).permute (0 , 3 , 1 , 2 ).float ()
243+ else :
244+ pred_very_good = 1000 * F .one_hot (target , num_classes = num_classes ).permute (0 , 3 , 1 , 2 ).float () - 500.0
236245
237246 # initialize the mean dice loss
238- loss = FocalLoss (to_onehot_y = True , use_softmax = use_softmax )
247+ loss = FocalLoss (to_onehot_y = True , use_softmax = use_softmax , use_sigmoid = use_sigmoid )
239248
240249 # focal loss for pred_very_good should be close to 0
241250 target = target .unsqueeze (1 ) # shape (1, 1, H, W)
242251 focal_loss_good = float (loss (pred_very_good , target ).cpu ())
243252 self .assertAlmostEqual (focal_loss_good , 0.0 , places = 3 )
244253
245254 # with alpha
246- loss = FocalLoss (to_onehot_y = True , alpha = 0.5 , use_softmax = use_softmax )
255+ loss = FocalLoss (to_onehot_y = True , alpha = 0.5 , use_softmax = use_softmax , use_sigmoid = use_sigmoid )
247256 focal_loss_good = float (loss (pred_very_good , target ).cpu ())
248257 self .assertAlmostEqual (focal_loss_good , 0.0 , places = 3 )
249258
250259 def test_multi_class_seg_2d (self ):
251- for use_softmax in [True , False ]:
260+ for use_softmax , use_sigmoid in product ( [True , False ], repeat = 2 ) :
252261 num_classes = 6 # labels 0 to 5
253262 # define 2d examples
254263 target = torch .tensor ([[0 , 0 , 0 , 0 ], [0 , 1 , 2 , 0 ], [0 , 3 , 4 , 0 ], [0 , 0 , 0 , 0 ]])
255264 # add another dimension corresponding to the batch (batch size = 1 here)
256265 target = target .unsqueeze (0 ) # shape (1, H, W)
257- pred_very_good = 1000 * F .one_hot (target , num_classes = num_classes ).permute (0 , 3 , 1 , 2 ).float () - 500.0
266+ if not use_sigmoid and not use_softmax :
267+ # The prediction here are probabilities, not logits.
268+ pred_very_good = F .one_hot (target , num_classes = num_classes ).permute (0 , 3 , 1 , 2 ).float ()
269+ else :
270+ pred_very_good = 1000 * F .one_hot (target , num_classes = num_classes ).permute (0 , 3 , 1 , 2 ).float () - 500.0
258271 # initialize the mean dice loss
259- loss = FocalLoss (to_onehot_y = True , use_softmax = use_softmax )
260- loss_onehot = FocalLoss (to_onehot_y = False , use_softmax = use_softmax )
272+ loss = FocalLoss (to_onehot_y = True , use_softmax = use_softmax , use_sigmoid = use_sigmoid )
273+ loss_onehot = FocalLoss (to_onehot_y = False , use_softmax = use_softmax , use_sigmoid = use_sigmoid )
261274
262275 # focal loss for pred_very_good should be close to 0
263276 target_one_hot = F .one_hot (target , num_classes = num_classes ).permute (0 , 3 , 1 , 2 ) # test one hot
@@ -270,15 +283,15 @@ def test_multi_class_seg_2d(self):
270283 self .assertAlmostEqual (focal_loss_good , 0.0 , places = 3 )
271284
272285 # with alpha
273- loss = FocalLoss (to_onehot_y = True , alpha = 0.5 , use_softmax = use_softmax )
286+ loss = FocalLoss (to_onehot_y = True , alpha = 0.5 , use_softmax = use_softmax , use_sigmoid = use_sigmoid )
274287 focal_loss_good = float (loss (pred_very_good , target ).cpu ())
275288 self .assertAlmostEqual (focal_loss_good , 0.0 , places = 3 )
276- loss_onehot = FocalLoss (to_onehot_y = False , alpha = 0.5 , use_softmax = use_softmax )
289+ loss_onehot = FocalLoss (to_onehot_y = False , alpha = 0.5 , use_softmax = use_softmax , use_sigmoid = use_sigmoid )
277290 focal_loss_good = float (loss_onehot (pred_very_good , target_one_hot ).cpu ())
278291 self .assertAlmostEqual (focal_loss_good , 0.0 , places = 3 )
279292
280293 def test_bin_seg_3d (self ):
281- for use_softmax in [True , False ]:
294+ for use_softmax , use_sigmoid in product ( [True , False ], repeat = 2 ) :
282295 num_classes = 2 # labels 0, 1
283296 # define 3d examples
284297 target = torch .tensor (
@@ -294,11 +307,15 @@ def test_bin_seg_3d(self):
294307 # add another dimension corresponding to the batch (batch size = 1 here)
295308 target = target .unsqueeze (0 ) # shape (1, H, W, D)
296309 target_one_hot = F .one_hot (target , num_classes = num_classes ).permute (0 , 4 , 1 , 2 , 3 ) # test one hot
297- pred_very_good = 1000 * F .one_hot (target , num_classes = num_classes ).permute (0 , 4 , 1 , 2 , 3 ).float () - 500.0
310+ if not use_sigmoid and not use_softmax :
311+ # The prediction here are probabilities, not logits.
312+ pred_very_good = target_one_hot .clone ().float ()
313+ else :
314+ pred_very_good = 1000 * F .one_hot (target , num_classes = num_classes ).permute (0 , 4 , 1 , 2 , 3 ).float () - 500.0
298315
299316 # initialize the mean dice loss
300- loss = FocalLoss (to_onehot_y = True , use_softmax = use_softmax )
301- loss_onehot = FocalLoss (to_onehot_y = False , use_softmax = use_softmax )
317+ loss = FocalLoss (to_onehot_y = True , use_softmax = use_softmax , use_sigmoid = use_sigmoid )
318+ loss_onehot = FocalLoss (to_onehot_y = False , use_softmax = use_softmax , use_sigmoid = use_sigmoid )
302319
303320 # focal loss for pred_very_good should be close to 0
304321 target = target .unsqueeze (1 ) # shape (1, 1, H, W)
@@ -309,10 +326,10 @@ def test_bin_seg_3d(self):
309326 self .assertAlmostEqual (focal_loss_good , 0.0 , places = 3 )
310327
311328 # with alpha
312- loss = FocalLoss (to_onehot_y = True , alpha = 0.5 , use_softmax = use_softmax )
329+ loss = FocalLoss (to_onehot_y = True , alpha = 0.5 , use_softmax = use_softmax , use_sigmoid = use_sigmoid )
313330 focal_loss_good = float (loss (pred_very_good , target ).cpu ())
314331 self .assertAlmostEqual (focal_loss_good , 0.0 , places = 3 )
315- loss_onehot = FocalLoss (to_onehot_y = False , alpha = 0.5 , use_softmax = use_softmax )
332+ loss_onehot = FocalLoss (to_onehot_y = False , alpha = 0.5 , use_softmax = use_softmax , use_sigmoid = use_sigmoid )
316333 focal_loss_good = float (loss_onehot (pred_very_good , target_one_hot ).cpu ())
317334 self .assertAlmostEqual (focal_loss_good , 0.0 , places = 3 )
318335
@@ -369,8 +386,8 @@ def test_warnings(self):
369386 loss (chn_input , chn_target )
370387
371388 def test_script (self ):
372- for use_softmax in [True , False ]:
373- loss = FocalLoss (use_softmax = use_softmax )
389+ for use_softmax , use_sigmoid in product ( [True , False ], repeat = 2 ) :
390+ loss = FocalLoss (use_softmax = use_softmax , use_sigmoid = use_sigmoid )
374391 test_input = torch .ones (2 , 2 , 8 , 8 )
375392 test_script_save (loss , test_input , test_input )
376393
0 commit comments