From 9a0afe91ec528ee9e3d20a9f2f5c563b08e7f648 Mon Sep 17 00:00:00 2001 From: Soumya Snigdha Kundu Date: Mon, 29 Jun 2026 21:51:22 +0100 Subject: [PATCH 1/6] fix(transforms): make Crop.compute_slices torch.compile-friendly CenterSpatialCrop (and other Crop subclasses) failed under torch.compile because compute_slices round-tripped integer ROI specs through CPU tensors via convert_to_tensor(..., device="cpu"). Under tracing the shape-derived values are fake tensors whose storage is not allocated, so torch.as_tensor raised "data is not allocated yet". The ROI arithmetic is plain integer math, so compute it directly in Python. This is traceable and drops two CPU allocations per call. Behaviour is unchanged for int, sequence, tensor and ndarray ROI inputs. Fixes #8191 Signed-off-by: Soumya Snigdha Kundu --- monai/transforms/croppad/array.py | 35 +++++++++++--------- tests/transforms/test_center_spatial_crop.py | 14 ++++++++ 2 files changed, 33 insertions(+), 16 deletions(-) diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index b23fbac7d9..7c66baea8a 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -342,6 +342,15 @@ def compute_pad_width(self, spatial_shape: Sequence[int]) -> tuple[tuple[int, in return spatial_pad.compute_pad_width(spatial_shape) +def _to_int_list(data: Sequence[int] | int | NdarrayOrTensor) -> list[int]: + """Coerce an ROI spec (scalar, sequence, tensor or ndarray) to a list of Python ints.""" + if isinstance(data, (torch.Tensor, np.ndarray)): + data = data.tolist() + if isinstance(data, Sequence): + return [int(i) for i in data] + return [int(data)] + + class Crop(InvertibleTransform, LazyTransform): """ Perform crop operations on the input image. @@ -379,31 +388,25 @@ def compute_slices( roi_slices: list of slices for each of the spatial dimensions. """ - roi_start_t: torch.Tensor - if roi_slices: if not all(s.step is None or s.step == 1 for s in roi_slices): raise ValueError(f"only slice steps of 1/None are currently supported, got {roi_slices}.") return ensure_tuple(roi_slices) else: if roi_center is not None and roi_size is not None: - roi_center_t = convert_to_tensor(data=roi_center, dtype=torch.int16, wrap_sequence=True, device="cpu") - roi_size_t = convert_to_tensor(data=roi_size, dtype=torch.int16, wrap_sequence=True, device="cpu") - _zeros = torch.zeros_like(roi_center_t) - half = torch.divide(roi_size_t, 2, rounding_mode="floor") - roi_start_t = torch.maximum(roi_center_t - half, _zeros) - roi_end_t = torch.maximum(roi_start_t + roi_size_t, roi_start_t) + centers = _to_int_list(roi_center) + sizes = _to_int_list(roi_size) + n = max(len(centers), len(sizes)) + centers = centers * n if len(centers) == 1 else centers + sizes = sizes * n if len(sizes) == 1 else sizes + starts = [max(c - s // 2, 0) for c, s in zip(centers, sizes)] + ends = [max(st + s, st) for st, s in zip(starts, sizes)] else: if roi_start is None or roi_end is None: raise ValueError("please specify either roi_center, roi_size or roi_start, roi_end.") - roi_start_t = convert_to_tensor(data=roi_start, dtype=torch.int16, wrap_sequence=True) - roi_start_t = torch.maximum(roi_start_t, torch.zeros_like(roi_start_t)) - roi_end_t = convert_to_tensor(data=roi_end, dtype=torch.int16, wrap_sequence=True) - roi_end_t = torch.maximum(roi_end_t, roi_start_t) - # convert to slices (accounting for 1d) - if roi_start_t.numel() == 1: - return ensure_tuple([slice(int(roi_start_t.item()), int(roi_end_t.item()))]) - return ensure_tuple([slice(int(s), int(e)) for s, e in zip(roi_start_t.tolist(), roi_end_t.tolist())]) + starts = [max(s, 0) for s in _to_int_list(roi_start)] + ends = [max(e, st) for e, st in zip(_to_int_list(roi_end), starts)] + return ensure_tuple([slice(s, e) for s, e in zip(starts, ends)]) def __call__( # type: ignore[override] self, img: torch.Tensor, slices: tuple[slice, ...], lazy: bool | None = None diff --git a/tests/transforms/test_center_spatial_crop.py b/tests/transforms/test_center_spatial_crop.py index 9120f30163..4d7deb26b1 100644 --- a/tests/transforms/test_center_spatial_crop.py +++ b/tests/transforms/test_center_spatial_crop.py @@ -14,10 +14,13 @@ import unittest import numpy as np +import torch from parameterized import parameterized +from monai.data.meta_obj import set_track_meta from monai.transforms import CenterSpatialCrop from tests.croppers import CropTest +from tests.test_utils import SkipIfBeforePyTorchVersion TEST_SHAPES = [ [{"roi_size": [2, 2, -1]}, (3, 3, 3, 3), (3, 2, 2, 3), True], @@ -50,6 +53,17 @@ def test_value(self, input_param, input_arr, expected_arr): def test_pending_ops(self, input_param, input_shape, _, align_corners): self.crop_test_pending_ops(input_param, input_shape, align_corners) + @SkipIfBeforePyTorchVersion((2, 1)) + def test_torch_compile(self): + set_track_meta(False) + try: + cropper = torch.compile(CenterSpatialCrop(roi_size=(1, 16, 16))) + img = torch.rand(1, 1, 32, 32, dtype=torch.float32) + self.assertEqual(tuple(cropper(img).shape), (1, 1, 16, 16)) + finally: + set_track_meta(True) + torch._dynamo.reset() + if __name__ == "__main__": unittest.main() From 6b021d303d0d34b79712e8a719db95240caccc4f Mon Sep 17 00:00:00 2001 From: Soumya Snigdha Kundu Date: Mon, 29 Jun 2026 22:18:35 +0100 Subject: [PATCH 2/6] fix(transforms): preserve ROI broadcasting in compute_slices The pure-Python rewrite zipped ROI lists directly, which silently truncated mismatched-length inputs and dropped scalar broadcasting for roi_start/roi_end. Restore the original semantics: broadcast a scalar against a sequence and raise on incompatible non-scalar lengths, via a shared _broadcast_int_pair helper. Adds a regression test for mixed scalar/sequence ROI inputs. Signed-off-by: Soumya Snigdha Kundu --- monai/transforms/croppad/array.py | 44 ++++++++++++++++---- tests/transforms/test_center_spatial_crop.py | 9 ++++ 2 files changed, 45 insertions(+), 8 deletions(-) diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index 7c66baea8a..94f317b780 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -343,7 +343,15 @@ def compute_pad_width(self, spatial_shape: Sequence[int]) -> tuple[tuple[int, in def _to_int_list(data: Sequence[int] | int | NdarrayOrTensor) -> list[int]: - """Coerce an ROI spec (scalar, sequence, tensor or ndarray) to a list of Python ints.""" + """ + Coerce an ROI spec to a list of Python ints. + + Args: + data: an ROI value as a Python scalar, a sequence, a ``torch.Tensor`` or a ``numpy.ndarray``. + + Returns: + The values as a list of Python ints (a scalar becomes a single-element list). + """ if isinstance(data, (torch.Tensor, np.ndarray)): data = data.tolist() if isinstance(data, Sequence): @@ -351,6 +359,29 @@ def _to_int_list(data: Sequence[int] | int | NdarrayOrTensor) -> list[int]: return [int(data)] +def _broadcast_int_pair( + a: Sequence[int] | int | NdarrayOrTensor, b: Sequence[int] | int | NdarrayOrTensor +) -> tuple[list[int], list[int]]: + """ + Coerce a pair of ROI specs to two equal-length int lists, broadcasting a scalar to match. + + Args: + a: first ROI spec (e.g. ``roi_center`` or ``roi_start``). + b: second ROI spec (e.g. ``roi_size`` or ``roi_end``). + + Returns: + The two specs as lists of Python ints, padded to a common length. + + Raises: + ValueError: when both are non-scalar sequences of differing lengths. + """ + list_a, list_b = _to_int_list(a), _to_int_list(b) + n = max(len(list_a), len(list_b)) + if len(list_a) not in (1, n) or len(list_b) not in (1, n): + raise ValueError(f"ROI specs must have matching lengths or be scalar, got {len(list_a)} and {len(list_b)}.") + return (list_a * n if len(list_a) == 1 else list_a), (list_b * n if len(list_b) == 1 else list_b) + + class Crop(InvertibleTransform, LazyTransform): """ Perform crop operations on the input image. @@ -394,18 +425,15 @@ def compute_slices( return ensure_tuple(roi_slices) else: if roi_center is not None and roi_size is not None: - centers = _to_int_list(roi_center) - sizes = _to_int_list(roi_size) - n = max(len(centers), len(sizes)) - centers = centers * n if len(centers) == 1 else centers - sizes = sizes * n if len(sizes) == 1 else sizes + centers, sizes = _broadcast_int_pair(roi_center, roi_size) starts = [max(c - s // 2, 0) for c, s in zip(centers, sizes)] ends = [max(st + s, st) for st, s in zip(starts, sizes)] else: if roi_start is None or roi_end is None: raise ValueError("please specify either roi_center, roi_size or roi_start, roi_end.") - starts = [max(s, 0) for s in _to_int_list(roi_start)] - ends = [max(e, st) for e, st in zip(_to_int_list(roi_end), starts)] + starts, ends = _broadcast_int_pair(roi_start, roi_end) + starts = [max(s, 0) for s in starts] + ends = [max(e, st) for e, st in zip(ends, starts)] return ensure_tuple([slice(s, e) for s, e in zip(starts, ends)]) def __call__( # type: ignore[override] diff --git a/tests/transforms/test_center_spatial_crop.py b/tests/transforms/test_center_spatial_crop.py index 4d7deb26b1..04be4dcb16 100644 --- a/tests/transforms/test_center_spatial_crop.py +++ b/tests/transforms/test_center_spatial_crop.py @@ -19,6 +19,7 @@ from monai.data.meta_obj import set_track_meta from monai.transforms import CenterSpatialCrop +from monai.transforms.croppad.array import Crop from tests.croppers import CropTest from tests.test_utils import SkipIfBeforePyTorchVersion @@ -53,6 +54,14 @@ def test_value(self, input_param, input_arr, expected_arr): def test_pending_ops(self, input_param, input_shape, _, align_corners): self.crop_test_pending_ops(input_param, input_shape, align_corners) + def test_compute_slices_broadcast(self): + self.assertEqual(Crop.compute_slices(roi_center=2, roi_size=(4, 6, 8)), (slice(0, 4), slice(0, 6), slice(0, 8))) + self.assertEqual(Crop.compute_slices(roi_start=1, roi_end=(3, 5, 7)), (slice(1, 3), slice(1, 5), slice(1, 7))) + with self.assertRaises(ValueError): + Crop.compute_slices(roi_center=(2, 3), roi_size=(4, 5, 6)) + with self.assertRaises(ValueError): + Crop.compute_slices(roi_start=(1, 2), roi_end=(3, 5, 7)) + @SkipIfBeforePyTorchVersion((2, 1)) def test_torch_compile(self): set_track_meta(False) From 7949d47e51a037689bb1a4dfeac7715ba9d91278 Mon Sep 17 00:00:00 2001 From: Soumya Snigdha Kundu Date: Mon, 29 Jun 2026 22:21:56 +0100 Subject: [PATCH 3/6] test: restore prior track_meta state in compile test set_track_meta is a process-global flag; save and restore the prior value instead of forcing it back to True, so the test is not order-dependent. Signed-off-by: Soumya Snigdha Kundu --- tests/transforms/test_center_spatial_crop.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/transforms/test_center_spatial_crop.py b/tests/transforms/test_center_spatial_crop.py index 04be4dcb16..8f9b69cc41 100644 --- a/tests/transforms/test_center_spatial_crop.py +++ b/tests/transforms/test_center_spatial_crop.py @@ -17,7 +17,7 @@ import torch from parameterized import parameterized -from monai.data.meta_obj import set_track_meta +from monai.data.meta_obj import get_track_meta, set_track_meta from monai.transforms import CenterSpatialCrop from monai.transforms.croppad.array import Crop from tests.croppers import CropTest @@ -64,13 +64,14 @@ def test_compute_slices_broadcast(self): @SkipIfBeforePyTorchVersion((2, 1)) def test_torch_compile(self): + prev_track_meta = get_track_meta() set_track_meta(False) try: cropper = torch.compile(CenterSpatialCrop(roi_size=(1, 16, 16))) img = torch.rand(1, 1, 32, 32, dtype=torch.float32) self.assertEqual(tuple(cropper(img).shape), (1, 1, 16, 16)) finally: - set_track_meta(True) + set_track_meta(prev_track_meta) torch._dynamo.reset() From 93da7849ad327b47d3bde2d089ec953820f5483e Mon Sep 17 00:00:00 2001 From: Soumya Snigdha Kundu Date: Mon, 29 Jun 2026 22:24:15 +0100 Subject: [PATCH 4/6] fix(transforms): reject str/bytes ROI specs in _to_int_list str/bytes are Sequences, so a value like "10" would be silently parsed into [1, 0]; raise TypeError instead and cover it with a test. Signed-off-by: Soumya Snigdha Kundu --- monai/transforms/croppad/array.py | 5 +++++ tests/transforms/test_center_spatial_crop.py | 2 ++ 2 files changed, 7 insertions(+) diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index 94f317b780..66370c5b74 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -351,9 +351,14 @@ def _to_int_list(data: Sequence[int] | int | NdarrayOrTensor) -> list[int]: Returns: The values as a list of Python ints (a scalar becomes a single-element list). + + Raises: + TypeError: when ``data`` is a ``str`` or ``bytes``. """ if isinstance(data, (torch.Tensor, np.ndarray)): data = data.tolist() + if isinstance(data, (str, bytes)): + raise TypeError(f"ROI spec must be numeric, got {type(data).__name__}.") if isinstance(data, Sequence): return [int(i) for i in data] return [int(data)] diff --git a/tests/transforms/test_center_spatial_crop.py b/tests/transforms/test_center_spatial_crop.py index 8f9b69cc41..9ed66271fc 100644 --- a/tests/transforms/test_center_spatial_crop.py +++ b/tests/transforms/test_center_spatial_crop.py @@ -61,6 +61,8 @@ def test_compute_slices_broadcast(self): Crop.compute_slices(roi_center=(2, 3), roi_size=(4, 5, 6)) with self.assertRaises(ValueError): Crop.compute_slices(roi_start=(1, 2), roi_end=(3, 5, 7)) + with self.assertRaises(TypeError): + Crop.compute_slices(roi_center="10", roi_size=(4, 6)) @SkipIfBeforePyTorchVersion((2, 1)) def test_torch_compile(self): From 8e200de58e4d90dbac7bbdd78548b59d92d063cb Mon Sep 17 00:00:00 2001 From: Soumya Snigdha Kundu Date: Mon, 29 Jun 2026 22:33:12 +0100 Subject: [PATCH 5/6] refactor(transforms): tidy compute_slices ROI helpers Trim the oversized docstrings on the private _to_int_list/_broadcast_int_pair helpers to one-line summaries, and unify the end-clamp so it lives in a single place in the final slice comprehension instead of being re-derived in both branches. Signed-off-by: Soumya Snigdha Kundu --- monai/transforms/croppad/array.py | 35 ++++++------------------------- 1 file changed, 6 insertions(+), 29 deletions(-) diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index 66370c5b74..d520446e25 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -343,21 +343,10 @@ def compute_pad_width(self, spatial_shape: Sequence[int]) -> tuple[tuple[int, in def _to_int_list(data: Sequence[int] | int | NdarrayOrTensor) -> list[int]: - """ - Coerce an ROI spec to a list of Python ints. - - Args: - data: an ROI value as a Python scalar, a sequence, a ``torch.Tensor`` or a ``numpy.ndarray``. - - Returns: - The values as a list of Python ints (a scalar becomes a single-element list). - - Raises: - TypeError: when ``data`` is a ``str`` or ``bytes``. - """ + """Coerce an ROI spec (scalar, sequence, tensor or ndarray) to a list of Python ints.""" if isinstance(data, (torch.Tensor, np.ndarray)): data = data.tolist() - if isinstance(data, (str, bytes)): + if isinstance(data, (str, bytes)): # a str is a Sequence, guard so it is not iterated into digits raise TypeError(f"ROI spec must be numeric, got {type(data).__name__}.") if isinstance(data, Sequence): return [int(i) for i in data] @@ -367,19 +356,7 @@ def _to_int_list(data: Sequence[int] | int | NdarrayOrTensor) -> list[int]: def _broadcast_int_pair( a: Sequence[int] | int | NdarrayOrTensor, b: Sequence[int] | int | NdarrayOrTensor ) -> tuple[list[int], list[int]]: - """ - Coerce a pair of ROI specs to two equal-length int lists, broadcasting a scalar to match. - - Args: - a: first ROI spec (e.g. ``roi_center`` or ``roi_start``). - b: second ROI spec (e.g. ``roi_size`` or ``roi_end``). - - Returns: - The two specs as lists of Python ints, padded to a common length. - - Raises: - ValueError: when both are non-scalar sequences of differing lengths. - """ + """Coerce a pair of ROI specs to two equal-length int lists, broadcasting a scalar to match.""" list_a, list_b = _to_int_list(a), _to_int_list(b) n = max(len(list_a), len(list_b)) if len(list_a) not in (1, n) or len(list_b) not in (1, n): @@ -432,14 +409,14 @@ def compute_slices( if roi_center is not None and roi_size is not None: centers, sizes = _broadcast_int_pair(roi_center, roi_size) starts = [max(c - s // 2, 0) for c, s in zip(centers, sizes)] - ends = [max(st + s, st) for st, s in zip(starts, sizes)] + ends = [st + s for st, s in zip(starts, sizes)] else: if roi_start is None or roi_end is None: raise ValueError("please specify either roi_center, roi_size or roi_start, roi_end.") starts, ends = _broadcast_int_pair(roi_start, roi_end) starts = [max(s, 0) for s in starts] - ends = [max(e, st) for e, st in zip(ends, starts)] - return ensure_tuple([slice(s, e) for s, e in zip(starts, ends)]) + # clamp each end to its own start so no slice has negative width + return ensure_tuple([slice(s, max(e, s)) for s, e in zip(starts, ends)]) def __call__( # type: ignore[override] self, img: torch.Tensor, slices: tuple[slice, ...], lazy: bool | None = None From 9cef164423c0e181bd91e8ba6264e9ce68c3ad64 Mon Sep 17 00:00:00 2001 From: Soumya Snigdha Kundu Date: Tue, 30 Jun 2026 00:39:04 +0100 Subject: [PATCH 6/6] tests: use eager backend for CenterSpatialCrop torch.compile test The Inductor backend forces a real C++ compilation, which fails on CI runners without a working toolchain (no MSVC on Windows) or with a clearml-corrupted sys.path on Linux. The eager backend still traces the transform through Dynamo, which is what this test verifies. Signed-off-by: Soumya Snigdha Kundu --- tests/transforms/test_center_spatial_crop.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/transforms/test_center_spatial_crop.py b/tests/transforms/test_center_spatial_crop.py index 9ed66271fc..d380c9a687 100644 --- a/tests/transforms/test_center_spatial_crop.py +++ b/tests/transforms/test_center_spatial_crop.py @@ -69,7 +69,8 @@ def test_torch_compile(self): prev_track_meta = get_track_meta() set_track_meta(False) try: - cropper = torch.compile(CenterSpatialCrop(roi_size=(1, 16, 16))) + # eager backend traces the transform without needing the Inductor C++ compiler + cropper = torch.compile(CenterSpatialCrop(roi_size=(1, 16, 16)), backend="eager") img = torch.rand(1, 1, 32, 32, dtype=torch.float32) self.assertEqual(tuple(cropper(img).shape), (1, 1, 16, 16)) finally: