Skip to content

Fix #5889: restore grad state on exception in FgImageStats and LabelStats#8810

Open
williams145 wants to merge 1 commit intoProject-MONAI:devfrom
williams145:fix/issue-5889-grad-state-leak-analyzers
Open

Fix #5889: restore grad state on exception in FgImageStats and LabelStats#8810
williams145 wants to merge 1 commit intoProject-MONAI:devfrom
williams145:fix/issue-5889-grad-state-leak-analyzers

Conversation

@williams145
Copy link
Copy Markdown

@williams145 williams145 commented Apr 8, 2026

Problem

FgImageStats.__call__ and LabelStats.__call__ disable PyTorch gradient computation on entry but only restore it via a plain assignment before return. If any exception is raised mid-execution (e.g. shape mismatch, RuntimeError from verify_report_format), the restore line is skipped and torch.is_grad_enabled() stays False for the rest of the process, silently breaking all subsequent gradient computations.

Root Cause

The grad state restore had no exception safety, torch.set_grad_enabled(restore_grad_state) was a bare statement that would be skipped entirely if anything above it raised.

Fix

Wrapped the computation body in try/finally in both FgImageStats.__call__ and LabelStats.__call__ so the restore is guaranteed regardless of how the function exits.

Testing

import torch
from monai.auto3dseg import FgImageStats
from monai.data.meta_tensor import MetaTensor

analyzer = FgImageStats(image_key="image", label_key="label")
torch.set_grad_enabled(True)
try:
    analyzer({"image": MetaTensor(torch.rand(1,4,4,4)), "label": MetaTensor(torch.rand(5,5,5))})
except Exception:
    pass
assert torch.is_grad_enabled()  # was False before this fix

…ts and LabelStats

FgImageStats.__call__ and LabelStats.__call__ both saved and disabled
torch grad state on entry but only restored it via a plain assignment
before return. Any exception raised between the disable and the restore
(e.g. shape mismatch, RuntimeError from verify_report_format) left
torch.is_grad_enabled() permanently False for the remainder of the
process, silently breaking all subsequent gradient computations.

Wrapped the computation body in try/finally in both methods so the grad
state is guaranteed to be restored regardless of how the function exits.

Fixes Project-MONAI#5889
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 8, 2026

📝 Walkthrough

Walkthrough

The FgImageStats.__call__ and LabelStats.__call__ methods in monai/auto3dseg/analyzer.py were refactored to wrap their primary computation blocks in try/finally constructs. This ensures that the autograd state restoration call torch.set_grad_enabled(restore_grad_state) executes regardless of whether exceptions occur during shape validation, foreground computation, connected-component processing, report construction, or report format verification. The functional output and logic paths remain unchanged under normal execution.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The title directly describes the main fix: grad state restoration on exceptions in two analyzer classes.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
Description check ✅ Passed PR description is well-structured with clear problem statement, root cause, fix explanation, and testing code.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
monai/auto3dseg/analyzer.py (1)

255-287: ⚠️ Potential issue | 🟠 Major

ImageStats.__call__ has the same grad-state leak vulnerability.

Lines 255-287 follow the identical pattern: save grad state, disable, compute, restore. The restore at line 287 is not exception-safe. Should apply the same try/finally fix for consistency.

Proposed fix
         restore_grad_state = torch.is_grad_enabled()
         torch.set_grad_enabled(False)
 
-        ndas = [d[self.image_key][i] for i in range(d[self.image_key].shape[0])]
-        if "nda_croppeds" not in d:
-            nda_croppeds = [get_foreground_image(nda) for nda in ndas]
-
-        # perform calculation
-        report = deepcopy(self.get_report_format())
-
-        report[ImageStatsKeys.SHAPE] = [list(nda.shape) for nda in ndas]
-        report[ImageStatsKeys.CHANNELS] = len(ndas)
-        report[ImageStatsKeys.CROPPED_SHAPE] = [list(nda_c.shape) for nda_c in nda_croppeds]
-        report[ImageStatsKeys.SPACING] = (
-            affine_to_spacing(data[self.image_key].affine).tolist()
-            if isinstance(data[self.image_key], MetaTensor)
-            else [1.0] * min(3, data[self.image_key].ndim)
-        )
-
-        report[ImageStatsKeys.SIZEMM] = [
-            a * b for a, b in zip(report[ImageStatsKeys.SHAPE][0], report[ImageStatsKeys.SPACING])
-        ]
-
-        report[ImageStatsKeys.INTENSITY] = [
-            self.ops[ImageStatsKeys.INTENSITY].evaluate(nda_c) for nda_c in nda_croppeds
-        ]
-
-        if not verify_report_format(report, self.get_report_format()):
-            raise RuntimeError(f"report generated by {self.__class__} differs from the report format.")
-
-        d[self.stats_name] = report
-
-        torch.set_grad_enabled(restore_grad_state)
+        try:
+            ndas = [d[self.image_key][i] for i in range(d[self.image_key].shape[0])]
+            if "nda_croppeds" not in d:
+                nda_croppeds = [get_foreground_image(nda) for nda in ndas]
+
+            # perform calculation
+            report = deepcopy(self.get_report_format())
+
+            report[ImageStatsKeys.SHAPE] = [list(nda.shape) for nda in ndas]
+            report[ImageStatsKeys.CHANNELS] = len(ndas)
+            report[ImageStatsKeys.CROPPED_SHAPE] = [list(nda_c.shape) for nda_c in nda_croppeds]
+            report[ImageStatsKeys.SPACING] = (
+                affine_to_spacing(data[self.image_key].affine).tolist()
+                if isinstance(data[self.image_key], MetaTensor)
+                else [1.0] * min(3, data[self.image_key].ndim)
+            )
+
+            report[ImageStatsKeys.SIZEMM] = [
+                a * b for a, b in zip(report[ImageStatsKeys.SHAPE][0], report[ImageStatsKeys.SPACING])
+            ]
+
+            report[ImageStatsKeys.INTENSITY] = [
+                self.ops[ImageStatsKeys.INTENSITY].evaluate(nda_c) for nda_c in nda_croppeds
+            ]
+
+            if not verify_report_format(report, self.get_report_format()):
+                raise RuntimeError(f"report generated by {self.__class__} differs from the report format.")
+
+            d[self.stats_name] = report
+        finally:
+            torch.set_grad_enabled(restore_grad_state)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@monai/auto3dseg/analyzer.py` around lines 255 - 287, The grad-state
save/restore in ImageStats.__call__ is not exception-safe: capture
restore_grad_state = torch.is_grad_enabled(), then disable grads, but the
current torch.set_grad_enabled(restore_grad_state) at the end can be skipped if
an exception occurs; wrap the whole computation (everything from building
ndas/nda_croppeds through assigning d[self.stats_name]) in a try/finally and
move torch.set_grad_enabled(restore_grad_state) into the finally block so the
original grad state is always restored; locate this logic inside the
ImageStats.__call__ implementation in monai/auto3dseg/analyzer.py and apply the
same pattern as used elsewhere in the codebase.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Outside diff comments:
In `@monai/auto3dseg/analyzer.py`:
- Around line 255-287: The grad-state save/restore in ImageStats.__call__ is not
exception-safe: capture restore_grad_state = torch.is_grad_enabled(), then
disable grads, but the current torch.set_grad_enabled(restore_grad_state) at the
end can be skipped if an exception occurs; wrap the whole computation
(everything from building ndas/nda_croppeds through assigning
d[self.stats_name]) in a try/finally and move
torch.set_grad_enabled(restore_grad_state) into the finally block so the
original grad state is always restored; locate this logic inside the
ImageStats.__call__ implementation in monai/auto3dseg/analyzer.py and apply the
same pattern as used elsewhere in the codebase.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 6ca1fa10-1cd5-4361-9419-8769707e2adb

📥 Commits

Reviewing files that changed from the base of the PR and between 8d39519 and 4ec0d4d.

📒 Files selected for processing (1)
  • monai/auto3dseg/analyzer.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant