Skip to content

Commit f2caaf8

Browse files
committed
feat: implement ignore_index support for losses and metrics
Signed-off-by: Rusheel Sharma <rusheelhere@gmail.com>
1 parent a1b0a4f commit f2caaf8

File tree

5 files changed

+43
-11
lines changed

5 files changed

+43
-11
lines changed

monai/losses/unified_focal_loss.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -262,12 +262,22 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
262262

263263
# Check if shapes match
264264
if y_true.shape[1] == 1 and y_pred.shape[1] == 2:
265-
y_true = torch.cat([1 - y_true, y_true], dim=1)
265+
if self.ignore_index is not None:
266+
# Create mask for valid pixels
267+
mask = (y_true != self.ignore_index).float()
268+
# Set ignore_index values to 0 before conversion
269+
y_true_clean = y_true * mask
270+
# Convert to 2-channel
271+
y_true = torch.cat([1 - y_true_clean, y_true_clean], dim=1)
272+
# Apply mask to both channels so ignored pixels are all zeros
273+
y_true = y_true * mask
274+
else:
275+
y_true = torch.cat([1 - y_true, y_true], dim=1)
276+
266277
if y_true.shape != y_pred.shape:
267278
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")
268-
269-
if torch.max(y_true) != self.num_classes - 1:
270-
raise ValueError(f"Please make sure the number of classes is {self.num_classes - 1}")
279+
if self.ignore_index is None and torch.max(y_true) > self.num_classes - 1:
280+
raise ValueError(f"Invalid class index found. Maximum class should be {self.num_classes - 1}")
271281

272282
asy_focal_loss = self.asy_focal_loss(y_pred, y_true)
273283
asy_focal_tversky_loss = self.asy_focal_tversky_loss(y_pred, y_true)

monai/metrics/generalized_dice.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,12 @@ def compute_generalized_dice(
154154
if y.shape != y_pred.shape:
155155
raise ValueError(f"y_pred - {y_pred.shape} - and y - {y.shape} - should have the same shapes.")
156156

157+
# Apply ignore_index masking
158+
if ignore_index is not None:
159+
mask = (y != ignore_index).all(dim=1, keepdim=True).float()
160+
y_pred = y_pred * mask
161+
y = y * mask
162+
157163
n_channels = y_pred.shape[1]
158164
channels_to_use = list(range(n_channels))
159165

@@ -188,17 +194,17 @@ def compute_generalized_dice(
188194
else:
189195
w_full = torch.ones_like(y_o_float)
190196

191-
w = w_full[:, channels_to_use]
192-
193197
# Replace infinite values for non-appearing classes by the maximum weight
194-
for b_idx in range(w.shape[0]):
195-
batch_w = w[b_idx]
198+
for b_idx in range(w_full.shape[0]):
199+
batch_w = w_full[b_idx]
196200
infs = torch.isinf(batch_w)
197201
if infs.any():
198202
batch_w[infs] = 0
199203
max_w = torch.max(batch_w)
200204
batch_w[infs] = max_w if max_w > 0 else 1.0
201205

206+
w = w_full[:, channels_to_use]
207+
202208
if sum_over_classes:
203209
intersection = (intersection * w).sum(dim=1, keepdim=True)
204210
denominator = (denominator * w).sum(dim=1, keepdim=True)

monai/metrics/surface_dice.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,16 @@ def compute_surface_dice(
277277
distances_gt_pred <= class_thresholds[c]
278278
)
279279
else:
280-
areas_pred, areas_gt = areas # type: ignore
280+
# Handle areas being returned as a single item or a tuple
281+
if isinstance(areas, (list, tuple)):
282+
if len(areas) == 2:
283+
areas_pred, areas_gt = areas
284+
elif len(areas) == 1:
285+
areas_pred = areas_gt = areas[0]
286+
else:
287+
areas_pred = areas_gt = torch.tensor([], device=y_pred.device)
288+
else:
289+
areas_pred = areas_gt = areas
281290
areas_gt, areas_pred = areas_gt[edges_gt], areas_pred[edges_pred]
282291
boundary_complete = areas_gt.sum() + areas_pred.sum()
283292
gt_true = areas_gt[distances_gt_pred <= class_thresholds[c]].sum() if len(areas_gt) > 0 else 0.0

monai/metrics/utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,15 @@ def get_edge_surface_distance(
374374

375375
distances = tuple(d if d is not None else edges_pred.new_empty((0,)) for d in distances)
376376

377-
areas = edge_results[3:] if use_subvoxels else ()
377+
areas = edge_results[2:] if use_subvoxels else ()
378+
379+
# Ensure areas is always a tuple of 2 when use_subvoxels=True
380+
if use_subvoxels and isinstance(areas, (list, tuple)):
381+
if len(areas) == 1:
382+
areas = (areas[0], areas[0])
383+
elif len(areas) != 2:
384+
# Unexpected length, create empty tensors
385+
areas = (torch.tensor([], device=y_pred.device), torch.tensor([], device=y_pred.device))
378386

379387
out = convert_to_tensor(((edges_pred, edges_gt), distances, tuple(areas)), device=y_pred.device) # type: ignore[no-any-return]
380388

tests/metrics/test_ignore_index_metrics.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@
4747

4848
@unittest.skipUnless(has_scipy, "Scipy required for surface metrics")
4949
class TestIgnoreIndexMetrics(unittest.TestCase):
50-
5150
@parameterized.expand(TEST_METRICS + SCIPY_METRICS)
5251
def test_metric_ignore_consistency(self, metric_class, kwargs):
5352
# Initialize metric with ignore_index

0 commit comments

Comments
 (0)