diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 005853ebffc0..5bc831219345 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -1046,9 +1046,13 @@ def create_gradient_handling_hooks(self): def wrapper(param, i): def grad_handling_hook(*notneeded): - self.reenter_backward_if_needed() + if self._remaining_grad_acc_hooks == 0: + self.reenter_backward_if_needed() self.process_gradients(param, i) - current_expected = count_used_parameters_in_backward(all_params_requiring_grad) + if self._hooks_fired_this_backward == 0: + current_expected = count_used_parameters_in_backward(all_params_requiring_grad) + else: + current_expected = self._max_expected_hooks_seen self.update_hook_state_and_maybe_run_epilogue(current_expected) self._grad_acc_hooks.append(register_grad_hook(param, grad_handling_hook))