4646 VoteEnsemble ,
4747)
4848from monai .transforms .transform import MapTransform
49- from monai .transforms .utility .array import ToTensor
49+ from monai .transforms .utility .array import ApplyTransformToPoints , ToTensor
5050from monai .transforms .utils import allow_missing_keys_mode , convert_applied_interp_mode
5151from monai .utils import PostFix , convert_to_tensor , ensure_tuple , ensure_tuple_rep
5252from monai .utils .type_conversion import convert_to_dst_type
@@ -527,6 +527,12 @@ class GenerateHeatmapd(MapTransform):
527527 heatmap_keys: keys to store output heatmaps. Default: "{key}_heatmap" for each key.
528528 ref_image_keys: keys of reference images to inherit spatial metadata from. When provided, heatmaps will
529529 have the same shape, affine, and spatial metadata as the reference images.
530+ coordinate_space: coordinate system of the input points. ``"voxel"`` keeps the existing behavior and treats
531+ points as voxel coordinates in the output heatmap space. ``"world"`` transforms points to reference-image
532+ voxel coordinates with ``ref_image_keys`` before generating heatmaps. If the points are a ``MetaTensor``
533+ with their own affine, that affine is used as the point-to-world transform.
534+ visibility_keys: optional keys to store a boolean visibility mask for each point after coordinate conversion.
535+ The value is ``True`` when the transformed point is finite and inside the heatmap spatial shape.
530536 spatial_shape: spatial dimensions of output heatmaps. Can be:
531537 - Single shape (tuple): applied to all keys
532538 - List of shapes: one per key (must match keys length)
@@ -542,6 +548,7 @@ class GenerateHeatmapd(MapTransform):
542548 ValueError: If heatmap_keys/ref_image_keys length doesn't match keys length.
543549 ValueError: If no spatial shape can be determined (need spatial_shape or ref_image_keys).
544550 ValueError: If input points have invalid shape (must be 2D array with shape (N, D)).
551+ ValueError: If ``coordinate_space="world"`` is used without a reference affine.
545552
546553 Example:
547554 .. code-block:: python
@@ -573,12 +580,24 @@ class GenerateHeatmapd(MapTransform):
573580 result = transform(data)
574581 # result["landmarks_heatmap"] has shape (2, 64, 64)
575582
583+ # World-space landmarks can be converted against the reference affine.
584+ transform = GenerateHeatmapd(
585+ keys="landmarks_world",
586+ heatmap_keys="landmark_heatmap",
587+ ref_image_keys="image",
588+ coordinate_space="world",
589+ visibility_keys="landmark_visible",
590+ sigma=2.0,
591+ )
592+
576593 Notes:
577594 - Default heatmap_keys are generated as "{key}_heatmap" for each input key
578595 - Shape inference precedence: static spatial_shape > ref_image
579596 - Input points shape: (N, D) where N is number of landmarks, D is spatial dimensions
580597 - Output heatmap shape: (N, H, W) for 2D or (N, H, W, D) for 3D
581598 - When using ref_image_keys, heatmaps inherit affine and spatial metadata from reference
599+ - ``coordinate_space="world"`` assumes that the points and reference affine use the same world-coordinate
600+ convention. Convert LPS/RAS conventions before calling this transform if needed.
582601 """
583602
584603 backend = GenerateHeatmap .backend
@@ -590,13 +609,20 @@ class GenerateHeatmapd(MapTransform):
590609 _ERR_NO_SHAPE = "Unable to determine spatial shape for GenerateHeatmapd. Provide spatial_shape or ref_image_keys."
591610 _ERR_INVALID_POINTS = "Landmark arrays must be 2D with shape (N, D)."
592611 _ERR_REF_NO_SHAPE = "Reference data must define a shape attribute."
612+ _ERR_VISIBILITY_KEYS_LEN = "Argument `visibility_keys` length must match keys length when provided."
613+ _ERR_COORDINATE_SPACE_LEN = (
614+ "Argument `coordinate_space` length must match keys length when providing per-key values."
615+ )
616+ _SUPPORTED_COORDINATE_SPACES = {"voxel" , "world" }
593617
594618 def __init__ (
595619 self ,
596620 keys : KeysCollection ,
597621 sigma : Sequence [float ] | float = 5.0 ,
598622 heatmap_keys : KeysCollection | None = None ,
599623 ref_image_keys : KeysCollection | None = None ,
624+ coordinate_space : str | Sequence [str ] = "voxel" ,
625+ visibility_keys : KeysCollection | None = None ,
600626 spatial_shape : Sequence [int ] | Sequence [Sequence [int ]] | None = None ,
601627 truncated : float = 4.0 ,
602628 normalize : bool = True ,
@@ -606,22 +632,32 @@ def __init__(
606632 super ().__init__ (keys , allow_missing_keys )
607633 self .heatmap_keys = self ._prepare_heatmap_keys (heatmap_keys )
608634 self .ref_image_keys = self ._prepare_optional_keys (ref_image_keys )
635+ self .coordinate_spaces = self ._prepare_coordinate_spaces (coordinate_space )
636+ self .visibility_keys = self ._prepare_visibility_keys (visibility_keys )
609637 self .static_shapes = self ._prepare_shapes (spatial_shape )
610638 self .generator = GenerateHeatmap (
611639 sigma = sigma , spatial_shape = None , truncated = truncated , normalize = normalize , dtype = dtype
612640 )
641+ self .world_to_voxel = ApplyTransformToPoints (dtype = torch .float32 , invert_affine = True )
613642
614643 def __call__ (self , data : Mapping [Hashable , Any ]) -> dict [Hashable , Any ]:
615644 d = dict (data )
616- for key , out_key , ref_key , static_shape in self .key_iterator (
617- d , self .heatmap_keys , self .ref_image_keys , self .static_shapes
645+ for key , out_key , ref_key , coordinate_space , visibility_key , static_shape in self .key_iterator (
646+ d ,
647+ self .heatmap_keys ,
648+ self .ref_image_keys ,
649+ self .coordinate_spaces ,
650+ self .visibility_keys ,
651+ self .static_shapes ,
618652 ):
619653 points = d [key ]
620654 shape = self ._determine_shape (points , static_shape , d , ref_key )
655+ reference = d .get (ref_key ) if ref_key is not None and ref_key in d else None
656+ points = self ._convert_points (points , reference , coordinate_space )
657+ visibility = self ._compute_visibility (points , shape )
621658 # The GenerateHeatmap transform will handle type conversion based on input points
622659 heatmap = self .generator (points , spatial_shape = shape )
623660 # If there's a reference image and we need to match its type/device
624- reference = d .get (ref_key ) if ref_key is not None and ref_key in d else None
625661 if reference is not None and isinstance (reference , (torch .Tensor , np .ndarray )):
626662 # Convert to match reference type and device while preserving heatmap's dtype
627663 heatmap , _ , _ = convert_to_dst_type (
@@ -632,6 +668,8 @@ def __call__(self, data: Mapping[Hashable, Any]) -> dict[Hashable, Any]:
632668 heatmap .affine = reference .affine
633669 self ._update_spatial_metadata (heatmap , shape )
634670 d [out_key ] = heatmap
671+ if visibility_key is not None :
672+ d [visibility_key ] = self ._convert_visibility (visibility , d [key ])
635673 return d
636674
637675 def _prepare_heatmap_keys (self , heatmap_keys : KeysCollection | None ) -> tuple [Hashable , ...]:
@@ -654,6 +692,34 @@ def _prepare_optional_keys(self, maybe_keys: KeysCollection | None) -> tuple[Has
654692 raise ValueError (self ._ERR_REF_KEYS_LEN )
655693 return tuple (keys_tuple )
656694
695+ def _prepare_visibility_keys (self , maybe_keys : KeysCollection | None ) -> tuple [Hashable | None , ...]:
696+ if maybe_keys is None :
697+ return (None ,) * len (self .keys )
698+ keys_tuple = ensure_tuple (maybe_keys )
699+ if len (keys_tuple ) == 1 and len (self .keys ) > 1 :
700+ keys_tuple = keys_tuple * len (self .keys )
701+ if len (keys_tuple ) != len (self .keys ):
702+ raise ValueError (self ._ERR_VISIBILITY_KEYS_LEN )
703+ return tuple (keys_tuple )
704+
705+ def _prepare_coordinate_spaces (self , coordinate_space : str | Sequence [str ]) -> tuple [str , ...]:
706+ if isinstance (coordinate_space , str ):
707+ spaces = (coordinate_space ,) * len (self .keys )
708+ else :
709+ spaces = ensure_tuple (coordinate_space )
710+ if len (spaces ) == 1 and len (self .keys ) > 1 :
711+ spaces = spaces * len (self .keys )
712+ if len (spaces ) != len (self .keys ):
713+ raise ValueError (self ._ERR_COORDINATE_SPACE_LEN )
714+ spaces = tuple (str (space ).lower () for space in spaces )
715+ invalid = set (spaces ) - self ._SUPPORTED_COORDINATE_SPACES
716+ if invalid :
717+ raise ValueError (
718+ f"Unsupported coordinate_space value: { sorted (invalid )} . "
719+ f"Supported values are { sorted (self ._SUPPORTED_COORDINATE_SPACES )} ."
720+ )
721+ return spaces
722+
657723 def _prepare_shapes (
658724 self , spatial_shape : Sequence [int ] | Sequence [Sequence [int ]] | None
659725 ) -> tuple [tuple [int , ...] | None , ...]:
@@ -711,6 +777,46 @@ def _update_spatial_metadata(self, heatmap: MetaTensor, spatial_shape: tuple[int
711777 """Set spatial_shape explicitly from resolved shape."""
712778 heatmap .meta ["spatial_shape" ] = tuple (int (v ) for v in spatial_shape )
713779
780+ def _convert_points (self , points : Any , reference : Any , coordinate_space : str ) -> Any :
781+ if coordinate_space == "voxel" :
782+ return points
783+
784+ affine = self ._get_reference_affine (reference )
785+ points_t = convert_to_tensor (points , dtype = torch .float32 , track_meta = False )
786+ if points_t .ndim != 2 :
787+ raise ValueError (f"{ self ._ERR_INVALID_POINTS } Got { points_t .ndim } D tensor." )
788+
789+ if isinstance (points , MetaTensor ):
790+ points_to_transform = points .unsqueeze (0 )
791+ else :
792+ points_to_transform = points_t .unsqueeze (0 )
793+ converted = self .world_to_voxel (points_to_transform , affine ).squeeze (0 )
794+ return converted
795+
796+ def _get_reference_affine (self , reference : Any ) -> torch .Tensor :
797+ if reference is None :
798+ raise ValueError ("coordinate_space='world' requires ref_image_keys or a reference affine." )
799+ affine = getattr (reference , "affine" , None )
800+ if affine is not None :
801+ return affine
802+ if isinstance (reference , (torch .Tensor , np .ndarray )) and reference .shape in ((3 , 3 ), (4 , 4 )):
803+ return reference
804+ raise ValueError ("coordinate_space='world' requires reference data with an affine matrix." )
805+
806+ def _compute_visibility (self , points : Any , spatial_shape : tuple [int , ...]) -> torch .Tensor :
807+ points_t = convert_to_tensor (points , dtype = torch .float32 , track_meta = False )
808+ if points_t .ndim != 2 :
809+ raise ValueError (f"{ self ._ERR_INVALID_POINTS } Got { points_t .ndim } D tensor." )
810+ bounds = torch .as_tensor (spatial_shape , dtype = points_t .dtype , device = points_t .device )
811+ return torch .isfinite (points_t ).all (dim = 1 ) & (points_t >= 0 ).all (dim = 1 ) & (points_t < bounds ).all (dim = 1 )
812+
813+ def _convert_visibility (self , visibility : torch .Tensor , points : Any ) -> NdarrayOrTensor :
814+ if isinstance (points , (MetaTensor , torch .Tensor )):
815+ return visibility .to (device = points .device , dtype = torch .bool )
816+ if isinstance (points , np .ndarray ):
817+ return visibility .cpu ().numpy ().astype (bool )
818+ return visibility .to (dtype = torch .bool )
819+
714820
715821GenerateHeatmapD = GenerateHeatmapDict = GenerateHeatmapd
716822
0 commit comments