Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 87 additions & 83 deletions monai/auto3dseg/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,28 +344,30 @@ def __call__(self, data: Mapping) -> dict:
restore_grad_state = torch.is_grad_enabled()
torch.set_grad_enabled(False)

ndas = [d[self.image_key][i] for i in range(d[self.image_key].shape[0])]
ndas_label = d[self.label_key] # (H,W,D)
try:
ndas = [d[self.image_key][i] for i in range(d[self.image_key].shape[0])]
ndas_label = d[self.label_key] # (H,W,D)

if ndas_label.shape != ndas[0].shape:
raise ValueError(f"Label shape {ndas_label.shape} is different from image shape {ndas[0].shape}")
if ndas_label.shape != ndas[0].shape:
raise ValueError(f"Label shape {ndas_label.shape} is different from image shape {ndas[0].shape}")

nda_foregrounds = [get_foreground_label(nda, ndas_label) for nda in ndas]
nda_foregrounds = [nda if nda.numel() > 0 else MetaTensor([0.0]) for nda in nda_foregrounds]
nda_foregrounds = [get_foreground_label(nda, ndas_label) for nda in ndas]
nda_foregrounds = [nda if nda.numel() > 0 else MetaTensor([0.0]) for nda in nda_foregrounds]

# perform calculation
report = deepcopy(self.get_report_format())
# perform calculation
report = deepcopy(self.get_report_format())

report[ImageStatsKeys.INTENSITY] = [
self.ops[ImageStatsKeys.INTENSITY].evaluate(nda_f) for nda_f in nda_foregrounds
]
report[ImageStatsKeys.INTENSITY] = [
self.ops[ImageStatsKeys.INTENSITY].evaluate(nda_f) for nda_f in nda_foregrounds
]

if not verify_report_format(report, self.get_report_format()):
raise RuntimeError(f"report generated by {self.__class__} differs from the report format.")
if not verify_report_format(report, self.get_report_format()):
raise RuntimeError(f"report generated by {self.__class__} differs from the report format.")

d[self.stats_name] = report
d[self.stats_name] = report
finally:
torch.set_grad_enabled(restore_grad_state)

torch.set_grad_enabled(restore_grad_state)
logger.debug(f"Get foreground image stats spent {time.time() - start}")
return d

Expand Down Expand Up @@ -477,78 +479,80 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe
restore_grad_state = torch.is_grad_enabled()
torch.set_grad_enabled(False)

if isinstance(image_tensor, (MetaTensor, torch.Tensor)) and isinstance(
label_tensor, (MetaTensor, torch.Tensor)
):
if label_tensor.device != image_tensor.device:
if using_cuda:
# Move both tensors to CUDA when mixing devices
cuda_device = image_tensor.device if image_tensor.device.type == "cuda" else label_tensor.device
image_tensor = cast(MetaTensor, image_tensor.to(cuda_device))
label_tensor = cast(MetaTensor, label_tensor.to(cuda_device))
else:
label_tensor = cast(MetaTensor, label_tensor.to(image_tensor.device))

ndas: list[MetaTensor] = [image_tensor[i] for i in range(image_tensor.shape[0])] # type: ignore
ndas_label: MetaTensor = label_tensor.astype(torch.int16) # (H,W,D)

if ndas_label.shape != ndas[0].shape:
raise ValueError(f"Label shape {ndas_label.shape} is different from image shape {ndas[0].shape}")

nda_foregrounds: list[torch.Tensor] = [get_foreground_label(nda, ndas_label) for nda in ndas]
nda_foregrounds = [nda if nda.numel() > 0 else MetaTensor([0.0]) for nda in nda_foregrounds]

unique_label = unique(ndas_label)
if isinstance(ndas_label, (MetaTensor, torch.Tensor)):
unique_label = unique_label.data.cpu().numpy() # type: ignore[assignment]

unique_label = unique_label.astype(np.int16).tolist()

label_substats = [] # each element is one label
pixel_sum = 0
pixel_arr = []
for index in unique_label:
start_label = time.time()
label_dict: dict[str, Any] = {}
mask_index = ndas_label == index

nda_masks = [nda[mask_index] for nda in ndas]
label_dict[LabelStatsKeys.IMAGE_INTST] = [
self.ops[LabelStatsKeys.IMAGE_INTST].evaluate(nda_m) for nda_m in nda_masks
]

pixel_count = sum(mask_index)
pixel_arr.append(pixel_count)
pixel_sum += pixel_count
if self.do_ccp: # apply connected component
if using_cuda:
# The back end of get_label_ccp is CuPy
# which is unable to automatically release CUDA GPU memory held by PyTorch
del nda_masks
torch.cuda.empty_cache()
shape_list, ncomponents = get_label_ccp(mask_index)
label_dict[LabelStatsKeys.LABEL_SHAPE] = shape_list
label_dict[LabelStatsKeys.LABEL_NCOMP] = ncomponents

label_substats.append(label_dict)
logger.debug(f" label {index} stats takes {time.time() - start_label}")

for i, _ in enumerate(unique_label):
label_substats[i].update({LabelStatsKeys.PIXEL_PCT: float(pixel_arr[i] / pixel_sum)})
try:
if isinstance(image_tensor, (MetaTensor, torch.Tensor)) and isinstance(
label_tensor, (MetaTensor, torch.Tensor)
):
if label_tensor.device != image_tensor.device:
if using_cuda:
# Move both tensors to CUDA when mixing devices
cuda_device = image_tensor.device if image_tensor.device.type == "cuda" else label_tensor.device
image_tensor = cast(MetaTensor, image_tensor.to(cuda_device))
label_tensor = cast(MetaTensor, label_tensor.to(cuda_device))
else:
label_tensor = cast(MetaTensor, label_tensor.to(image_tensor.device))

ndas: list[MetaTensor] = [image_tensor[i] for i in range(image_tensor.shape[0])] # type: ignore
ndas_label: MetaTensor = label_tensor.astype(torch.int16) # (H,W,D)

if ndas_label.shape != ndas[0].shape:
raise ValueError(f"Label shape {ndas_label.shape} is different from image shape {ndas[0].shape}")

nda_foregrounds: list[torch.Tensor] = [get_foreground_label(nda, ndas_label) for nda in ndas]
nda_foregrounds = [nda if nda.numel() > 0 else MetaTensor([0.0]) for nda in nda_foregrounds]

unique_label = unique(ndas_label)
if isinstance(ndas_label, (MetaTensor, torch.Tensor)):
unique_label = unique_label.data.cpu().numpy() # type: ignore[assignment]

unique_label = unique_label.astype(np.int16).tolist()

label_substats = [] # each element is one label
pixel_sum = 0
pixel_arr = []
for index in unique_label:
start_label = time.time()
label_dict: dict[str, Any] = {}
mask_index = ndas_label == index

nda_masks = [nda[mask_index] for nda in ndas]
label_dict[LabelStatsKeys.IMAGE_INTST] = [
self.ops[LabelStatsKeys.IMAGE_INTST].evaluate(nda_m) for nda_m in nda_masks
]

report = deepcopy(self.get_report_format())
report[LabelStatsKeys.LABEL_UID] = unique_label
report[LabelStatsKeys.IMAGE_INTST] = [
self.ops[LabelStatsKeys.IMAGE_INTST].evaluate(nda_f) for nda_f in nda_foregrounds
]
report[LabelStatsKeys.LABEL] = label_substats
pixel_count = sum(mask_index)
pixel_arr.append(pixel_count)
pixel_sum += pixel_count
if self.do_ccp: # apply connected component
if using_cuda:
# The back end of get_label_ccp is CuPy
# which is unable to automatically release CUDA GPU memory held by PyTorch
del nda_masks
torch.cuda.empty_cache()
shape_list, ncomponents = get_label_ccp(mask_index)
label_dict[LabelStatsKeys.LABEL_SHAPE] = shape_list
label_dict[LabelStatsKeys.LABEL_NCOMP] = ncomponents

label_substats.append(label_dict)
logger.debug(f" label {index} stats takes {time.time() - start_label}")

for i, _ in enumerate(unique_label):
label_substats[i].update({LabelStatsKeys.PIXEL_PCT: float(pixel_arr[i] / pixel_sum)})

report = deepcopy(self.get_report_format())
report[LabelStatsKeys.LABEL_UID] = unique_label
report[LabelStatsKeys.IMAGE_INTST] = [
self.ops[LabelStatsKeys.IMAGE_INTST].evaluate(nda_f) for nda_f in nda_foregrounds
]
report[LabelStatsKeys.LABEL] = label_substats

if not verify_report_format(report, self.get_report_format()):
raise RuntimeError(f"report generated by {self.__class__} differs from the report format.")
if not verify_report_format(report, self.get_report_format()):
raise RuntimeError(f"report generated by {self.__class__} differs from the report format.")

d[self.stats_name] = report # type: ignore[assignment]
d[self.stats_name] = report # type: ignore[assignment]
finally:
torch.set_grad_enabled(restore_grad_state)

torch.set_grad_enabled(restore_grad_state)
logger.debug(f"Get label stats spent {time.time() - start}")
return d # type: ignore[return-value]

Expand Down
Loading