-
Notifications
You must be signed in to change notification settings - Fork 653
[JAX] Add bias support for v2 grouped GEMM path #2744
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1685,6 +1685,17 @@ def impl( | |
| use_async_d2h_group_sizes=use_async_d2h_group_sizes, | ||
| use_v2_ffi=use_v2_ffi, | ||
| ) | ||
| if use_v2_ffi and has_bias: | ||
| # The C++ FFI for v2 grouped GEMM does not support bias, so we apply it here in | ||
| # pure JAX. Groups are contiguous, so we build a per-token expert index via | ||
| # jnp.repeat and gather the corresponding bias row for each token. | ||
| num_groups = group_sizes.shape[0] | ||
| segment_ids = jnp.repeat( | ||
| jnp.arange(num_groups, dtype=jnp.int32), | ||
| group_sizes, | ||
| total_repeat_length=M, | ||
| ) | ||
| out = out + bias[segment_ids].astype(out.dtype) | ||
|
Comment on lines
+1688
to
+1698
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Bias application is incorrect when When While this combination ( if use_v2_ffi and has_bias:
assert not is_grouped_dense_wgrad, (
"Bias is not supported for the grouped dense wgrad path with v2 FFI."
)
# ... rest of bias application |
||
| return (out,) | ||
|
|
||
|
|
||
|
|
@@ -2008,18 +2019,17 @@ def grouped_gemm_copy_group_sizes( | |
| def _can_use_v2_grouped_gemm( | ||
| scaling_mode: ScalingMode, | ||
| dtype: jnp.dtype, | ||
| has_bias: bool, | ||
| ) -> bool: | ||
| """Determine whether the cuda-graphable grouped GEMM implementation can be used based on the input parameters.""" | ||
| # Use the cuda-graphable path for plain BF16 non-quantized inputs; fall back to the legacy | ||
| # nvte_multi_tensor_gemm path for all other cases (FP8, MXFP8, etc.) to stay | ||
| # feature-compatible with the main branch. | ||
| # Bias can be supported in a kernel or in pure-JAX in the future. | ||
| # Bias is applied in pure JAX after the GEMM in GroupedGemmPrimitive.impl. | ||
|
|
||
| if not _v2_grouped_gemm_available: | ||
| return False | ||
|
|
||
| return scaling_mode == ScalingMode.NO_SCALING and dtype == jnp.bfloat16 and not has_bias | ||
| return scaling_mode == ScalingMode.NO_SCALING and dtype == jnp.bfloat16 | ||
|
|
||
|
|
||
| def grouped_gemm( | ||
|
|
@@ -2205,7 +2215,7 @@ def grouped_gemm( | |
| " and padded with zeros to not affect the result of the MoE block." | ||
| ) | ||
|
|
||
| use_v2_ffi = _can_use_v2_grouped_gemm(scaling_mode, lhs_data.dtype, has_bias) | ||
| use_v2_ffi = _can_use_v2_grouped_gemm(scaling_mode, lhs_data.dtype) | ||
| if use_v2_ffi: | ||
| num_gemms = group_sizes.shape[0] | ||
| additional_arg_0 = jnp.ones((num_gemms,), jnp.float32) # alpha | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Misleading test name: "fp16" actually uses bfloat16
The method is named
test_grouped_gemm_fp16_with_bias, but the body hardcodesdtype = jnp.bfloat16and the docstring explicitly states "bfloat16 only, since v2 is only active for bfloat16 no-scaling inputs." This makes the name actively confusing for anyone trying to find tests for fp16 or bfloat16 coverage.