From 2560cf45a97d591fbf066648f25e34137ee87bf0 Mon Sep 17 00:00:00 2001 From: rraminen Date: Mon, 2 Mar 2026 21:32:11 +0000 Subject: [PATCH 1/3] Fix grad_handling_hook --- deepspeed/runtime/zero/stage_1_and_2.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 005853ebffc0..80183160f7a6 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -1048,8 +1048,9 @@ def wrapper(param, i): def grad_handling_hook(*notneeded): self.reenter_backward_if_needed() self.process_gradients(param, i) - current_expected = count_used_parameters_in_backward(all_params_requiring_grad) - self.update_hook_state_and_maybe_run_epilogue(current_expected) + if self._hooks_fired_this_backward == 0: + self.current_expected_hooks = count_used_parameters_in_backward(all_params_requiring_grad) + self.update_hook_state_and_maybe_run_epilogue(self.current_expected_hooks) self._grad_acc_hooks.append(register_grad_hook(param, grad_handling_hook)) From 51bec2ba769c2bc82a88e3b5f572aa6eb1d266fc Mon Sep 17 00:00:00 2001 From: rraminen Date: Tue, 3 Mar 2026 21:05:46 +0000 Subject: [PATCH 2/3] Update if condition --- deepspeed/runtime/zero/stage_1_and_2.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 80183160f7a6..84241da33898 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -1048,9 +1048,10 @@ def wrapper(param, i): def grad_handling_hook(*notneeded): self.reenter_backward_if_needed() self.process_gradients(param, i) - if self._hooks_fired_this_backward == 0: + if self._remaining_grad_acc_hooks == 0: self.current_expected_hooks = count_used_parameters_in_backward(all_params_requiring_grad) self.update_hook_state_and_maybe_run_epilogue(self.current_expected_hooks) + self._remaining_grad_acc_hooks -= 1 self._grad_acc_hooks.append(register_grad_hook(param, grad_handling_hook)) From 8e4da49a9d1c0ceb6069a4d2724de15a34e70f38 Mon Sep 17 00:00:00 2001 From: rraminen Date: Wed, 4 Mar 2026 16:55:45 +0000 Subject: [PATCH 3/3] Use _remaining_grad_acc_hooks and _hooks_fired_this_backward --- deepspeed/runtime/zero/stage_1_and_2.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 84241da33898..5bc831219345 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -1046,12 +1046,14 @@ def create_gradient_handling_hooks(self): def wrapper(param, i): def grad_handling_hook(*notneeded): - self.reenter_backward_if_needed() - self.process_gradients(param, i) if self._remaining_grad_acc_hooks == 0: - self.current_expected_hooks = count_used_parameters_in_backward(all_params_requiring_grad) - self.update_hook_state_and_maybe_run_epilogue(self.current_expected_hooks) - self._remaining_grad_acc_hooks -= 1 + self.reenter_backward_if_needed() + self.process_gradients(param, i) + if self._hooks_fired_this_backward == 0: + current_expected = count_used_parameters_in_backward(all_params_requiring_grad) + else: + current_expected = self._max_expected_hooks_seen + self.update_hook_state_and_maybe_run_epilogue(current_expected) self._grad_acc_hooks.append(register_grad_hook(param, grad_handling_hook))