diff --git a/monai/auto3dseg/analyzer.py b/monai/auto3dseg/analyzer.py index a731546a9e..a7781f00c3 100644 --- a/monai/auto3dseg/analyzer.py +++ b/monai/auto3dseg/analyzer.py @@ -344,28 +344,30 @@ def __call__(self, data: Mapping) -> dict: restore_grad_state = torch.is_grad_enabled() torch.set_grad_enabled(False) - ndas = [d[self.image_key][i] for i in range(d[self.image_key].shape[0])] - ndas_label = d[self.label_key] # (H,W,D) + try: + ndas = [d[self.image_key][i] for i in range(d[self.image_key].shape[0])] + ndas_label = d[self.label_key] # (H,W,D) - if ndas_label.shape != ndas[0].shape: - raise ValueError(f"Label shape {ndas_label.shape} is different from image shape {ndas[0].shape}") + if ndas_label.shape != ndas[0].shape: + raise ValueError(f"Label shape {ndas_label.shape} is different from image shape {ndas[0].shape}") - nda_foregrounds = [get_foreground_label(nda, ndas_label) for nda in ndas] - nda_foregrounds = [nda if nda.numel() > 0 else MetaTensor([0.0]) for nda in nda_foregrounds] + nda_foregrounds = [get_foreground_label(nda, ndas_label) for nda in ndas] + nda_foregrounds = [nda if nda.numel() > 0 else MetaTensor([0.0]) for nda in nda_foregrounds] - # perform calculation - report = deepcopy(self.get_report_format()) + # perform calculation + report = deepcopy(self.get_report_format()) - report[ImageStatsKeys.INTENSITY] = [ - self.ops[ImageStatsKeys.INTENSITY].evaluate(nda_f) for nda_f in nda_foregrounds - ] + report[ImageStatsKeys.INTENSITY] = [ + self.ops[ImageStatsKeys.INTENSITY].evaluate(nda_f) for nda_f in nda_foregrounds + ] - if not verify_report_format(report, self.get_report_format()): - raise RuntimeError(f"report generated by {self.__class__} differs from the report format.") + if not verify_report_format(report, self.get_report_format()): + raise RuntimeError(f"report generated by {self.__class__} differs from the report format.") - d[self.stats_name] = report + d[self.stats_name] = report + finally: + torch.set_grad_enabled(restore_grad_state) - torch.set_grad_enabled(restore_grad_state) logger.debug(f"Get foreground image stats spent {time.time() - start}") return d @@ -477,78 +479,80 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe restore_grad_state = torch.is_grad_enabled() torch.set_grad_enabled(False) - if isinstance(image_tensor, (MetaTensor, torch.Tensor)) and isinstance( - label_tensor, (MetaTensor, torch.Tensor) - ): - if label_tensor.device != image_tensor.device: - if using_cuda: - # Move both tensors to CUDA when mixing devices - cuda_device = image_tensor.device if image_tensor.device.type == "cuda" else label_tensor.device - image_tensor = cast(MetaTensor, image_tensor.to(cuda_device)) - label_tensor = cast(MetaTensor, label_tensor.to(cuda_device)) - else: - label_tensor = cast(MetaTensor, label_tensor.to(image_tensor.device)) - - ndas: list[MetaTensor] = [image_tensor[i] for i in range(image_tensor.shape[0])] # type: ignore - ndas_label: MetaTensor = label_tensor.astype(torch.int16) # (H,W,D) - - if ndas_label.shape != ndas[0].shape: - raise ValueError(f"Label shape {ndas_label.shape} is different from image shape {ndas[0].shape}") - - nda_foregrounds: list[torch.Tensor] = [get_foreground_label(nda, ndas_label) for nda in ndas] - nda_foregrounds = [nda if nda.numel() > 0 else MetaTensor([0.0]) for nda in nda_foregrounds] - - unique_label = unique(ndas_label) - if isinstance(ndas_label, (MetaTensor, torch.Tensor)): - unique_label = unique_label.data.cpu().numpy() # type: ignore[assignment] - - unique_label = unique_label.astype(np.int16).tolist() - - label_substats = [] # each element is one label - pixel_sum = 0 - pixel_arr = [] - for index in unique_label: - start_label = time.time() - label_dict: dict[str, Any] = {} - mask_index = ndas_label == index - - nda_masks = [nda[mask_index] for nda in ndas] - label_dict[LabelStatsKeys.IMAGE_INTST] = [ - self.ops[LabelStatsKeys.IMAGE_INTST].evaluate(nda_m) for nda_m in nda_masks - ] - - pixel_count = sum(mask_index) - pixel_arr.append(pixel_count) - pixel_sum += pixel_count - if self.do_ccp: # apply connected component - if using_cuda: - # The back end of get_label_ccp is CuPy - # which is unable to automatically release CUDA GPU memory held by PyTorch - del nda_masks - torch.cuda.empty_cache() - shape_list, ncomponents = get_label_ccp(mask_index) - label_dict[LabelStatsKeys.LABEL_SHAPE] = shape_list - label_dict[LabelStatsKeys.LABEL_NCOMP] = ncomponents - - label_substats.append(label_dict) - logger.debug(f" label {index} stats takes {time.time() - start_label}") - - for i, _ in enumerate(unique_label): - label_substats[i].update({LabelStatsKeys.PIXEL_PCT: float(pixel_arr[i] / pixel_sum)}) + try: + if isinstance(image_tensor, (MetaTensor, torch.Tensor)) and isinstance( + label_tensor, (MetaTensor, torch.Tensor) + ): + if label_tensor.device != image_tensor.device: + if using_cuda: + # Move both tensors to CUDA when mixing devices + cuda_device = image_tensor.device if image_tensor.device.type == "cuda" else label_tensor.device + image_tensor = cast(MetaTensor, image_tensor.to(cuda_device)) + label_tensor = cast(MetaTensor, label_tensor.to(cuda_device)) + else: + label_tensor = cast(MetaTensor, label_tensor.to(image_tensor.device)) + + ndas: list[MetaTensor] = [image_tensor[i] for i in range(image_tensor.shape[0])] # type: ignore + ndas_label: MetaTensor = label_tensor.astype(torch.int16) # (H,W,D) + + if ndas_label.shape != ndas[0].shape: + raise ValueError(f"Label shape {ndas_label.shape} is different from image shape {ndas[0].shape}") + + nda_foregrounds: list[torch.Tensor] = [get_foreground_label(nda, ndas_label) for nda in ndas] + nda_foregrounds = [nda if nda.numel() > 0 else MetaTensor([0.0]) for nda in nda_foregrounds] + + unique_label = unique(ndas_label) + if isinstance(ndas_label, (MetaTensor, torch.Tensor)): + unique_label = unique_label.data.cpu().numpy() # type: ignore[assignment] + + unique_label = unique_label.astype(np.int16).tolist() + + label_substats = [] # each element is one label + pixel_sum = 0 + pixel_arr = [] + for index in unique_label: + start_label = time.time() + label_dict: dict[str, Any] = {} + mask_index = ndas_label == index + + nda_masks = [nda[mask_index] for nda in ndas] + label_dict[LabelStatsKeys.IMAGE_INTST] = [ + self.ops[LabelStatsKeys.IMAGE_INTST].evaluate(nda_m) for nda_m in nda_masks + ] - report = deepcopy(self.get_report_format()) - report[LabelStatsKeys.LABEL_UID] = unique_label - report[LabelStatsKeys.IMAGE_INTST] = [ - self.ops[LabelStatsKeys.IMAGE_INTST].evaluate(nda_f) for nda_f in nda_foregrounds - ] - report[LabelStatsKeys.LABEL] = label_substats + pixel_count = sum(mask_index) + pixel_arr.append(pixel_count) + pixel_sum += pixel_count + if self.do_ccp: # apply connected component + if using_cuda: + # The back end of get_label_ccp is CuPy + # which is unable to automatically release CUDA GPU memory held by PyTorch + del nda_masks + torch.cuda.empty_cache() + shape_list, ncomponents = get_label_ccp(mask_index) + label_dict[LabelStatsKeys.LABEL_SHAPE] = shape_list + label_dict[LabelStatsKeys.LABEL_NCOMP] = ncomponents + + label_substats.append(label_dict) + logger.debug(f" label {index} stats takes {time.time() - start_label}") + + for i, _ in enumerate(unique_label): + label_substats[i].update({LabelStatsKeys.PIXEL_PCT: float(pixel_arr[i] / pixel_sum)}) + + report = deepcopy(self.get_report_format()) + report[LabelStatsKeys.LABEL_UID] = unique_label + report[LabelStatsKeys.IMAGE_INTST] = [ + self.ops[LabelStatsKeys.IMAGE_INTST].evaluate(nda_f) for nda_f in nda_foregrounds + ] + report[LabelStatsKeys.LABEL] = label_substats - if not verify_report_format(report, self.get_report_format()): - raise RuntimeError(f"report generated by {self.__class__} differs from the report format.") + if not verify_report_format(report, self.get_report_format()): + raise RuntimeError(f"report generated by {self.__class__} differs from the report format.") - d[self.stats_name] = report # type: ignore[assignment] + d[self.stats_name] = report # type: ignore[assignment] + finally: + torch.set_grad_enabled(restore_grad_state) - torch.set_grad_enabled(restore_grad_state) logger.debug(f"Get label stats spent {time.time() - start}") return d # type: ignore[return-value]