Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions deepspeed/runtime/base_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,17 @@ def run_grad_acc_post_hooks(self):

def enter_backward(self):
"""Enter backward context. Call at the start of backward pass."""
# On first real backward entry of a step, reset counters that may have been
# polluted by pre-user-backward hooks (e.g. TiledFusedLogitsLoss calling
# torch.autograd.backward() from forward). Do NOT reset on reentrant
# phase re-entry (backward_seen_this_step == True) so phase-to-phase
# state remains intact.
if self.backward_active_depth == 0 and not self.backward_seen_this_step:
self.hooks_fired_this_backward = 0
self.max_expected_hooks_seen = 0
self.remaining_grad_acc_hooks = 0
self.post_backward_callback_queued = False
self.post_backward_callback_graph_task_id = None
self.backward_active_depth += 1
# Track that backward has been active at some point in this step.
# This is used to detect subsequent gradient hook phases with reentrant checkpointing.
Expand All @@ -128,6 +139,22 @@ def reset_for_new_step(self):
self.post_backward_callback_queued = False
self.post_backward_callback_graph_task_id = None

def should_refresh_expected_hook_count(self):
"""Return True when count_used_parameters_in_backward() should be re-evaluated.

Refresh is needed in two cases:
1. First hook of a backward (or backward phase): hooks_fired == 0.
2. A new reentrant phase started: remaining hooks exhausted, we exited
backward, but backward was active earlier this step.

The predicate must be evaluated BEFORE reenter_backward_if_needed()
because re-entering changes backward_active_depth and hides the
phase-boundary signal.
"""
return (self.hooks_fired_this_backward == 0
or (self.remaining_grad_acc_hooks == 0 and self.backward_active_depth == 0
and self.backward_seen_this_step))

def reenter_backward_if_needed(self):
"""Re-enter backward context for subsequent phases in reentrant checkpointing.

Expand Down Expand Up @@ -401,6 +428,10 @@ def clear_backward_seen_flag(self):
"""Clear the backward seen flag and reset hook counters at the start of each step."""
self._backward_hook_state.reset_for_new_step()

def should_refresh_expected_hook_count(self):
"""Return True when count_used_parameters_in_backward() should be re-evaluated."""
return self._backward_hook_state.should_refresh_expected_hook_count()

def reenter_backward_if_needed(self):
"""Re-enter backward context for subsequent phases in reentrant checkpointing."""
self._backward_hook_state.reenter_backward_if_needed()
Expand Down
18 changes: 14 additions & 4 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1279,14 +1279,19 @@ def wrapper(param):

@instrument_w_nvtx
def reduce_partition_and_remove_grads(*notneeded):
# Evaluate refresh condition before reenter_backward_if_needed()
refresh_expected = self.should_refresh_expected_hook_count()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @tohtana the refresh_expected = self.should_refresh_expected_hook_count() line is used on L1289, why it is placed before L1284 (reenter_backward_if_needed()) and L1286 (reduce_ready_partitions_and_remove_grads()). Is there implicit dependency here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for checking, @delock!

Yes, there's actually a dependency. should_refresh_expected_hook_count() detects reentrant phase boundaries by checking backward_active_depth == 0. reenter_backward_if_needed() increments backward_active_depth from 0 → 1 when it detects a new phase. If we called reenter_backward first, it would set backward_active_depth = 1 before the predicate runs, making the condition always false.

I added a comment to clarify it.

# Re-enter backward for subsequent phases in reentrant checkpointing
self.reenter_backward_if_needed()

self.reduce_ready_partitions_and_remove_grads(param)

# Update hook state and run epilogue if all expected hooks have fired
current_expected = count_used_parameters_in_backward(
non_leaf_params_requiring_grad) + leaf_module_count
if refresh_expected:
current_expected = count_used_parameters_in_backward(
non_leaf_params_requiring_grad) + leaf_module_count
else:
current_expected = self._max_expected_hooks_seen
self.update_hook_state_and_maybe_run_epilogue(current_expected)

self._grad_acc_hooks.append(register_grad_hook(param, reduce_partition_and_remove_grads))
Expand All @@ -1303,6 +1308,8 @@ def reduce_partition_and_remove_grads(*notneeded):
def make_hook(params):

def reduce_leaf_module_grads(module, grad_input, grad_output):
# Evaluate refresh condition before reenter_backward_if_needed()
refresh_expected = self.should_refresh_expected_hook_count()
self.reenter_backward_if_needed()

for param in params:
Expand All @@ -1311,8 +1318,11 @@ def reduce_leaf_module_grads(module, grad_input, grad_output):
param.grad = torch.zeros_like(param)
self.reduce_ready_partitions_and_remove_grads(param)

current_expected = count_used_parameters_in_backward(
non_leaf_params_requiring_grad) + leaf_module_count
if refresh_expected:
current_expected = count_used_parameters_in_backward(
non_leaf_params_requiring_grad) + leaf_module_count
else:
current_expected = self._max_expected_hooks_seen
self.update_hook_state_and_maybe_run_epilogue(current_expected)

return reduce_leaf_module_grads
Expand Down
7 changes: 6 additions & 1 deletion deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1046,9 +1046,14 @@ def create_gradient_handling_hooks(self):
def wrapper(param, i):

def grad_handling_hook(*notneeded):
# Evaluate refresh condition before reenter_backward_if_needed()
refresh_expected = self.should_refresh_expected_hook_count()
self.reenter_backward_if_needed()
self.process_gradients(param, i)
current_expected = count_used_parameters_in_backward(all_params_requiring_grad)
if refresh_expected:
current_expected = count_used_parameters_in_backward(all_params_requiring_grad)
else:
current_expected = self._max_expected_hooks_seen
self.update_hook_state_and_maybe_run_epilogue(current_expected)

self._grad_acc_hooks.append(register_grad_hook(param, grad_handling_hook))
Expand Down
104 changes: 104 additions & 0 deletions tests/unit/v1/zero/test_zero_hook_count_regression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
"""Regression tests for count_used_parameters_in_backward() call count.

Verifies fix for https://github.com/deepspeedai/DeepSpeed/issues/7885:
count_used_parameters_in_backward() was called once per gradient hook
(O(n) calls per backward) instead of once per backward phase (O(1)
for non-reentrant, O(p) for reentrant with p phases).
"""

import pytest
import torch
from unittest.mock import patch

import deepspeed
from deepspeed.accelerator import get_accelerator
from unit.common import DistributedTest
from unit.simple_model import SimpleModel, random_dataloader


def get_config_dict(zero_stage):
config_dict = {
"train_micro_batch_size_per_gpu": 2,
"gradient_accumulation_steps": 1,
"steps_per_print": 1,
"zero_optimization": {
"stage": zero_stage,
},
"optimizer": {
"type": "Adam",
"params": {
"lr": 1e-3
}
},
}

if zero_stage == 3:
config_dict["zero_optimization"]["stage3_param_persistence_threshold"] = 0

if get_accelerator().is_bf16_supported():
config_dict["bf16"] = {"enabled": True}
elif get_accelerator().is_fp16_supported():
config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8}

return config_dict


class TestHookCountRegression(DistributedTest):
"""Test that count_used_parameters_in_backward is not called per-hook."""
world_size = 2

@pytest.mark.parametrize("zero_stage", [2, 3])
def test_non_reentrant_single_count_call(self, zero_stage):
"""Non-reentrant backward should call count_used_parameters_in_backward exactly once."""
hidden_dim = 16
model = SimpleModel(hidden_dim)
config = get_config_dict(zero_stage)
engine, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config)

data_loader = random_dataloader(model=engine, total_samples=4, hidden_dim=hidden_dim, device=engine.device)

# Determine the correct module path to patch based on stage
if zero_stage == 2:
patch_target = "deepspeed.runtime.zero.stage_1_and_2.count_used_parameters_in_backward"
else:
patch_target = "deepspeed.runtime.zero.stage3.count_used_parameters_in_backward"

call_counts = []

for batch in data_loader:
with patch(patch_target, wraps=deepspeed.runtime.utils.count_used_parameters_in_backward) as mock_count:
loss = engine(batch[0], batch[1])
engine.backward(loss)
call_counts.append(mock_count.call_count)
engine.step()
break

# Non-reentrant: exactly 1 call per backward
assert call_counts[0] == 1, (f"Expected exactly 1 call to count_used_parameters_in_backward "
f"per backward, got {call_counts[0]}")

@pytest.mark.parametrize("zero_stage", [2, 3])
def test_training_step_succeeds_after_fix(self, zero_stage):
"""Verify a full training step produces a finite loss after the caching fix."""
hidden_dim = 16
model = SimpleModel(hidden_dim)
config = get_config_dict(zero_stage)
engine, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config)

data_loader = random_dataloader(model=engine, total_samples=8, hidden_dim=hidden_dim, device=engine.device)

losses = []
for i, batch in enumerate(data_loader):
loss = engine(batch[0], batch[1])
assert torch.isfinite(loss), f"Loss is not finite at step {i}: {loss.item()}"
losses.append(loss.item())
engine.backward(loss)
engine.step()
if i >= 1:
break

assert len(losses) >= 2, "Expected at least 2 training steps"
Loading