Skip to content

Commit 8e34cc0

Browse files
pre-commit-ci[bot]sewon.jeon
authored andcommitted
Fix Invertd confusion with postprocessing transforms
- Modified Invertd to only invert preprocessing transforms - When orig_key == key, limits transform_info to preprocessing transforms only - Counts invertible transforms in preprocessing and uses only that many from applied_operations - Fixes issue #8396 where Lambdad before Invertd caused errors - Added test case to verify the fix The issue occurred because Invertd would try to invert all transforms in applied_operations, including those from postprocessing. When Lambdad was applied in postprocessing before Invertd, it would be at the top of the stack, causing ID mismatch errors when preprocessing transforms tried to pop themselves. DCO Remediation Commit for sewon.jeon <sewon.jeon@connecteve.com> I, sewon.jeon <sewon.jeon@connecteve.com>, hereby add my Signed-off-by to this commit: 342ff38 Signed-off-by: sewon.jeon <sewon.jeon@connecteve.com>
1 parent 342ff38 commit 8e34cc0

File tree

2 files changed

+9
-9
lines changed

2 files changed

+9
-9
lines changed

monai/transforms/post/dictionary.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -689,7 +689,7 @@ def __call__(self, data: Mapping[Hashable, Any]) -> dict[Hashable, Any]:
689689
if orig_key in d and isinstance(d[orig_key], MetaTensor):
690690
all_transforms = d[orig_key].applied_operations
691691
meta_info = d[orig_key].meta
692-
692+
693693
# If orig_key == key, the data at d[orig_key] may have been modified by
694694
# postprocessing transforms. We need to exclude any transforms that were
695695
# added after the preprocessing pipeline completed.
@@ -704,7 +704,7 @@ def __call__(self, data: Mapping[Hashable, Any]) -> dict[Hashable, Any]:
704704
num_preproc_transforms += 1
705705
elif isinstance(self.transform, InvertibleTransform):
706706
num_preproc_transforms = 1
707-
707+
708708
# Use only the first N transforms from applied_operations,
709709
# where N is the number of preprocessing transforms
710710
# This excludes any postprocessing transforms like Lambdad

tests/transforms/inverse/test_invertd.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -140,10 +140,10 @@ def test_invert(self):
140140
def test_invert_with_postproc_lambdad(self):
141141
"""Test that Invertd works correctly when postprocessing contains invertible transforms like Lambdad."""
142142
set_determinism(seed=0)
143-
143+
144144
# Create test images
145145
im_fname, seg_fname = (make_nifti_image(i) for i in create_test_image_3d(101, 100, 107, noise_max=100))
146-
146+
147147
# Define preprocessing transforms
148148
preproc = Compose([
149149
LoadImaged(KEYS, image_only=True),
@@ -152,7 +152,7 @@ def test_invert_with_postproc_lambdad(self):
152152
ScaleIntensityd("image", minv=1, maxv=10),
153153
ResizeWithPadOrCropd(KEYS, 100),
154154
])
155-
155+
156156
# Define postprocessing with Lambdad before Invertd (the problematic case)
157157
from monai.transforms import Lambdad
158158
postproc = Compose([
@@ -161,14 +161,14 @@ def test_invert_with_postproc_lambdad(self):
161161
# Invertd should only invert the preprocessing transforms
162162
Invertd(["pred"], preproc, orig_keys=["image"], nearest_interp=True),
163163
])
164-
164+
165165
# Apply preprocessing
166166
data = {"image": im_fname, "label": seg_fname}
167167
preprocessed = preproc(data)
168-
168+
169169
# Create prediction (copy from preprocessed image)
170170
preprocessed["pred"] = preprocessed["image"].clone()
171-
171+
172172
# Apply postprocessing with Lambdad before Invertd
173173
# This should work without errors - the main issue was that it would fail
174174
result = postproc(preprocessed)
@@ -177,7 +177,7 @@ def test_invert_with_postproc_lambdad(self):
177177
# Check that the shape was correctly inverted
178178
self.assertTupleEqual(result["pred"].shape[1:], (101, 100, 107))
179179
# The fact that we got here without an exception means the fix is working
180-
180+
181181
set_determinism(seed=None)
182182

183183

0 commit comments

Comments
 (0)