Fix #5889: restore grad state on exception in FgImageStats and LabelStats#8810
Fix #5889: restore grad state on exception in FgImageStats and LabelStats#8810williams145 wants to merge 1 commit intoProject-MONAI:devfrom
Conversation
…ts and LabelStats FgImageStats.__call__ and LabelStats.__call__ both saved and disabled torch grad state on entry but only restored it via a plain assignment before return. Any exception raised between the disable and the restore (e.g. shape mismatch, RuntimeError from verify_report_format) left torch.is_grad_enabled() permanently False for the remainder of the process, silently breaking all subsequent gradient computations. Wrapped the computation body in try/finally in both methods so the grad state is guaranteed to be restored regardless of how the function exits. Fixes Project-MONAI#5889
📝 WalkthroughWalkthroughThe Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes 🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
monai/auto3dseg/analyzer.py (1)
255-287:⚠️ Potential issue | 🟠 Major
ImageStats.__call__has the same grad-state leak vulnerability.Lines 255-287 follow the identical pattern: save grad state, disable, compute, restore. The restore at line 287 is not exception-safe. Should apply the same try/finally fix for consistency.
Proposed fix
restore_grad_state = torch.is_grad_enabled() torch.set_grad_enabled(False) - ndas = [d[self.image_key][i] for i in range(d[self.image_key].shape[0])] - if "nda_croppeds" not in d: - nda_croppeds = [get_foreground_image(nda) for nda in ndas] - - # perform calculation - report = deepcopy(self.get_report_format()) - - report[ImageStatsKeys.SHAPE] = [list(nda.shape) for nda in ndas] - report[ImageStatsKeys.CHANNELS] = len(ndas) - report[ImageStatsKeys.CROPPED_SHAPE] = [list(nda_c.shape) for nda_c in nda_croppeds] - report[ImageStatsKeys.SPACING] = ( - affine_to_spacing(data[self.image_key].affine).tolist() - if isinstance(data[self.image_key], MetaTensor) - else [1.0] * min(3, data[self.image_key].ndim) - ) - - report[ImageStatsKeys.SIZEMM] = [ - a * b for a, b in zip(report[ImageStatsKeys.SHAPE][0], report[ImageStatsKeys.SPACING]) - ] - - report[ImageStatsKeys.INTENSITY] = [ - self.ops[ImageStatsKeys.INTENSITY].evaluate(nda_c) for nda_c in nda_croppeds - ] - - if not verify_report_format(report, self.get_report_format()): - raise RuntimeError(f"report generated by {self.__class__} differs from the report format.") - - d[self.stats_name] = report - - torch.set_grad_enabled(restore_grad_state) + try: + ndas = [d[self.image_key][i] for i in range(d[self.image_key].shape[0])] + if "nda_croppeds" not in d: + nda_croppeds = [get_foreground_image(nda) for nda in ndas] + + # perform calculation + report = deepcopy(self.get_report_format()) + + report[ImageStatsKeys.SHAPE] = [list(nda.shape) for nda in ndas] + report[ImageStatsKeys.CHANNELS] = len(ndas) + report[ImageStatsKeys.CROPPED_SHAPE] = [list(nda_c.shape) for nda_c in nda_croppeds] + report[ImageStatsKeys.SPACING] = ( + affine_to_spacing(data[self.image_key].affine).tolist() + if isinstance(data[self.image_key], MetaTensor) + else [1.0] * min(3, data[self.image_key].ndim) + ) + + report[ImageStatsKeys.SIZEMM] = [ + a * b for a, b in zip(report[ImageStatsKeys.SHAPE][0], report[ImageStatsKeys.SPACING]) + ] + + report[ImageStatsKeys.INTENSITY] = [ + self.ops[ImageStatsKeys.INTENSITY].evaluate(nda_c) for nda_c in nda_croppeds + ] + + if not verify_report_format(report, self.get_report_format()): + raise RuntimeError(f"report generated by {self.__class__} differs from the report format.") + + d[self.stats_name] = report + finally: + torch.set_grad_enabled(restore_grad_state)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@monai/auto3dseg/analyzer.py` around lines 255 - 287, The grad-state save/restore in ImageStats.__call__ is not exception-safe: capture restore_grad_state = torch.is_grad_enabled(), then disable grads, but the current torch.set_grad_enabled(restore_grad_state) at the end can be skipped if an exception occurs; wrap the whole computation (everything from building ndas/nda_croppeds through assigning d[self.stats_name]) in a try/finally and move torch.set_grad_enabled(restore_grad_state) into the finally block so the original grad state is always restored; locate this logic inside the ImageStats.__call__ implementation in monai/auto3dseg/analyzer.py and apply the same pattern as used elsewhere in the codebase.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Outside diff comments:
In `@monai/auto3dseg/analyzer.py`:
- Around line 255-287: The grad-state save/restore in ImageStats.__call__ is not
exception-safe: capture restore_grad_state = torch.is_grad_enabled(), then
disable grads, but the current torch.set_grad_enabled(restore_grad_state) at the
end can be skipped if an exception occurs; wrap the whole computation
(everything from building ndas/nda_croppeds through assigning
d[self.stats_name]) in a try/finally and move
torch.set_grad_enabled(restore_grad_state) into the finally block so the
original grad state is always restored; locate this logic inside the
ImageStats.__call__ implementation in monai/auto3dseg/analyzer.py and apply the
same pattern as used elsewhere in the codebase.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 6ca1fa10-1cd5-4361-9419-8769707e2adb
📒 Files selected for processing (1)
monai/auto3dseg/analyzer.py
Problem
FgImageStats.__call__andLabelStats.__call__disable PyTorch gradient computation on entry but only restore it via a plain assignment beforereturn. If any exception is raised mid-execution (e.g. shape mismatch,RuntimeErrorfromverify_report_format), the restore line is skipped andtorch.is_grad_enabled()staysFalsefor the rest of the process, silently breaking all subsequent gradient computations.Root Cause
The grad state restore had no exception safety,
torch.set_grad_enabled(restore_grad_state)was a bare statement that would be skipped entirely if anything above it raised.Fix
Wrapped the computation body in
try/finallyin bothFgImageStats.__call__andLabelStats.__call__so the restore is guaranteed regardless of how the function exits.Testing