Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions tests/jax/test_custom_call_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -1799,6 +1799,22 @@ def test_grouped_gemm_fp16(self, dtype, input_shape, layout):

self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, dtype)

@pytest_parametrize_wrapper("layout", ["NN"])
def test_grouped_gemm_fp16_with_bias(self, input_shape, layout):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Misleading test name: "fp16" actually uses bfloat16

The method is named test_grouped_gemm_fp16_with_bias, but the body hardcodes dtype = jnp.bfloat16 and the docstring explicitly states "bfloat16 only, since v2 is only active for bfloat16 no-scaling inputs." This makes the name actively confusing for anyone trying to find tests for fp16 or bfloat16 coverage.

Suggested change
def test_grouped_gemm_fp16_with_bias(self, input_shape, layout):
def test_grouped_gemm_bf16_with_bias(self, input_shape, layout):

"""Directly exercises the v2 grouped GEMM path with bias (bfloat16 only, since v2 is
only active for bfloat16 no-scaling inputs)."""
dtype = jnp.bfloat16
lhs, rhs, group_sizes, contracting_dims, bias = self._generate_grouped_dense_input(
dtype, input_shape, layout, with_bias=True
)
ref_out = self._ref_grouped_dense(lhs, rhs, bias, group_sizes, contracting_dims)

prim_out = jax.jit(tex.grouped_gemm, static_argnames=("contracting_dims",))(
lhs, rhs, group_sizes, contracting_dims, bias=bias
)

self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, dtype)

@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest.mark.parametrize("fwd_bwd_dtype", fwd_bwd_dtypes)
@pytest_parametrize_wrapper("scaling_mode", non_fp4_supported_scaling_modes)
Expand Down
18 changes: 14 additions & 4 deletions transformer_engine/jax/cpp_extensions/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1685,6 +1685,17 @@ def impl(
use_async_d2h_group_sizes=use_async_d2h_group_sizes,
use_v2_ffi=use_v2_ffi,
)
if use_v2_ffi and has_bias:
# The C++ FFI for v2 grouped GEMM does not support bias, so we apply it here in
# pure JAX. Groups are contiguous, so we build a per-token expert index via
# jnp.repeat and gather the corresponding bias row for each token.
num_groups = group_sizes.shape[0]
segment_ids = jnp.repeat(
jnp.arange(num_groups, dtype=jnp.int32),
group_sizes,
total_repeat_length=M,
)
out = out + bias[segment_ids].astype(out.dtype)
Comment on lines +1688 to +1698
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bias application is incorrect when is_grouped_dense_wgrad=True

When is_grouped_dense_wgrad=True (triggered when rhs is 2D — len(rhs_shape) == 2), the GEMM output shape is (num_groups, M, N) (see abstract: out_shape = (num_groups, M, N)). However, the bias addition code computes bias[segment_ids] with segment_ids.shape == (M,), yielding a (M, N) tensor. Broadcasting (M, N) onto (num_groups, M, N) does not raise an error at runtime, but it adds the same token-mapped bias to every group's output matrix rather than leaving the output unchanged for each group's individual result — semantically incorrect.

While this combination (is_grouped_dense_wgrad=True and has_bias=True and use_v2_ffi=True) is unlikely in the existing call sites (weight-gradient paths don't normally carry a bias), the code as written has no guard against it, and a future caller could hit this silently wrong behavior. Consider adding an assertion:

if use_v2_ffi and has_bias:
    assert not is_grouped_dense_wgrad, (
        "Bias is not supported for the grouped dense wgrad path with v2 FFI."
    )
    # ... rest of bias application

return (out,)


Expand Down Expand Up @@ -2008,18 +2019,17 @@ def grouped_gemm_copy_group_sizes(
def _can_use_v2_grouped_gemm(
scaling_mode: ScalingMode,
dtype: jnp.dtype,
has_bias: bool,
) -> bool:
"""Determine whether the cuda-graphable grouped GEMM implementation can be used based on the input parameters."""
# Use the cuda-graphable path for plain BF16 non-quantized inputs; fall back to the legacy
# nvte_multi_tensor_gemm path for all other cases (FP8, MXFP8, etc.) to stay
# feature-compatible with the main branch.
# Bias can be supported in a kernel or in pure-JAX in the future.
# Bias is applied in pure JAX after the GEMM in GroupedGemmPrimitive.impl.

if not _v2_grouped_gemm_available:
return False

return scaling_mode == ScalingMode.NO_SCALING and dtype == jnp.bfloat16 and not has_bias
return scaling_mode == ScalingMode.NO_SCALING and dtype == jnp.bfloat16


def grouped_gemm(
Expand Down Expand Up @@ -2205,7 +2215,7 @@ def grouped_gemm(
" and padded with zeros to not affect the result of the MoE block."
)

use_v2_ffi = _can_use_v2_grouped_gemm(scaling_mode, lhs_data.dtype, has_bias)
use_v2_ffi = _can_use_v2_grouped_gemm(scaling_mode, lhs_data.dtype)
if use_v2_ffi:
num_gemms = group_sizes.shape[0]
additional_arg_0 = jnp.ones((num_gemms,), jnp.float32) # alpha
Expand Down
Loading