Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1046,9 +1046,13 @@ def create_gradient_handling_hooks(self):
def wrapper(param, i):

def grad_handling_hook(*notneeded):
self.reenter_backward_if_needed()
if self._remaining_grad_acc_hooks == 0:
self.reenter_backward_if_needed()
self.process_gradients(param, i)
current_expected = count_used_parameters_in_backward(all_params_requiring_grad)
if self._hooks_fired_this_backward == 0:
current_expected = count_used_parameters_in_backward(all_params_requiring_grad)
else:
current_expected = self._max_expected_hooks_seen
Comment on lines +1052 to +1055

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Recompute expected hooks at each reentrant backward phase

When queue_post_backward_callback() cannot queue a callback (the fallback path in BackwardHookStateManager.update_hook_state_and_maybe_run_epilogue), ZeRO relies on max_expected_hooks_seen growing as new reentrant phases appear. This change only calls count_used_parameters_in_backward() on the very first hook (_hooks_fired_this_backward == 0), then reuses _max_expected_hooks_seen, so phase-2+ hook counts are never observed if they exceed phase 1. In reentrant checkpointing, that can make the fallback path conclude backward too early and run epilogue/post-hooks before later phase hooks fire, leading to premature grad reduction/cleanup.

Useful? React with 👍 / 👎.

self.update_hook_state_and_maybe_run_epilogue(current_expected)

self._grad_acc_hooks.append(register_grad_hook(param, grad_handling_hook))
Expand Down
Loading