Skip to content

Commit dbbfbb3

Browse files
committed
Add affine-aware landmark heatmap generation
Signed-off-by: Mustafa Merchant <mustafamerchant072@gmail.com>
1 parent b7d14c8 commit dbbfbb3

3 files changed

Lines changed: 218 additions & 4 deletions

File tree

docs/source/whatsnew_1_6.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
- Nested dot-notation key access in `ConfigParser`.
1010
- Auto3DSeg algo serialization migrated from pickle to JSON for improved security and portability.
1111
- Global coordinates support in spatial crop transforms. These now support global coordinate mode, allowing crops to be specified in world/global coordinates rather than local image indices, improving interoperability with physical-space annotations.
12+
- `GenerateHeatmapd` can convert world-coordinate landmarks to reference-image voxel space and emit landmark visibility masks.
1213
- `SoftclDiceLoss` and `SoftDiceclDiceLoss` enhanced with `DiceLoss`-compatible API
1314
- Variable expansion hardening has been added to the nnUNet app to eliminate code injection attacks when composing shell command lines, addressing concerns in [GHSA-rghg-q7wp-9767](https://github.com/Project-MONAI/MONAI/security/advisories/GHSA-rghg-q7wp-9767).
1415
- `NumpyReader` has been updated with an `allow_pickle` boolean argument to enable/disable pickle loading from `.npy/.npz` files. This was previously hard-coded to be enabled, but is now defined by this argument and disabled by default. This addresses [GHSA-qxq5-qhx6-94qw](https://github.com/Project-MONAI/MONAI/security/advisories/GHSA-qxq5-qhx6-94qw).

monai/transforms/post/dictionary.py

Lines changed: 110 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
VoteEnsemble,
4747
)
4848
from monai.transforms.transform import MapTransform
49-
from monai.transforms.utility.array import ToTensor
49+
from monai.transforms.utility.array import ApplyTransformToPoints, ToTensor
5050
from monai.transforms.utils import allow_missing_keys_mode, convert_applied_interp_mode
5151
from monai.utils import PostFix, convert_to_tensor, ensure_tuple, ensure_tuple_rep
5252
from monai.utils.type_conversion import convert_to_dst_type
@@ -527,6 +527,12 @@ class GenerateHeatmapd(MapTransform):
527527
heatmap_keys: keys to store output heatmaps. Default: "{key}_heatmap" for each key.
528528
ref_image_keys: keys of reference images to inherit spatial metadata from. When provided, heatmaps will
529529
have the same shape, affine, and spatial metadata as the reference images.
530+
coordinate_space: coordinate system of the input points. ``"voxel"`` keeps the existing behavior and treats
531+
points as voxel coordinates in the output heatmap space. ``"world"`` transforms points to reference-image
532+
voxel coordinates with ``ref_image_keys`` before generating heatmaps. If the points are a ``MetaTensor``
533+
with their own affine, that affine is used as the point-to-world transform.
534+
visibility_keys: optional keys to store a boolean visibility mask for each point after coordinate conversion.
535+
The value is ``True`` when the transformed point is finite and inside the heatmap spatial shape.
530536
spatial_shape: spatial dimensions of output heatmaps. Can be:
531537
- Single shape (tuple): applied to all keys
532538
- List of shapes: one per key (must match keys length)
@@ -542,6 +548,7 @@ class GenerateHeatmapd(MapTransform):
542548
ValueError: If heatmap_keys/ref_image_keys length doesn't match keys length.
543549
ValueError: If no spatial shape can be determined (need spatial_shape or ref_image_keys).
544550
ValueError: If input points have invalid shape (must be 2D array with shape (N, D)).
551+
ValueError: If ``coordinate_space="world"`` is used without a reference affine.
545552
546553
Example:
547554
.. code-block:: python
@@ -573,12 +580,24 @@ class GenerateHeatmapd(MapTransform):
573580
result = transform(data)
574581
# result["landmarks_heatmap"] has shape (2, 64, 64)
575582
583+
# World-space landmarks can be converted against the reference affine.
584+
transform = GenerateHeatmapd(
585+
keys="landmarks_world",
586+
heatmap_keys="landmark_heatmap",
587+
ref_image_keys="image",
588+
coordinate_space="world",
589+
visibility_keys="landmark_visible",
590+
sigma=2.0,
591+
)
592+
576593
Notes:
577594
- Default heatmap_keys are generated as "{key}_heatmap" for each input key
578595
- Shape inference precedence: static spatial_shape > ref_image
579596
- Input points shape: (N, D) where N is number of landmarks, D is spatial dimensions
580597
- Output heatmap shape: (N, H, W) for 2D or (N, H, W, D) for 3D
581598
- When using ref_image_keys, heatmaps inherit affine and spatial metadata from reference
599+
- ``coordinate_space="world"`` assumes that the points and reference affine use the same world-coordinate
600+
convention. Convert LPS/RAS conventions before calling this transform if needed.
582601
"""
583602

584603
backend = GenerateHeatmap.backend
@@ -590,13 +609,20 @@ class GenerateHeatmapd(MapTransform):
590609
_ERR_NO_SHAPE = "Unable to determine spatial shape for GenerateHeatmapd. Provide spatial_shape or ref_image_keys."
591610
_ERR_INVALID_POINTS = "Landmark arrays must be 2D with shape (N, D)."
592611
_ERR_REF_NO_SHAPE = "Reference data must define a shape attribute."
612+
_ERR_VISIBILITY_KEYS_LEN = "Argument `visibility_keys` length must match keys length when provided."
613+
_ERR_COORDINATE_SPACE_LEN = (
614+
"Argument `coordinate_space` length must match keys length when providing per-key values."
615+
)
616+
_SUPPORTED_COORDINATE_SPACES = {"voxel", "world"}
593617

594618
def __init__(
595619
self,
596620
keys: KeysCollection,
597621
sigma: Sequence[float] | float = 5.0,
598622
heatmap_keys: KeysCollection | None = None,
599623
ref_image_keys: KeysCollection | None = None,
624+
coordinate_space: str | Sequence[str] = "voxel",
625+
visibility_keys: KeysCollection | None = None,
600626
spatial_shape: Sequence[int] | Sequence[Sequence[int]] | None = None,
601627
truncated: float = 4.0,
602628
normalize: bool = True,
@@ -606,22 +632,32 @@ def __init__(
606632
super().__init__(keys, allow_missing_keys)
607633
self.heatmap_keys = self._prepare_heatmap_keys(heatmap_keys)
608634
self.ref_image_keys = self._prepare_optional_keys(ref_image_keys)
635+
self.coordinate_spaces = self._prepare_coordinate_spaces(coordinate_space)
636+
self.visibility_keys = self._prepare_visibility_keys(visibility_keys)
609637
self.static_shapes = self._prepare_shapes(spatial_shape)
610638
self.generator = GenerateHeatmap(
611639
sigma=sigma, spatial_shape=None, truncated=truncated, normalize=normalize, dtype=dtype
612640
)
641+
self.world_to_voxel = ApplyTransformToPoints(dtype=torch.float32, invert_affine=True)
613642

614643
def __call__(self, data: Mapping[Hashable, Any]) -> dict[Hashable, Any]:
615644
d = dict(data)
616-
for key, out_key, ref_key, static_shape in self.key_iterator(
617-
d, self.heatmap_keys, self.ref_image_keys, self.static_shapes
645+
for key, out_key, ref_key, coordinate_space, visibility_key, static_shape in self.key_iterator(
646+
d,
647+
self.heatmap_keys,
648+
self.ref_image_keys,
649+
self.coordinate_spaces,
650+
self.visibility_keys,
651+
self.static_shapes,
618652
):
619653
points = d[key]
620654
shape = self._determine_shape(points, static_shape, d, ref_key)
655+
reference = d.get(ref_key) if ref_key is not None and ref_key in d else None
656+
points = self._convert_points(points, reference, coordinate_space)
657+
visibility = self._compute_visibility(points, shape)
621658
# The GenerateHeatmap transform will handle type conversion based on input points
622659
heatmap = self.generator(points, spatial_shape=shape)
623660
# If there's a reference image and we need to match its type/device
624-
reference = d.get(ref_key) if ref_key is not None and ref_key in d else None
625661
if reference is not None and isinstance(reference, (torch.Tensor, np.ndarray)):
626662
# Convert to match reference type and device while preserving heatmap's dtype
627663
heatmap, _, _ = convert_to_dst_type(
@@ -632,6 +668,8 @@ def __call__(self, data: Mapping[Hashable, Any]) -> dict[Hashable, Any]:
632668
heatmap.affine = reference.affine
633669
self._update_spatial_metadata(heatmap, shape)
634670
d[out_key] = heatmap
671+
if visibility_key is not None:
672+
d[visibility_key] = self._convert_visibility(visibility, d[key])
635673
return d
636674

637675
def _prepare_heatmap_keys(self, heatmap_keys: KeysCollection | None) -> tuple[Hashable, ...]:
@@ -654,6 +692,34 @@ def _prepare_optional_keys(self, maybe_keys: KeysCollection | None) -> tuple[Has
654692
raise ValueError(self._ERR_REF_KEYS_LEN)
655693
return tuple(keys_tuple)
656694

695+
def _prepare_visibility_keys(self, maybe_keys: KeysCollection | None) -> tuple[Hashable | None, ...]:
696+
if maybe_keys is None:
697+
return (None,) * len(self.keys)
698+
keys_tuple = ensure_tuple(maybe_keys)
699+
if len(keys_tuple) == 1 and len(self.keys) > 1:
700+
keys_tuple = keys_tuple * len(self.keys)
701+
if len(keys_tuple) != len(self.keys):
702+
raise ValueError(self._ERR_VISIBILITY_KEYS_LEN)
703+
return tuple(keys_tuple)
704+
705+
def _prepare_coordinate_spaces(self, coordinate_space: str | Sequence[str]) -> tuple[str, ...]:
706+
if isinstance(coordinate_space, str):
707+
spaces = (coordinate_space,) * len(self.keys)
708+
else:
709+
spaces = ensure_tuple(coordinate_space)
710+
if len(spaces) == 1 and len(self.keys) > 1:
711+
spaces = spaces * len(self.keys)
712+
if len(spaces) != len(self.keys):
713+
raise ValueError(self._ERR_COORDINATE_SPACE_LEN)
714+
spaces = tuple(str(space).lower() for space in spaces)
715+
invalid = set(spaces) - self._SUPPORTED_COORDINATE_SPACES
716+
if invalid:
717+
raise ValueError(
718+
f"Unsupported coordinate_space value: {sorted(invalid)}. "
719+
f"Supported values are {sorted(self._SUPPORTED_COORDINATE_SPACES)}."
720+
)
721+
return spaces
722+
657723
def _prepare_shapes(
658724
self, spatial_shape: Sequence[int] | Sequence[Sequence[int]] | None
659725
) -> tuple[tuple[int, ...] | None, ...]:
@@ -711,6 +777,46 @@ def _update_spatial_metadata(self, heatmap: MetaTensor, spatial_shape: tuple[int
711777
"""Set spatial_shape explicitly from resolved shape."""
712778
heatmap.meta["spatial_shape"] = tuple(int(v) for v in spatial_shape)
713779

780+
def _convert_points(self, points: Any, reference: Any, coordinate_space: str) -> Any:
781+
if coordinate_space == "voxel":
782+
return points
783+
784+
affine = self._get_reference_affine(reference)
785+
points_t = convert_to_tensor(points, dtype=torch.float32, track_meta=False)
786+
if points_t.ndim != 2:
787+
raise ValueError(f"{self._ERR_INVALID_POINTS} Got {points_t.ndim}D tensor.")
788+
789+
if isinstance(points, MetaTensor):
790+
points_to_transform = points.unsqueeze(0)
791+
else:
792+
points_to_transform = points_t.unsqueeze(0)
793+
converted = self.world_to_voxel(points_to_transform, affine).squeeze(0)
794+
return converted
795+
796+
def _get_reference_affine(self, reference: Any) -> torch.Tensor:
797+
if reference is None:
798+
raise ValueError("coordinate_space='world' requires ref_image_keys or a reference affine.")
799+
affine = getattr(reference, "affine", None)
800+
if affine is not None:
801+
return affine
802+
if isinstance(reference, (torch.Tensor, np.ndarray)) and reference.shape in ((3, 3), (4, 4)):
803+
return reference
804+
raise ValueError("coordinate_space='world' requires reference data with an affine matrix.")
805+
806+
def _compute_visibility(self, points: Any, spatial_shape: tuple[int, ...]) -> torch.Tensor:
807+
points_t = convert_to_tensor(points, dtype=torch.float32, track_meta=False)
808+
if points_t.ndim != 2:
809+
raise ValueError(f"{self._ERR_INVALID_POINTS} Got {points_t.ndim}D tensor.")
810+
bounds = torch.as_tensor(spatial_shape, dtype=points_t.dtype, device=points_t.device)
811+
return torch.isfinite(points_t).all(dim=1) & (points_t >= 0).all(dim=1) & (points_t < bounds).all(dim=1)
812+
813+
def _convert_visibility(self, visibility: torch.Tensor, points: Any) -> NdarrayOrTensor:
814+
if isinstance(points, (MetaTensor, torch.Tensor)):
815+
return visibility.to(device=points.device, dtype=torch.bool)
816+
if isinstance(points, np.ndarray):
817+
return visibility.cpu().numpy().astype(bool)
818+
return visibility.to(dtype=torch.bool)
819+
714820

715821
GenerateHeatmapD = GenerateHeatmapDict = GenerateHeatmapd
716822

tests/transforms/test_generate_heatmapd.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,12 @@
2121
from monai.transforms.post.dictionary import GenerateHeatmapd
2222
from tests.test_utils import assert_allclose
2323

24+
25+
def _peak_coord(channel: torch.Tensor) -> torch.Tensor:
26+
idx = torch.argmax(channel)
27+
return torch.stack(torch.unravel_index(idx, channel.shape))
28+
29+
2430
# Test cases for dictionary transforms with reference image
2531
# Only test with non-MetaTensor types to avoid affine conflicts
2632
TEST_CASES_WITH_REF = [
@@ -220,6 +226,107 @@ def test_metatensor_points_with_ref(self):
220226
# Heatmap should inherit affine from the reference image
221227
assert_allclose(heatmap.affine, image.affine, type_test=False)
222228

229+
def test_world_points_with_reference_affine_and_visibility(self):
230+
affine = torch.diag(torch.tensor([2.0, 2.0, 2.0, 1.0]))
231+
image = MetaTensor(torch.zeros((1, 8, 8, 8), dtype=torch.float32), affine=affine)
232+
image.meta["spatial_shape"] = (8, 8, 8)
233+
points = torch.tensor(
234+
[
235+
[4.0, 6.0, 8.0], # voxel coordinate [2, 3, 4]
236+
[20.0, 0.0, 0.0], # out of bounds after affine conversion
237+
[float("nan"), 0.0, 0.0],
238+
],
239+
dtype=torch.float32,
240+
)
241+
242+
transform = GenerateHeatmapd(
243+
keys="points",
244+
heatmap_keys="heatmap",
245+
ref_image_keys="image",
246+
coordinate_space="world",
247+
visibility_keys="visible",
248+
sigma=1.0,
249+
)
250+
result = transform({"points": points, "image": image})
251+
252+
heatmap = result["heatmap"]
253+
self.assertIsInstance(heatmap, MetaTensor)
254+
self.assertEqual(tuple(heatmap.shape), (3, 8, 8, 8))
255+
assert_allclose(_peak_coord(heatmap[0]), torch.tensor([2, 3, 4]), type_test=False)
256+
self.assertTrue(torch.equal(result["visible"], torch.tensor([True, False, False])))
257+
self.assertGreater(heatmap[0].max(), 0.99)
258+
self.assertEqual(float(heatmap[1].max()), 0.0)
259+
self.assertEqual(float(heatmap[2].max()), 0.0)
260+
261+
def test_world_points_with_translated_rotated_affine(self):
262+
affine = torch.tensor(
263+
[
264+
[0.0, -2.0, 0.0, 10.0],
265+
[3.0, 0.0, 0.0, 20.0],
266+
[0.0, 0.0, 4.0, 30.0],
267+
[0.0, 0.0, 0.0, 1.0],
268+
],
269+
dtype=torch.float32,
270+
)
271+
image = MetaTensor(torch.zeros((1, 8, 8, 8), dtype=torch.float32), affine=affine)
272+
image.meta["spatial_shape"] = (8, 8, 8)
273+
voxel_point = torch.tensor([2.0, 3.0, 4.0], dtype=torch.float32)
274+
world_point = affine[:3, :3] @ voxel_point + affine[:3, 3]
275+
276+
transform = GenerateHeatmapd(
277+
keys="points",
278+
heatmap_keys="heatmap",
279+
ref_image_keys="image",
280+
coordinate_space="world",
281+
visibility_keys="visible",
282+
sigma=1.0,
283+
)
284+
result = transform({"points": world_point[None], "image": image})
285+
286+
assert_allclose(_peak_coord(result["heatmap"][0]), voxel_point.to(torch.long), type_test=False)
287+
self.assertTrue(torch.equal(result["visible"], torch.tensor([True])))
288+
289+
def test_world_metatensor_points_use_point_affine(self):
290+
image = MetaTensor(torch.zeros((1, 8, 8, 8), dtype=torch.float32), affine=torch.eye(4))
291+
image.meta["spatial_shape"] = (8, 8, 8)
292+
points_affine = torch.diag(torch.tensor([2.0, 2.0, 2.0, 1.0]))
293+
points = MetaTensor(torch.tensor([[1.0, 2.0, 3.0]], dtype=torch.float32), affine=points_affine)
294+
295+
transform = GenerateHeatmapd(
296+
keys="points",
297+
heatmap_keys="heatmap",
298+
ref_image_keys="image",
299+
coordinate_space="world",
300+
visibility_keys="visible",
301+
sigma=1.0,
302+
)
303+
result = transform({"points": points, "image": image})
304+
305+
assert_allclose(_peak_coord(result["heatmap"][0]), torch.tensor([2, 4, 6]), type_test=False)
306+
self.assertIsInstance(result["visible"], torch.Tensor)
307+
self.assertNotIsInstance(result["visible"], MetaTensor)
308+
self.assertTrue(bool(result["visible"][0]))
309+
310+
def test_world_points_require_reference_affine(self):
311+
transform = GenerateHeatmapd(
312+
keys="points", heatmap_keys="heatmap", spatial_shape=(8, 8, 8), coordinate_space="world"
313+
)
314+
with self.assertRaisesRegex(ValueError, "reference|affine|ref_image_keys"):
315+
transform({"points": torch.zeros((1, 3), dtype=torch.float32)})
316+
317+
def test_invalid_coordinate_space_raises(self):
318+
with self.assertRaisesRegex(ValueError, "coordinate_space"):
319+
GenerateHeatmapd(keys="points", heatmap_keys="heatmap", spatial_shape=(8, 8), coordinate_space="scanner")
320+
321+
def test_visibility_key_length_mismatch_raises(self):
322+
with self.assertRaises(ValueError):
323+
GenerateHeatmapd(
324+
keys=["pts1", "pts2"],
325+
heatmap_keys=["hm1", "hm2"],
326+
visibility_keys=["visible1", "visible2", "visible3"],
327+
spatial_shape=(8, 8),
328+
)
329+
223330

224331
if __name__ == "__main__":
225332
unittest.main()

0 commit comments

Comments
 (0)