diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index b23fbac7d9..d520446e25 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -342,6 +342,28 @@ def compute_pad_width(self, spatial_shape: Sequence[int]) -> tuple[tuple[int, in return spatial_pad.compute_pad_width(spatial_shape) +def _to_int_list(data: Sequence[int] | int | NdarrayOrTensor) -> list[int]: + """Coerce an ROI spec (scalar, sequence, tensor or ndarray) to a list of Python ints.""" + if isinstance(data, (torch.Tensor, np.ndarray)): + data = data.tolist() + if isinstance(data, (str, bytes)): # a str is a Sequence, guard so it is not iterated into digits + raise TypeError(f"ROI spec must be numeric, got {type(data).__name__}.") + if isinstance(data, Sequence): + return [int(i) for i in data] + return [int(data)] + + +def _broadcast_int_pair( + a: Sequence[int] | int | NdarrayOrTensor, b: Sequence[int] | int | NdarrayOrTensor +) -> tuple[list[int], list[int]]: + """Coerce a pair of ROI specs to two equal-length int lists, broadcasting a scalar to match.""" + list_a, list_b = _to_int_list(a), _to_int_list(b) + n = max(len(list_a), len(list_b)) + if len(list_a) not in (1, n) or len(list_b) not in (1, n): + raise ValueError(f"ROI specs must have matching lengths or be scalar, got {len(list_a)} and {len(list_b)}.") + return (list_a * n if len(list_a) == 1 else list_a), (list_b * n if len(list_b) == 1 else list_b) + + class Crop(InvertibleTransform, LazyTransform): """ Perform crop operations on the input image. @@ -379,31 +401,22 @@ def compute_slices( roi_slices: list of slices for each of the spatial dimensions. """ - roi_start_t: torch.Tensor - if roi_slices: if not all(s.step is None or s.step == 1 for s in roi_slices): raise ValueError(f"only slice steps of 1/None are currently supported, got {roi_slices}.") return ensure_tuple(roi_slices) else: if roi_center is not None and roi_size is not None: - roi_center_t = convert_to_tensor(data=roi_center, dtype=torch.int16, wrap_sequence=True, device="cpu") - roi_size_t = convert_to_tensor(data=roi_size, dtype=torch.int16, wrap_sequence=True, device="cpu") - _zeros = torch.zeros_like(roi_center_t) - half = torch.divide(roi_size_t, 2, rounding_mode="floor") - roi_start_t = torch.maximum(roi_center_t - half, _zeros) - roi_end_t = torch.maximum(roi_start_t + roi_size_t, roi_start_t) + centers, sizes = _broadcast_int_pair(roi_center, roi_size) + starts = [max(c - s // 2, 0) for c, s in zip(centers, sizes)] + ends = [st + s for st, s in zip(starts, sizes)] else: if roi_start is None or roi_end is None: raise ValueError("please specify either roi_center, roi_size or roi_start, roi_end.") - roi_start_t = convert_to_tensor(data=roi_start, dtype=torch.int16, wrap_sequence=True) - roi_start_t = torch.maximum(roi_start_t, torch.zeros_like(roi_start_t)) - roi_end_t = convert_to_tensor(data=roi_end, dtype=torch.int16, wrap_sequence=True) - roi_end_t = torch.maximum(roi_end_t, roi_start_t) - # convert to slices (accounting for 1d) - if roi_start_t.numel() == 1: - return ensure_tuple([slice(int(roi_start_t.item()), int(roi_end_t.item()))]) - return ensure_tuple([slice(int(s), int(e)) for s, e in zip(roi_start_t.tolist(), roi_end_t.tolist())]) + starts, ends = _broadcast_int_pair(roi_start, roi_end) + starts = [max(s, 0) for s in starts] + # clamp each end to its own start so no slice has negative width + return ensure_tuple([slice(s, max(e, s)) for s, e in zip(starts, ends)]) def __call__( # type: ignore[override] self, img: torch.Tensor, slices: tuple[slice, ...], lazy: bool | None = None diff --git a/tests/transforms/test_center_spatial_crop.py b/tests/transforms/test_center_spatial_crop.py index 9120f30163..d380c9a687 100644 --- a/tests/transforms/test_center_spatial_crop.py +++ b/tests/transforms/test_center_spatial_crop.py @@ -14,10 +14,14 @@ import unittest import numpy as np +import torch from parameterized import parameterized +from monai.data.meta_obj import get_track_meta, set_track_meta from monai.transforms import CenterSpatialCrop +from monai.transforms.croppad.array import Crop from tests.croppers import CropTest +from tests.test_utils import SkipIfBeforePyTorchVersion TEST_SHAPES = [ [{"roi_size": [2, 2, -1]}, (3, 3, 3, 3), (3, 2, 2, 3), True], @@ -50,6 +54,29 @@ def test_value(self, input_param, input_arr, expected_arr): def test_pending_ops(self, input_param, input_shape, _, align_corners): self.crop_test_pending_ops(input_param, input_shape, align_corners) + def test_compute_slices_broadcast(self): + self.assertEqual(Crop.compute_slices(roi_center=2, roi_size=(4, 6, 8)), (slice(0, 4), slice(0, 6), slice(0, 8))) + self.assertEqual(Crop.compute_slices(roi_start=1, roi_end=(3, 5, 7)), (slice(1, 3), slice(1, 5), slice(1, 7))) + with self.assertRaises(ValueError): + Crop.compute_slices(roi_center=(2, 3), roi_size=(4, 5, 6)) + with self.assertRaises(ValueError): + Crop.compute_slices(roi_start=(1, 2), roi_end=(3, 5, 7)) + with self.assertRaises(TypeError): + Crop.compute_slices(roi_center="10", roi_size=(4, 6)) + + @SkipIfBeforePyTorchVersion((2, 1)) + def test_torch_compile(self): + prev_track_meta = get_track_meta() + set_track_meta(False) + try: + # eager backend traces the transform without needing the Inductor C++ compiler + cropper = torch.compile(CenterSpatialCrop(roi_size=(1, 16, 16)), backend="eager") + img = torch.rand(1, 1, 32, 32, dtype=torch.float32) + self.assertEqual(tuple(cropper(img).shape), (1, 1, 16, 16)) + finally: + set_track_meta(prev_track_meta) + torch._dynamo.reset() + if __name__ == "__main__": unittest.main()