Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 29 additions & 16 deletions monai/transforms/croppad/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,28 @@ def compute_pad_width(self, spatial_shape: Sequence[int]) -> tuple[tuple[int, in
return spatial_pad.compute_pad_width(spatial_shape)


def _to_int_list(data: Sequence[int] | int | NdarrayOrTensor) -> list[int]:
"""Coerce an ROI spec (scalar, sequence, tensor or ndarray) to a list of Python ints."""
if isinstance(data, (torch.Tensor, np.ndarray)):
data = data.tolist()
if isinstance(data, (str, bytes)): # a str is a Sequence, guard so it is not iterated into digits
raise TypeError(f"ROI spec must be numeric, got {type(data).__name__}.")
if isinstance(data, Sequence):
return [int(i) for i in data]
return [int(data)]
Comment thread
coderabbitai[bot] marked this conversation as resolved.


def _broadcast_int_pair(
a: Sequence[int] | int | NdarrayOrTensor, b: Sequence[int] | int | NdarrayOrTensor
) -> tuple[list[int], list[int]]:
"""Coerce a pair of ROI specs to two equal-length int lists, broadcasting a scalar to match."""
list_a, list_b = _to_int_list(a), _to_int_list(b)
n = max(len(list_a), len(list_b))
if len(list_a) not in (1, n) or len(list_b) not in (1, n):
raise ValueError(f"ROI specs must have matching lengths or be scalar, got {len(list_a)} and {len(list_b)}.")
return (list_a * n if len(list_a) == 1 else list_a), (list_b * n if len(list_b) == 1 else list_b)


class Crop(InvertibleTransform, LazyTransform):
"""
Perform crop operations on the input image.
Expand Down Expand Up @@ -379,31 +401,22 @@ def compute_slices(
roi_slices: list of slices for each of the spatial dimensions.

"""
roi_start_t: torch.Tensor

if roi_slices:
if not all(s.step is None or s.step == 1 for s in roi_slices):
raise ValueError(f"only slice steps of 1/None are currently supported, got {roi_slices}.")
return ensure_tuple(roi_slices)
else:
if roi_center is not None and roi_size is not None:
roi_center_t = convert_to_tensor(data=roi_center, dtype=torch.int16, wrap_sequence=True, device="cpu")
roi_size_t = convert_to_tensor(data=roi_size, dtype=torch.int16, wrap_sequence=True, device="cpu")
_zeros = torch.zeros_like(roi_center_t)
half = torch.divide(roi_size_t, 2, rounding_mode="floor")
roi_start_t = torch.maximum(roi_center_t - half, _zeros)
roi_end_t = torch.maximum(roi_start_t + roi_size_t, roi_start_t)
centers, sizes = _broadcast_int_pair(roi_center, roi_size)
starts = [max(c - s // 2, 0) for c, s in zip(centers, sizes)]
ends = [st + s for st, s in zip(starts, sizes)]
else:
if roi_start is None or roi_end is None:
raise ValueError("please specify either roi_center, roi_size or roi_start, roi_end.")
roi_start_t = convert_to_tensor(data=roi_start, dtype=torch.int16, wrap_sequence=True)
roi_start_t = torch.maximum(roi_start_t, torch.zeros_like(roi_start_t))
roi_end_t = convert_to_tensor(data=roi_end, dtype=torch.int16, wrap_sequence=True)
roi_end_t = torch.maximum(roi_end_t, roi_start_t)
# convert to slices (accounting for 1d)
if roi_start_t.numel() == 1:
return ensure_tuple([slice(int(roi_start_t.item()), int(roi_end_t.item()))])
return ensure_tuple([slice(int(s), int(e)) for s, e in zip(roi_start_t.tolist(), roi_end_t.tolist())])
starts, ends = _broadcast_int_pair(roi_start, roi_end)
starts = [max(s, 0) for s in starts]
# clamp each end to its own start so no slice has negative width
return ensure_tuple([slice(s, max(e, s)) for s, e in zip(starts, ends)])

def __call__( # type: ignore[override]
self, img: torch.Tensor, slices: tuple[slice, ...], lazy: bool | None = None
Expand Down
27 changes: 27 additions & 0 deletions tests/transforms/test_center_spatial_crop.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,14 @@
import unittest

import numpy as np
import torch
from parameterized import parameterized

from monai.data.meta_obj import get_track_meta, set_track_meta
from monai.transforms import CenterSpatialCrop
from monai.transforms.croppad.array import Crop
from tests.croppers import CropTest
from tests.test_utils import SkipIfBeforePyTorchVersion

TEST_SHAPES = [
[{"roi_size": [2, 2, -1]}, (3, 3, 3, 3), (3, 2, 2, 3), True],
Expand Down Expand Up @@ -50,6 +54,29 @@ def test_value(self, input_param, input_arr, expected_arr):
def test_pending_ops(self, input_param, input_shape, _, align_corners):
self.crop_test_pending_ops(input_param, input_shape, align_corners)

def test_compute_slices_broadcast(self):
self.assertEqual(Crop.compute_slices(roi_center=2, roi_size=(4, 6, 8)), (slice(0, 4), slice(0, 6), slice(0, 8)))
self.assertEqual(Crop.compute_slices(roi_start=1, roi_end=(3, 5, 7)), (slice(1, 3), slice(1, 5), slice(1, 7)))
with self.assertRaises(ValueError):
Crop.compute_slices(roi_center=(2, 3), roi_size=(4, 5, 6))
with self.assertRaises(ValueError):
Crop.compute_slices(roi_start=(1, 2), roi_end=(3, 5, 7))
with self.assertRaises(TypeError):
Crop.compute_slices(roi_center="10", roi_size=(4, 6))

@SkipIfBeforePyTorchVersion((2, 1))
def test_torch_compile(self):
prev_track_meta = get_track_meta()
set_track_meta(False)
try:
# eager backend traces the transform without needing the Inductor C++ compiler
cropper = torch.compile(CenterSpatialCrop(roi_size=(1, 16, 16)), backend="eager")
img = torch.rand(1, 1, 32, 32, dtype=torch.float32)
self.assertEqual(tuple(cropper(img).shape), (1, 1, 16, 16))
finally:
set_track_meta(prev_track_meta)
torch._dynamo.reset()
Comment thread
coderabbitai[bot] marked this conversation as resolved.


if __name__ == "__main__":
unittest.main()
Loading