Skip to content

Commit de0b762

Browse files
committed
Enable global coordinates in spatial crop transforms
Add convenience transforms for converting points between world and image coordinate spaces, and extend SpatialCropd to accept string dictionary keys for ROI parameters, enabling deferred coordinate resolution at call time. New transforms: - TransformPointsWorldToImaged: world-to-image coordinate conversion - TransformPointsImageToWorldd: image-to-world coordinate conversion SpatialCropd changes: - roi_center, roi_size, roi_start, roi_end now accept string keys - When strings are provided, coordinates are resolved from the data dictionary at __call__ time (zero overhead for existing usage) - Tensors from ApplyTransformToPoints are automatically flattened and rounded to integers - Inverse override with check=False for string-key path to handle recreated cropper identity mismatch Signed-off-by: Emanuilo Jovanovic <emanuilo.jovanovic@smartcat.io>
1 parent b3fff92 commit de0b762

File tree

6 files changed

+564
-10
lines changed

6 files changed

+564
-10
lines changed

monai/transforms/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -676,6 +676,12 @@
676676
ToTensord,
677677
ToTensorD,
678678
ToTensorDict,
679+
TransformPointsImageToWorldd,
680+
TransformPointsImageToWorldD,
681+
TransformPointsImageToWorldDict,
682+
TransformPointsWorldToImaged,
683+
TransformPointsWorldToImageD,
684+
TransformPointsWorldToImageDict,
679685
Transposed,
680686
TransposeD,
681687
TransposeDict,

monai/transforms/croppad/dictionary.py

Lines changed: 130 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
from monai.transforms.traits import LazyTrait, MultiSampleTrait
5151
from monai.transforms.transform import LazyTransform, MapTransform, Randomizable
5252
from monai.transforms.utils import is_positive
53-
from monai.utils import MAX_SEED, Method, PytorchPadMode, ensure_tuple_rep
53+
from monai.utils import MAX_SEED, Method, PytorchPadMode, TraceKeys, ensure_tuple_rep
5454

5555
__all__ = [
5656
"Padd",
@@ -431,17 +431,33 @@ class SpatialCropd(Cropd):
431431
- a spatial center and size
432432
- the start and end coordinates of the ROI
433433
434+
ROI parameters (``roi_center``, ``roi_size``, ``roi_start``, ``roi_end``) can also be specified as
435+
string dictionary keys. When a string is provided, the actual coordinate values are read from the
436+
data dictionary at call time. This enables pipelines where coordinates are computed by earlier
437+
transforms (e.g., :py:class:`monai.transforms.TransformPointsWorldToImaged`) and stored in the
438+
data dictionary under the given key.
439+
440+
Example::
441+
442+
from monai.transforms import Compose, TransformPointsWorldToImaged, SpatialCropd
443+
444+
pipeline = Compose([
445+
TransformPointsWorldToImaged(keys="roi_start", refer_keys="image"),
446+
TransformPointsWorldToImaged(keys="roi_end", refer_keys="image"),
447+
SpatialCropd(keys="image", roi_start="roi_start", roi_end="roi_end"),
448+
])
449+
434450
This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`
435451
for more information.
436452
"""
437453

438454
def __init__(
439455
self,
440456
keys: KeysCollection,
441-
roi_center: Sequence[int] | int | None = None,
442-
roi_size: Sequence[int] | int | None = None,
443-
roi_start: Sequence[int] | int | None = None,
444-
roi_end: Sequence[int] | int | None = None,
457+
roi_center: Sequence[int] | int | str | None = None,
458+
roi_size: Sequence[int] | int | str | None = None,
459+
roi_start: Sequence[int] | int | str | None = None,
460+
roi_end: Sequence[int] | int | str | None = None,
445461
roi_slices: Sequence[slice] | None = None,
446462
allow_missing_keys: bool = False,
447463
lazy: bool = False,
@@ -450,19 +466,123 @@ def __init__(
450466
Args:
451467
keys: keys of the corresponding items to be transformed.
452468
See also: :py:class:`monai.transforms.compose.MapTransform`
453-
roi_center: voxel coordinates for center of the crop ROI.
469+
roi_center: voxel coordinates for center of the crop ROI, or a string key to look up
470+
the coordinates from the data dictionary.
454471
roi_size: size of the crop ROI, if a dimension of ROI size is larger than image size,
455-
will not crop that dimension of the image.
456-
roi_start: voxel coordinates for start of the crop ROI.
472+
will not crop that dimension of the image. Can also be a string key.
473+
roi_start: voxel coordinates for start of the crop ROI, or a string key to look up
474+
the coordinates from the data dictionary.
457475
roi_end: voxel coordinates for end of the crop ROI, if a coordinate is out of image,
458-
use the end coordinate of image.
476+
use the end coordinate of image. Can also be a string key.
459477
roi_slices: list of slices for each of the spatial dimensions.
460478
allow_missing_keys: don't raise exception if key is missing.
461479
lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False.
462480
"""
463-
cropper = SpatialCrop(roi_center, roi_size, roi_start, roi_end, roi_slices, lazy=lazy)
481+
self._roi_center = roi_center
482+
self._roi_size = roi_size
483+
self._roi_start = roi_start
484+
self._roi_end = roi_end
485+
self._roi_slices = roi_slices
486+
self._has_str_roi = any(isinstance(v, str) for v in [roi_center, roi_size, roi_start, roi_end])
487+
488+
if not self._has_str_roi:
489+
cropper = SpatialCrop(roi_center, roi_size, roi_start, roi_end, roi_slices, lazy=lazy)
490+
else:
491+
# Placeholder cropper for the string-key path. Replaced on self.cropper at
492+
# __call__ time once string keys are resolved from the data dictionary.
493+
cropper = SpatialCrop(roi_start=[0], roi_end=[1], lazy=lazy)
464494
super().__init__(keys, cropper=cropper, allow_missing_keys=allow_missing_keys, lazy=lazy)
465495

496+
@staticmethod
497+
def _resolve_roi_param(val, d):
498+
"""Resolve an ROI parameter from the data dictionary if it is a string key.
499+
500+
Args:
501+
val: the ROI parameter value. If a string, it is used as a key to look up
502+
the actual value from ``d``. Otherwise returned as-is.
503+
d: the data dictionary.
504+
505+
Returns:
506+
The resolved ROI parameter. Tensors and numpy arrays are flattened to 1-D
507+
and rounded to int64 so they can be consumed by ``Crop.compute_slices``.
508+
509+
Raises:
510+
KeyError: if ``val`` is a string key that does not exist in ``d``.
511+
"""
512+
if not isinstance(val, str):
513+
return val
514+
if val not in d:
515+
raise KeyError(f"ROI key '{val}' not found in the data dictionary.")
516+
resolved = d[val]
517+
# ApplyTransformToPoints outputs tensors of shape (C, N, dims).
518+
# A single coordinate like [142.5, -67.3, 301.8] becomes shape (1, 1, 3).
519+
# Flatten to 1-D and round to integers for compute_slices.
520+
# Uses banker's rounding (torch.round) to avoid systematic bias in spatial coordinates.
521+
if isinstance(resolved, np.ndarray):
522+
resolved = torch.from_numpy(resolved)
523+
if isinstance(resolved, torch.Tensor):
524+
resolved = torch.round(resolved.flatten()).to(torch.int64)
525+
return resolved
526+
527+
@property
528+
def requires_current_data(self):
529+
"""bool: Whether this transform requires the current data dictionary to resolve ROI parameters."""
530+
return self._has_str_roi
531+
532+
def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None) -> dict[Hashable, torch.Tensor]:
533+
"""
534+
Args:
535+
data: dictionary of data items to be transformed.
536+
lazy: whether to execute lazily. If ``None``, uses the instance default.
537+
538+
Returns:
539+
Dictionary with cropped data for each key.
540+
"""
541+
if not self._has_str_roi:
542+
return super().__call__(data, lazy=lazy)
543+
544+
d = dict(data)
545+
roi_center = self._resolve_roi_param(self._roi_center, d)
546+
roi_size = self._resolve_roi_param(self._roi_size, d)
547+
roi_start = self._resolve_roi_param(self._roi_start, d)
548+
roi_end = self._resolve_roi_param(self._roi_end, d)
549+
550+
lazy_ = self.lazy if lazy is None else lazy
551+
self.cropper = SpatialCrop(
552+
roi_center=roi_center, roi_size=roi_size,
553+
roi_start=roi_start, roi_end=roi_end,
554+
roi_slices=self._roi_slices, lazy=lazy_,
555+
)
556+
for key in self.key_iterator(d):
557+
d[key] = self.cropper(d[key], lazy=lazy_)
558+
return d
559+
560+
def inverse(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTensor]:
561+
"""
562+
Inverse of the crop transform, restoring the original spatial dimensions via padding.
563+
564+
For the string-key path, ``self.cropper`` is recreated on each ``__call__``, so its
565+
``id()`` won't match the one stored in the MetaTensor's transform stack. This override
566+
bypasses the ID check and applies the inverse directly using the crop info stored in the
567+
MetaTensor.
568+
569+
Args:
570+
data: dictionary of cropped ``MetaTensor`` items.
571+
572+
Returns:
573+
Dictionary with inverse-transformed (padded) data for each key.
574+
"""
575+
if not self._has_str_roi:
576+
return super().inverse(data)
577+
d = dict(data)
578+
for key in self.key_iterator(d):
579+
transform = self.cropper.pop_transform(d[key], check=False)
580+
cropped = transform[TraceKeys.EXTRA_INFO]["cropped"]
581+
inverse_transform = BorderPad(cropped)
582+
with inverse_transform.trace_transform(False):
583+
d[key] = inverse_transform(d[key])
584+
return d
585+
466586

467587
class CenterSpatialCropd(Cropd):
468588
"""

monai/transforms/utility/dictionary.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,12 @@
192192
"ApplyTransformToPointsd",
193193
"ApplyTransformToPointsD",
194194
"ApplyTransformToPointsDict",
195+
"TransformPointsWorldToImaged",
196+
"TransformPointsWorldToImageD",
197+
"TransformPointsWorldToImageDict",
198+
"TransformPointsImageToWorldd",
199+
"TransformPointsImageToWorldD",
200+
"TransformPointsImageToWorldDict",
195201
"FlattenSequenced",
196202
"FlattenSequenceD",
197203
"FlattenSequenceDict",
@@ -1910,6 +1916,86 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch
19101916
return d
19111917

19121918

1919+
class TransformPointsWorldToImaged(ApplyTransformToPointsd):
1920+
"""
1921+
Dictionary-based transform to convert points from world coordinates to image coordinates.
1922+
1923+
This is a convenience subclass of :py:class:`monai.transforms.ApplyTransformToPointsd` with
1924+
``invert_affine=True``, which transforms world-space coordinates into the coordinate space of a
1925+
reference image by inverting the image's affine matrix.
1926+
1927+
Args:
1928+
keys: keys of the corresponding items to be transformed.
1929+
See also: monai.transforms.MapTransform
1930+
refer_keys: The key of the reference image used to derive the affine transformation.
1931+
This is required because the affine must come from a reference image.
1932+
It can also be a sequence of keys, in which case each refers to the affine applied
1933+
to the matching points in ``keys``.
1934+
dtype: The desired data type for the output.
1935+
affine_lps_to_ras: Defaults to ``False``. Set to ``True`` if your point data is in the RAS
1936+
coordinate system or you're using ``ITKReader`` with ``affine_lps_to_ras=True``.
1937+
allow_missing_keys: Don't raise exception if key is missing.
1938+
"""
1939+
1940+
def __init__(
1941+
self,
1942+
keys: KeysCollection,
1943+
refer_keys: KeysCollection,
1944+
dtype: DtypeLike | torch.dtype = torch.float64,
1945+
affine_lps_to_ras: bool = False,
1946+
allow_missing_keys: bool = False,
1947+
):
1948+
super().__init__(
1949+
keys=keys,
1950+
refer_keys=refer_keys,
1951+
dtype=dtype,
1952+
affine=None,
1953+
invert_affine=True,
1954+
affine_lps_to_ras=affine_lps_to_ras,
1955+
allow_missing_keys=allow_missing_keys,
1956+
)
1957+
1958+
1959+
class TransformPointsImageToWorldd(ApplyTransformToPointsd):
1960+
"""
1961+
Dictionary-based transform to convert points from image coordinates to world coordinates.
1962+
1963+
This is a convenience subclass of :py:class:`monai.transforms.ApplyTransformToPointsd` with
1964+
``invert_affine=False``, which transforms image-space coordinates into world-space coordinates
1965+
by applying the reference image's affine matrix directly.
1966+
1967+
Args:
1968+
keys: keys of the corresponding items to be transformed.
1969+
See also: monai.transforms.MapTransform
1970+
refer_keys: The key of the reference image used to derive the affine transformation.
1971+
This is required because the affine must come from a reference image.
1972+
It can also be a sequence of keys, in which case each refers to the affine applied
1973+
to the matching points in ``keys``.
1974+
dtype: The desired data type for the output.
1975+
affine_lps_to_ras: Defaults to ``False``. Set to ``True`` if your point data is in the RAS
1976+
coordinate system or you're using ``ITKReader`` with ``affine_lps_to_ras=True``.
1977+
allow_missing_keys: Don't raise exception if key is missing.
1978+
"""
1979+
1980+
def __init__(
1981+
self,
1982+
keys: KeysCollection,
1983+
refer_keys: KeysCollection,
1984+
dtype: DtypeLike | torch.dtype = torch.float64,
1985+
affine_lps_to_ras: bool = False,
1986+
allow_missing_keys: bool = False,
1987+
):
1988+
super().__init__(
1989+
keys=keys,
1990+
refer_keys=refer_keys,
1991+
dtype=dtype,
1992+
affine=None,
1993+
invert_affine=False,
1994+
affine_lps_to_ras=affine_lps_to_ras,
1995+
allow_missing_keys=allow_missing_keys,
1996+
)
1997+
1998+
19131999
class FlattenSequenced(MapTransform, ReduceTrait):
19142000
"""
19152001
Dictionary-based wrapper of :py:class:`monai.transforms.FlattenSequence`.
@@ -1975,4 +2061,6 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
19752061
AddCoordinateChannelsD = AddCoordinateChannelsDict = AddCoordinateChannelsd
19762062
FlattenSubKeysD = FlattenSubKeysDict = FlattenSubKeysd
19772063
ApplyTransformToPointsD = ApplyTransformToPointsDict = ApplyTransformToPointsd
2064+
TransformPointsWorldToImageD = TransformPointsWorldToImageDict = TransformPointsWorldToImaged
2065+
TransformPointsImageToWorldD = TransformPointsImageToWorldDict = TransformPointsImageToWorldd
19782066
FlattenSequenceD = FlattenSequenceDict = FlattenSequenced

0 commit comments

Comments
 (0)