5050from monai .transforms .traits import LazyTrait , MultiSampleTrait
5151from monai .transforms .transform import LazyTransform , MapTransform , Randomizable
5252from monai .transforms .utils import is_positive
53- from monai .utils import MAX_SEED , Method , PytorchPadMode , ensure_tuple_rep
53+ from monai .utils import MAX_SEED , Method , PytorchPadMode , TraceKeys , ensure_tuple_rep
5454
5555__all__ = [
5656 "Padd" ,
@@ -431,17 +431,33 @@ class SpatialCropd(Cropd):
431431 - a spatial center and size
432432 - the start and end coordinates of the ROI
433433
434+ ROI parameters (``roi_center``, ``roi_size``, ``roi_start``, ``roi_end``) can also be specified as
435+ string dictionary keys. When a string is provided, the actual coordinate values are read from the
436+ data dictionary at call time. This enables pipelines where coordinates are computed by earlier
437+ transforms (e.g., :py:class:`monai.transforms.TransformPointsWorldToImaged`) and stored in the
438+ data dictionary under the given key.
439+
440+ Example::
441+
442+ from monai.transforms import Compose, TransformPointsWorldToImaged, SpatialCropd
443+
444+ pipeline = Compose([
445+ TransformPointsWorldToImaged(keys="roi_start", refer_keys="image"),
446+ TransformPointsWorldToImaged(keys="roi_end", refer_keys="image"),
447+ SpatialCropd(keys="image", roi_start="roi_start", roi_end="roi_end"),
448+ ])
449+
434450 This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`
435451 for more information.
436452 """
437453
438454 def __init__ (
439455 self ,
440456 keys : KeysCollection ,
441- roi_center : Sequence [int ] | int | None = None ,
442- roi_size : Sequence [int ] | int | None = None ,
443- roi_start : Sequence [int ] | int | None = None ,
444- roi_end : Sequence [int ] | int | None = None ,
457+ roi_center : Sequence [int ] | int | str | None = None ,
458+ roi_size : Sequence [int ] | int | str | None = None ,
459+ roi_start : Sequence [int ] | int | str | None = None ,
460+ roi_end : Sequence [int ] | int | str | None = None ,
445461 roi_slices : Sequence [slice ] | None = None ,
446462 allow_missing_keys : bool = False ,
447463 lazy : bool = False ,
@@ -450,19 +466,123 @@ def __init__(
450466 Args:
451467 keys: keys of the corresponding items to be transformed.
452468 See also: :py:class:`monai.transforms.compose.MapTransform`
453- roi_center: voxel coordinates for center of the crop ROI.
469+ roi_center: voxel coordinates for center of the crop ROI, or a string key to look up
470+ the coordinates from the data dictionary.
454471 roi_size: size of the crop ROI, if a dimension of ROI size is larger than image size,
455- will not crop that dimension of the image.
456- roi_start: voxel coordinates for start of the crop ROI.
472+ will not crop that dimension of the image. Can also be a string key.
473+ roi_start: voxel coordinates for start of the crop ROI, or a string key to look up
474+ the coordinates from the data dictionary.
457475 roi_end: voxel coordinates for end of the crop ROI, if a coordinate is out of image,
458- use the end coordinate of image.
476+ use the end coordinate of image. Can also be a string key.
459477 roi_slices: list of slices for each of the spatial dimensions.
460478 allow_missing_keys: don't raise exception if key is missing.
461479 lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False.
462480 """
463- cropper = SpatialCrop (roi_center , roi_size , roi_start , roi_end , roi_slices , lazy = lazy )
481+ self ._roi_center = roi_center
482+ self ._roi_size = roi_size
483+ self ._roi_start = roi_start
484+ self ._roi_end = roi_end
485+ self ._roi_slices = roi_slices
486+ self ._has_str_roi = any (isinstance (v , str ) for v in [roi_center , roi_size , roi_start , roi_end ])
487+
488+ if not self ._has_str_roi :
489+ cropper = SpatialCrop (roi_center , roi_size , roi_start , roi_end , roi_slices , lazy = lazy )
490+ else :
491+ # Placeholder cropper for the string-key path. Replaced on self.cropper at
492+ # __call__ time once string keys are resolved from the data dictionary.
493+ cropper = SpatialCrop (roi_start = [0 ], roi_end = [1 ], lazy = lazy )
464494 super ().__init__ (keys , cropper = cropper , allow_missing_keys = allow_missing_keys , lazy = lazy )
465495
496+ @staticmethod
497+ def _resolve_roi_param (val , d ):
498+ """Resolve an ROI parameter from the data dictionary if it is a string key.
499+
500+ Args:
501+ val: the ROI parameter value. If a string, it is used as a key to look up
502+ the actual value from ``d``. Otherwise returned as-is.
503+ d: the data dictionary.
504+
505+ Returns:
506+ The resolved ROI parameter. Tensors and numpy arrays are flattened to 1-D
507+ and rounded to int64 so they can be consumed by ``Crop.compute_slices``.
508+
509+ Raises:
510+ KeyError: if ``val`` is a string key that does not exist in ``d``.
511+ """
512+ if not isinstance (val , str ):
513+ return val
514+ if val not in d :
515+ raise KeyError (f"ROI key '{ val } ' not found in the data dictionary." )
516+ resolved = d [val ]
517+ # ApplyTransformToPoints outputs tensors of shape (C, N, dims).
518+ # A single coordinate like [142.5, -67.3, 301.8] becomes shape (1, 1, 3).
519+ # Flatten to 1-D and round to integers for compute_slices.
520+ # Uses banker's rounding (torch.round) to avoid systematic bias in spatial coordinates.
521+ if isinstance (resolved , np .ndarray ):
522+ resolved = torch .from_numpy (resolved )
523+ if isinstance (resolved , torch .Tensor ):
524+ resolved = torch .round (resolved .flatten ()).to (torch .int64 )
525+ return resolved
526+
527+ @property
528+ def requires_current_data (self ):
529+ """bool: Whether this transform requires the current data dictionary to resolve ROI parameters."""
530+ return self ._has_str_roi
531+
532+ def __call__ (self , data : Mapping [Hashable , torch .Tensor ], lazy : bool | None = None ) -> dict [Hashable , torch .Tensor ]:
533+ """
534+ Args:
535+ data: dictionary of data items to be transformed.
536+ lazy: whether to execute lazily. If ``None``, uses the instance default.
537+
538+ Returns:
539+ Dictionary with cropped data for each key.
540+ """
541+ if not self ._has_str_roi :
542+ return super ().__call__ (data , lazy = lazy )
543+
544+ d = dict (data )
545+ roi_center = self ._resolve_roi_param (self ._roi_center , d )
546+ roi_size = self ._resolve_roi_param (self ._roi_size , d )
547+ roi_start = self ._resolve_roi_param (self ._roi_start , d )
548+ roi_end = self ._resolve_roi_param (self ._roi_end , d )
549+
550+ lazy_ = self .lazy if lazy is None else lazy
551+ self .cropper = SpatialCrop (
552+ roi_center = roi_center , roi_size = roi_size ,
553+ roi_start = roi_start , roi_end = roi_end ,
554+ roi_slices = self ._roi_slices , lazy = lazy_ ,
555+ )
556+ for key in self .key_iterator (d ):
557+ d [key ] = self .cropper (d [key ], lazy = lazy_ )
558+ return d
559+
560+ def inverse (self , data : Mapping [Hashable , MetaTensor ]) -> dict [Hashable , MetaTensor ]:
561+ """
562+ Inverse of the crop transform, restoring the original spatial dimensions via padding.
563+
564+ For the string-key path, ``self.cropper`` is recreated on each ``__call__``, so its
565+ ``id()`` won't match the one stored in the MetaTensor's transform stack. This override
566+ bypasses the ID check and applies the inverse directly using the crop info stored in the
567+ MetaTensor.
568+
569+ Args:
570+ data: dictionary of cropped ``MetaTensor`` items.
571+
572+ Returns:
573+ Dictionary with inverse-transformed (padded) data for each key.
574+ """
575+ if not self ._has_str_roi :
576+ return super ().inverse (data )
577+ d = dict (data )
578+ for key in self .key_iterator (d ):
579+ transform = self .cropper .pop_transform (d [key ], check = False )
580+ cropped = transform [TraceKeys .EXTRA_INFO ]["cropped" ]
581+ inverse_transform = BorderPad (cropped )
582+ with inverse_transform .trace_transform (False ):
583+ d [key ] = inverse_transform (d [key ])
584+ return d
585+
466586
467587class CenterSpatialCropd (Cropd ):
468588 """
0 commit comments