-
Notifications
You must be signed in to change notification settings - Fork 22
Enhance GroupedLinear with integrating AITER triton kernels #413
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
sudhu2k
wants to merge
6
commits into
dev
Choose a base branch
from
sudhu/aiter_grouped_gemm_integration
base: dev
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
fa41505
Enhance GroupedLinear with Triton kernel support and update setup.py …
4c97e84
Added copyright and fixed env var for triton grouped gemm
f029d98
Added grouped_linear module test with triton
d50f86a
Fix for unit test, set back the env variable to 0
bf8e167
Update numerical test tolerances for float32 in grouped linear accura…
ba11350
Relaxed rtol tolerance for float32 in grouped linear accuracy test
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -170,7 +170,10 @@ def setup_requirements() -> Tuple[List[str], List[str]]: | |
| install_requires, test_requires = setup_requirements() | ||
| ext_modules = [setup_common_extension()] | ||
| cmdclass = {"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist} | ||
| package_data = {"": ["VERSION.txt"]} | ||
| package_data = { | ||
| "": ["VERSION.txt"], | ||
| "transformer_engine.pytorch.triton_kernels.gmm": ["configs/*.json"], | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. They should be part of pytorch extension installation not TE core |
||
| } | ||
| include_package_data = True | ||
| extras_require = {"test": test_requires} | ||
|
|
||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,6 +11,7 @@ | |
| import pytest | ||
| import random | ||
|
|
||
| from triton_kernels.test_common import get_tolerances | ||
| import torch | ||
| import torch.nn as nn | ||
| from torch.nn import Parameter | ||
|
|
@@ -2016,6 +2017,118 @@ def _test_grouped_linear_accuracy( | |
| return outputs | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("dtype", param_types, ids=str) | ||
| @pytest.mark.parametrize("num_gemms", [3, 6]) | ||
| @pytest.mark.parametrize("bs", batch_sizes) | ||
| @pytest.mark.parametrize("model", ["126m"]) | ||
| @pytest.mark.parametrize("recipe", [None]) | ||
| @pytest.mark.parametrize("fp8_model_params", [False]) | ||
| @pytest.mark.parametrize("fuse_wgrad_accumulation", [False]) | ||
| @pytest.mark.parametrize("bias", all_boolean) | ||
| @pytest.mark.parametrize("delay_wgrad_compute", all_boolean) | ||
| def test_grouped_linear_triton_accuracy( | ||
| dtype, | ||
| num_gemms, | ||
| bs, | ||
| model, | ||
| recipe, | ||
| fp8_model_params, | ||
| fuse_wgrad_accumulation, | ||
| bias, | ||
| delay_wgrad_compute, | ||
| parallel_mode=None, | ||
| ): | ||
| os.environ["NVTE_USE_GROUPED_GEMM_TRITON"] = "1" | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This env won't be cleared if the test is skipped of failed |
||
| fp8 = recipe is not None | ||
|
|
||
| if IS_HIP_EXTENSION: | ||
| if dtype not in (torch.float32,) and fuse_wgrad_accumulation and not fp8: | ||
| pytest.skip(f"Rocm does not support fused wgrad accumulation for {dtype}.") | ||
| if fp8 and not fp8_available: | ||
| pytest.skip(reason_for_no_fp8) | ||
| if fp8 and recipe.mxfp8() and not mxfp8_available: | ||
| pytest.skip(reason_for_no_mxfp8) | ||
| if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: | ||
| pytest.skip("FP8 parameters are not supported in debug mode.") | ||
| if fp8 and recipe.float8_block_scaling() and not fp8_block_scaling_available: | ||
| pytest.skip(reason_for_no_fp8_block_scaling) | ||
|
|
||
| config = model_configs[model] | ||
| if config.seq_len % 16 != 0 and fp8: | ||
| pytest.skip("FP8 requires sequence length to be divisible by 16.") | ||
|
|
||
| with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): | ||
| grouped_linear = GroupedLinear( | ||
| num_gemms, | ||
| config.hidden_size, | ||
| 4 * config.hidden_size, | ||
| bias=bias, | ||
| params_dtype=dtype, | ||
| parallel_mode=parallel_mode, | ||
| device="cuda", | ||
| fuse_wgrad_accumulation=fuse_wgrad_accumulation, | ||
| delay_wgrad_compute=delay_wgrad_compute, | ||
| save_original_input=False, | ||
| ).eval() | ||
| sequential_linear = torch.nn.ModuleList( | ||
| [ | ||
| Linear( | ||
| config.hidden_size, | ||
| 4 * config.hidden_size, | ||
| bias=bias, | ||
| params_dtype=dtype, | ||
| parallel_mode=parallel_mode, | ||
| device="cuda", | ||
| fuse_wgrad_accumulation=fuse_wgrad_accumulation, | ||
| ).eval() | ||
| for _ in range(num_gemms) | ||
| ] | ||
| ) | ||
|
|
||
| # Share params | ||
| with torch.no_grad(): | ||
| for i in range(num_gemms): | ||
| sequential_linear[i].weight = Parameter(getattr(grouped_linear, f"weight{i}").clone()) | ||
| if bias: | ||
| sequential_linear[i].bias = Parameter(getattr(grouped_linear, f"bias{i}").clone()) | ||
| if fuse_wgrad_accumulation: | ||
| weight_i = getattr(grouped_linear, f"weight{i}") | ||
| weight_i.main_grad = torch.rand_like(weight_i, dtype=torch.float32) | ||
| sequential_linear[i].weight.main_grad = weight_i.main_grad.clone() | ||
|
|
||
| outputs_ref = _test_grouped_linear_accuracy( | ||
| sequential_linear, | ||
| num_gemms, | ||
| bs, | ||
| dtype, | ||
| config, | ||
| recipe, | ||
| fp8, | ||
| fuse_wgrad_accumulation, | ||
| delay_wgrad_compute, | ||
| ) | ||
| outputs = _test_grouped_linear_accuracy( | ||
| grouped_linear, | ||
| num_gemms, | ||
| bs, | ||
| dtype, | ||
| config, | ||
| recipe, | ||
| fp8, | ||
| fuse_wgrad_accumulation, | ||
| delay_wgrad_compute, | ||
| ) | ||
|
|
||
| # Shoule be bit-wise match | ||
| atol, rtol = get_tolerances(dtype) | ||
| if dtype == torch.float32: | ||
| atol = 2.6e-6 | ||
| rtol = 5e-2 | ||
| for i, (o, o_ref) in enumerate(zip(outputs, outputs_ref)): | ||
| torch.testing.assert_close(o, o_ref, rtol=rtol, atol=atol) | ||
| os.environ["NVTE_USE_GROUPED_GEMM_TRITON"] = "0" | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("dtype", param_types, ids=str) | ||
| @pytest.mark.parametrize("num_gemms", [3, 6]) | ||
| @pytest.mark.parametrize("bs", batch_sizes) | ||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please move it two lines higher, alphabetical sort helps to find tests