6464 GridSamplePadMode ,
6565 InterpolateMode ,
6666 NumpyPadMode ,
67+ SpaceKeys ,
6768 convert_to_cupy ,
6869 convert_to_dst_type ,
6970 convert_to_numpy ,
7576 issequenceiterable ,
7677 optional_import ,
7778)
79+ from monai .utils .deprecate_utils import deprecated_arg_default
7880from monai .utils .enums import GridPatchSort , PatchKeys , TraceKeys , TransformBackends
7981from monai .utils .misc import ImageMetaKey as Key
8082from monai .utils .module import look_up_option
@@ -556,11 +558,20 @@ class Orientation(InvertibleTransform, LazyTransform):
556558
557559 backend = [TransformBackends .NUMPY , TransformBackends .TORCH ]
558560
561+ @deprecated_arg_default (
562+ name = "labels" ,
563+ old_default = (("L" , "R" ), ("P" , "A" ), ("I" , "S" )),
564+ new_default = None ,
565+ msg_suffix = (
566+ "Default value changed to None meaning that the transform now uses the 'space' of a "
567+ "meta-tensor, if applicable, to determine appropriate axis labels."
568+ ),
569+ )
559570 def __init__ (
560571 self ,
561572 axcodes : str | None = None ,
562573 as_closest_canonical : bool = False ,
563- labels : Sequence [tuple [str , str ]] | None = (( "L" , "R" ), ( "P" , "A" ), ( "I" , "S" )) ,
574+ labels : Sequence [tuple [str , str ]] | None = None ,
564575 lazy : bool = False ,
565576 ) -> None :
566577 """
@@ -573,7 +584,14 @@ def __init__(
573584 as_closest_canonical: if True, load the image as closest to canonical axis format.
574585 labels: optional, None or sequence of (2,) sequences
575586 (2,) sequences are labels for (beginning, end) of output axis.
576- Defaults to ``(('L', 'R'), ('P', 'A'), ('I', 'S'))``.
587+ If ``None``, an appropriate value is chosen depending on the
588+ value of the ``"space"`` metadata item of a metatensor: if
589+ ``"space"`` is ``"LPS"``, the value used is ``(('R', 'L'),
590+ ('A', 'P'), ('I', 'S'))``, if ``"space"`` is ``"RPS"`` or the
591+ input is not a meta-tensor or has no ``"space"`` item, the
592+ value ``(('L', 'R'), ('P', 'A'), ('I', 'S'))`` is used. If not
593+ ``None``, the provided value is always used and the ``"space"``
594+ metadata item (if any) of the input is ignored.
577595 lazy: a flag to indicate whether this transform should execute lazily or not.
578596 Defaults to False
579597
@@ -619,9 +637,19 @@ def __call__(self, data_array: torch.Tensor, lazy: bool | None = None) -> torch.
619637 raise ValueError (f"data_array must have at least one spatial dimension, got { spatial_shape } ." )
620638 affine_ : np .ndarray
621639 affine_np : np .ndarray
640+ labels = self .labels
622641 if isinstance (data_array , MetaTensor ):
623642 affine_np , * _ = convert_data_type (data_array .peek_pending_affine (), np .ndarray )
624643 affine_ = to_affine_nd (sr , affine_np )
644+
645+ # Set up "labels" such that LPS tensors are handled correctly by default
646+ if (
647+ self .labels is None
648+ and "space" in data_array .meta
649+ and SpaceKeys (data_array .meta ["space" ]) == SpaceKeys .LPS
650+ ):
651+ labels = (("R" , "L" ), ("A" , "P" ), ("I" , "S" )) # value for LPS
652+
625653 else :
626654 warnings .warn ("`data_array` is not of type `MetaTensor, assuming affine to be identity." )
627655 # default to identity
@@ -640,7 +668,7 @@ def __call__(self, data_array: torch.Tensor, lazy: bool | None = None) -> torch.
640668 f"{ self .__class__ .__name__ } : spatial shape = { spatial_shape } , channels = { data_array .shape [0 ]} ,"
641669 "please make sure the input is in the channel-first format."
642670 )
643- dst = nib .orientations .axcodes2ornt (self .axcodes [:sr ], labels = self . labels )
671+ dst = nib .orientations .axcodes2ornt (self .axcodes [:sr ], labels = labels )
644672 if len (dst ) < sr :
645673 raise ValueError (
646674 f"axcodes must match data_array spatially, got axcodes={ len (self .axcodes )} D data_array={ sr } D"
@@ -653,8 +681,19 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor:
653681 transform = self .pop_transform (data )
654682 # Create inverse transform
655683 orig_affine = transform [TraceKeys .EXTRA_INFO ]["original_affine" ]
656- orig_axcodes = nib .orientations .aff2axcodes (orig_affine )
657- inverse_transform = Orientation (axcodes = orig_axcodes , as_closest_canonical = False , labels = self .labels )
684+ labels = self .labels
685+
686+ # Set up "labels" such that LPS tensors are handled correctly by default
687+ if (
688+ isinstance (data , MetaTensor )
689+ and self .labels is None
690+ and "space" in data .meta
691+ and SpaceKeys (data .meta ["space" ]) == SpaceKeys .LPS
692+ ):
693+ labels = (("R" , "L" ), ("A" , "P" ), ("I" , "S" )) # value for LPS
694+
695+ orig_axcodes = nib .orientations .aff2axcodes (orig_affine , labels = labels )
696+ inverse_transform = Orientation (axcodes = orig_axcodes , as_closest_canonical = False , labels = labels )
658697 # Apply inverse
659698 with inverse_transform .trace_transform (False ):
660699 data = inverse_transform (data )
0 commit comments