6464 GridSamplePadMode ,
6565 InterpolateMode ,
6666 NumpyPadMode ,
67+ SpaceKeys ,
6768 convert_to_cupy ,
6869 convert_to_dst_type ,
6970 convert_to_numpy ,
@@ -560,7 +561,7 @@ def __init__(
560561 self ,
561562 axcodes : str | None = None ,
562563 as_closest_canonical : bool = False ,
563- labels : Sequence [tuple [str , str ]] | None = (( "L" , "R" ), ( "P" , "A" ), ( "I" , "S" )) ,
564+ labels : Sequence [tuple [str , str ]] | None = None ,
564565 lazy : bool = False ,
565566 ) -> None :
566567 """
@@ -573,7 +574,9 @@ def __init__(
573574 as_closest_canonical: if True, load the image as closest to canonical axis format.
574575 labels: optional, None or sequence of (2,) sequences
575576 (2,) sequences are labels for (beginning, end) of output axis.
576- Defaults to ``(('L', 'R'), ('P', 'A'), ('I', 'S'))``.
577+ Defaults to using the ``"space"`` attribute of a metatensor,
578+ where appliable, or (('L', 'R'), ('P', 'A'), ('I', 'S'))``
579+ otherwise (i.e. for plain tensors).
577580 lazy: a flag to indicate whether this transform should execute lazily or not.
578581 Defaults to False
579582
@@ -619,9 +622,15 @@ def __call__(self, data_array: torch.Tensor, lazy: bool | None = None) -> torch.
619622 raise ValueError (f"data_array must have at least one spatial dimension, got { spatial_shape } ." )
620623 affine_ : np .ndarray
621624 affine_np : np .ndarray
625+ labels = self .labels
622626 if isinstance (data_array , MetaTensor ):
623627 affine_np , * _ = convert_data_type (data_array .peek_pending_affine (), np .ndarray )
624628 affine_ = to_affine_nd (sr , affine_np )
629+
630+ # Set up "labels" such that LPS tensors are handled correctly by default
631+ if self .labels is None and SpaceKeys (data_array .meta ["space" ]) == SpaceKeys .LPS :
632+ labels = (("R" , "L" ), ("A" , "P" ), ("I" , "S" )) # value for LPS
633+
625634 else :
626635 warnings .warn ("`data_array` is not of type `MetaTensor, assuming affine to be identity." )
627636 # default to identity
@@ -640,7 +649,7 @@ def __call__(self, data_array: torch.Tensor, lazy: bool | None = None) -> torch.
640649 f"{ self .__class__ .__name__ } : spatial shape = { spatial_shape } , channels = { data_array .shape [0 ]} ,"
641650 "please make sure the input is in the channel-first format."
642651 )
643- dst = nib .orientations .axcodes2ornt (self .axcodes [:sr ], labels = self . labels )
652+ dst = nib .orientations .axcodes2ornt (self .axcodes [:sr ], labels = labels )
644653 if len (dst ) < sr :
645654 raise ValueError (
646655 f"axcodes must match data_array spatially, got axcodes={ len (self .axcodes )} D data_array={ sr } D"
@@ -653,8 +662,18 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor:
653662 transform = self .pop_transform (data )
654663 # Create inverse transform
655664 orig_affine = transform [TraceKeys .EXTRA_INFO ]["original_affine" ]
656- orig_axcodes = nib .orientations .aff2axcodes (orig_affine )
657- inverse_transform = Orientation (axcodes = orig_axcodes , as_closest_canonical = False , labels = self .labels )
665+ labels = self .labels
666+
667+ # Set up "labels" such that LPS tensors are handled correctly by default
668+ if (
669+ isinstance (data , MetaTensor ) and
670+ self .labels is None and
671+ SpaceKeys (data .meta ["space" ]) == SpaceKeys .LPS
672+ ):
673+ labels = (("R" , "L" ), ("A" , "P" ), ("I" , "S" )) # value for LPS
674+
675+ orig_axcodes = nib .orientations .aff2axcodes (orig_affine , labels = labels )
676+ inverse_transform = Orientation (axcodes = orig_axcodes , as_closest_canonical = False , labels = labels )
658677 # Apply inverse
659678 with inverse_transform .trace_transform (False ):
660679 data = inverse_transform (data )
0 commit comments