Skip to content
Merged
11 changes: 9 additions & 2 deletions monai/auto3dseg/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from abc import ABC, abstractmethod
from collections.abc import Hashable, Mapping
from copy import deepcopy
from typing import Any
from typing import Any, cast

import numpy as np
import torch
Expand Down Expand Up @@ -470,6 +470,7 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe
start = time.time()
image_tensor = d[self.image_key]
label_tensor = d[self.label_key]
# Check if either tensor is on CUDA to determine if we should move both to CUDA for processing
using_cuda = any(
isinstance(t, (torch.Tensor, MetaTensor)) and t.device.type == "cuda" for t in (image_tensor, label_tensor)
)
Expand All @@ -480,7 +481,13 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe
label_tensor, (MetaTensor, torch.Tensor)
):
if label_tensor.device != image_tensor.device:
label_tensor = label_tensor.to(image_tensor.device) # type: ignore
if using_cuda:
# Move both tensors to CUDA when mixing devices
cuda_device = image_tensor.device if image_tensor.device.type == "cuda" else label_tensor.device
image_tensor = cast(MetaTensor, image_tensor.to(cuda_device))
label_tensor = cast(MetaTensor, label_tensor.to(cuda_device))
else:
label_tensor = cast(MetaTensor, label_tensor.to(image_tensor.device))

ndas: list[MetaTensor] = [image_tensor[i] for i in range(image_tensor.shape[0])] # type: ignore
ndas_label: MetaTensor = label_tensor.astype(torch.int16) # (H,W,D)
Expand Down
1 change: 1 addition & 0 deletions tests/apps/test_auto3dseg.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,7 @@ def test_label_stats_mixed_device_analyzer(self, input_params):
result = analyzer({"image": image_tensor, "label": label_tensor})
report = result["label_stats"]

# Verify report format and computation succeeded despite mixed/unified devices
assert verify_report_format(report, analyzer.get_report_format())
assert report[LabelStatsKeys.LABEL_UID] == [0, 1]

Expand Down
Loading