File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -1279,6 +1279,7 @@ def wrapper(param):
12791279
12801280 @instrument_w_nvtx
12811281 def reduce_partition_and_remove_grads (* notneeded ):
1282+ # Evaluate refresh condition before reenter_backward_if_needed()
12821283 refresh_expected = self .should_refresh_expected_hook_count ()
12831284 # Re-enter backward for subsequent phases in reentrant checkpointing
12841285 self .reenter_backward_if_needed ()
@@ -1307,6 +1308,7 @@ def reduce_partition_and_remove_grads(*notneeded):
13071308 def make_hook (params ):
13081309
13091310 def reduce_leaf_module_grads (module , grad_input , grad_output ):
1311+ # Evaluate refresh condition before reenter_backward_if_needed()
13101312 refresh_expected = self .should_refresh_expected_hook_count ()
13111313 self .reenter_backward_if_needed ()
13121314
Original file line number Diff line number Diff line change @@ -1046,6 +1046,7 @@ def create_gradient_handling_hooks(self):
10461046 def wrapper (param , i ):
10471047
10481048 def grad_handling_hook (* notneeded ):
1049+ # Evaluate refresh condition before reenter_backward_if_needed()
10491050 refresh_expected = self .should_refresh_expected_hook_count ()
10501051 self .reenter_backward_if_needed ()
10511052 self .process_gradients (param , i )
You can’t perform that action at this time.
0 commit comments