Skip to content

Commit ab5c676

Browse files
committed
Assign group IDs to wrapped transforms in Compose
Extend Compose group tagging to discover traceable transforms wrapped in object attributes and container fields, including common patterns like transform/transforms wrappers. Preserve nested Compose boundaries by setting group IDs only on the nested Compose instance itself, without traversing its internal pipeline. Add a regression test covering wrapped transform attributes and nested container traversal. Signed-off-by: sewon jeon <irocks0922@gmail.com>
1 parent 821bc01 commit ab5c676

File tree

2 files changed

+76
-11
lines changed

2 files changed

+76
-11
lines changed

monai/transforms/compose.py

Lines changed: 39 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -277,39 +277,67 @@ def _set_transform_groups(self):
277277
group_id = str(id(self))
278278
visited = set() # Track visited objects to avoid infinite recursion
279279

280-
def set_group_recursive(obj, gid):
280+
def set_group_recursive(obj, gid, allow_compose: bool = False):
281281
"""
282282
Recursively set a group ID on a transform and its wrapped transforms.
283283
284284
Args:
285285
obj: Transform instance to process.
286286
gid: Group identifier to assign.
287+
allow_compose: Whether to set group on ``Compose`` instances.
288+
``Compose`` internals are not traversed to preserve nested
289+
pipeline boundaries.
287290
288291
Returns:
289292
None.
290293
"""
294+
if obj is None or isinstance(obj, (bool, int, float, str, bytes)):
295+
return
296+
291297
# Avoid infinite recursion
292298
obj_id = id(obj)
293299
if obj_id in visited:
294300
return
295301
visited.add(obj_id)
296302

303+
if isinstance(obj, Compose):
304+
if allow_compose:
305+
obj._group = gid
306+
return
307+
297308
if isinstance(obj, TraceableTransform):
298309
obj._group = gid
299310

300-
# Handle wrapped transforms in dictionary transforms
301-
# Check common attribute patterns for wrapped transforms
302-
for attr_name in dir(obj):
303-
# Skip magic methods and common non-transform attributes
304-
if attr_name.startswith("__") or attr_name in ("transforms", "transform"):
305-
continue
306-
attr = getattr(obj, attr_name, None)
307-
if attr is not None and isinstance(attr, TraceableTransform) and not isinstance(attr, Compose):
308-
# Recursively set group on nested transforms
311+
if isinstance(obj, Mapping):
312+
for attr in obj.values():
313+
set_group_recursive(attr, gid)
314+
return
315+
316+
if isinstance(obj, (list, tuple, set)):
317+
for attr in obj:
309318
set_group_recursive(attr, gid)
319+
return
320+
321+
attrs: list[Any] = []
322+
if hasattr(obj, "__dict__"):
323+
attrs.extend(vars(obj).values())
324+
325+
slots = getattr(type(obj), "__slots__", ())
326+
if isinstance(slots, str):
327+
slots = (slots,)
328+
for slot in slots:
329+
if slot.startswith("__"):
330+
continue
331+
try:
332+
attrs.append(getattr(obj, slot))
333+
except AttributeError:
334+
continue
335+
336+
for attr in attrs:
337+
set_group_recursive(attr, gid)
310338

311339
for transform in self.transforms:
312-
set_group_recursive(transform, group_id)
340+
set_group_recursive(transform, group_id, allow_compose=True)
313341

314342
@LazyTransform.lazy.setter # type: ignore
315343
def lazy(self, val: bool):

tests/transforms/compose/test_compose.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,43 @@ def test_data_loader_2(self):
268268
self.assertAlmostEqual(out_1.cpu().item(), 0.28602141572)
269269
set_determinism(None)
270270

271+
def test_set_transform_groups_on_wrapped_transform_attributes(self):
272+
class _IdentityInvertible(mt.InvertibleTransform):
273+
def __call__(self, data):
274+
return data
275+
276+
def inverse(self, data):
277+
return data
278+
279+
class _WrapperWithTransform:
280+
def __init__(self):
281+
self.transform = _IdentityInvertible()
282+
283+
def __call__(self, data):
284+
return self.transform(data)
285+
286+
class _WrapperWithTransforms:
287+
def __init__(self):
288+
self.transforms = [_IdentityInvertible(), {"inner": _IdentityInvertible()}]
289+
290+
def __call__(self, data):
291+
for transform in self.transforms:
292+
if isinstance(transform, dict):
293+
for nested_transform in transform.values():
294+
data = nested_transform(data)
295+
else:
296+
data = transform(data)
297+
return data
298+
299+
wrapped_transform = _WrapperWithTransform()
300+
wrapped_transforms = _WrapperWithTransforms()
301+
composed = mt.Compose([wrapped_transform, wrapped_transforms])
302+
expected_group = str(id(composed))
303+
304+
self.assertEqual(getattr(wrapped_transform.transform, "_group", None), expected_group)
305+
self.assertEqual(getattr(wrapped_transforms.transforms[0], "_group", None), expected_group)
306+
self.assertEqual(getattr(wrapped_transforms.transforms[1]["inner"], "_group", None), expected_group)
307+
271308
def test_flatten_and_len(self):
272309
x = mt.EnsureChannelFirst(channel_dim="no_channel")
273310
t1 = mt.Compose([x, x, x, x, mt.Compose([mt.Compose([x, x]), x, x])])

0 commit comments

Comments
 (0)