diff --git a/monai/auto3dseg/analyzer.py b/monai/auto3dseg/analyzer.py index d48d5fc878..a731546a9e 100644 --- a/monai/auto3dseg/analyzer.py +++ b/monai/auto3dseg/analyzer.py @@ -15,7 +15,7 @@ from abc import ABC, abstractmethod from collections.abc import Hashable, Mapping from copy import deepcopy -from typing import Any +from typing import Any, cast import numpy as np import torch @@ -470,6 +470,7 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe start = time.time() image_tensor = d[self.image_key] label_tensor = d[self.label_key] + # Check if either tensor is on CUDA to determine if we should move both to CUDA for processing using_cuda = any( isinstance(t, (torch.Tensor, MetaTensor)) and t.device.type == "cuda" for t in (image_tensor, label_tensor) ) @@ -480,7 +481,13 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe label_tensor, (MetaTensor, torch.Tensor) ): if label_tensor.device != image_tensor.device: - label_tensor = label_tensor.to(image_tensor.device) # type: ignore + if using_cuda: + # Move both tensors to CUDA when mixing devices + cuda_device = image_tensor.device if image_tensor.device.type == "cuda" else label_tensor.device + image_tensor = cast(MetaTensor, image_tensor.to(cuda_device)) + label_tensor = cast(MetaTensor, label_tensor.to(cuda_device)) + else: + label_tensor = cast(MetaTensor, label_tensor.to(image_tensor.device)) ndas: list[MetaTensor] = [image_tensor[i] for i in range(image_tensor.shape[0])] # type: ignore ndas_label: MetaTensor = label_tensor.astype(torch.int16) # (H,W,D) diff --git a/tests/apps/test_auto3dseg.py b/tests/apps/test_auto3dseg.py index 2159265873..6c840e944d 100644 --- a/tests/apps/test_auto3dseg.py +++ b/tests/apps/test_auto3dseg.py @@ -393,6 +393,7 @@ def test_label_stats_mixed_device_analyzer(self, input_params): result = analyzer({"image": image_tensor, "label": label_tensor}) report = result["label_stats"] + # Verify report format and computation succeeded despite mixed/unified devices assert verify_report_format(report, analyzer.get_report_format()) assert report[LabelStatsKeys.LABEL_UID] == [0, 1]