Skip to content

Commit b7c5c85

Browse files
committed
Enhance ExtractDataKeyFromMetaKeyd to work with MetaTensor
When LoadImaged is used with image_only=True (the default), the loaded data is a MetaTensor with metadata accessible via .meta attribute. Previously, ExtractDataKeyFromMetaKeyd could only extract keys from metadata dictionaries (image_only=False scenario). This change adds support for MetaTensor by detecting if the meta_key references a MetaTensor instance and extracting from its .meta attribute instead of treating it as a plain dictionary. Fixes #7562 Signed-off-by: haoyu-haoyu <haoyu-haoyu@users.noreply.github.com> Signed-off-by: SexyERIC0723 <haoyuwang144@gmail.com>
1 parent daaedaa commit b7c5c85

File tree

2 files changed

+150
-3
lines changed

2 files changed

+150
-3
lines changed

monai/apps/reconstruction/transforms/dictionary.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from monai.apps.reconstruction.transforms.array import EquispacedKspaceMask, RandomKspaceMask
2121
from monai.config import DtypeLike, KeysCollection
2222
from monai.config.type_definitions import NdarrayOrTensor
23+
from monai.data.meta_tensor import MetaTensor
24+
2325
from monai.transforms import InvertibleTransform
2426
from monai.transforms.croppad.array import SpatialCrop
2527
from monai.transforms.intensity.array import NormalizeIntensity
@@ -33,15 +35,36 @@ class ExtractDataKeyFromMetaKeyd(MapTransform):
3335
Moves keys from meta to data. It is useful when a dataset of paired samples
3436
is loaded and certain keys should be moved from meta to data.
3537
38+
This transform supports two modes:
39+
40+
1. When ``meta_key`` references a metadata dictionary in the data (e.g., when
41+
``image_only=False`` was used with ``LoadImaged``), the requested keys are
42+
extracted directly from that dictionary.
43+
44+
2. When ``meta_key`` references a ``MetaTensor`` in the data (e.g., when
45+
``image_only=True`` was used with ``LoadImaged``), the requested keys are
46+
extracted from its ``.meta`` attribute.
47+
3648
Args:
3749
keys: keys to be transferred from meta to data
38-
meta_key: the meta key where all the meta-data is stored
50+
meta_key: the key in the data dictionary where the metadata source is
51+
stored. This can be either a metadata dictionary or a ``MetaTensor``.
3952
allow_missing_keys: don't raise exception if key is missing
4053
4154
Example:
4255
When the fastMRI dataset is loaded, "kspace" is stored in the data dictionary,
4356
but the ground-truth image with the key "reconstruction_rss" is stored in the meta data.
4457
In this case, ExtractDataKeyFromMetaKeyd moves "reconstruction_rss" to data.
58+
59+
When ``LoadImaged`` is used with ``image_only=True`` (the default), the loaded
60+
data is a ``MetaTensor`` with metadata accessible via ``.meta``. In this case,
61+
set ``meta_key`` to the key of the ``MetaTensor`` itself::
62+
63+
li = LoadImaged(keys="image") # image_only=True by default
64+
dat = li({"image": "image.nii"})
65+
e = ExtractDataKeyFromMetaKeyd("filename_or_obj", meta_key="image")
66+
dat = e(dat)
67+
assert dat["image"].meta["filename_or_obj"] == dat["filename_or_obj"]
4568
"""
4669

4770
def __init__(self, keys: KeysCollection, meta_key: str, allow_missing_keys: bool = False) -> None:
@@ -58,9 +81,18 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, T
5881
the new data dictionary
5982
"""
6083
d = dict(data)
84+
meta_obj = d[self.meta_key]
85+
86+
# If meta_key references a MetaTensor, extract from its .meta attribute;
87+
# otherwise treat it as a metadata dictionary directly.
88+
if isinstance(meta_obj, MetaTensor):
89+
meta_dict: dict = meta_obj.meta
90+
else:
91+
meta_dict = dict(meta_obj)
92+
6193
for key in self.keys:
62-
if key in d[self.meta_key]:
63-
d[key] = d[self.meta_key][key] # type: ignore
94+
if key in meta_dict:
95+
d[key] = meta_dict[key] # type: ignore
6496
elif not self.allow_missing_keys:
6597
raise KeyError(
6698
f"Key `{key}` of transform `{self.__class__.__name__}` was missing in the meta data"
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
import unittest
15+
16+
import torch
17+
18+
from monai.apps.reconstruction.transforms.dictionary import ExtractDataKeyFromMetaKeyd
19+
from monai.data import MetaTensor
20+
21+
22+
class TestExtractDataKeyFromMetaKeyd(unittest.TestCase):
23+
"""Tests for ExtractDataKeyFromMetaKeyd covering both dict-based and MetaTensor-based metadata."""
24+
25+
def test_extract_from_dict(self):
26+
"""Test extracting keys from a plain metadata dictionary (image_only=False scenario)."""
27+
data = {
28+
"image": torch.zeros(1, 2, 2),
29+
"image_meta_dict": {"filename_or_obj": "image.nii", "spatial_shape": [2, 2]},
30+
}
31+
transform = ExtractDataKeyFromMetaKeyd(keys="filename_or_obj", meta_key="image_meta_dict")
32+
result = transform(data)
33+
self.assertIn("filename_or_obj", result)
34+
self.assertEqual(result["filename_or_obj"], "image.nii")
35+
self.assertEqual(result["image_meta_dict"]["filename_or_obj"], result["filename_or_obj"])
36+
37+
def test_extract_from_metatensor(self):
38+
"""Test extracting keys from a MetaTensor's .meta attribute (image_only=True scenario)."""
39+
mt = MetaTensor(torch.zeros(1, 2, 2))
40+
mt.meta["filename_or_obj"] = "image.nii"
41+
mt.meta["spatial_shape"] = [2, 2]
42+
data = {"image": mt}
43+
transform = ExtractDataKeyFromMetaKeyd(keys="filename_or_obj", meta_key="image")
44+
result = transform(data)
45+
self.assertIn("filename_or_obj", result)
46+
self.assertEqual(result["filename_or_obj"], "image.nii")
47+
self.assertEqual(result["image"].meta["filename_or_obj"], result["filename_or_obj"])
48+
49+
def test_extract_multiple_keys_from_metatensor(self):
50+
"""Test extracting multiple keys from a MetaTensor."""
51+
mt = MetaTensor(torch.zeros(1, 2, 2))
52+
mt.meta["filename_or_obj"] = "image.nii"
53+
mt.meta["spatial_shape"] = [2, 2]
54+
data = {"image": mt}
55+
transform = ExtractDataKeyFromMetaKeyd(keys=["filename_or_obj", "spatial_shape"], meta_key="image")
56+
result = transform(data)
57+
self.assertIn("filename_or_obj", result)
58+
self.assertIn("spatial_shape", result)
59+
self.assertEqual(result["filename_or_obj"], "image.nii")
60+
self.assertEqual(result["spatial_shape"], [2, 2])
61+
62+
def test_extract_multiple_keys_from_dict(self):
63+
"""Test extracting multiple keys from a plain dictionary."""
64+
data = {
65+
"image": torch.zeros(1, 2, 2),
66+
"image_meta_dict": {"filename_or_obj": "image.nii", "spatial_shape": [2, 2]},
67+
}
68+
transform = ExtractDataKeyFromMetaKeyd(keys=["filename_or_obj", "spatial_shape"], meta_key="image_meta_dict")
69+
result = transform(data)
70+
self.assertIn("filename_or_obj", result)
71+
self.assertIn("spatial_shape", result)
72+
self.assertEqual(result["filename_or_obj"], "image.nii")
73+
self.assertEqual(result["spatial_shape"], [2, 2])
74+
75+
def test_missing_key_raises(self):
76+
"""Test that a missing key raises KeyError when allow_missing_keys=False."""
77+
mt = MetaTensor(torch.zeros(1, 2, 2))
78+
mt.meta["filename_or_obj"] = "image.nii"
79+
data = {"image": mt}
80+
transform = ExtractDataKeyFromMetaKeyd(keys="nonexistent_key", meta_key="image")
81+
with self.assertRaises(KeyError):
82+
transform(data)
83+
84+
def test_missing_key_allowed_metatensor(self):
85+
"""Test that a missing key is silently skipped when allow_missing_keys=True with MetaTensor."""
86+
mt = MetaTensor(torch.zeros(1, 2, 2))
87+
mt.meta["filename_or_obj"] = "image.nii"
88+
data = {"image": mt}
89+
transform = ExtractDataKeyFromMetaKeyd(keys="nonexistent_key", meta_key="image", allow_missing_keys=True)
90+
result = transform(data)
91+
self.assertNotIn("nonexistent_key", result)
92+
93+
def test_missing_key_allowed_dict(self):
94+
"""Test that a missing key is silently skipped when allow_missing_keys=True with dict."""
95+
data = {"image": torch.zeros(1, 2, 2), "image_meta_dict": {"filename_or_obj": "image.nii"}}
96+
transform = ExtractDataKeyFromMetaKeyd(
97+
keys="nonexistent_key", meta_key="image_meta_dict", allow_missing_keys=True
98+
)
99+
result = transform(data)
100+
self.assertNotIn("nonexistent_key", result)
101+
102+
def test_original_data_preserved_metatensor(self):
103+
"""Test that the original MetaTensor remains in the data dictionary."""
104+
mt = MetaTensor(torch.ones(1, 2, 2))
105+
mt.meta["filename_or_obj"] = "image.nii"
106+
data = {"image": mt}
107+
transform = ExtractDataKeyFromMetaKeyd(keys="filename_or_obj", meta_key="image")
108+
result = transform(data)
109+
self.assertIn("image", result)
110+
self.assertIsInstance(result["image"], MetaTensor)
111+
self.assertTrue(torch.equal(result["image"], mt))
112+
113+
114+
if __name__ == "__main__":
115+
unittest.main()

0 commit comments

Comments
 (0)