-
Notifications
You must be signed in to change notification settings - Fork 4.7k
Fix hook count performance regression from v0.18.5 #7886
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
tohtana
merged 5 commits into
deepspeedai:master
from
tohtana:tohtana/fix-perf-regression
Mar 5, 2026
+155
−5
Merged
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
e1b99f7
Add should_refresh_expected_hook_count and harden enter_backward
tohtana 91b591b
Fix ZeRO-2 hook to cache count_used_parameters_in_backward result
tohtana d3a6e99
Fix ZeRO-3 hooks to cache count_used_parameters_in_backward result
tohtana e1b41ee
Add regression tests for hook count performance fix
tohtana c43e423
Add comment clarifying refresh-before-reenter ordering
tohtana File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,104 @@ | ||
| # Copyright (c) Microsoft Corporation. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| # DeepSpeed Team | ||
| """Regression tests for count_used_parameters_in_backward() call count. | ||
|
|
||
| Verifies fix for https://github.com/deepspeedai/DeepSpeed/issues/7885: | ||
| count_used_parameters_in_backward() was called once per gradient hook | ||
| (O(n) calls per backward) instead of once per backward phase (O(1) | ||
| for non-reentrant, O(p) for reentrant with p phases). | ||
| """ | ||
|
|
||
| import pytest | ||
| import torch | ||
| from unittest.mock import patch | ||
|
|
||
| import deepspeed | ||
| from deepspeed.accelerator import get_accelerator | ||
| from unit.common import DistributedTest | ||
| from unit.simple_model import SimpleModel, random_dataloader | ||
|
|
||
|
|
||
| def get_config_dict(zero_stage): | ||
| config_dict = { | ||
| "train_micro_batch_size_per_gpu": 2, | ||
| "gradient_accumulation_steps": 1, | ||
| "steps_per_print": 1, | ||
| "zero_optimization": { | ||
| "stage": zero_stage, | ||
| }, | ||
| "optimizer": { | ||
| "type": "Adam", | ||
| "params": { | ||
| "lr": 1e-3 | ||
| } | ||
| }, | ||
| } | ||
|
|
||
| if zero_stage == 3: | ||
| config_dict["zero_optimization"]["stage3_param_persistence_threshold"] = 0 | ||
|
|
||
| if get_accelerator().is_bf16_supported(): | ||
| config_dict["bf16"] = {"enabled": True} | ||
| elif get_accelerator().is_fp16_supported(): | ||
| config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} | ||
|
|
||
| return config_dict | ||
|
|
||
|
|
||
| class TestHookCountRegression(DistributedTest): | ||
| """Test that count_used_parameters_in_backward is not called per-hook.""" | ||
| world_size = 2 | ||
|
|
||
| @pytest.mark.parametrize("zero_stage", [2, 3]) | ||
| def test_non_reentrant_single_count_call(self, zero_stage): | ||
| """Non-reentrant backward should call count_used_parameters_in_backward exactly once.""" | ||
| hidden_dim = 16 | ||
| model = SimpleModel(hidden_dim) | ||
| config = get_config_dict(zero_stage) | ||
| engine, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config) | ||
|
|
||
| data_loader = random_dataloader(model=engine, total_samples=4, hidden_dim=hidden_dim, device=engine.device) | ||
|
|
||
| # Determine the correct module path to patch based on stage | ||
| if zero_stage == 2: | ||
| patch_target = "deepspeed.runtime.zero.stage_1_and_2.count_used_parameters_in_backward" | ||
| else: | ||
| patch_target = "deepspeed.runtime.zero.stage3.count_used_parameters_in_backward" | ||
|
|
||
| call_counts = [] | ||
|
|
||
| for batch in data_loader: | ||
| with patch(patch_target, wraps=deepspeed.runtime.utils.count_used_parameters_in_backward) as mock_count: | ||
| loss = engine(batch[0], batch[1]) | ||
| engine.backward(loss) | ||
| call_counts.append(mock_count.call_count) | ||
| engine.step() | ||
| break | ||
|
|
||
| # Non-reentrant: exactly 1 call per backward | ||
| assert call_counts[0] == 1, (f"Expected exactly 1 call to count_used_parameters_in_backward " | ||
| f"per backward, got {call_counts[0]}") | ||
|
|
||
| @pytest.mark.parametrize("zero_stage", [2, 3]) | ||
| def test_training_step_succeeds_after_fix(self, zero_stage): | ||
| """Verify a full training step produces a finite loss after the caching fix.""" | ||
| hidden_dim = 16 | ||
| model = SimpleModel(hidden_dim) | ||
| config = get_config_dict(zero_stage) | ||
| engine, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config) | ||
|
|
||
| data_loader = random_dataloader(model=engine, total_samples=8, hidden_dim=hidden_dim, device=engine.device) | ||
|
|
||
| losses = [] | ||
| for i, batch in enumerate(data_loader): | ||
| loss = engine(batch[0], batch[1]) | ||
| assert torch.isfinite(loss), f"Loss is not finite at step {i}: {loss.item()}" | ||
| losses.append(loss.item()) | ||
| engine.backward(loss) | ||
| engine.step() | ||
| if i >= 1: | ||
| break | ||
|
|
||
| assert len(losses) >= 2, "Expected at least 2 training steps" |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @tohtana the
refresh_expected = self.should_refresh_expected_hook_count()line is used on L1289, why it is placed before L1284 (reenter_backward_if_needed()) and L1286 (reduce_ready_partitions_and_remove_grads()). Is there implicit dependency here?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for checking, @delock!
Yes, there's actually a dependency.
should_refresh_expected_hook_count()detects reentrant phase boundaries by checkingbackward_active_depth == 0.reenter_backward_if_needed()incrementsbackward_active_depthfrom 0 → 1 when it detects a new phase. If we calledreenter_backwardfirst, it would setbackward_active_depth = 1before the predicate runs, making the condition always false.I added a comment to clarify it.