[JAX] Add bias support for v2 grouped GEMM path#2744
[JAX] Add bias support for v2 grouped GEMM path#2744jberchtold-nvidia wants to merge 1 commit intoNVIDIA:mainfrom
Conversation
The cuda-graphable (v2) grouped GEMM FFI does not natively support bias. This change applies bias in pure JAX after the GEMM in GroupedGemmPrimitive.impl, using a per-token expert index built from group_sizes to gather the correct bias row for each token. A dedicated unit test (test_grouped_gemm_fp16_with_bias) is added to directly exercise the v2 path with a non-None bfloat16 bias. Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
|
/te-ci L0 jax |
Greptile SummaryThis PR adds bias support to the cuda-graphable v2 grouped GEMM path in JAX by applying bias in pure JAX after the GEMM kernel (since the v2 C++ FFI does not natively support it). The Key changes:
Issues found:
Confidence Score: 3/5
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["grouped_gemm()"] --> B{"_can_use_v2_grouped_gemm?\nBF16 + NO_SCALING"}
B -- "No (FP8, MXFP8, fp16...)" --> C["Legacy v1 FFI\nte_grouped_gemm_ffi\n(native bias support)"]
B -- "Yes" --> D["v2 FFI\nte_grouped_gemm_v2_ffi\n(no native bias)"]
D --> E["GroupedGemmPrimitive.impl\ninner_primitive.bind(...)"]
E --> F{"use_v2_ffi\nand has_bias?"}
F -- "No" --> G["Return GEMM output"]
F -- "Yes" --> H["Build segment_ids\njnp.repeat(arange(G), group_sizes, M)"]
H --> I["Gather bias rows\nbias[segment_ids] → shape (M, N)"]
I --> J["out = out + bias[segment_ids].astype(out.dtype)"]
J --> G
C --> G
Last reviewed commit: 1d5faa2 |
| self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, dtype) | ||
|
|
||
| @pytest_parametrize_wrapper("layout", ["NN"]) | ||
| def test_grouped_gemm_fp16_with_bias(self, input_shape, layout): |
There was a problem hiding this comment.
Misleading test name: "fp16" actually uses bfloat16
The method is named test_grouped_gemm_fp16_with_bias, but the body hardcodes dtype = jnp.bfloat16 and the docstring explicitly states "bfloat16 only, since v2 is only active for bfloat16 no-scaling inputs." This makes the name actively confusing for anyone trying to find tests for fp16 or bfloat16 coverage.
| def test_grouped_gemm_fp16_with_bias(self, input_shape, layout): | |
| def test_grouped_gemm_bf16_with_bias(self, input_shape, layout): |
| if use_v2_ffi and has_bias: | ||
| # The C++ FFI for v2 grouped GEMM does not support bias, so we apply it here in | ||
| # pure JAX. Groups are contiguous, so we build a per-token expert index via | ||
| # jnp.repeat and gather the corresponding bias row for each token. | ||
| num_groups = group_sizes.shape[0] | ||
| segment_ids = jnp.repeat( | ||
| jnp.arange(num_groups, dtype=jnp.int32), | ||
| group_sizes, | ||
| total_repeat_length=M, | ||
| ) | ||
| out = out + bias[segment_ids].astype(out.dtype) |
There was a problem hiding this comment.
Bias application is incorrect when is_grouped_dense_wgrad=True
When is_grouped_dense_wgrad=True (triggered when rhs is 2D — len(rhs_shape) == 2), the GEMM output shape is (num_groups, M, N) (see abstract: out_shape = (num_groups, M, N)). However, the bias addition code computes bias[segment_ids] with segment_ids.shape == (M,), yielding a (M, N) tensor. Broadcasting (M, N) onto (num_groups, M, N) does not raise an error at runtime, but it adds the same token-mapped bias to every group's output matrix rather than leaving the output unchanged for each group's individual result — semantically incorrect.
While this combination (is_grouped_dense_wgrad=True and has_bias=True and use_v2_ffi=True) is unlikely in the existing call sites (weight-gradient paths don't normally carry a bias), the code as written has no guard against it, and a future caller could hit this silently wrong behavior. Consider adding an assertion:
if use_v2_ffi and has_bias:
assert not is_grouped_dense_wgrad, (
"Bias is not supported for the grouped dense wgrad path with v2 FFI."
)
# ... rest of bias application
Description
The cuda-graphable (v2) grouped GEMM FFI does not natively support bias. This change applies bias in pure JAX after the GEMM in GroupedGemmPrimitive.impl, using a per-token expert index built from group_sizes to gather the correct bias row for each token.
A dedicated unit test (test_grouped_gemm_fp16_with_bias) is added to directly exercise the v2 path with a non-None bfloat16 bias.
Type of change
Changes
Checklist: