Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 35 additions & 25 deletions monai/auto3dseg/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,36 +255,46 @@ def __call__(self, data):
restore_grad_state = torch.is_grad_enabled()
torch.set_grad_enabled(False)

ndas = [d[self.image_key][i] for i in range(d[self.image_key].shape[0])]
if "nda_croppeds" not in d:
nda_croppeds = [get_foreground_image(nda) for nda in ndas]

# perform calculation
report = deepcopy(self.get_report_format())

report[ImageStatsKeys.SHAPE] = [list(nda.shape) for nda in ndas]
report[ImageStatsKeys.CHANNELS] = len(ndas)
report[ImageStatsKeys.CROPPED_SHAPE] = [list(nda_c.shape) for nda_c in nda_croppeds]
report[ImageStatsKeys.SPACING] = (
affine_to_spacing(data[self.image_key].affine).tolist()
if isinstance(data[self.image_key], MetaTensor)
else [1.0] * min(3, data[self.image_key].ndim)
)
try:
ndas = [d[self.image_key][i] for i in range(d[self.image_key].shape[0])]
if "nda_croppeds" not in d:
nda_croppeds = [get_foreground_image(nda) for nda in ndas]
else:
nda_croppeds = d["nda_croppeds"]
if not isinstance(nda_croppeds, (list, tuple)) or len(nda_croppeds) != len(ndas):
raise ValueError(
f"Pre-computed 'nda_croppeds' must be a list with one entry per image channel "
f"(expected {len(ndas)}, got "
f"{len(nda_croppeds) if isinstance(nda_croppeds, (list, tuple)) else type(nda_croppeds).__name__})."
)

# perform calculation
report = deepcopy(self.get_report_format())

report[ImageStatsKeys.SHAPE] = [list(nda.shape) for nda in ndas]
report[ImageStatsKeys.CHANNELS] = len(ndas)
report[ImageStatsKeys.CROPPED_SHAPE] = [list(nda_c.shape) for nda_c in nda_croppeds]
report[ImageStatsKeys.SPACING] = (
affine_to_spacing(data[self.image_key].affine).tolist()
if isinstance(data[self.image_key], MetaTensor)
else [1.0] * min(3, data[self.image_key].ndim)
)

report[ImageStatsKeys.SIZEMM] = [
a * b for a, b in zip(report[ImageStatsKeys.SHAPE][0], report[ImageStatsKeys.SPACING])
]
report[ImageStatsKeys.SIZEMM] = [
a * b for a, b in zip(report[ImageStatsKeys.SHAPE][0], report[ImageStatsKeys.SPACING])
]

report[ImageStatsKeys.INTENSITY] = [
self.ops[ImageStatsKeys.INTENSITY].evaluate(nda_c) for nda_c in nda_croppeds
]
report[ImageStatsKeys.INTENSITY] = [
self.ops[ImageStatsKeys.INTENSITY].evaluate(nda_c) for nda_c in nda_croppeds
]

if not verify_report_format(report, self.get_report_format()):
raise RuntimeError(f"report generated by {self.__class__} differs from the report format.")
if not verify_report_format(report, self.get_report_format()):
raise RuntimeError(f"report generated by {self.__class__} differs from the report format.")

d[self.stats_name] = report
d[self.stats_name] = report
finally:
torch.set_grad_enabled(restore_grad_state)

torch.set_grad_enabled(restore_grad_state)
logger.debug(f"Get image stats spent {time.time() - start}")
return d

Expand Down
37 changes: 37 additions & 0 deletions tests/apps/test_auto3dseg.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,43 @@ def test_seg_summarizer(self):
assert str(DataStatsKeys.FG_IMAGE_STATS) in report
assert str(DataStatsKeys.LABEL_STATS) in report

def test_image_stats_precomputed_nda_croppeds(self):
"""Verify ImageStats handles pre-populated nda_croppeds without crashing.

Previously raised UnboundLocalError because nda_croppeds was only assigned
inside the ``if "nda_croppeds" not in d`` branch but used unconditionally.
"""
analyzer = ImageStats(image_key="image")
image = torch.rand(1, 10, 10, 10)
precomputed = [np.random.rand(8, 8, 8)] # simulated pre-cropped foreground
data = {"image": MetaTensor(image), "nda_croppeds": precomputed}
result = analyzer(data)
assert "image_stats" in result
assert verify_report_format(result["image_stats"], analyzer.get_report_format())

def test_analyzer_grad_state_restored_after_call(self):
"""Verify ImageStats restores torch grad-enabled state on both normal and disabled entry.

Checks that the try/finally guard correctly restores the state regardless of
whether grad was enabled or disabled before the call.
"""
analyzer = ImageStats(image_key="image")
image = torch.rand(1, 10, 10, 10)
data = {"image": MetaTensor(image)}

# grad enabled before call → must still be enabled after
torch.set_grad_enabled(True)
analyzer(data)
assert torch.is_grad_enabled(), "grad state was not restored after ImageStats call"

# grad disabled before call → must still be disabled after
torch.set_grad_enabled(False)
try:
analyzer(data)
assert not torch.is_grad_enabled(), "grad state was not restored after ImageStats call"
finally:
torch.set_grad_enabled(True) # always restore for subsequent tests

def tearDown(self) -> None:
self.test_dir.cleanup()

Expand Down
Loading