Skip to content

Commit c43e423

Browse files
committed
Add comment clarifying refresh-before-reenter ordering
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
1 parent e1b41ee commit c43e423

2 files changed

Lines changed: 3 additions & 0 deletions

File tree

deepspeed/runtime/zero/stage3.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1279,6 +1279,7 @@ def wrapper(param):
12791279

12801280
@instrument_w_nvtx
12811281
def reduce_partition_and_remove_grads(*notneeded):
1282+
# Evaluate refresh condition before reenter_backward_if_needed()
12821283
refresh_expected = self.should_refresh_expected_hook_count()
12831284
# Re-enter backward for subsequent phases in reentrant checkpointing
12841285
self.reenter_backward_if_needed()
@@ -1307,6 +1308,7 @@ def reduce_partition_and_remove_grads(*notneeded):
13071308
def make_hook(params):
13081309

13091310
def reduce_leaf_module_grads(module, grad_input, grad_output):
1311+
# Evaluate refresh condition before reenter_backward_if_needed()
13101312
refresh_expected = self.should_refresh_expected_hook_count()
13111313
self.reenter_backward_if_needed()
13121314

deepspeed/runtime/zero/stage_1_and_2.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1046,6 +1046,7 @@ def create_gradient_handling_hooks(self):
10461046
def wrapper(param, i):
10471047

10481048
def grad_handling_hook(*notneeded):
1049+
# Evaluate refresh condition before reenter_backward_if_needed()
10491050
refresh_expected = self.should_refresh_expected_hook_count()
10501051
self.reenter_backward_if_needed()
10511052
self.process_gradients(param, i)

0 commit comments

Comments
 (0)