Skip to content

Commit e3b6c38

Browse files
committed
Enhance ExtractDataKeyFromMetaKeyd to work with MetaTensor
When LoadImaged is used with image_only=True (the default), the loaded data is a MetaTensor with metadata accessible via .meta attribute. Previously, ExtractDataKeyFromMetaKeyd could only extract keys from metadata dictionaries (image_only=False scenario). This change adds support for MetaTensor by detecting if the meta_key references a MetaTensor instance and extracting from its .meta attribute instead of treating it as a plain dictionary. Fixes #7562 Signed-off-by: haoyu-haoyu <haoyu-haoyu@users.noreply.github.com> Signed-off-by: SexyERIC0723 <haoyuwang144@gmail.com>
1 parent daaedaa commit e3b6c38

File tree

3 files changed

+157
-3
lines changed

3 files changed

+157
-3
lines changed

monai/apps/reconstruction/transforms/dictionary.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from monai.apps.reconstruction.transforms.array import EquispacedKspaceMask, RandomKspaceMask
2121
from monai.config import DtypeLike, KeysCollection
2222
from monai.config.type_definitions import NdarrayOrTensor
23+
2324
from monai.transforms import InvertibleTransform
2425
from monai.transforms.croppad.array import SpatialCrop
2526
from monai.transforms.intensity.array import NormalizeIntensity
@@ -33,15 +34,36 @@ class ExtractDataKeyFromMetaKeyd(MapTransform):
3334
Moves keys from meta to data. It is useful when a dataset of paired samples
3435
is loaded and certain keys should be moved from meta to data.
3536
37+
This transform supports two modes:
38+
39+
1. When ``meta_key`` references a metadata dictionary in the data (e.g., when
40+
``image_only=False`` was used with ``LoadImaged``), the requested keys are
41+
extracted directly from that dictionary.
42+
43+
2. When ``meta_key`` references a ``MetaTensor`` in the data (e.g., when
44+
``image_only=True`` was used with ``LoadImaged``), the requested keys are
45+
extracted from its ``.meta`` attribute.
46+
3647
Args:
3748
keys: keys to be transferred from meta to data
38-
meta_key: the meta key where all the meta-data is stored
49+
meta_key: the key in the data dictionary where the metadata source is
50+
stored. This can be either a metadata dictionary or a ``MetaTensor``.
3951
allow_missing_keys: don't raise exception if key is missing
4052
4153
Example:
4254
When the fastMRI dataset is loaded, "kspace" is stored in the data dictionary,
4355
but the ground-truth image with the key "reconstruction_rss" is stored in the meta data.
4456
In this case, ExtractDataKeyFromMetaKeyd moves "reconstruction_rss" to data.
57+
58+
When ``LoadImaged`` is used with ``image_only=True`` (the default), the loaded
59+
data is a ``MetaTensor`` with metadata accessible via ``.meta``. In this case,
60+
set ``meta_key`` to the key of the ``MetaTensor`` itself::
61+
62+
li = LoadImaged(keys="image") # image_only=True by default
63+
dat = li({"image": "image.nii"})
64+
e = ExtractDataKeyFromMetaKeyd("filename_or_obj", meta_key="image")
65+
dat = e(dat)
66+
assert dat["image"].meta["filename_or_obj"] == dat["filename_or_obj"]
4567
"""
4668

4769
def __init__(self, keys: KeysCollection, meta_key: str, allow_missing_keys: bool = False) -> None:
@@ -58,9 +80,14 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, T
5880
the new data dictionary
5981
"""
6082
d = dict(data)
83+
meta_obj = d[self.meta_key]
84+
85+
# MetaTensor.__contains__ and __getitem__ delegate string keys to
86+
# .meta, so both MetaTensor and plain dict objects work here without
87+
# an isinstance check.
6188
for key in self.keys:
62-
if key in d[self.meta_key]:
63-
d[key] = d[self.meta_key][key] # type: ignore
89+
if key in meta_obj:
90+
d[key] = meta_obj[key] # type: ignore
6491
elif not self.allow_missing_keys:
6592
raise KeyError(
6693
f"Key `{key}` of transform `{self.__class__.__name__}` was missing in the meta data"

monai/data/meta_tensor.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,18 @@ def __init__(
171171
if MetaKeys.SPACE not in self.meta:
172172
self.meta[MetaKeys.SPACE] = SpaceKeys.RAS # defaulting to the right-anterior-superior space
173173

174+
def __contains__(self, key):
175+
"""Allow string-key lookups to check ``.meta``, e.g. ``"filename_or_obj" in meta_tensor``."""
176+
if isinstance(key, str):
177+
return key in self.meta
178+
return super().__contains__(key)
179+
180+
def __getitem__(self, key):
181+
"""Allow string-key indexing to access ``.meta``, e.g. ``meta_tensor["filename_or_obj"]``."""
182+
if isinstance(key, str):
183+
return self.meta[key]
184+
return super().__getitem__(key)
185+
174186
@staticmethod
175187
def update_meta(rets: Sequence, func, args, kwargs) -> Sequence:
176188
"""
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
import unittest
15+
16+
import torch
17+
18+
from monai.apps.reconstruction.transforms.dictionary import ExtractDataKeyFromMetaKeyd
19+
from monai.data import MetaTensor
20+
21+
22+
class TestExtractDataKeyFromMetaKeyd(unittest.TestCase):
23+
"""Tests for ExtractDataKeyFromMetaKeyd covering both dict-based and MetaTensor-based metadata."""
24+
25+
def test_extract_from_dict(self):
26+
"""Test extracting keys from a plain metadata dictionary (image_only=False scenario)."""
27+
data = {
28+
"image": torch.zeros(1, 2, 2),
29+
"image_meta_dict": {"filename_or_obj": "image.nii", "spatial_shape": [2, 2]},
30+
}
31+
transform = ExtractDataKeyFromMetaKeyd(keys="filename_or_obj", meta_key="image_meta_dict")
32+
result = transform(data)
33+
self.assertIn("filename_or_obj", result)
34+
self.assertEqual(result["filename_or_obj"], "image.nii")
35+
self.assertEqual(result["image_meta_dict"]["filename_or_obj"], result["filename_or_obj"])
36+
37+
def test_extract_from_metatensor(self):
38+
"""Test extracting keys from a MetaTensor's .meta attribute (image_only=True scenario)."""
39+
mt = MetaTensor(torch.zeros(1, 2, 2))
40+
mt.meta["filename_or_obj"] = "image.nii"
41+
mt.meta["spatial_shape"] = [2, 2]
42+
data = {"image": mt}
43+
transform = ExtractDataKeyFromMetaKeyd(keys="filename_or_obj", meta_key="image")
44+
result = transform(data)
45+
self.assertIn("filename_or_obj", result)
46+
self.assertEqual(result["filename_or_obj"], "image.nii")
47+
self.assertEqual(result["image"].meta["filename_or_obj"], result["filename_or_obj"])
48+
49+
def test_extract_multiple_keys_from_metatensor(self):
50+
"""Test extracting multiple keys from a MetaTensor."""
51+
mt = MetaTensor(torch.zeros(1, 2, 2))
52+
mt.meta["filename_or_obj"] = "image.nii"
53+
mt.meta["spatial_shape"] = [2, 2]
54+
data = {"image": mt}
55+
transform = ExtractDataKeyFromMetaKeyd(keys=["filename_or_obj", "spatial_shape"], meta_key="image")
56+
result = transform(data)
57+
self.assertIn("filename_or_obj", result)
58+
self.assertIn("spatial_shape", result)
59+
self.assertEqual(result["filename_or_obj"], "image.nii")
60+
self.assertEqual(result["spatial_shape"], [2, 2])
61+
62+
def test_extract_multiple_keys_from_dict(self):
63+
"""Test extracting multiple keys from a plain dictionary."""
64+
data = {
65+
"image": torch.zeros(1, 2, 2),
66+
"image_meta_dict": {"filename_or_obj": "image.nii", "spatial_shape": [2, 2]},
67+
}
68+
transform = ExtractDataKeyFromMetaKeyd(keys=["filename_or_obj", "spatial_shape"], meta_key="image_meta_dict")
69+
result = transform(data)
70+
self.assertIn("filename_or_obj", result)
71+
self.assertIn("spatial_shape", result)
72+
self.assertEqual(result["filename_or_obj"], "image.nii")
73+
self.assertEqual(result["spatial_shape"], [2, 2])
74+
75+
def test_missing_key_raises(self):
76+
"""Test that a missing key raises KeyError when allow_missing_keys=False."""
77+
mt = MetaTensor(torch.zeros(1, 2, 2))
78+
mt.meta["filename_or_obj"] = "image.nii"
79+
data = {"image": mt}
80+
transform = ExtractDataKeyFromMetaKeyd(keys="nonexistent_key", meta_key="image")
81+
with self.assertRaises(KeyError):
82+
transform(data)
83+
84+
def test_missing_key_allowed_metatensor(self):
85+
"""Test that a missing key is silently skipped when allow_missing_keys=True with MetaTensor."""
86+
mt = MetaTensor(torch.zeros(1, 2, 2))
87+
mt.meta["filename_or_obj"] = "image.nii"
88+
data = {"image": mt}
89+
transform = ExtractDataKeyFromMetaKeyd(keys="nonexistent_key", meta_key="image", allow_missing_keys=True)
90+
result = transform(data)
91+
self.assertNotIn("nonexistent_key", result)
92+
93+
def test_missing_key_allowed_dict(self):
94+
"""Test that a missing key is silently skipped when allow_missing_keys=True with dict."""
95+
data = {"image": torch.zeros(1, 2, 2), "image_meta_dict": {"filename_or_obj": "image.nii"}}
96+
transform = ExtractDataKeyFromMetaKeyd(
97+
keys="nonexistent_key", meta_key="image_meta_dict", allow_missing_keys=True
98+
)
99+
result = transform(data)
100+
self.assertNotIn("nonexistent_key", result)
101+
102+
def test_original_data_preserved_metatensor(self):
103+
"""Test that the original MetaTensor remains in the data dictionary."""
104+
mt = MetaTensor(torch.ones(1, 2, 2))
105+
mt.meta["filename_or_obj"] = "image.nii"
106+
data = {"image": mt}
107+
transform = ExtractDataKeyFromMetaKeyd(keys="filename_or_obj", meta_key="image")
108+
result = transform(data)
109+
self.assertIn("image", result)
110+
self.assertIsInstance(result["image"], MetaTensor)
111+
self.assertTrue(torch.equal(result["image"], mt))
112+
113+
114+
if __name__ == "__main__":
115+
unittest.main()

0 commit comments

Comments
 (0)