-
Notifications
You must be signed in to change notification settings - Fork 25
Add fsdp2 fp8 unit tests TE 2.10 #492
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
sudhu2k
wants to merge
10
commits into
dev
Choose a base branch
from
sudhu/FSDP2_unit_tests_fix_2.10
base: dev
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
e8e63b1
Initial commit
sudhu2k c3e33e3
- Updated test functions to include new parameters for FP8 autocastin…
db36143
Refactor quantizer state checks and optimize tensor initialization in…
sudhu2k 13b4007
Refactor assertion function in FP8 tests to use relative and absolute…
sudhu2k d91241f
Update test tolerances for FP8 configurations to account for potentia…
sudhu2k 2b8818d
Ensure proper newline at the end of the test_torch_fsdp2_fp8.py file …
sudhu2k 8964d56
Refactor tolerance calculations.
sudhu2k 54938d9
Refactor model initialization and autocasting logic in FSDP2 FP8 test…
sudhu2k f771955
Merge remote-tracking branch 'origin/dev' into sudhu/FSDP2_unit_tests…
sudhu2k c1949d3
Fix FusedAdam DTensor state initialization for FSDP2
sudhu2k File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,49 +17,82 @@ | |
|
|
||
| NUM_PROCS: int = torch.cuda.device_count() | ||
|
|
||
| def assertEqual( | ||
| l1: List[torch.Tensor], l2: List[torch.Tensor]) -> bool: | ||
| """Ensures two lists are exactly equal.""" | ||
| def assert_allclose( | ||
| l1: List[torch.Tensor], l2: List[torch.Tensor], atol: float, rtol: float = None | ||
| ) -> bool: | ||
| """Ensures two lists are equal.""" | ||
| assert len(l1) == len(l2), "Unequal number of outputs." | ||
| tols = dict(atol=atol) | ||
| tols["rtol"] = rtol if rtol is not None else 0 | ||
| tol = tols["atol"] + (tols["rtol"] * torch.abs(l2)) | ||
| for i, (t1, t2) in enumerate(zip(l1, l2)): | ||
| result = torch.allclose(t1, t2, atol=0, rtol=0) | ||
| result = torch.allclose(t1, t2, **tols) | ||
| if not result: | ||
| diff = torch.abs(t1 - t2) | ||
| exceed_mask = diff > 0 | ||
| if exceed_mask.any(): | ||
| indices = torch.nonzero(exceed_mask, as_tuple=True) | ||
| max_diff = diff[exceed_mask].max() | ||
| max_idx = (diff[exceed_mask] == max_diff).nonzero(as_tuple=True)[0][0] | ||
| max_location = [idx[max_idx].item() for idx in indices] | ||
| if diff.dim() == 0: | ||
| max_diff = diff | ||
| max_location = [] | ||
| msg = ( | ||
| f"Outputs not close enough in tensor at idx={i}. " | ||
| f"Maximum difference at location {max_location} " | ||
| f"with {t1[exceed_mask][max_idx].item()} vs {t2[exceed_mask][max_idx].item()} " | ||
| f"(diff {max_diff.item()})." | ||
| f"Outputs not close enough in scalar tensor at idx={i}. " | ||
| f"Difference: {max_diff.item()}." | ||
| ) | ||
| else: | ||
| exceed_mask = diff > tol | ||
|
|
||
| if exceed_mask.any(): | ||
| indices = torch.nonzero(exceed_mask, as_tuple=True) | ||
| max_diff = diff[exceed_mask].max() | ||
| max_idx = (diff[exceed_mask] == max_diff).nonzero(as_tuple=True)[0][0] | ||
| max_location = [idx[max_idx].item() for idx in indices] | ||
| msg = ( | ||
| f"Outputs not close enough in tensor at idx={i}. " | ||
| f"Maximum difference at location {max_location} " | ||
| f"with {t1[exceed_mask][max_idx].item()} vs {t2[exceed_mask][max_idx].item()} " | ||
| f"(diff {max_diff.item()})." | ||
| ) | ||
| raise AssertionError(msg) | ||
|
|
||
| def _run_test(fp_init, recipe): | ||
| def _run_test(quantized_init, autocast, recipe): | ||
| test_dir = Path(__file__).parent.resolve() | ||
| fsdp_script = test_dir / "run_fsdp2_fp8_model.py" | ||
|
|
||
| test_cmd = ["torchrun", f"--nproc_per_node={NUM_PROCS}", "--master-port=29501", str(fsdp_script)] | ||
|
|
||
| if fp_init: | ||
| test_cmd += ["--fp8-init"] | ||
| test_cmd += ["--recipe", recipe] | ||
| if quantized_init: | ||
| test_cmd += ["--quantized-init"] | ||
| if autocast: | ||
| test_cmd += ["--autocast"] | ||
| if autocast or quantized_init: | ||
| test_cmd += ["--recipe", recipe] | ||
|
|
||
| subprocess.run(test_cmd + ['--use-fsdp2','--gradients-save-file', 'all_iters_fsdp2.pt'], env=os.environ, check=True) | ||
| subprocess.run(test_cmd + ['--gradients-save-file', 'all_iters_dp.pt'], env=os.environ, check=True) | ||
|
|
||
| # Load outputs | ||
| output_fsdp = torch.load("all_iters_fsdp2.pt", map_location="cpu") | ||
| output_dp = torch.load("all_iters_dp.pt", map_location="cpu") | ||
| atol = 0 | ||
| rtol = 0 | ||
| # Use relaxed tolerance when FSDP2 and DDP are not guaranteed to be bit-identical: | ||
| # | ||
| # - quantized_init=True: After each optimizer step, FP8 weights are re-quantized | ||
| # from FP32 master weights. Hence we use a relaxed tolerance. | ||
| # | ||
| # - No FP8 (quantized_init=False, autocast=False): gradient reduction order differs | ||
| # (all-reduce vs reduce-scatter), so float non-associativity produces last-bit | ||
| # differences in the reduced gradients and updated weights. Hence we use a relaxed tolerance. | ||
| # | ||
| # When autocast=True and quantized_init=False, FP8 quantization happens after the | ||
| # FSDP2 AllGather reconstructs the full weight, so both paths compute identical | ||
| # scales and produce bit-identical FP8 GEMMs — strict tolerance (0) is used. | ||
| if quantized_init or (not quantized_init and not autocast): | ||
| atol = 1e-6 | ||
| rtol = 5e-5 | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If our reference is ddp with the same fp8 primary weight, then the same cast from fp32 master weight to fp8 happens in both target and reference flow. Then we will have exact match? |
||
|
|
||
| for idx, (te_output_no_cache, te_output_cache) in enumerate(zip(output_fsdp, output_dp)): | ||
|
|
||
| print(f"Comparing FSDP {te_output_no_cache[0]}, DDP {te_output_cache[0]} at index {idx}...") | ||
| assertEqual(te_output_no_cache[1], te_output_cache[1]) # expects exact match | ||
| assert_allclose(te_output_no_cache[1], te_output_cache[1], atol=atol, rtol=rtol) | ||
| print(f"Tensor at index {idx} passed comparison.") | ||
|
|
||
|
|
||
|
|
@@ -70,13 +103,24 @@ def cleanup_artifacts(): | |
| if os.path.exists(fname): | ||
| os.remove(fname) | ||
|
|
||
| # Define test cases explicitly | ||
| test_cases = [] | ||
| # All FP8 enabled cases (all recipes) | ||
| for quantized_init in [True, False]: | ||
| for autocast in [True, False]: | ||
| if quantized_init or autocast: | ||
| for recipe in ["delayed", "current", "mxfp8"]: | ||
| test_cases.append((quantized_init, autocast, recipe)) | ||
| # FP8 disabled case (only once) | ||
| test_cases.append((False, False, "delayed")) | ||
|
|
||
|
|
||
| @pytest.mark.skipif(NUM_PROCS < 4, reason="Requires 4+ GPUs") | ||
| @pytest.mark.skipif(NUM_PROCS % 2 != 0, reason="Requires even number of GPUs") | ||
| @pytest.mark.skipif(not torch_version() >= (2, 4, 0), reason="Requires PyTorch 2.4.0+") | ||
| @pytest.mark.parametrize("fp8_init", ([False])) | ||
| @pytest.mark.parametrize("recipe", (["delayed", "current", "mxfp8"])) | ||
| @pytest.mark.parametrize("quantized_init, autocast, recipe", test_cases) | ||
| @pytest.mark.usefixtures("cleanup_artifacts") | ||
| def test_distributed(fp8_init, recipe): | ||
| def test_distributed(quantized_init, autocast, recipe): | ||
|
|
||
| batch_size = 2048 | ||
| input_size = 2048 | ||
|
|
@@ -96,12 +140,12 @@ def test_distributed(fp8_init, recipe): | |
| if torch.cuda.device_count() < 4: | ||
| pytest.skip("FSDP2 test requires at least 4 GPUs") | ||
|
|
||
| if fp8_init and not fp8_available: | ||
| if quantized_init and not fp8_available: | ||
| pytest.skip(reason_for_no_fp8) | ||
| if recipe == "mxfp8" and not mxfp8_available: | ||
| pytest.skip(reason_for_no_mxfp8) | ||
|
|
||
| _run_test(fp8_init, recipe) | ||
| _run_test(quantized_init, autocast, recipe) | ||
|
|
||
|
|
||
| def test_dummy() -> None: | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does with te.fp8_autocast(enabled=args.fp8_autocast,.. ) do the same?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It does do the same but since with TEv2.10, te.fp8_autocast is replaced with te.autocast, I've made the change to be consistent.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So will 'with te.autocast(enabled=args.fp8_autocast, recipe=...)' do the same as if/else?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it should. I'll make the changes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.