diff --git a/monai/auto3dseg/analyzer.py b/monai/auto3dseg/analyzer.py index a731546a9e..2f6e3857ef 100644 --- a/monai/auto3dseg/analyzer.py +++ b/monai/auto3dseg/analyzer.py @@ -255,36 +255,46 @@ def __call__(self, data): restore_grad_state = torch.is_grad_enabled() torch.set_grad_enabled(False) - ndas = [d[self.image_key][i] for i in range(d[self.image_key].shape[0])] - if "nda_croppeds" not in d: - nda_croppeds = [get_foreground_image(nda) for nda in ndas] - - # perform calculation - report = deepcopy(self.get_report_format()) - - report[ImageStatsKeys.SHAPE] = [list(nda.shape) for nda in ndas] - report[ImageStatsKeys.CHANNELS] = len(ndas) - report[ImageStatsKeys.CROPPED_SHAPE] = [list(nda_c.shape) for nda_c in nda_croppeds] - report[ImageStatsKeys.SPACING] = ( - affine_to_spacing(data[self.image_key].affine).tolist() - if isinstance(data[self.image_key], MetaTensor) - else [1.0] * min(3, data[self.image_key].ndim) - ) + try: + ndas = [d[self.image_key][i] for i in range(d[self.image_key].shape[0])] + if "nda_croppeds" not in d: + nda_croppeds = [get_foreground_image(nda) for nda in ndas] + else: + nda_croppeds = d["nda_croppeds"] + if not isinstance(nda_croppeds, (list, tuple)) or len(nda_croppeds) != len(ndas): + raise ValueError( + f"Pre-computed 'nda_croppeds' must be a list with one entry per image channel " + f"(expected {len(ndas)}, got " + f"{len(nda_croppeds) if isinstance(nda_croppeds, (list, tuple)) else type(nda_croppeds).__name__})." + ) + + # perform calculation + report = deepcopy(self.get_report_format()) + + report[ImageStatsKeys.SHAPE] = [list(nda.shape) for nda in ndas] + report[ImageStatsKeys.CHANNELS] = len(ndas) + report[ImageStatsKeys.CROPPED_SHAPE] = [list(nda_c.shape) for nda_c in nda_croppeds] + report[ImageStatsKeys.SPACING] = ( + affine_to_spacing(data[self.image_key].affine).tolist() + if isinstance(data[self.image_key], MetaTensor) + else [1.0] * min(3, data[self.image_key].ndim) + ) - report[ImageStatsKeys.SIZEMM] = [ - a * b for a, b in zip(report[ImageStatsKeys.SHAPE][0], report[ImageStatsKeys.SPACING]) - ] + report[ImageStatsKeys.SIZEMM] = [ + a * b for a, b in zip(report[ImageStatsKeys.SHAPE][0], report[ImageStatsKeys.SPACING]) + ] - report[ImageStatsKeys.INTENSITY] = [ - self.ops[ImageStatsKeys.INTENSITY].evaluate(nda_c) for nda_c in nda_croppeds - ] + report[ImageStatsKeys.INTENSITY] = [ + self.ops[ImageStatsKeys.INTENSITY].evaluate(nda_c) for nda_c in nda_croppeds + ] - if not verify_report_format(report, self.get_report_format()): - raise RuntimeError(f"report generated by {self.__class__} differs from the report format.") + if not verify_report_format(report, self.get_report_format()): + raise RuntimeError(f"report generated by {self.__class__} differs from the report format.") - d[self.stats_name] = report + d[self.stats_name] = report + finally: + torch.set_grad_enabled(restore_grad_state) - torch.set_grad_enabled(restore_grad_state) logger.debug(f"Get image stats spent {time.time() - start}") return d diff --git a/tests/apps/test_auto3dseg.py b/tests/apps/test_auto3dseg.py index 6c840e944d..1b41e84ef4 100644 --- a/tests/apps/test_auto3dseg.py +++ b/tests/apps/test_auto3dseg.py @@ -539,6 +539,43 @@ def test_seg_summarizer(self): assert str(DataStatsKeys.FG_IMAGE_STATS) in report assert str(DataStatsKeys.LABEL_STATS) in report + def test_image_stats_precomputed_nda_croppeds(self): + """Verify ImageStats handles pre-populated nda_croppeds without crashing. + + Previously raised UnboundLocalError because nda_croppeds was only assigned + inside the ``if "nda_croppeds" not in d`` branch but used unconditionally. + """ + analyzer = ImageStats(image_key="image") + image = torch.rand(1, 10, 10, 10) + precomputed = [np.random.rand(8, 8, 8)] # simulated pre-cropped foreground + data = {"image": MetaTensor(image), "nda_croppeds": precomputed} + result = analyzer(data) + assert "image_stats" in result + assert verify_report_format(result["image_stats"], analyzer.get_report_format()) + + def test_analyzer_grad_state_restored_after_call(self): + """Verify ImageStats restores torch grad-enabled state on both normal and disabled entry. + + Checks that the try/finally guard correctly restores the state regardless of + whether grad was enabled or disabled before the call. + """ + analyzer = ImageStats(image_key="image") + image = torch.rand(1, 10, 10, 10) + data = {"image": MetaTensor(image)} + + # grad enabled before call → must still be enabled after + torch.set_grad_enabled(True) + analyzer(data) + assert torch.is_grad_enabled(), "grad state was not restored after ImageStats call" + + # grad disabled before call → must still be disabled after + torch.set_grad_enabled(False) + try: + analyzer(data) + assert not torch.is_grad_enabled(), "grad state was not restored after ImageStats call" + finally: + torch.set_grad_enabled(True) # always restore for subsequent tests + def tearDown(self) -> None: self.test_dir.cleanup()