fix: restore cached AMP step context after no_grad workaround#21616
Conversation
1511c8a to
af1af99
Compare
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #21616 +/- ##
=========================================
- Coverage 87% 79% -8%
=========================================
Files 270 267 -3
Lines 23898 23866 -32
=========================================
- Hits 20678 18791 -1887
- Misses 3220 5075 +1855 |
deependujha
left a comment
There was a problem hiding this comment.
Is there any specific reason for not having CUDA tests? Adding them would be valuable.
Added CUDA coverage in the latest push. There are now two CUDA-specific tests in
|
…ing-AI#21616) * fix: restore cached AMP step context after no_grad workaround * chore: trigger ci * chore: trigger ci * test: add CUDA coverage for AMP no_grad cache handling
What does this PR do?
Fixes #21611
This PR addresses the AMP memory regression introduced when #20921 changed
MixedPrecision.autocast_context_manager()to usecache_enabled=Falseglobally as a workaround for the nestedno_grad()autocast cache-poisoning bug reported in #20644.That workaround fixed the original correctness issue, but it also forced Lightning's step execution onto the uncached autocast
path. In workloads that call the same module repeatedly within one
training_step(for example iterative decoding / RL-styleloops), this can cause repeated recasting of the same parameters and significant memory growth, which is the regression
reported in #21611.
This PR keeps the existing public behavior of
MixedPrecision.autocast_context_manager()unchanged for compatibility, butnarrows the workaround in Lightning's internal runtime path:
autocast_context_manager()still usescache_enabled=Falseforward_context()now uses the cached autocast path againtorch.no_grad()ortorch.inference_mode()exits insideforward_context(), the autocast cache gets clearedso the original Computation graph not being built #20644 bug remains fixed
In short, this restores the cached AMP behavior for normal Lightning step execution while preserving the original nested
no_grad()safeguard.Changes
MixedPrecision.autocast_context_manager()behavior unchangedMixedPrecision.forward_context()no_grad()/inference_mode()exits withinforward_context()cache_enabled=TruevsFalseno_grad()inside Lightning AMPforward_context()inference_mode()inside Lightning AMPforward_context()forward_context()📚 Documentation preview 📚: https://pytorch-lightning--21616.org.readthedocs.build/en/21616/