Skip to content

Commit d99cdbe

Browse files
committed
Enhance ExtractDataKeyFromMetaKeyd to work with MetaTensor
When LoadImaged is used with image_only=True (the default), the loaded data is a MetaTensor with metadata accessible via .meta attribute. Previously, ExtractDataKeyFromMetaKeyd could only extract keys from metadata dictionaries (image_only=False scenario). This change adds support for MetaTensor by detecting if the meta_key references a MetaTensor instance and extracting from its .meta attribute instead of treating it as a plain dictionary. Fixes #7562 Signed-off-by: haoyu-haoyu <haoyu-haoyu@users.noreply.github.com> Signed-off-by: SexyERIC0723 <haoyuwang144@gmail.com>
1 parent daaedaa commit d99cdbe

File tree

2 files changed

+147
-3
lines changed

2 files changed

+147
-3
lines changed

monai/apps/reconstruction/transforms/dictionary.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from monai.apps.reconstruction.transforms.array import EquispacedKspaceMask, RandomKspaceMask
2121
from monai.config import DtypeLike, KeysCollection
2222
from monai.config.type_definitions import NdarrayOrTensor
23+
from monai.data.meta_tensor import MetaTensor
2324
from monai.transforms import InvertibleTransform
2425
from monai.transforms.croppad.array import SpatialCrop
2526
from monai.transforms.intensity.array import NormalizeIntensity
@@ -33,15 +34,36 @@ class ExtractDataKeyFromMetaKeyd(MapTransform):
3334
Moves keys from meta to data. It is useful when a dataset of paired samples
3435
is loaded and certain keys should be moved from meta to data.
3536
37+
This transform supports two modes:
38+
39+
1. When ``meta_key`` references a metadata dictionary in the data (e.g., when
40+
``image_only=False`` was used with ``LoadImaged``), the requested keys are
41+
extracted directly from that dictionary.
42+
43+
2. When ``meta_key`` references a ``MetaTensor`` in the data (e.g., when
44+
``image_only=True`` was used with ``LoadImaged``), the requested keys are
45+
extracted from its ``.meta`` attribute.
46+
3647
Args:
3748
keys: keys to be transferred from meta to data
38-
meta_key: the meta key where all the meta-data is stored
49+
meta_key: the key in the data dictionary where the metadata source is
50+
stored. This can be either a metadata dictionary or a ``MetaTensor``.
3951
allow_missing_keys: don't raise exception if key is missing
4052
4153
Example:
4254
When the fastMRI dataset is loaded, "kspace" is stored in the data dictionary,
4355
but the ground-truth image with the key "reconstruction_rss" is stored in the meta data.
4456
In this case, ExtractDataKeyFromMetaKeyd moves "reconstruction_rss" to data.
57+
58+
When ``LoadImaged`` is used with ``image_only=True`` (the default), the loaded
59+
data is a ``MetaTensor`` with metadata accessible via ``.meta``. In this case,
60+
set ``meta_key`` to the key of the ``MetaTensor`` itself::
61+
62+
li = LoadImaged(keys="image") # image_only=True by default
63+
dat = li({"image": "image.nii"})
64+
e = ExtractDataKeyFromMetaKeyd("filename_or_obj", meta_key="image")
65+
dat = e(dat)
66+
assert dat["image"].meta["filename_or_obj"] == dat["filename_or_obj"]
4567
"""
4668

4769
def __init__(self, keys: KeysCollection, meta_key: str, allow_missing_keys: bool = False) -> None:
@@ -58,9 +80,18 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, T
5880
the new data dictionary
5981
"""
6082
d = dict(data)
83+
meta_obj = d[self.meta_key]
84+
85+
# If meta_key references a MetaTensor, extract from its .meta attribute;
86+
# otherwise treat it as a metadata dictionary directly.
87+
if isinstance(meta_obj, MetaTensor):
88+
meta_dict = meta_obj.meta
89+
else:
90+
meta_dict = meta_obj
91+
6192
for key in self.keys:
62-
if key in d[self.meta_key]:
63-
d[key] = d[self.meta_key][key] # type: ignore
93+
if key in meta_dict:
94+
d[key] = meta_dict[key] # type: ignore
6495
elif not self.allow_missing_keys:
6596
raise KeyError(
6697
f"Key `{key}` of transform `{self.__class__.__name__}` was missing in the meta data"
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
import unittest
15+
16+
import torch
17+
18+
from monai.apps.reconstruction.transforms.dictionary import ExtractDataKeyFromMetaKeyd
19+
from monai.data import MetaTensor
20+
21+
22+
class TestExtractDataKeyFromMetaKeyd(unittest.TestCase):
23+
"""Tests for ExtractDataKeyFromMetaKeyd covering both dict-based and MetaTensor-based metadata."""
24+
25+
def test_extract_from_dict(self):
26+
"""Test extracting keys from a plain metadata dictionary (image_only=False scenario)."""
27+
data = {
28+
"image": torch.zeros(1, 2, 2),
29+
"image_meta_dict": {"filename_or_obj": "image.nii", "spatial_shape": [2, 2]},
30+
}
31+
transform = ExtractDataKeyFromMetaKeyd(keys="filename_or_obj", meta_key="image_meta_dict")
32+
result = transform(data)
33+
self.assertIn("filename_or_obj", result)
34+
self.assertEqual(result["filename_or_obj"], "image.nii")
35+
self.assertEqual(result["image_meta_dict"]["filename_or_obj"], result["filename_or_obj"])
36+
37+
def test_extract_from_metatensor(self):
38+
"""Test extracting keys from a MetaTensor's .meta attribute (image_only=True scenario)."""
39+
meta = {"filename_or_obj": "image.nii", "spatial_shape": [2, 2]}
40+
mt = MetaTensor(torch.zeros(1, 2, 2), meta=meta)
41+
data = {"image": mt}
42+
transform = ExtractDataKeyFromMetaKeyd(keys="filename_or_obj", meta_key="image")
43+
result = transform(data)
44+
self.assertIn("filename_or_obj", result)
45+
self.assertEqual(result["filename_or_obj"], "image.nii")
46+
self.assertEqual(result["image"].meta["filename_or_obj"], result["filename_or_obj"])
47+
48+
def test_extract_multiple_keys_from_metatensor(self):
49+
"""Test extracting multiple keys from a MetaTensor."""
50+
meta = {"filename_or_obj": "image.nii", "spatial_shape": [2, 2], "affine": "identity"}
51+
mt = MetaTensor(torch.zeros(1, 2, 2), meta=meta)
52+
data = {"image": mt}
53+
transform = ExtractDataKeyFromMetaKeyd(keys=["filename_or_obj", "spatial_shape"], meta_key="image")
54+
result = transform(data)
55+
self.assertIn("filename_or_obj", result)
56+
self.assertIn("spatial_shape", result)
57+
self.assertEqual(result["filename_or_obj"], "image.nii")
58+
self.assertEqual(result["spatial_shape"], [2, 2])
59+
60+
def test_extract_multiple_keys_from_dict(self):
61+
"""Test extracting multiple keys from a plain dictionary."""
62+
data = {
63+
"image": torch.zeros(1, 2, 2),
64+
"image_meta_dict": {"filename_or_obj": "image.nii", "spatial_shape": [2, 2]},
65+
}
66+
transform = ExtractDataKeyFromMetaKeyd(keys=["filename_or_obj", "spatial_shape"], meta_key="image_meta_dict")
67+
result = transform(data)
68+
self.assertIn("filename_or_obj", result)
69+
self.assertIn("spatial_shape", result)
70+
self.assertEqual(result["filename_or_obj"], "image.nii")
71+
self.assertEqual(result["spatial_shape"], [2, 2])
72+
73+
def test_missing_key_raises(self):
74+
"""Test that a missing key raises KeyError when allow_missing_keys=False."""
75+
meta = {"filename_or_obj": "image.nii"}
76+
mt = MetaTensor(torch.zeros(1, 2, 2), meta=meta)
77+
data = {"image": mt}
78+
transform = ExtractDataKeyFromMetaKeyd(keys="nonexistent_key", meta_key="image")
79+
with self.assertRaises(KeyError):
80+
transform(data)
81+
82+
def test_missing_key_allowed_metatensor(self):
83+
"""Test that a missing key is silently skipped when allow_missing_keys=True with MetaTensor."""
84+
meta = {"filename_or_obj": "image.nii"}
85+
mt = MetaTensor(torch.zeros(1, 2, 2), meta=meta)
86+
data = {"image": mt}
87+
transform = ExtractDataKeyFromMetaKeyd(keys="nonexistent_key", meta_key="image", allow_missing_keys=True)
88+
result = transform(data)
89+
self.assertNotIn("nonexistent_key", result)
90+
91+
def test_missing_key_allowed_dict(self):
92+
"""Test that a missing key is silently skipped when allow_missing_keys=True with dict."""
93+
data = {"image": torch.zeros(1, 2, 2), "image_meta_dict": {"filename_or_obj": "image.nii"}}
94+
transform = ExtractDataKeyFromMetaKeyd(
95+
keys="nonexistent_key", meta_key="image_meta_dict", allow_missing_keys=True
96+
)
97+
result = transform(data)
98+
self.assertNotIn("nonexistent_key", result)
99+
100+
def test_original_data_preserved_metatensor(self):
101+
"""Test that the original MetaTensor remains in the data dictionary."""
102+
meta = {"filename_or_obj": "image.nii"}
103+
mt = MetaTensor(torch.ones(1, 2, 2), meta=meta)
104+
data = {"image": mt}
105+
transform = ExtractDataKeyFromMetaKeyd(keys="filename_or_obj", meta_key="image")
106+
result = transform(data)
107+
self.assertIn("image", result)
108+
self.assertIsInstance(result["image"], MetaTensor)
109+
self.assertTrue(torch.equal(result["image"], mt))
110+
111+
112+
if __name__ == "__main__":
113+
unittest.main()

0 commit comments

Comments
 (0)