Skip to content

Commit 853f702

Browse files
benediktjohannesericspodgarciadias
authored
Fixes #8697 GPU memory leak by checking both image and label tensors for CUDA device (#8708)
Modified device detection to check BOTH image and label tensors torch.cuda.empty_cache() now called if EITHER tensor is on GPU Prevents GPU memory leaks in mixed device scenarios --------- Signed-off-by: benediktjohannes <benedikt.johannes.hofer@gmail.com> Signed-off-by: Benedikt Johannes <benedikt.johannes.hofer@gmail.com> Signed-off-by: R. Garcia-Dias <rafaelagd@gmail.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Co-authored-by: R. Garcia-Dias <rafaelagd@gmail.com>
1 parent d30de96 commit 853f702

File tree

3 files changed

+69
-11
lines changed

3 files changed

+69
-11
lines changed

monai/auto3dseg/analyzer.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -468,21 +468,28 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe
468468
"""
469469
d: dict[Hashable, MetaTensor] = dict(data)
470470
start = time.time()
471-
if isinstance(d[self.image_key], (torch.Tensor, MetaTensor)) and d[self.image_key].device.type == "cuda":
472-
using_cuda = True
473-
else:
474-
using_cuda = False
471+
image_tensor = d[self.image_key]
472+
label_tensor = d[self.label_key]
473+
using_cuda = any(
474+
isinstance(t, (torch.Tensor, MetaTensor)) and t.device.type == "cuda" for t in (image_tensor, label_tensor)
475+
)
475476
restore_grad_state = torch.is_grad_enabled()
476477
torch.set_grad_enabled(False)
477478

478-
ndas: list[MetaTensor] = [d[self.image_key][i] for i in range(d[self.image_key].shape[0])] # type: ignore
479-
ndas_label: MetaTensor = d[self.label_key].astype(torch.int16) # (H,W,D)
479+
if isinstance(image_tensor, (MetaTensor, torch.Tensor)) and isinstance(
480+
label_tensor, (MetaTensor, torch.Tensor)
481+
):
482+
if label_tensor.device != image_tensor.device:
483+
label_tensor = label_tensor.to(image_tensor.device) # type: ignore
484+
485+
ndas: list[MetaTensor] = [image_tensor[i] for i in range(image_tensor.shape[0])] # type: ignore
486+
ndas_label: MetaTensor = label_tensor.astype(torch.int16) # (H,W,D)
480487

481488
if ndas_label.shape != ndas[0].shape:
482489
raise ValueError(f"Label shape {ndas_label.shape} is different from image shape {ndas[0].shape}")
483490

484491
nda_foregrounds: list[torch.Tensor] = [get_foreground_label(nda, ndas_label) for nda in ndas]
485-
nda_foregrounds = [nda if nda.numel() > 0 else torch.Tensor([0]) for nda in nda_foregrounds]
492+
nda_foregrounds = [nda if nda.numel() > 0 else MetaTensor([0.0]) for nda in nda_foregrounds]
486493

487494
unique_label = unique(ndas_label)
488495
if isinstance(ndas_label, (MetaTensor, torch.Tensor)):

requirements-dev.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@ mccabe
1414
pep8-naming
1515
pycodestyle
1616
pyflakes
17-
black>=25.1.0
18-
isort>=5.1, !=6.0.0
19-
ruff
17+
black==25.1.0
18+
isort>=5.1, <6, !=6.0.0
19+
ruff>=0.14.11,<0.15
2020
pytype>=2020.6.1, <=2024.4.11; platform_system != "Windows"
2121
types-setuptools
2222
mypy>=1.5.0, <1.12.0

tests/apps/test_auto3dseg.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
SqueezeDimd,
5454
ToDeviced,
5555
)
56-
from monai.utils.enums import DataStatsKeys
56+
from monai.utils.enums import DataStatsKeys, LabelStatsKeys
5757
from tests.test_utils import skip_if_no_cuda
5858

5959
device = "cpu"
@@ -78,6 +78,13 @@
7878

7979
SIM_GPU_TEST_CASES = [[{"sim_dim": (32, 32, 32), "label_key": "label"}], [{"sim_dim": (32, 32, 32), "label_key": None}]]
8080

81+
LABEL_STATS_DEVICE_TEST_CASES = [
82+
[{"image_device": "cpu", "label_device": "cpu", "image_meta": False}],
83+
[{"image_device": "cuda", "label_device": "cuda", "image_meta": True}],
84+
[{"image_device": "cpu", "label_device": "cuda", "image_meta": True}],
85+
[{"image_device": "cuda", "label_device": "cpu", "image_meta": False}],
86+
]
87+
8188

8289
def create_sim_data(dataroot: str, sim_datalist: dict, sim_dim: tuple, image_only: bool = False, **kwargs) -> None:
8390
"""
@@ -360,6 +367,50 @@ def test_label_stats_case_analyzer(self):
360367
report_format = analyzer.get_report_format()
361368
assert verify_report_format(d["label_stats"], report_format)
362369

370+
@parameterized.expand(LABEL_STATS_DEVICE_TEST_CASES)
371+
def test_label_stats_mixed_device_analyzer(self, input_params):
372+
image_device = torch.device(input_params["image_device"])
373+
label_device = torch.device(input_params["label_device"])
374+
375+
if (image_device.type == "cuda" or label_device.type == "cuda") and not torch.cuda.is_available():
376+
self.skipTest("CUDA is not available for mixed-device LabelStats tests.")
377+
378+
analyzer = LabelStats(image_key="image", label_key="label")
379+
380+
image_tensor = torch.tensor(
381+
[
382+
[[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]],
383+
[[[11.0, 12.0], [13.0, 14.0]], [[15.0, 16.0], [17.0, 18.0]]],
384+
],
385+
dtype=torch.float32,
386+
).to(image_device)
387+
label_tensor = torch.tensor([[[0, 1], [1, 0]], [[0, 1], [0, 1]]], dtype=torch.int64).to(label_device)
388+
389+
if input_params["image_meta"]:
390+
image_tensor = MetaTensor(image_tensor)
391+
label_tensor = MetaTensor(label_tensor)
392+
393+
result = analyzer({"image": image_tensor, "label": label_tensor})
394+
report = result["label_stats"]
395+
396+
assert verify_report_format(report, analyzer.get_report_format())
397+
assert report[LabelStatsKeys.LABEL_UID] == [0, 1]
398+
399+
label_stats = report[LabelStatsKeys.LABEL]
400+
self.assertAlmostEqual(label_stats[0][LabelStatsKeys.PIXEL_PCT], 0.5)
401+
self.assertAlmostEqual(label_stats[1][LabelStatsKeys.PIXEL_PCT], 0.5)
402+
403+
label0_intensity = label_stats[0][LabelStatsKeys.IMAGE_INTST]
404+
label1_intensity = label_stats[1][LabelStatsKeys.IMAGE_INTST]
405+
self.assertAlmostEqual(label0_intensity[0]["mean"], 4.25)
406+
self.assertAlmostEqual(label1_intensity[0]["mean"], 4.75)
407+
self.assertAlmostEqual(label0_intensity[1]["mean"], 14.25)
408+
self.assertAlmostEqual(label1_intensity[1]["mean"], 14.75)
409+
410+
foreground_stats = report[LabelStatsKeys.IMAGE_INTST]
411+
self.assertAlmostEqual(foreground_stats[0]["mean"], 4.75)
412+
self.assertAlmostEqual(foreground_stats[1]["mean"], 14.75)
413+
363414
def test_filename_case_analyzer(self):
364415
analyzer_image = FilenameStats("image", DataStatsKeys.BY_CASE_IMAGE_PATH)
365416
analyzer_label = FilenameStats("label", DataStatsKeys.BY_CASE_IMAGE_PATH)

0 commit comments

Comments
 (0)