Skip to content

Commit 7d8e0b7

Browse files
author
sewon.jeon
committed
Fix Invertd confusion with postprocessing transforms
- Modified Invertd to only invert preprocessing transforms - When orig_key == key, limits transform_info to preprocessing transforms only - Counts invertible transforms in preprocessing and uses only that many from applied_operations - Fixes issue #8396 where Lambdad before Invertd caused errors - Added test case to verify the fix The issue occurred because Invertd would try to invert all transforms in applied_operations, including those from postprocessing. When Lambdad was applied in postprocessing before Invertd, it would be at the top of the stack, causing ID mismatch errors when preprocessing transforms tried to pop themselves. Fix Invertd confusion with postprocessing transforms - Modified Invertd to only invert preprocessing transforms - When orig_key == key, limits transform_info to preprocessing transforms only - Counts invertible transforms in preprocessing and uses only that many from applied_operations - Fixes issue #8396 where Lambdad before Invertd caused errors - Added test case to verify the fix The issue occurred because Invertd would try to invert all transforms in applied_operations, including those from postprocessing. When Lambdad was applied in postprocessing before Invertd, it would be at the top of the stack, causing ID mismatch errors when preprocessing transforms tried to pop themselves. Signed-off-by: sewon.jeon <sewon.jeon@connecteve.com>
1 parent b92b2ce commit 7d8e0b7

File tree

2 files changed

+69
-2
lines changed

2 files changed

+69
-2
lines changed

monai/transforms/post/dictionary.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -687,9 +687,33 @@ def __call__(self, data: Mapping[Hashable, Any]) -> dict[Hashable, Any]:
687687

688688
orig_meta_key = orig_meta_key or f"{orig_key}_{meta_key_postfix}"
689689
if orig_key in d and isinstance(d[orig_key], MetaTensor):
690-
transform_info = d[orig_key].applied_operations
690+
all_transforms = d[orig_key].applied_operations
691691
meta_info = d[orig_key].meta
692-
else:
692+
693+
# If orig_key == key, the data at d[orig_key] may have been modified by
694+
# postprocessing transforms. We need to exclude any transforms that were
695+
# added after the preprocessing pipeline completed.
696+
# When orig_key == key, filter out postprocessing transforms to prevent
697+
# confusion during inversion (see issue #8396)
698+
if orig_key == key:
699+
num_preproc_transforms = 0
700+
try:
701+
if hasattr(self.transform, 'transforms'):
702+
for t in self.transform.flatten().transforms:
703+
if isinstance(t, InvertibleTransform):
704+
num_preproc_transforms += 1
705+
elif isinstance(self.transform, InvertibleTransform):
706+
num_preproc_transforms = 1
707+
except AttributeError:
708+
# Fallback: use all transforms if flatten fails
709+
num_preproc_transforms = len(all_transforms)
710+
711+
if num_preproc_transforms > 0:
712+
transform_info = all_transforms[:num_preproc_transforms]
713+
else:
714+
transform_info = all_transforms
715+
else:
716+
transform_info = all_transforms
693717
transform_info = d[InvertibleTransform.trace_key(orig_key)]
694718
meta_info = d.get(orig_meta_key, {})
695719
if nearest_interp:

tests/transforms/inverse/test_invertd.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,49 @@ def test_invert(self):
137137

138138
set_determinism(seed=None)
139139

140+
def test_invert_with_postproc_lambdad(self):
141+
"""Test that Invertd works correctly when postprocessing contains invertible transforms like Lambdad."""
142+
set_determinism(seed=0)
143+
144+
# Create test images
145+
im_fname, seg_fname = (make_nifti_image(i) for i in create_test_image_3d(101, 100, 107, noise_max=100))
146+
147+
# Define preprocessing transforms
148+
preproc = Compose([
149+
LoadImaged(KEYS, image_only=True),
150+
EnsureChannelFirstd(KEYS),
151+
Spacingd(KEYS, pixdim=(1.2, 1.01, 0.9), mode=["bilinear", "nearest"], dtype=np.float32),
152+
ScaleIntensityd("image", minv=1, maxv=10),
153+
ResizeWithPadOrCropd(KEYS, 100),
154+
])
155+
156+
# Define postprocessing with Lambdad before Invertd (the problematic case)
157+
from monai.transforms import Lambdad
158+
postproc = Compose([
159+
# This Lambdad should not interfere with Invertd
160+
Lambdad(["pred"], lambda x: x), # Identity transform
161+
# Invertd should only invert the preprocessing transforms
162+
Invertd(["pred"], preproc, orig_keys=["image"], nearest_interp=True),
163+
])
164+
165+
# Apply preprocessing
166+
data = {"image": im_fname, "label": seg_fname}
167+
preprocessed = preproc(data)
168+
169+
# Create prediction (copy from preprocessed image)
170+
preprocessed["pred"] = preprocessed["image"].clone()
171+
172+
# Apply postprocessing with Lambdad before Invertd
173+
# This should work without errors - the main issue was that it would fail
174+
result = postproc(preprocessed)
175+
# Check that the inversion was successful
176+
self.assertIn("pred", result)
177+
# Check that the shape was correctly inverted
178+
self.assertTupleEqual(result["pred"].shape[1:], (101, 100, 107))
179+
# The fact that we got here without an exception means the fix is working
180+
181+
set_determinism(seed=None)
182+
140183

141184
if __name__ == "__main__":
142185
unittest.main()

0 commit comments

Comments
 (0)