diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 613aefc178..bc9d52f428 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -1799,6 +1799,22 @@ def test_grouped_gemm_fp16(self, dtype, input_shape, layout): self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, dtype) + @pytest_parametrize_wrapper("layout", ["NN"]) + def test_grouped_gemm_fp16_with_bias(self, input_shape, layout): + """Directly exercises the v2 grouped GEMM path with bias (bfloat16 only, since v2 is + only active for bfloat16 no-scaling inputs).""" + dtype = jnp.bfloat16 + lhs, rhs, group_sizes, contracting_dims, bias = self._generate_grouped_dense_input( + dtype, input_shape, layout, with_bias=True + ) + ref_out = self._ref_grouped_dense(lhs, rhs, bias, group_sizes, contracting_dims) + + prim_out = jax.jit(tex.grouped_gemm, static_argnames=("contracting_dims",))( + lhs, rhs, group_sizes, contracting_dims, bias=bias + ) + + self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, dtype) + @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest.mark.parametrize("fwd_bwd_dtype", fwd_bwd_dtypes) @pytest_parametrize_wrapper("scaling_mode", non_fp4_supported_scaling_modes) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index ab2be7f799..daf53a8067 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -1685,6 +1685,17 @@ def impl( use_async_d2h_group_sizes=use_async_d2h_group_sizes, use_v2_ffi=use_v2_ffi, ) + if use_v2_ffi and has_bias: + # The C++ FFI for v2 grouped GEMM does not support bias, so we apply it here in + # pure JAX. Groups are contiguous, so we build a per-token expert index via + # jnp.repeat and gather the corresponding bias row for each token. + num_groups = group_sizes.shape[0] + segment_ids = jnp.repeat( + jnp.arange(num_groups, dtype=jnp.int32), + group_sizes, + total_repeat_length=M, + ) + out = out + bias[segment_ids].astype(out.dtype) return (out,) @@ -2008,18 +2019,17 @@ def grouped_gemm_copy_group_sizes( def _can_use_v2_grouped_gemm( scaling_mode: ScalingMode, dtype: jnp.dtype, - has_bias: bool, ) -> bool: """Determine whether the cuda-graphable grouped GEMM implementation can be used based on the input parameters.""" # Use the cuda-graphable path for plain BF16 non-quantized inputs; fall back to the legacy # nvte_multi_tensor_gemm path for all other cases (FP8, MXFP8, etc.) to stay # feature-compatible with the main branch. - # Bias can be supported in a kernel or in pure-JAX in the future. + # Bias is applied in pure JAX after the GEMM in GroupedGemmPrimitive.impl. if not _v2_grouped_gemm_available: return False - return scaling_mode == ScalingMode.NO_SCALING and dtype == jnp.bfloat16 and not has_bias + return scaling_mode == ScalingMode.NO_SCALING and dtype == jnp.bfloat16 def grouped_gemm( @@ -2205,7 +2215,7 @@ def grouped_gemm( " and padded with zeros to not affect the result of the MoE block." ) - use_v2_ffi = _can_use_v2_grouped_gemm(scaling_mode, lhs_data.dtype, has_bias) + use_v2_ffi = _can_use_v2_grouped_gemm(scaling_mode, lhs_data.dtype) if use_v2_ffi: num_gemms = group_sizes.shape[0] additional_arg_0 = jnp.ones((num_gemms,), jnp.float32) # alpha