diff --git a/deepspeed/runtime/base_optimizer.py b/deepspeed/runtime/base_optimizer.py index 4163f719c1a8..213b5c659499 100644 --- a/deepspeed/runtime/base_optimizer.py +++ b/deepspeed/runtime/base_optimizer.py @@ -109,6 +109,17 @@ def run_grad_acc_post_hooks(self): def enter_backward(self): """Enter backward context. Call at the start of backward pass.""" + # On first real backward entry of a step, reset counters that may have been + # polluted by pre-user-backward hooks (e.g. TiledFusedLogitsLoss calling + # torch.autograd.backward() from forward). Do NOT reset on reentrant + # phase re-entry (backward_seen_this_step == True) so phase-to-phase + # state remains intact. + if self.backward_active_depth == 0 and not self.backward_seen_this_step: + self.hooks_fired_this_backward = 0 + self.max_expected_hooks_seen = 0 + self.remaining_grad_acc_hooks = 0 + self.post_backward_callback_queued = False + self.post_backward_callback_graph_task_id = None self.backward_active_depth += 1 # Track that backward has been active at some point in this step. # This is used to detect subsequent gradient hook phases with reentrant checkpointing. @@ -128,6 +139,22 @@ def reset_for_new_step(self): self.post_backward_callback_queued = False self.post_backward_callback_graph_task_id = None + def should_refresh_expected_hook_count(self): + """Return True when count_used_parameters_in_backward() should be re-evaluated. + + Refresh is needed in two cases: + 1. First hook of a backward (or backward phase): hooks_fired == 0. + 2. A new reentrant phase started: remaining hooks exhausted, we exited + backward, but backward was active earlier this step. + + The predicate must be evaluated BEFORE reenter_backward_if_needed() + because re-entering changes backward_active_depth and hides the + phase-boundary signal. + """ + return (self.hooks_fired_this_backward == 0 + or (self.remaining_grad_acc_hooks == 0 and self.backward_active_depth == 0 + and self.backward_seen_this_step)) + def reenter_backward_if_needed(self): """Re-enter backward context for subsequent phases in reentrant checkpointing. @@ -401,6 +428,10 @@ def clear_backward_seen_flag(self): """Clear the backward seen flag and reset hook counters at the start of each step.""" self._backward_hook_state.reset_for_new_step() + def should_refresh_expected_hook_count(self): + """Return True when count_used_parameters_in_backward() should be re-evaluated.""" + return self._backward_hook_state.should_refresh_expected_hook_count() + def reenter_backward_if_needed(self): """Re-enter backward context for subsequent phases in reentrant checkpointing.""" self._backward_hook_state.reenter_backward_if_needed() diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 4f2a19e7431a..f7562b361aec 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -1279,14 +1279,19 @@ def wrapper(param): @instrument_w_nvtx def reduce_partition_and_remove_grads(*notneeded): + # Evaluate refresh condition before reenter_backward_if_needed() + refresh_expected = self.should_refresh_expected_hook_count() # Re-enter backward for subsequent phases in reentrant checkpointing self.reenter_backward_if_needed() self.reduce_ready_partitions_and_remove_grads(param) # Update hook state and run epilogue if all expected hooks have fired - current_expected = count_used_parameters_in_backward( - non_leaf_params_requiring_grad) + leaf_module_count + if refresh_expected: + current_expected = count_used_parameters_in_backward( + non_leaf_params_requiring_grad) + leaf_module_count + else: + current_expected = self._max_expected_hooks_seen self.update_hook_state_and_maybe_run_epilogue(current_expected) self._grad_acc_hooks.append(register_grad_hook(param, reduce_partition_and_remove_grads)) @@ -1303,6 +1308,8 @@ def reduce_partition_and_remove_grads(*notneeded): def make_hook(params): def reduce_leaf_module_grads(module, grad_input, grad_output): + # Evaluate refresh condition before reenter_backward_if_needed() + refresh_expected = self.should_refresh_expected_hook_count() self.reenter_backward_if_needed() for param in params: @@ -1311,8 +1318,11 @@ def reduce_leaf_module_grads(module, grad_input, grad_output): param.grad = torch.zeros_like(param) self.reduce_ready_partitions_and_remove_grads(param) - current_expected = count_used_parameters_in_backward( - non_leaf_params_requiring_grad) + leaf_module_count + if refresh_expected: + current_expected = count_used_parameters_in_backward( + non_leaf_params_requiring_grad) + leaf_module_count + else: + current_expected = self._max_expected_hooks_seen self.update_hook_state_and_maybe_run_epilogue(current_expected) return reduce_leaf_module_grads diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 005853ebffc0..12f97348a21f 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -1046,9 +1046,14 @@ def create_gradient_handling_hooks(self): def wrapper(param, i): def grad_handling_hook(*notneeded): + # Evaluate refresh condition before reenter_backward_if_needed() + refresh_expected = self.should_refresh_expected_hook_count() self.reenter_backward_if_needed() self.process_gradients(param, i) - current_expected = count_used_parameters_in_backward(all_params_requiring_grad) + if refresh_expected: + current_expected = count_used_parameters_in_backward(all_params_requiring_grad) + else: + current_expected = self._max_expected_hooks_seen self.update_hook_state_and_maybe_run_epilogue(current_expected) self._grad_acc_hooks.append(register_grad_hook(param, grad_handling_hook)) diff --git a/tests/unit/v1/zero/test_zero_hook_count_regression.py b/tests/unit/v1/zero/test_zero_hook_count_regression.py new file mode 100644 index 000000000000..057afc458d2c --- /dev/null +++ b/tests/unit/v1/zero/test_zero_hook_count_regression.py @@ -0,0 +1,104 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +"""Regression tests for count_used_parameters_in_backward() call count. + +Verifies fix for https://github.com/deepspeedai/DeepSpeed/issues/7885: +count_used_parameters_in_backward() was called once per gradient hook +(O(n) calls per backward) instead of once per backward phase (O(1) +for non-reentrant, O(p) for reentrant with p phases). +""" + +import pytest +import torch +from unittest.mock import patch + +import deepspeed +from deepspeed.accelerator import get_accelerator +from unit.common import DistributedTest +from unit.simple_model import SimpleModel, random_dataloader + + +def get_config_dict(zero_stage): + config_dict = { + "train_micro_batch_size_per_gpu": 2, + "gradient_accumulation_steps": 1, + "steps_per_print": 1, + "zero_optimization": { + "stage": zero_stage, + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-3 + } + }, + } + + if zero_stage == 3: + config_dict["zero_optimization"]["stage3_param_persistence_threshold"] = 0 + + if get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} + elif get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} + + return config_dict + + +class TestHookCountRegression(DistributedTest): + """Test that count_used_parameters_in_backward is not called per-hook.""" + world_size = 2 + + @pytest.mark.parametrize("zero_stage", [2, 3]) + def test_non_reentrant_single_count_call(self, zero_stage): + """Non-reentrant backward should call count_used_parameters_in_backward exactly once.""" + hidden_dim = 16 + model = SimpleModel(hidden_dim) + config = get_config_dict(zero_stage) + engine, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config) + + data_loader = random_dataloader(model=engine, total_samples=4, hidden_dim=hidden_dim, device=engine.device) + + # Determine the correct module path to patch based on stage + if zero_stage == 2: + patch_target = "deepspeed.runtime.zero.stage_1_and_2.count_used_parameters_in_backward" + else: + patch_target = "deepspeed.runtime.zero.stage3.count_used_parameters_in_backward" + + call_counts = [] + + for batch in data_loader: + with patch(patch_target, wraps=deepspeed.runtime.utils.count_used_parameters_in_backward) as mock_count: + loss = engine(batch[0], batch[1]) + engine.backward(loss) + call_counts.append(mock_count.call_count) + engine.step() + break + + # Non-reentrant: exactly 1 call per backward + assert call_counts[0] == 1, (f"Expected exactly 1 call to count_used_parameters_in_backward " + f"per backward, got {call_counts[0]}") + + @pytest.mark.parametrize("zero_stage", [2, 3]) + def test_training_step_succeeds_after_fix(self, zero_stage): + """Verify a full training step produces a finite loss after the caching fix.""" + hidden_dim = 16 + model = SimpleModel(hidden_dim) + config = get_config_dict(zero_stage) + engine, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config) + + data_loader = random_dataloader(model=engine, total_samples=8, hidden_dim=hidden_dim, device=engine.device) + + losses = [] + for i, batch in enumerate(data_loader): + loss = engine(batch[0], batch[1]) + assert torch.isfinite(loss), f"Loss is not finite at step {i}: {loss.item()}" + losses.append(loss.item()) + engine.backward(loss) + engine.step() + if i >= 1: + break + + assert len(losses) >= 2, "Expected at least 2 training steps"