Skip to content

fix: restore cached AMP step context after no_grad workaround#21616

Merged
deependujha merged 5 commits intoLightning-AI:masterfrom
littlebullGit:fix/21611-autocast-cache-enabled
Apr 1, 2026
Merged

fix: restore cached AMP step context after no_grad workaround#21616
deependujha merged 5 commits intoLightning-AI:masterfrom
littlebullGit:fix/21611-autocast-cache-enabled

Conversation

@littlebullGit
Copy link
Copy Markdown
Contributor

@littlebullGit littlebullGit commented Mar 26, 2026

What does this PR do?

Fixes #21611

This PR addresses the AMP memory regression introduced when #20921 changed MixedPrecision.autocast_context_manager() to use
cache_enabled=False globally as a workaround for the nested no_grad() autocast cache-poisoning bug reported in #20644.

That workaround fixed the original correctness issue, but it also forced Lightning's step execution onto the uncached autocast
path. In workloads that call the same module repeatedly within one training_step (for example iterative decoding / RL-style
loops), this can cause repeated recasting of the same parameters and significant memory growth, which is the regression
reported in #21611.

This PR keeps the existing public behavior of MixedPrecision.autocast_context_manager() unchanged for compatibility, but
narrows the workaround in Lightning's internal runtime path:

  • autocast_context_manager() still uses cache_enabled=False
  • forward_context() now uses the cached autocast path again
  • when nested torch.no_grad() or torch.inference_mode() exits inside forward_context(), the autocast cache gets cleared
    so the original Computation graph not being built #20644 bug remains fixed

In short, this restores the cached AMP behavior for normal Lightning step execution while preserving the original nested
no_grad() safeguard.

Changes

  • keep MixedPrecision.autocast_context_manager() behavior unchanged
  • move the narrower workaround into MixedPrecision.forward_context()
  • clear the autocast cache after nested no_grad() / inference_mode() exits within forward_context()
  • add regression tests covering:
    • raw PyTorch autocast behavior with cache_enabled=True vs False
    • nested no_grad() inside Lightning AMP forward_context()
    • nested inference_mode() inside Lightning AMP forward_context()
    • restoration of patched grad-mode context managers after exiting forward_context()

📚 Documentation preview 📚: https://pytorch-lightning--21616.org.readthedocs.build/en/21616/

@github-actions github-actions bot added the pl Generic label for PyTorch Lightning package label Mar 26, 2026
@littlebullGit littlebullGit force-pushed the fix/21611-autocast-cache-enabled branch from 1511c8a to af1af99 Compare March 26, 2026 21:32
@codecov
Copy link
Copy Markdown

codecov bot commented Mar 26, 2026

Codecov Report

❌ Patch coverage is 96.66667% with 1 line in your changes missing coverage. Please review.
✅ Project coverage is 79%. Comparing base (612ab08) to head (167a244).
⚠️ Report is 2 commits behind head on master.
✅ All tests successful. No failed tests found.

❗ There is a different number of reports uploaded between BASE (612ab08) and HEAD (167a244). Click for more details.

HEAD has 920 uploads less than BASE
Flag BASE (612ab08) HEAD (167a244)
cpu 251 42
lightning_fabric 80 0
pytest 125 0
python3.12 72 12
python 18 3
lightning 90 15
python3.11 36 6
python3.13 53 9
python3.12.7 54 9
python3.10 18 3
pytorch_lightning 81 27
pytorch2.7 9 3
pytest-full 126 42
pytorch2.1 18 6
pytorch2.4.1 9 3
pytorch2.5.1 9 3
pytorch2.2.2 9 3
pytorch2.9 18 6
pytorch2.10 18 6
pytorch2.8 18 6
pytorch2.3 9 3
pytorch2.6 9 3
Additional details and impacted files
@@            Coverage Diff            @@
##           master   #21616     +/-   ##
=========================================
- Coverage      87%      79%     -8%     
=========================================
  Files         270      267      -3     
  Lines       23898    23866     -32     
=========================================
- Hits        20678    18791   -1887     
- Misses       3220     5075   +1855     

Copy link
Copy Markdown
Collaborator

@deependujha deependujha left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any specific reason for not having CUDA tests? Adding them would be valuable.

@littlebullGit
Copy link
Copy Markdown
Contributor Author

Is there any specific reason for not having CUDA tests? Adding them would be valuable.

Is there any specific reason for not having CUDA tests? Adding them would be valuable.

Added CUDA coverage in the latest push.

There are now two CUDA-specific tests in tests/tests_pytorch/plugins/precision/test_amp.py:

  • a raw PyTorch regression test documenting the cache_enabled=True/False behavior on CUDA
  • a Lightning MixedPrecision.forward_context() test covering the nested no_grad() workaround on CUDA

@deependujha deependujha merged commit 4a548c9 into Lightning-AI:master Apr 1, 2026
99 checks passed
bhimrazy pushed a commit to bhimrazy/pytorch-lightning that referenced this pull request Apr 13, 2026
…ing-AI#21616)

* fix: restore cached AMP step context after no_grad workaround

* chore: trigger ci

* chore: trigger ci

* test: add CUDA coverage for AMP no_grad cache handling
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

pl Generic label for PyTorch Lightning package

Projects

None yet

Development

Successfully merging this pull request may close these issues.

cache_enabled=False in autocast causes OOM regression for iterative decoding workloads

3 participants