From e40b710bfac1f6a63c116dcc66c3dd71556a7006 Mon Sep 17 00:00:00 2001 From: tdophung Date: Wed, 25 Feb 2026 15:17:33 -0800 Subject: [PATCH 1/9] initial draft Signed-off-by: tdophung --- tests/jax/test_distributed_router.py | 497 ++++++++++ tests/jax/test_fused_router.py | 521 ++++++++++ .../transformer_engine/transformer_engine.h | 8 +- .../jax/cpp_extensions/__init__.py | 1 + .../jax/cpp_extensions/router.py | 891 ++++++++++++++++++ transformer_engine/jax/csrc/extensions.h | 8 + .../jax/csrc/extensions/pybind.cpp | 14 + .../jax/csrc/extensions/router.cpp | 337 +++++++ transformer_engine/jax/router.py | 358 +++++++ 9 files changed, 2633 insertions(+), 2 deletions(-) create mode 100644 tests/jax/test_distributed_router.py create mode 100644 tests/jax/test_fused_router.py create mode 100644 transformer_engine/jax/cpp_extensions/router.py create mode 100644 transformer_engine/jax/csrc/extensions/router.cpp create mode 100644 transformer_engine/jax/router.py diff --git a/tests/jax/test_distributed_router.py b/tests/jax/test_distributed_router.py new file mode 100644 index 0000000000..1d1d060711 --- /dev/null +++ b/tests/jax/test_distributed_router.py @@ -0,0 +1,497 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Tests for distributed/sharded execution of fused MoE router primitives. + +Testing Strategy: +================= +Router operations process each token independently (1 warp per token), so +sharded execution on the token dimension should produce identical results +to processing each shard independently with the reference implementation. + +For fused_topk_with_score_function and fused_compute_score_for_moe_aux_loss: +- Input logits [num_tokens, num_experts] are sharded on num_tokens (DP axis) +- Expert dimension is replicated +- Each GPU processes its local tokens independently +- We verify sharded output matches per-shard reference, concatenated + +For fused_moe_aux_loss: +- This is a global reduction to a scalar +- All inputs and outputs are replicated (partition function forces this) +- We verify the op works correctly under a mesh context + +These tests exercise: partition, infer_sharding_from_operands, batcher, +and shardy_sharding_rule from the router primitives. +""" + +import pytest + +import jax +import jax.numpy as jnp +import numpy as np +from jax.sharding import Mesh, NamedSharding, PartitionSpec + +from distributed_test_base import generate_configs +from utils import assert_allclose, pytest_parametrize_wrapper + +from transformer_engine.jax.router import ( + fused_topk_with_score_function, + fused_compute_score_for_moe_aux_loss, + fused_moe_aux_loss, +) + +from test_fused_router import ( + reference_topk_softmax_sigmoid, + reference_compute_scores_for_aux_loss, + reference_aux_loss, + make_logits, +) + +# (num_tokens, num_experts, topk) +ALL_TOPK_CASES = [ + (128, 32, 4), + (2048, 128, 8), +] +TOPK_CASES = { + "L0": ALL_TOPK_CASES[0:1], + "L2": ALL_TOPK_CASES, +} + +ALL_AUX_LOSS_CASES = [ + (128, 32, 4), + (2048, 128, 4), +] +AUX_LOSS_CASES = { + "L0": ALL_AUX_LOSS_CASES[0:1], + "L2": ALL_AUX_LOSS_CASES, +} + + +class TestDistributedFusedTopk: + """Test distributed execution of fused_topk_with_score_function. + + Shards logits on the token dimension. Each GPU independently runs the + fused kernel on its local tokens. We compare against the reference + implementation run per-shard and concatenated. + """ + + def _impl_test( + self, + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + num_tokens, + num_experts, + topk, + score_function, + use_shardy, + ): + jax.config.update("jax_use_shardy_partitioner", use_shardy) + + logits = make_logits(num_tokens, num_experts, score_function) + + devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) + mesh = Mesh(devices, mesh_axes) + + dp_axis = mesh_resource.dp_resource + sharded_pspec = PartitionSpec(dp_axis, None) + num_dp_devices = mesh.shape[dp_axis] if dp_axis else 1 + local_num_tokens = num_tokens // num_dp_devices + + with mesh: + logits_sharding = NamedSharding(mesh, sharded_pspec) + logits_sharded = jax.device_put(logits, logits_sharding) + + # === Forward === + @jax.jit + def target_fwd(x): + return fused_topk_with_score_function( + x, topk=topk, score_function=score_function, + ) + + target_probs, target_routing_map = target_fwd(logits_sharded) + + logits_shards = jnp.reshape( + logits, (num_dp_devices, local_num_tokens, num_experts) + ) + ref_fwd_fn = jax.jit(lambda x: reference_topk_softmax_sigmoid( + x, topk=topk, score_function=score_function, + )) + ref_probs_list = [] + ref_routing_list = [] + for i in range(num_dp_devices): + p, rm = ref_fwd_fn(logits_shards[i]) + ref_probs_list.append(p) + ref_routing_list.append(rm) + + ref_probs = jnp.concatenate(ref_probs_list, axis=0) + ref_routing = jnp.concatenate(ref_routing_list, axis=0) + + assert_allclose( + jax.device_get(target_probs), ref_probs, dtype=jnp.float32, + ) + assert jnp.array_equal( + jax.device_get(target_routing_map), ref_routing, + ), "Routing map mismatch in distributed fused_topk" + + # === Backward === + def target_loss(x): + p, _ = fused_topk_with_score_function( + x, topk=topk, score_function=score_function, + ) + return jnp.sum(p) + + def ref_chunk_loss(x_chunk): + p, _ = reference_topk_softmax_sigmoid( + x_chunk, topk=topk, score_function=score_function, + ) + return jnp.sum(p) + + target_grad = jax.jit(jax.grad(target_loss))(logits_sharded) + + ref_grads = [] + ref_chunk_grad_fn = jax.jit(jax.grad(ref_chunk_loss)) + for i in range(num_dp_devices): + ref_grads.append(ref_chunk_grad_fn(logits_shards[i])) + ref_grad = jnp.concatenate(ref_grads, axis=0) + + assert_allclose( + jax.device_get(target_grad), ref_grad, dtype=jnp.float32, + ) + + @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) + @pytest_parametrize_wrapper( + "num_tokens,num_experts,topk", TOPK_CASES, + ) + @pytest.mark.parametrize("score_function", ["softmax", "sigmoid"]) + @pytest.mark.parametrize("use_shardy", [True]) + def test_distributed_topk( + self, + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + num_tokens, + num_experts, + topk, + score_function, + use_shardy, + ): + self._impl_test( + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + num_tokens, + num_experts, + topk, + score_function, + use_shardy, + ) + + @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) + @pytest.mark.parametrize("score_function", ["softmax", "sigmoid"]) + def test_distributed_topk_gspmd( + self, + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + score_function, + ): + self._impl_test( + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + num_tokens=128, + num_experts=32, + topk=4, + score_function=score_function, + use_shardy=False, + ) + + +class TestDistributedScoreForAuxLoss: + """Test distributed execution of fused_compute_score_for_moe_aux_loss. + + Same sharding strategy as fused_topk: shard on token dim, replicate experts. + Each GPU independently computes scores and routing map for its local tokens. + """ + + def _impl_test( + self, + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + num_tokens, + num_experts, + topk, + score_function, + use_shardy, + ): + jax.config.update("jax_use_shardy_partitioner", use_shardy) + + logits = make_logits(num_tokens, num_experts, score_function) + + devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) + mesh = Mesh(devices, mesh_axes) + + dp_axis = mesh_resource.dp_resource + sharded_pspec = PartitionSpec(dp_axis, None) + num_dp_devices = mesh.shape[dp_axis] if dp_axis else 1 + local_num_tokens = num_tokens // num_dp_devices + + with mesh: + logits_sharding = NamedSharding(mesh, sharded_pspec) + logits_sharded = jax.device_put(logits, logits_sharding) + + # === Forward === + @jax.jit + def target_fwd(x): + return fused_compute_score_for_moe_aux_loss( + x, topk=topk, score_function=score_function, + ) + + target_routing_map, target_scores = target_fwd(logits_sharded) + + logits_shards = jnp.reshape( + logits, (num_dp_devices, local_num_tokens, num_experts) + ) + ref_fwd_fn = jax.jit(lambda x: reference_compute_scores_for_aux_loss( + x, topk=topk, score_function=score_function, + )) + ref_routing_list = [] + ref_scores_list = [] + for i in range(num_dp_devices): + rm, s = ref_fwd_fn(logits_shards[i]) + ref_routing_list.append(rm) + ref_scores_list.append(s) + + ref_routing = jnp.concatenate(ref_routing_list, axis=0) + ref_scores = jnp.concatenate(ref_scores_list, axis=0) + + assert_allclose( + jax.device_get(target_scores), ref_scores, dtype=jnp.float32, + ) + assert jnp.array_equal( + jax.device_get(target_routing_map), ref_routing, + ), "Routing map mismatch in distributed score_for_aux_loss" + + # === Backward === + def target_loss(x): + _, s = fused_compute_score_for_moe_aux_loss( + x, topk=topk, score_function=score_function, + ) + return jnp.sum(s) + + def ref_chunk_loss(x_chunk): + _, s = reference_compute_scores_for_aux_loss( + x_chunk, topk=topk, score_function=score_function, + ) + return jnp.sum(s) + + target_grad = jax.jit(jax.grad(target_loss))(logits_sharded) + + ref_grads = [] + ref_chunk_grad_fn = jax.jit(jax.grad(ref_chunk_loss)) + for i in range(num_dp_devices): + ref_grads.append(ref_chunk_grad_fn(logits_shards[i])) + ref_grad = jnp.concatenate(ref_grads, axis=0) + + assert_allclose( + jax.device_get(target_grad), ref_grad, dtype=jnp.float32, + ) + + @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) + @pytest_parametrize_wrapper( + "num_tokens,num_experts,topk", TOPK_CASES, + ) + @pytest.mark.parametrize("score_function", ["softmax", "sigmoid"]) + @pytest.mark.parametrize("use_shardy", [True]) + def test_distributed_score_for_aux_loss( + self, + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + num_tokens, + num_experts, + topk, + score_function, + use_shardy, + ): + self._impl_test( + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + num_tokens, + num_experts, + topk, + score_function, + use_shardy, + ) + + @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) + @pytest.mark.parametrize("score_function", ["softmax", "sigmoid"]) + def test_distributed_score_for_aux_loss_gspmd( + self, + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + score_function, + ): + self._impl_test( + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + num_tokens=128, + num_experts=32, + topk=4, + score_function=score_function, + use_shardy=False, + ) + + +class TestDistributedMoEAuxLoss: + """Test distributed execution of fused_moe_aux_loss. + + Aux loss is a global reduction to a scalar. The partition function forces + all inputs to be replicated. We verify the op produces correct results + under a mesh context with replicated sharding, testing both forward + (scalar loss) and backward (gradient w.r.t. probs). + """ + + def _impl_test( + self, + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + num_tokens, + num_experts, + topk, + use_shardy, + ): + jax.config.update("jax_use_shardy_partitioner", use_shardy) + + key = jax.random.PRNGKey(42) + key, subkey1, subkey2 = jax.random.split(key, 3) + + offset = jnp.arange(-num_tokens // 2, num_tokens // 2, dtype=jnp.float32) * 1e-4 + probs = jnp.arange(-num_experts // 2, num_experts // 2, dtype=jnp.float32) * 1e-2 + probs = probs[None, :].repeat(num_tokens, axis=0) + offset[:, None] + + tokens_per_expert = jax.random.randint( + subkey1, (num_experts,), 1, 1000 + ).astype(jnp.int32) + coeff = 0.01 + + devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) + mesh = Mesh(devices, mesh_axes) + + replicated_2d_pspec = PartitionSpec(None, None) + replicated_1d_pspec = PartitionSpec(None) + + with mesh: + probs_sharding = NamedSharding(mesh, replicated_2d_pspec) + tpe_sharding = NamedSharding(mesh, replicated_1d_pspec) + + probs_dev = jax.device_put(probs, probs_sharding) + tpe_dev = jax.device_put(tokens_per_expert, tpe_sharding) + + # === Forward === + @jax.jit + def target_fwd(p, tpe): + return fused_moe_aux_loss( + p, tpe, + total_num_tokens=num_tokens, + num_experts=num_experts, + topk=topk, + coeff=coeff, + ) + + target_loss = target_fwd(probs_dev, tpe_dev) + + ref_fwd_fn = jax.jit(lambda p: reference_aux_loss( + p, tokens_per_expert, num_tokens, topk, num_experts, coeff, + )) + ref_loss = ref_fwd_fn(probs) + + assert_allclose( + jax.device_get(target_loss), ref_loss, dtype=jnp.float32, + ) + + # === Backward === + def target_loss_fn(p): + return fused_moe_aux_loss( + p, tokens_per_expert, + total_num_tokens=num_tokens, + num_experts=num_experts, + topk=topk, + coeff=coeff, + ) + + def ref_loss_fn(p): + return reference_aux_loss( + p, tokens_per_expert, num_tokens, topk, num_experts, coeff, + ) + + target_grad = jax.jit(jax.grad(target_loss_fn))(probs_dev) + ref_grad = jax.jit(jax.grad(ref_loss_fn))(probs) + + assert_allclose( + jax.device_get(target_grad), ref_grad, dtype=jnp.float32, + ) + + @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) + @pytest_parametrize_wrapper( + "num_tokens,num_experts,topk", AUX_LOSS_CASES, + ) + @pytest.mark.parametrize("use_shardy", [True]) + def test_distributed_aux_loss( + self, + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + num_tokens, + num_experts, + topk, + use_shardy, + ): + self._impl_test( + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + num_tokens, + num_experts, + topk, + use_shardy, + ) + + @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) + def test_distributed_aux_loss_gspmd( + self, + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + ): + self._impl_test( + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + num_tokens=128, + num_experts=32, + topk=4, + use_shardy=False, + ) diff --git a/tests/jax/test_fused_router.py b/tests/jax/test_fused_router.py new file mode 100644 index 0000000000..2a4aa4672f --- /dev/null +++ b/tests/jax/test_fused_router.py @@ -0,0 +1,521 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Tests for fused MoE router CUDA kernels (JAX wrappers).""" + +from functools import partial +from typing import Optional + +import jax +import jax.numpy as jnp +import pytest + +from utils import pytest_parametrize_wrapper + +from transformer_engine.jax.router import ( + fused_topk_with_score_function, + fused_compute_score_for_moe_aux_loss, + fused_moe_aux_loss, +) + +# ============================================================================= +# Test case definitions (L0 = fast smoke, L2 = comprehensive) +# ============================================================================= + +# (num_tokens, num_experts, topk) +ALL_TOPK_CASES = [ + (128, 32, 4), + (2048, 32, 4), + (2048, 128, 8), + (7168, 128, 4), + (7168, 32, 8), +] +TOPK_CASES = { + "L0": ALL_TOPK_CASES[0:2], + "L2": ALL_TOPK_CASES, +} + +ALL_GROUP_TOPK_OPTIONS = [None, 4] +GROUP_TOPK_OPTIONS = { + "L0": [None], + "L2": ALL_GROUP_TOPK_OPTIONS, +} + +ALL_SCALING_FACTOR_OPTIONS = [None, 1.2] +SCALING_FACTOR_OPTIONS = { + "L0": [None], + "L2": ALL_SCALING_FACTOR_OPTIONS, +} + +ALL_ENABLE_BIAS_OPTIONS = [True, False] +ENABLE_BIAS_OPTIONS = { + "L0": [False], + "L2": ALL_ENABLE_BIAS_OPTIONS, +} + +ALL_USE_PRE_SOFTMAX_OPTIONS = [True, False] +USE_PRE_SOFTMAX_OPTIONS = { + "L0": [False], + "L2": ALL_USE_PRE_SOFTMAX_OPTIONS, +} + +# (num_tokens, num_experts, topk) +ALL_SCORE_AUX_LOSS_CASES = [ + (128, 32, 4), + (2048, 128, 4), + (2048, 256, 8), + (7168, 128, 8), + (7168, 32, 4), +] +SCORE_AUX_LOSS_CASES = { + "L0": ALL_SCORE_AUX_LOSS_CASES[0:2], + "L2": ALL_SCORE_AUX_LOSS_CASES, +} + +ALL_SCORE_FUNCTIONS = ["softmax", "sigmoid"] +SCORE_FUNCTIONS = { + "L0": ["softmax"], + "L2": ALL_SCORE_FUNCTIONS, +} + +# (num_tokens, num_experts, topk) +ALL_AUX_LOSS_CASES = [ + (128, 32, 4), + (2048, 128, 4), + (2048, 256, 4), + (7168, 128, 4), + (7168, 32, 4), +] +AUX_LOSS_CASES = { + "L0": ALL_AUX_LOSS_CASES[0:2], + "L2": ALL_AUX_LOSS_CASES, +} + +ALL_DTYPES = [jnp.float32] +DTYPES = { + "L0": [jnp.float32], + "L2": ALL_DTYPES, +} + +seed = 42 +key = jax.random.PRNGKey(seed) + + +# ============================================================================= +# Reference Implementations +# ============================================================================= + + +def reference_group_limited_topk( + scores: jnp.ndarray, + topk: int, + num_tokens: int, + num_experts: int, + num_groups: int, + group_topk: int, +): + """Reference implementation for grouped top-k. + + Only valid when num_groups and group_topk are both positive integers. + For plain top-k without grouping, use jax.lax.top_k directly. + """ + assert num_groups is not None and num_groups > 0, ( + "reference_group_limited_topk requires valid num_groups > 0. " + "For plain top-k, use jax.lax.top_k directly." + ) + assert group_topk is not None and group_topk > 0, ( + "reference_group_limited_topk requires valid group_topk > 0." + ) + assert num_experts % num_groups == 0, ( + f"num_experts ({num_experts}) must be divisible by num_groups ({num_groups})" + ) + group_size = num_experts // num_groups + experts_per_group = topk // group_topk + + group_scores = ( + scores.reshape(num_tokens, num_groups, group_size) + .sort(axis=-1)[..., -experts_per_group:] + .sum(axis=-1) + ) + group_idx = jax.lax.top_k(group_scores, k=group_topk)[1] + group_mask = jnp.zeros_like(group_scores).at[ + jnp.arange(num_tokens)[:, None], group_idx + ].set(1) + + score_mask = ( + group_mask[:, :, None] + * jnp.ones((num_tokens, num_groups, group_size)) + ).reshape(num_tokens, -1) + + masked_scores = jnp.where(score_mask.astype(bool), scores, -jnp.inf) + probs, top_indices = jax.lax.top_k(masked_scores, k=topk) + return probs, top_indices + + +def reference_topk_softmax_sigmoid( + logits: jnp.ndarray, + topk: int, + use_pre_softmax: bool = False, + num_groups: Optional[int] = None, + group_topk: Optional[int] = None, + scaling_factor: Optional[float] = None, + score_function: str = "softmax", + expert_bias: Optional[jnp.ndarray] = None, +): + """Reference implementation for topk + softmax/sigmoid.""" + num_tokens, num_experts = logits.shape + + def compute_topk(scores, topk, num_groups=None, group_topk=None): + if group_topk: + return reference_group_limited_topk( + scores=scores, + topk=topk, + num_tokens=num_tokens, + num_experts=num_experts, + num_groups=num_groups, + group_topk=group_topk, + ) + else: + return jax.lax.top_k(scores, k=topk) + + if score_function == "softmax": + if use_pre_softmax: + scores = jax.nn.softmax(logits.astype(jnp.float32), axis=-1).astype(logits.dtype) + probs, top_indices = compute_topk(scores, topk, num_groups, group_topk) + else: + scores, top_indices = compute_topk(logits, topk, num_groups, group_topk) + probs = jax.nn.softmax(scores.astype(jnp.float32), axis=-1).astype(logits.dtype) + elif score_function == "sigmoid": + scores = jax.nn.sigmoid(logits.astype(jnp.float32)).astype(logits.dtype) + if expert_bias is not None: + scores_for_routing = scores + expert_bias + _, top_indices = compute_topk(scores_for_routing, topk, num_groups, group_topk) + scores = jnp.take_along_axis(scores, top_indices, axis=1).astype(logits.dtype) + else: + scores, top_indices = compute_topk(scores, topk, num_groups, group_topk) + probs = scores / (scores.sum(axis=-1, keepdims=True) + 1e-20) if topk > 1 else scores + else: + raise ValueError(f"Invalid score_function: {score_function}") + + if scaling_factor: + probs = probs * scaling_factor + + topk_masked_gates = jnp.zeros_like(logits).at[ + jnp.arange(num_tokens)[:, None], top_indices + ].set(probs) + topk_map = jnp.zeros_like(logits, dtype=jnp.bool_).at[ + jnp.arange(num_tokens)[:, None], top_indices + ].set(True) + + return topk_masked_gates, topk_map + + +def reference_compute_scores_for_aux_loss( + logits: jnp.ndarray, topk: int, score_function: str +): + """Reference implementation for computing routing scores for aux loss.""" + if score_function == "softmax": + scores = jax.nn.softmax(logits.astype(jnp.float32), axis=-1) + elif score_function == "sigmoid": + scores = jax.nn.sigmoid(logits.astype(jnp.float32)) + scores = scores / (scores.sum(axis=-1, keepdims=True) + 1e-20) if topk > 1 else scores + else: + raise ValueError(f"Invalid score_function: {score_function}") + + _, top_indices = jax.lax.top_k(scores, k=topk) + num_tokens = logits.shape[0] + routing_map = jnp.zeros_like(logits, dtype=jnp.bool_).at[ + jnp.arange(num_tokens)[:, None], top_indices + ].set(True) + return routing_map, scores + + +def reference_aux_loss( + probs: jnp.ndarray, + tokens_per_expert: jnp.ndarray, + total_num_tokens: int, + topk: int, + num_experts: int, + moe_aux_loss_coeff: float, +): + """Reference implementation for MoE auxiliary loss.""" + aggregated_probs_per_expert = probs.sum(axis=0) + aux_loss = jnp.sum(aggregated_probs_per_expert * tokens_per_expert) * ( + num_experts * moe_aux_loss_coeff / (topk * total_num_tokens * total_num_tokens) + ) + return aux_loss + + +# ============================================================================= +# Helper: logits generation +# ============================================================================= + + +def make_logits(num_tokens, num_experts, score_function, dtype=jnp.float32): + """Create deterministic logits for testing.""" + if score_function == "sigmoid": + offset = jnp.arange(-num_tokens // 2, num_tokens // 2, dtype=dtype) * 1e-4 + logits = jnp.arange(-num_experts // 2, num_experts // 2, dtype=dtype) * 1e-2 + logits = logits[None, :].repeat(num_tokens, axis=0) + offset[:, None] + else: + logits = ( + jnp.arange( + -num_tokens * num_experts // 2, + num_tokens * num_experts // 2, + dtype=dtype, + ) + * 1e-4 + ) + logits = logits.reshape(num_tokens, num_experts) + return logits + + +# ============================================================================= +# Test: Fused Top-K with Score Function +# ============================================================================= + + +def run_topk_comparison( + dtype, + num_tokens, + num_experts, + topk, + use_pre_softmax, + num_groups, + group_topk, + scaling_factor, + score_function, + enable_bias, +): + """Compare fused vs reference top-k implementation, both jitted.""" + logits = make_logits(num_tokens, num_experts, score_function, dtype) + + if enable_bias and score_function == "sigmoid": + expert_bias = jnp.arange(num_experts, dtype=jnp.float32) * 0.1 + expert_bias = jnp.flip(expert_bias) + else: + expert_bias = None + + # Forward: reference (jitted) + ref_fwd_fn = jax.jit(partial( + reference_topk_softmax_sigmoid, + topk=topk, + use_pre_softmax=use_pre_softmax, + num_groups=num_groups, + group_topk=group_topk, + scaling_factor=scaling_factor, + score_function=score_function, + expert_bias=expert_bias, + )) + probs_ref, routing_map_ref = ref_fwd_fn(logits) + + # Forward: fused (jitted) + fused_fwd_fn = jax.jit(partial( + fused_topk_with_score_function, + topk=topk, + use_pre_softmax=use_pre_softmax, + num_groups=num_groups if num_groups else -1, + group_topk=group_topk if group_topk else -1, + scaling_factor=scaling_factor if scaling_factor else 1.0, + score_function=score_function, + expert_bias=expert_bias, + )) + probs_fused, routing_map_fused = fused_fwd_fn(logits) + + assert jnp.allclose(probs_ref, probs_fused, atol=1e-5, rtol=1e-5), \ + f"Probs mismatch: max diff = {jnp.abs(probs_ref - probs_fused).max()}" + assert jnp.array_equal(routing_map_ref, routing_map_fused), "Routing map mismatch" + + # Backward: reference (jitted) + def loss_ref(logits_): + p, _ = reference_topk_softmax_sigmoid( + logits_, topk, use_pre_softmax, num_groups, group_topk, + scaling_factor, score_function, expert_bias, + ) + return p.sum() + + def loss_fused(logits_): + p, _ = fused_topk_with_score_function( + logits_, topk, use_pre_softmax, + num_groups if num_groups else -1, + group_topk if group_topk else -1, + scaling_factor if scaling_factor else 1.0, + score_function, expert_bias, + ) + return p.sum() + + grad_ref = jax.jit(jax.grad(loss_ref))(logits) + grad_fused = jax.jit(jax.grad(loss_fused))(logits) + assert jnp.allclose(grad_ref, grad_fused, atol=1e-5, rtol=1e-5), \ + f"Grad mismatch: max diff = {jnp.abs(grad_ref - grad_fused).max()}" + + +@pytest_parametrize_wrapper("dtype", DTYPES) +@pytest_parametrize_wrapper( + "num_tokens,num_experts,topk", TOPK_CASES, +) +@pytest_parametrize_wrapper("group_topk", GROUP_TOPK_OPTIONS) +@pytest_parametrize_wrapper("scaling_factor", SCALING_FACTOR_OPTIONS) +@pytest_parametrize_wrapper("enable_bias", ENABLE_BIAS_OPTIONS) +def test_topk_sigmoid( + dtype, num_tokens, num_experts, topk, group_topk, scaling_factor, enable_bias +): + num_groups = 8 if group_topk else None + run_topk_comparison( + dtype=dtype, + num_tokens=num_tokens, + num_experts=num_experts, + topk=topk, + use_pre_softmax=False, + num_groups=num_groups, + group_topk=group_topk, + scaling_factor=scaling_factor, + score_function="sigmoid", + enable_bias=enable_bias, + ) + + +@pytest_parametrize_wrapper("dtype", DTYPES) +@pytest_parametrize_wrapper( + "num_tokens,num_experts,topk", TOPK_CASES, +) +@pytest_parametrize_wrapper("use_pre_softmax", USE_PRE_SOFTMAX_OPTIONS) +@pytest_parametrize_wrapper("group_topk", GROUP_TOPK_OPTIONS) +@pytest_parametrize_wrapper("scaling_factor", SCALING_FACTOR_OPTIONS) +def test_topk_softmax( + dtype, num_tokens, num_experts, topk, use_pre_softmax, group_topk, scaling_factor +): + num_groups = 8 if group_topk else None + run_topk_comparison( + dtype=dtype, + num_tokens=num_tokens, + num_experts=num_experts, + topk=topk, + use_pre_softmax=use_pre_softmax, + num_groups=num_groups, + group_topk=group_topk, + scaling_factor=scaling_factor, + score_function="softmax", + enable_bias=False, + ) + + +# ============================================================================= +# Test: Fused Score for MoE Aux Loss +# ============================================================================= + + +@pytest_parametrize_wrapper("dtype", DTYPES) +@pytest_parametrize_wrapper( + "num_tokens,num_experts,topk", SCORE_AUX_LOSS_CASES, +) +@pytest_parametrize_wrapper("score_function", SCORE_FUNCTIONS) +def test_fused_scores_for_aux_loss(dtype, num_tokens, num_experts, topk, score_function): + logits = make_logits(num_tokens, num_experts, score_function, dtype) + + # Forward: reference (jitted) + ref_fwd_fn = jax.jit(partial( + reference_compute_scores_for_aux_loss, + topk=topk, + score_function=score_function, + )) + routing_map_ref, scores_ref = ref_fwd_fn(logits) + + # Forward: fused (jitted) + fused_fwd_fn = jax.jit(partial( + fused_compute_score_for_moe_aux_loss, + topk=topk, + score_function=score_function, + )) + routing_map_fused, scores_fused = fused_fwd_fn(logits) + + assert jnp.allclose(scores_ref, scores_fused, atol=1e-5, rtol=1e-5), \ + f"Scores mismatch: max diff = {jnp.abs(scores_ref - scores_fused).max()}" + assert jnp.array_equal(routing_map_ref, routing_map_fused), "Routing map mismatch" + + # Backward (jitted) + def loss_ref(logits_): + _, s = reference_compute_scores_for_aux_loss(logits_, topk, score_function) + return s.sum() + + def loss_fused(logits_): + _, s = fused_compute_score_for_moe_aux_loss(logits_, topk, score_function) + return s.sum() + + grad_ref = jax.jit(jax.grad(loss_ref))(logits) + grad_fused = jax.jit(jax.grad(loss_fused))(logits) + assert jnp.allclose(grad_ref, grad_fused, atol=1e-5, rtol=1e-5), \ + f"Grad mismatch: max diff = {jnp.abs(grad_ref - grad_fused).max()}" + + +# ============================================================================= +# Test: Fused MoE Aux Loss +# ============================================================================= + + +@pytest_parametrize_wrapper("dtype", DTYPES) +@pytest_parametrize_wrapper( + "num_tokens,num_experts,topk", AUX_LOSS_CASES, +) +def test_fused_moe_aux_loss(dtype, num_tokens, num_experts, topk): + global key + key, subkey1 = jax.random.split(key) + + offset = jnp.arange(-num_tokens // 2, num_tokens // 2, dtype=dtype) * 1e-4 + probs = jnp.arange(-num_experts // 2, num_experts // 2, dtype=dtype) * 1e-2 + probs = probs[None, :].repeat(num_tokens, axis=0) + offset[:, None] + probs = probs.reshape(num_tokens, num_experts) + + tokens_per_expert = jax.random.randint(subkey1, (num_experts,), 1, 1000).astype(jnp.int32) + coeff = 0.01 + + # Forward: reference (jitted) + ref_fwd_fn = jax.jit(partial( + reference_aux_loss, + tokens_per_expert=tokens_per_expert, + total_num_tokens=num_tokens, + topk=topk, + num_experts=num_experts, + moe_aux_loss_coeff=coeff, + )) + aux_loss_ref = ref_fwd_fn(probs) + + # Forward: fused (jitted) + fused_fwd_fn = jax.jit(partial( + fused_moe_aux_loss, + tokens_per_expert=tokens_per_expert, + total_num_tokens=num_tokens, + num_experts=num_experts, + topk=topk, + coeff=coeff, + )) + aux_loss_fused = fused_fwd_fn(probs) + + assert jnp.allclose(aux_loss_ref, aux_loss_fused, atol=1e-5, rtol=1e-5), \ + f"Aux loss mismatch: ref={aux_loss_ref}, fused={aux_loss_fused}" + + # Backward (jitted) + def loss_ref_fn(probs_): + return reference_aux_loss(probs_, tokens_per_expert, num_tokens, topk, num_experts, coeff) + + def loss_fused_fn(probs_): + return fused_moe_aux_loss(probs_, tokens_per_expert, num_tokens, num_experts, topk, coeff) + + grad_ref = jax.jit(jax.grad(loss_ref_fn))(probs) + grad_fused = jax.jit(jax.grad(loss_fused_fn))(probs) + assert jnp.allclose(grad_ref, grad_fused, atol=1e-5, rtol=1e-5), \ + f"Grad mismatch: max diff = {jnp.abs(grad_ref - grad_fused).max()}" + + +if __name__ == "__main__": + test_topk_softmax( + dtype=jnp.float32, + num_tokens=128, + num_experts=32, + topk=4, + use_pre_softmax=False, + group_topk=None, + scaling_factor=None, + ) + print("All tests passed!") diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index ae41f238a4..f73af31d74 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -608,8 +608,12 @@ class TensorWrapper { * \param[in] scale_inv_dptr Pointer to the inverse of scale value. * \param[in] scaling_mode Tensor data format. */ - TensorWrapper(void *dptr, const NVTEShape &shape, const DType dtype, float *amax_dptr = nullptr, - float *scale_dptr = nullptr, float *scale_inv_dptr = nullptr, + TensorWrapper(void *dptr, + const NVTEShape &shape, + const DType dtype, + float *amax_dptr = nullptr, + float *scale_dptr = nullptr, + float *scale_inv_dptr = nullptr, NVTEShape scale_inv_shape = defaultShape, const NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING) { tensor_ = nvte_create_tensor(scaling_mode); diff --git a/transformer_engine/jax/cpp_extensions/__init__.py b/transformer_engine/jax/cpp_extensions/__init__.py index 6a2f9b7378..d203fcea9d 100644 --- a/transformer_engine/jax/cpp_extensions/__init__.py +++ b/transformer_engine/jax/cpp_extensions/__init__.py @@ -9,3 +9,4 @@ from .quantization import * from .softmax import * from .gemm import * +from .router import * diff --git a/transformer_engine/jax/cpp_extensions/router.py b/transformer_engine/jax/cpp_extensions/router.py new file mode 100644 index 0000000000..fbf2a04285 --- /dev/null +++ b/transformer_engine/jax/cpp_extensions/router.py @@ -0,0 +1,891 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""JAX/TE custom ops for fused MoE router""" +import warnings +from functools import partial + +import jax +import jax.numpy as jnp +from jax import dtypes, ffi +from jax.sharding import PartitionSpec, NamedSharding + +from .base import BasePrimitive, register_primitive +from .misc import get_padded_spec + +__all__ = [ + "fused_topk_with_score_function_fwd", + "fused_topk_with_score_function_bwd", + "fused_score_for_moe_aux_loss_fwd", + "fused_score_for_moe_aux_loss_bwd", + "fused_moe_aux_loss_fwd", + "fused_moe_aux_loss_bwd", +] + +SCORE_FUNCTION_MAP = {"sigmoid": 0, "softmax": 1} + + +# =========================================== ================================== +# Fused Top-K with Score Function - Forward +# ============================================================================= + +class FusedTopkWithScoreFunctionFwdPrimitive(BasePrimitive): + """ + Fused Top-K with Score Function Forward Primitive. + Computes score_function(logits) -> top-k -> probs, routing_map. + """ + + name = "te_fused_topk_with_score_function_forward_ffi" + multiple_results = True + impl_static_args = (2, 3, 4, 5, 6, 7) # topk, use_pre_softmax, num_groups, group_topk, scaling_factor, score_function + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract( + logits_aval, + expert_bias_aval, + topk, + use_pre_softmax, + num_groups, + group_topk, + scaling_factor, + score_function, + ): + """Abstract evaluation: describe output shapes and dtypes.""" + del topk, use_pre_softmax, num_groups, group_topk, scaling_factor, score_function + i_dtype = dtypes.canonicalize_dtype(logits_aval.dtype) + i_shape = logits_aval.shape + assert len(i_shape) == 2, f"logits must be 2D [num_tokens, num_experts], got {i_shape}" + probs_aval = logits_aval.update(shape=i_shape, dtype=i_dtype) + routing_map_aval = logits_aval.update(shape=i_shape, dtype=jnp.bool_) + intermediate_aval = logits_aval.update(shape=i_shape, dtype=i_dtype) + return probs_aval, routing_map_aval, intermediate_aval + + @staticmethod + def lowering( + ctx, + logits, + expert_bias, + *, + topk, + use_pre_softmax, + num_groups, + group_topk, + scaling_factor, + score_function, + ): + return ffi.ffi_lowering(FusedTopkWithScoreFunctionFwdPrimitive.name)( + ctx, + logits, + expert_bias, + topk=topk, + use_pre_softmax=use_pre_softmax, + num_groups=num_groups, + group_topk=group_topk, + scaling_factor=scaling_factor, + score_function=score_function, + ) + + @staticmethod + def impl( + logits, + expert_bias, + topk, + use_pre_softmax, + num_groups, + group_topk, + scaling_factor, + score_function, + ): + assert FusedTopkWithScoreFunctionFwdPrimitive.inner_primitive is not None + return FusedTopkWithScoreFunctionFwdPrimitive.inner_primitive.bind( + logits, + expert_bias, + topk=topk, + use_pre_softmax=use_pre_softmax, + num_groups=num_groups, + group_topk=group_topk, + scaling_factor=scaling_factor, + score_function=score_function, + ) + + @staticmethod + def batcher( + batched_args, + batch_dims, + *, + topk, + use_pre_softmax, + num_groups, + group_topk, + scaling_factor, + score_function, + ): + assert FusedTopkWithScoreFunctionFwdPrimitive.outer_primitive is not None + logits, expert_bias = batched_args + logits_bdim, _ = batch_dims + return ( + FusedTopkWithScoreFunctionFwdPrimitive.outer_primitive.bind( + logits, + expert_bias, + topk=topk, + use_pre_softmax=use_pre_softmax, + num_groups=num_groups, + group_topk=group_topk, + scaling_factor=scaling_factor, + score_function=score_function, + ), + (logits_bdim, logits_bdim, logits_bdim), + ) + + @staticmethod + def infer_sharding_from_operands( + topk, + use_pre_softmax, + num_groups, + group_topk, + scaling_factor, + score_function, + mesh, + arg_infos, + result_infos, + ): + del ( + topk, + use_pre_softmax, + num_groups, + group_topk, + scaling_factor, + score_function, + result_infos, + ) + logits_spec = get_padded_spec(arg_infos[0]) + if logits_spec[-1] is not None: + warnings.warn( + f"Sharding the expert dimension is not supported in " + f"{FusedTopkWithScoreFunctionFwdPrimitive.name}! " + "Forcing XLA to not shard the expert dim, which might introduce extra " + "collective ops and hurt performance." + ) + out_sharding = NamedSharding(mesh, PartitionSpec(*logits_spec[:-1], None)) + return out_sharding, out_sharding, out_sharding + + @staticmethod + def partition( + topk, + use_pre_softmax, + num_groups, + group_topk, + scaling_factor, + score_function, + mesh, + arg_infos, + result_infos, + ): + del result_infos + logits_spec = get_padded_spec(arg_infos[0]) + if logits_spec[-1] is not None: + warnings.warn( + f"Sharding the expert dimension is not supported in " + f"{FusedTopkWithScoreFunctionFwdPrimitive.name}! " + "Forcing XLA to not shard the expert dim, which might introduce extra " + "collective ops and hurt performance." + ) + out_sharding = NamedSharding(mesh, PartitionSpec(*logits_spec[:-1], None)) + logits_sharding = out_sharding + bias_sharding = NamedSharding(mesh, PartitionSpec(None)) + arg_shardings = (logits_sharding, bias_sharding) + impl = partial( + FusedTopkWithScoreFunctionFwdPrimitive.impl, + topk=topk, + use_pre_softmax=use_pre_softmax, + num_groups=num_groups, + group_topk=group_topk, + scaling_factor=scaling_factor, + score_function=score_function, + ) + return mesh, impl, (out_sharding, out_sharding, out_sharding), arg_shardings + + @staticmethod + def shardy_sharding_rule(*args): + del args + return "num_tokens num_experts, num_experts -> num_tokens num_experts, num_tokens num_experts, num_tokens num_experts" + + +register_primitive(FusedTopkWithScoreFunctionFwdPrimitive) + + +# ============================================================================= +# Fused Top-K with Score Function - Backward +# ============================================================================= + + +class FusedTopkWithScoreFunctionBwdPrimitive(BasePrimitive): + """ + Fused Top-K with Score Function Backward Primitive. + """ + + name = "te_fused_topk_with_score_function_backward_ffi" + multiple_results = False + impl_static_args = (3, 4, 5, 6) # topk, use_pre_softmax, scaling_factor, score_function + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract( + routing_map_aval, + intermediate_aval, + grad_probs_aval, + topk, + use_pre_softmax, + scaling_factor, + score_function, + ): + del topk, use_pre_softmax, scaling_factor, score_function, routing_map_aval + return intermediate_aval.update( + shape=intermediate_aval.shape, + dtype=dtypes.canonicalize_dtype(grad_probs_aval.dtype), + ) + + @staticmethod + def lowering( + ctx, + routing_map, + intermediate, + grad_probs, + *, + topk, + use_pre_softmax, + scaling_factor, + score_function, + ): + return ffi.ffi_lowering(FusedTopkWithScoreFunctionBwdPrimitive.name)( + ctx, + routing_map, + intermediate, + grad_probs, + topk=topk, + use_pre_softmax=use_pre_softmax, + scaling_factor=scaling_factor, + score_function=score_function, + ) + + @staticmethod + def impl( + routing_map, + intermediate, + grad_probs, + topk, + use_pre_softmax, + scaling_factor, + score_function, + ): + assert FusedTopkWithScoreFunctionBwdPrimitive.inner_primitive is not None + return FusedTopkWithScoreFunctionBwdPrimitive.inner_primitive.bind( + routing_map, + intermediate, + grad_probs, + topk=topk, + use_pre_softmax=use_pre_softmax, + scaling_factor=scaling_factor, + score_function=score_function, + ) + + @staticmethod + def batcher( + batched_args, + batch_dims, + *, + topk, + use_pre_softmax, + scaling_factor, + score_function, + ): + assert FusedTopkWithScoreFunctionBwdPrimitive.outer_primitive is not None + routing_map, intermediate, grad_probs = batched_args + _, _, grad_probs_bdim = batch_dims + return ( + FusedTopkWithScoreFunctionBwdPrimitive.outer_primitive.bind( + routing_map, + intermediate, + grad_probs, + topk=topk, + use_pre_softmax=use_pre_softmax, + scaling_factor=scaling_factor, + score_function=score_function, + ), + grad_probs_bdim, + ) + + @staticmethod + def infer_sharding_from_operands( + topk, use_pre_softmax, scaling_factor, score_function, mesh, arg_infos, result_infos + ): + del topk, use_pre_softmax, scaling_factor, score_function, result_infos + grad_spec = get_padded_spec(arg_infos[2]) + if grad_spec[-1] is not None: + warnings.warn( + f"Sharding the expert dimension is not supported in " + f"{FusedTopkWithScoreFunctionBwdPrimitive.name}! " + "Forcing XLA to not shard the expert dim." + ) + return NamedSharding(mesh, PartitionSpec(*grad_spec[:-1], None)) + + @staticmethod + def partition( + topk, use_pre_softmax, scaling_factor, score_function, mesh, arg_infos, result_infos + ): + del result_infos + grad_spec = get_padded_spec(arg_infos[2]) + if grad_spec[-1] is not None: + warnings.warn( + f"Sharding the expert dimension is not supported in " + f"{FusedTopkWithScoreFunctionBwdPrimitive.name}! " + "Forcing XLA to not shard the expert dim." + ) + out_sharding = NamedSharding(mesh, PartitionSpec(*grad_spec[:-1], None)) + arg_shardings = tuple( + NamedSharding(mesh, PartitionSpec(*get_padded_spec(a)[:-1], None)) + for a in arg_infos + ) + impl = partial( + FusedTopkWithScoreFunctionBwdPrimitive.impl, + topk=topk, + use_pre_softmax=use_pre_softmax, + scaling_factor=scaling_factor, + score_function=score_function, + ) + return mesh, impl, out_sharding, arg_shardings + + @staticmethod + def shardy_sharding_rule(*args): + del args + return "num_tokens num_experts, num_tokens num_experts, num_tokens num_experts -> num_tokens num_experts" + + +register_primitive(FusedTopkWithScoreFunctionBwdPrimitive) + + +# ============================================================================= +# Fused Score for MoE Aux Loss - Forward +# ============================================================================= + + +class FusedScoreForMoEAuxLossFwdPrimitive(BasePrimitive): + """ + Fused Score for MoE Aux Loss Forward Primitive. + """ + + name = "te_fused_score_for_moe_aux_loss_forward_ffi" + multiple_results = True + impl_static_args = (1, 2) # topk, score_function + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract(logits_aval, topk, score_function): + del topk, score_function + i_dtype = dtypes.canonicalize_dtype(logits_aval.dtype) + i_shape = logits_aval.shape + scores_aval = logits_aval.update(shape=i_shape, dtype=i_dtype) + routing_map_aval = logits_aval.update(shape=i_shape, dtype=jnp.bool_) + intermediate_aval = logits_aval.update(shape=i_shape, dtype=i_dtype) + return scores_aval, routing_map_aval, intermediate_aval + + @staticmethod + def lowering(ctx, logits, *, topk, score_function): + return ffi.ffi_lowering(FusedScoreForMoEAuxLossFwdPrimitive.name)( + ctx, logits, topk=topk, score_function=score_function + ) + + @staticmethod + def impl(logits, topk, score_function): + assert FusedScoreForMoEAuxLossFwdPrimitive.inner_primitive is not None + return FusedScoreForMoEAuxLossFwdPrimitive.inner_primitive.bind( + logits, topk=topk, score_function=score_function + ) + + @staticmethod + def batcher(batched_args, batch_dims, *, topk, score_function): + assert FusedScoreForMoEAuxLossFwdPrimitive.outer_primitive is not None + (logits,) = batched_args + (logits_bdim,) = batch_dims + return ( + FusedScoreForMoEAuxLossFwdPrimitive.outer_primitive.bind( + logits, topk=topk, score_function=score_function + ), + (logits_bdim, logits_bdim, logits_bdim), + ) + + @staticmethod + def infer_sharding_from_operands(topk, score_function, mesh, arg_infos, result_infos): + del topk, score_function, result_infos + logits_spec = get_padded_spec(arg_infos[0]) + if logits_spec[-1] is not None: + warnings.warn( + f"Sharding the expert dimension is not supported in " + f"{FusedScoreForMoEAuxLossFwdPrimitive.name}! " + "Forcing XLA to not shard the expert dim." + ) + out_sharding = NamedSharding(mesh, PartitionSpec(*logits_spec[:-1], None)) + return out_sharding, out_sharding, out_sharding + + @staticmethod + def partition(topk, score_function, mesh, arg_infos, result_infos): + del result_infos + logits_spec = get_padded_spec(arg_infos[0]) + if logits_spec[-1] is not None: + warnings.warn( + f"Sharding the expert dimension is not supported in " + f"{FusedScoreForMoEAuxLossFwdPrimitive.name}! " + "Forcing XLA to not shard the expert dim." + ) + out_sharding = NamedSharding(mesh, PartitionSpec(*logits_spec[:-1], None)) + arg_shardings = (out_sharding,) + impl = partial( + FusedScoreForMoEAuxLossFwdPrimitive.impl, topk=topk, score_function=score_function + ) + return mesh, impl, (out_sharding, out_sharding, out_sharding), arg_shardings + + @staticmethod + def shardy_sharding_rule(*args): + del args + return "num_tokens num_experts -> num_tokens num_experts, num_tokens num_experts, num_tokens num_experts" + + +register_primitive(FusedScoreForMoEAuxLossFwdPrimitive) + + +# ============================================================================= +# Fused Score for MoE Aux Loss - Backward +# ============================================================================= + + +class FusedScoreForMoEAuxLossBwdPrimitive(BasePrimitive): + """ + Fused Score for MoE Aux Loss Backward Primitive. + """ + + name = "te_fused_score_for_moe_aux_loss_backward_ffi" + multiple_results = False + impl_static_args = (2, 3) # topk, score_function + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract(intermediate_aval, grad_scores_aval, topk, score_function): + del topk, score_function, intermediate_aval + return grad_scores_aval.update( + shape=grad_scores_aval.shape, + dtype=dtypes.canonicalize_dtype(grad_scores_aval.dtype), + ) + + @staticmethod + def lowering(ctx, intermediate, grad_scores, *, topk, score_function): + return ffi.ffi_lowering(FusedScoreForMoEAuxLossBwdPrimitive.name)( + ctx, intermediate, grad_scores, topk=topk, score_function=score_function + ) + + @staticmethod + def impl(intermediate, grad_scores, topk, score_function): + assert FusedScoreForMoEAuxLossBwdPrimitive.inner_primitive is not None + return FusedScoreForMoEAuxLossBwdPrimitive.inner_primitive.bind( + intermediate, grad_scores, topk=topk, score_function=score_function + ) + + @staticmethod + def batcher(batched_args, batch_dims, *, topk, score_function): + assert FusedScoreForMoEAuxLossBwdPrimitive.outer_primitive is not None + intermediate, grad_scores = batched_args + _, grad_scores_bdim = batch_dims + return ( + FusedScoreForMoEAuxLossBwdPrimitive.outer_primitive.bind( + intermediate, grad_scores, topk=topk, score_function=score_function + ), + grad_scores_bdim, + ) + + @staticmethod + def infer_sharding_from_operands(topk, score_function, mesh, arg_infos, result_infos): + del topk, score_function, result_infos + spec = get_padded_spec(arg_infos[1]) + if spec[-1] is not None: + warnings.warn( + f"Sharding the expert dimension is not supported in " + f"{FusedScoreForMoEAuxLossBwdPrimitive.name}! " + "Forcing XLA to not shard the expert dim." + ) + return NamedSharding(mesh, PartitionSpec(*spec[:-1], None)) + + @staticmethod + def partition(topk, score_function, mesh, arg_infos, result_infos): + del result_infos + spec = get_padded_spec(arg_infos[1]) + if spec[-1] is not None: + warnings.warn( + f"Sharding the expert dimension is not supported in " + f"{FusedScoreForMoEAuxLossBwdPrimitive.name}! " + "Forcing XLA to not shard the expert dim." + ) + out_sharding = NamedSharding(mesh, PartitionSpec(*spec[:-1], None)) + arg_shardings = tuple( + NamedSharding(mesh, PartitionSpec(*get_padded_spec(a)[:-1], None)) + for a in arg_infos + ) + impl = partial( + FusedScoreForMoEAuxLossBwdPrimitive.impl, topk=topk, score_function=score_function + ) + return mesh, impl, out_sharding, arg_shardings + + @staticmethod + def shardy_sharding_rule(*args): + del args + return "num_tokens num_experts, num_tokens num_experts -> num_tokens num_experts" + + +register_primitive(FusedScoreForMoEAuxLossBwdPrimitive) + + +# ============================================================================= +# Fused MoE Aux Loss - Forward +# ============================================================================= + + +class FusedMoEAuxLossFwdPrimitive(BasePrimitive): + """ + Fused MoE Aux Loss Forward Primitive. + """ + + name = "te_fused_moe_aux_loss_forward_ffi" + multiple_results = True + impl_static_args = (2, 3, 4, 5) # total_num_tokens, num_experts, topk, coeff + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract(probs_aval, tokens_per_expert_aval, total_num_tokens, num_experts, topk, coeff): + del total_num_tokens, num_experts, topk, coeff, tokens_per_expert_aval + i_dtype = dtypes.canonicalize_dtype(probs_aval.dtype) + aux_loss_aval = probs_aval.update(shape=(1,), dtype=i_dtype) + const_buf_aval = probs_aval.update(shape=(1,), dtype=jnp.float32) + return aux_loss_aval, const_buf_aval + + @staticmethod + def lowering(ctx, probs, tokens_per_expert, *, total_num_tokens, num_experts, topk, coeff): + return ffi.ffi_lowering(FusedMoEAuxLossFwdPrimitive.name)( + ctx, + probs, + tokens_per_expert, + total_num_tokens=total_num_tokens, + num_experts=num_experts, + topk=topk, + coeff=coeff, + ) + + @staticmethod + def impl(probs, tokens_per_expert, total_num_tokens, num_experts, topk, coeff): + assert FusedMoEAuxLossFwdPrimitive.inner_primitive is not None + return FusedMoEAuxLossFwdPrimitive.inner_primitive.bind( + probs, + tokens_per_expert, + total_num_tokens=total_num_tokens, + num_experts=num_experts, + topk=topk, + coeff=coeff, + ) + + @staticmethod + def batcher( + batched_args, batch_dims, *, total_num_tokens, num_experts, topk, coeff + ): + assert FusedMoEAuxLossFwdPrimitive.outer_primitive is not None + probs, tokens_per_expert = batched_args + probs_bdim, _ = batch_dims + return ( + FusedMoEAuxLossFwdPrimitive.outer_primitive.bind( + probs, + tokens_per_expert, + total_num_tokens=total_num_tokens, + num_experts=num_experts, + topk=topk, + coeff=coeff, + ), + (probs_bdim, probs_bdim), + ) + + @staticmethod + def infer_sharding_from_operands( + total_num_tokens, num_experts, topk, coeff, mesh, arg_infos, result_infos + ): + del total_num_tokens, num_experts, topk, coeff, arg_infos, result_infos + replicated = NamedSharding(mesh, PartitionSpec()) + return replicated, replicated + + @staticmethod + def partition( + total_num_tokens, num_experts, topk, coeff, mesh, arg_infos, result_infos + ): + del result_infos, arg_infos + replicated = NamedSharding(mesh, PartitionSpec()) + # Global reduction: all inputs must be replicated + arg_shardings = (replicated, replicated) + impl = partial( + FusedMoEAuxLossFwdPrimitive.impl, + total_num_tokens=total_num_tokens, + num_experts=num_experts, + topk=topk, + coeff=coeff, + ) + return mesh, impl, (replicated, replicated), arg_shardings + + @staticmethod + def shardy_sharding_rule(*args): + del args + return "num_tokens num_experts, num_experts -> , " + + +register_primitive(FusedMoEAuxLossFwdPrimitive) + + +# ============================================================================= +# Fused MoE Aux Loss - Backward +# ============================================================================= + + +class FusedMoEAuxLossBwdPrimitive(BasePrimitive): + """ + Fused MoE Aux Loss Backward Primitive. + """ + + name = "te_fused_moe_aux_loss_backward_ffi" + multiple_results = False + impl_static_args = (3, 4) # num_rows, num_cols + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract(const_buf_aval, tokens_per_expert_aval, grad_aux_loss_aval, num_rows, num_cols): + del const_buf_aval, tokens_per_expert_aval + out_dtype = dtypes.canonicalize_dtype(grad_aux_loss_aval.dtype) + return grad_aux_loss_aval.update( + shape=(num_rows, num_cols), + dtype=out_dtype, + ) + + @staticmethod + def lowering(ctx, const_buf, tokens_per_expert, grad_aux_loss, *, num_rows, num_cols): + return ffi.ffi_lowering(FusedMoEAuxLossBwdPrimitive.name)( + ctx, + const_buf, + tokens_per_expert, + grad_aux_loss, + num_rows=num_rows, + num_cols=num_cols, + ) + + @staticmethod + def impl(const_buf, tokens_per_expert, grad_aux_loss, num_rows, num_cols): + assert FusedMoEAuxLossBwdPrimitive.inner_primitive is not None + return FusedMoEAuxLossBwdPrimitive.inner_primitive.bind( + const_buf, tokens_per_expert, grad_aux_loss, num_rows=num_rows, num_cols=num_cols + ) + + @staticmethod + def batcher(batched_args, batch_dims, *, num_rows, num_cols): + assert FusedMoEAuxLossBwdPrimitive.outer_primitive is not None + const_buf, tokens_per_expert, grad_aux_loss = batched_args + _, _, grad_bdim = batch_dims + return ( + FusedMoEAuxLossBwdPrimitive.outer_primitive.bind( + const_buf, tokens_per_expert, grad_aux_loss, num_rows=num_rows, num_cols=num_cols + ), + grad_bdim, + ) + + @staticmethod + def infer_sharding_from_operands(num_rows, num_cols, mesh, arg_infos, result_infos): + del num_rows, num_cols, result_infos, arg_infos + # Output is [num_rows, num_cols]; cannot infer token sharding from + # scalar/1D inputs, so replicate by default. + return NamedSharding(mesh, PartitionSpec(None, None)) + + @staticmethod + def partition(num_rows, num_cols, mesh, arg_infos, result_infos): + del result_infos, arg_infos + # All inputs are scalars or 1D vectors — replicate them. + # Output is [num_rows, num_cols] — replicate (no token sharding info + # available from scalar inputs). + replicated = NamedSharding(mesh, PartitionSpec()) + out_sharding = NamedSharding(mesh, PartitionSpec(None, None)) + arg_shardings = (replicated, replicated, replicated) + impl = partial(FusedMoEAuxLossBwdPrimitive.impl, num_rows=num_rows, num_cols=num_cols) + return mesh, impl, out_sharding, arg_shardings + + @staticmethod + def shardy_sharding_rule(*args): + del args + return ", , -> num_tokens num_experts" + + +register_primitive(FusedMoEAuxLossBwdPrimitive) + + +# ============================================================================= +# Public API functions +# ============================================================================= + + +def fused_topk_with_score_function_fwd( + logits: jnp.ndarray, + topk: int, + use_pre_softmax: bool, + num_groups: int, + group_topk: int, + scaling_factor: float, + score_function: str, + expert_bias: jnp.ndarray, +): + """ + Fused top-k with score function forward pass. + + Parameters + ---------- + logits : jnp.ndarray + [num_tokens, num_experts] logits from gating GEMM. + topk : int + Number of top experts to select. + use_pre_softmax : bool + If True, apply softmax before top-k. + num_groups : int + Number of groups for grouped top-k (-1 to disable). + group_topk : int + Top-k at group level (-1 to disable). + scaling_factor : float + Scaling factor for output probs. + score_function : str + "softmax" or "sigmoid". + expert_bias : jnp.ndarray + Expert bias (only used with sigmoid). Pass empty array if unused. + + Returns + ------- + probs, routing_map, intermediate_output + """ + score_fn_int = SCORE_FUNCTION_MAP[score_function] + return FusedTopkWithScoreFunctionFwdPrimitive.outer_primitive.bind( + logits, + expert_bias, + topk=int(topk), + use_pre_softmax=int(use_pre_softmax), + num_groups=int(num_groups), + group_topk=int(group_topk), + scaling_factor=float(scaling_factor), + score_function=int(score_fn_int), + ) + + +def fused_topk_with_score_function_bwd( + routing_map: jnp.ndarray, + intermediate_output: jnp.ndarray, + grad_probs: jnp.ndarray, + topk: int, + use_pre_softmax: bool, + scaling_factor: float, + score_function: str, +): + """ + Fused top-k with score function backward pass. + """ + score_fn_int = SCORE_FUNCTION_MAP[score_function] + return FusedTopkWithScoreFunctionBwdPrimitive.outer_primitive.bind( + routing_map, + intermediate_output, + grad_probs, + topk=int(topk), + use_pre_softmax=int(use_pre_softmax), + scaling_factor=float(scaling_factor), + score_function=int(score_fn_int), + ) + + +def fused_score_for_moe_aux_loss_fwd( + logits: jnp.ndarray, + topk: int, + score_function: str, +): + """ + Fused compute scores for MoE aux loss forward pass. + + Returns + ------- + scores, routing_map, intermediate_output + """ + score_fn_int = SCORE_FUNCTION_MAP[score_function] + return FusedScoreForMoEAuxLossFwdPrimitive.outer_primitive.bind( + logits, + topk=int(topk), + score_function=int(score_fn_int), + ) + + +def fused_score_for_moe_aux_loss_bwd( + intermediate_output: jnp.ndarray, + grad_scores: jnp.ndarray, + topk: int, + score_function: str, +): + """ + Fused compute scores for MoE aux loss backward pass. + """ + score_fn_int = SCORE_FUNCTION_MAP[score_function] + return FusedScoreForMoEAuxLossBwdPrimitive.outer_primitive.bind( + intermediate_output, + grad_scores, + topk=int(topk), + score_function=int(score_fn_int), + ) + + +def fused_moe_aux_loss_fwd( + probs: jnp.ndarray, + tokens_per_expert: jnp.ndarray, + total_num_tokens: int, + num_experts: int, + topk: int, + coeff: float, +): + """ + Fused MoE aux loss forward pass. + + Returns + ------- + aux_loss, const_buf + """ + return FusedMoEAuxLossFwdPrimitive.outer_primitive.bind( + probs, + tokens_per_expert, + total_num_tokens=int(total_num_tokens), + num_experts=int(num_experts), + topk=int(topk), + coeff=float(coeff), + ) + + +def fused_moe_aux_loss_bwd( + const_buf: jnp.ndarray, + tokens_per_expert: jnp.ndarray, + grad_aux_loss: jnp.ndarray, + num_rows: int, + num_cols: int, +): + """ + Fused MoE aux loss backward pass. + """ + return FusedMoEAuxLossBwdPrimitive.outer_primitive.bind( + const_buf, + tokens_per_expert, + grad_aux_loss, + num_rows=int(num_rows), + num_cols=int(num_cols), + ) diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 3fd086e257..575bba8a44 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -149,6 +149,14 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(CudnnHandleInitHandler); // CuBLAS helpers XLA_FFI_DECLARE_HANDLER_SYMBOL(CublasHandleInitHandler); +// Router +XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedTopkWithScoreFunctionForwardHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedTopkWithScoreFunctionBackwardHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedScoreForMoEAuxLossForwardHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedScoreForMoEAuxLossBackwardHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedMoEAuxLossForwardHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedMoEAuxLossBackwardHandler); + } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index bd4b8fe2c2..7120b89ece 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -81,6 +81,20 @@ pybind11::dict Registrations() { pybind11::arg("initialize") = EncapsulateFFI(RHTAmaxCalculationInitializeHandler), pybind11::arg("execute") = EncapsulateFFI(RHTAmaxCalculationHandler)); + // Router + dict["te_fused_topk_with_score_function_forward_ffi"] = + EncapsulateFFI(FusedTopkWithScoreFunctionForwardHandler); + dict["te_fused_topk_with_score_function_backward_ffi"] = + EncapsulateFFI(FusedTopkWithScoreFunctionBackwardHandler); + dict["te_fused_score_for_moe_aux_loss_forward_ffi"] = + EncapsulateFFI(FusedScoreForMoEAuxLossForwardHandler); + dict["te_fused_score_for_moe_aux_loss_backward_ffi"] = + EncapsulateFFI(FusedScoreForMoEAuxLossBackwardHandler); + dict["te_fused_moe_aux_loss_forward_ffi"] = + EncapsulateFFI(FusedMoEAuxLossForwardHandler); + dict["te_fused_moe_aux_loss_backward_ffi"] = + EncapsulateFFI(FusedMoEAuxLossBackwardHandler); + return dict; } diff --git a/transformer_engine/jax/csrc/extensions/router.cpp b/transformer_engine/jax/csrc/extensions/router.cpp new file mode 100644 index 0000000000..fb75745ae6 --- /dev/null +++ b/transformer_engine/jax/csrc/extensions/router.cpp @@ -0,0 +1,337 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include + +#include "../extensions.h" +#include "xla/ffi/api/c_api.h" + +namespace transformer_engine { +namespace jax { + +// ============================================================================ +// Fused Top-K with Score Function - Forward +// ============================================================================ + +Error_Type FusedTopkWithScoreFunctionForwardFFI( + cudaStream_t stream, + Buffer_Type logits_buf, // [num_tokens, num_experts] + Buffer_Type expert_bias_buf, // [num_experts] or empty + Result_Type probs_buf, // [num_tokens, num_experts] + Result_Type routing_map_buf, // [num_tokens, num_experts] + Result_Type intermediate_buf, // [num_tokens, num_experts] + int64_t topk, + int64_t use_pre_softmax, + int64_t num_groups, + int64_t group_topk, + double scaling_factor, + int64_t score_function) { + auto dtype = convert_ffi_datatype_to_te_dtype(logits_buf.element_type()); + auto dims = logits_buf.dimensions(); + auto num_tokens = static_cast(dims[0]); + auto num_experts = static_cast(dims[1]); + + auto *logits = logits_buf.untyped_data(); + auto *expert_bias = expert_bias_buf.untyped_data(); + auto *probs = probs_buf->untyped_data(); + auto *routing_map = routing_map_buf->untyped_data(); + auto *intermediate = intermediate_buf->untyped_data(); + + auto logits_shape = std::vector{static_cast(num_tokens), + static_cast(num_experts)}; + auto logits_tensor = TensorWrapper(logits, logits_shape, dtype); + auto probs_tensor = TensorWrapper(probs, logits_shape, dtype); + auto routing_map_tensor = TensorWrapper(routing_map, logits_shape, DType::kByte); + auto intermediate_tensor = TensorWrapper(intermediate, logits_shape, dtype); + + // Expert bias: may be empty (dims will be 0) + auto bias_dims = expert_bias_buf.dimensions(); + auto expert_bias_tensor = + (bias_dims.size() > 0 && bias_dims[0] > 0) + ? TensorWrapper(expert_bias, + std::vector{static_cast(bias_dims[0])}, + convert_ffi_datatype_to_te_dtype(expert_bias_buf.element_type())) + : TensorWrapper(); + + nvte_fused_topk_with_score_function_forward( + logits_tensor.data(), num_tokens, num_experts, static_cast(topk), + static_cast(use_pre_softmax), static_cast(num_groups), + static_cast(group_topk), static_cast(scaling_factor), + static_cast(score_function), expert_bias_tensor.data(), probs_tensor.data(), + routing_map_tensor.data(), intermediate_tensor.data(), stream); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + FusedTopkWithScoreFunctionForwardHandler, FusedTopkWithScoreFunctionForwardFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // logits + .Arg() // expert_bias + .Ret() // probs + .Ret() // routing_map + .Ret() // intermediate_output + .Attr("topk") + .Attr("use_pre_softmax") + .Attr("num_groups") + .Attr("group_topk") + .Attr("scaling_factor") + .Attr("score_function"), + FFI_CudaGraph_Traits); + +// ============================================================================ +// Fused Top-K with Score Function - Backward +// ============================================================================ + +Error_Type FusedTopkWithScoreFunctionBackwardFFI( + cudaStream_t stream, + Buffer_Type routing_map_buf, // [num_tokens, num_experts] + Buffer_Type intermediate_buf, // [num_tokens, num_experts] + Buffer_Type grad_probs_buf, // [num_tokens, num_experts] + Result_Type grad_logits_buf, // [num_tokens, num_experts] + int64_t topk, + int64_t use_pre_softmax, + double scaling_factor, + int64_t score_function) { + auto dtype = convert_ffi_datatype_to_te_dtype(intermediate_buf.element_type()); + auto dims = intermediate_buf.dimensions(); + auto num_tokens = static_cast(dims[0]); + auto num_experts = static_cast(dims[1]); + + auto shape = std::vector{static_cast(num_tokens), + static_cast(num_experts)}; + + auto routing_map_tensor = + TensorWrapper(routing_map_buf.untyped_data(), shape, DType::kByte); + auto intermediate_tensor = + TensorWrapper(intermediate_buf.untyped_data(), shape, dtype); + auto grad_probs_tensor = + TensorWrapper(grad_probs_buf.untyped_data(), shape, dtype); + auto grad_logits_tensor = + TensorWrapper(grad_logits_buf->untyped_data(), shape, dtype); + + nvte_fused_topk_with_score_function_backward( + routing_map_tensor.data(), intermediate_tensor.data(), grad_probs_tensor.data(), num_tokens, + num_experts, static_cast(topk), static_cast(use_pre_softmax), + static_cast(scaling_factor), static_cast(score_function), + grad_logits_tensor.data(), stream); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + FusedTopkWithScoreFunctionBackwardHandler, FusedTopkWithScoreFunctionBackwardFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // routing_map + .Arg() // intermediate_output + .Arg() // grad_probs + .Ret() // grad_logits + .Attr("topk") + .Attr("use_pre_softmax") + .Attr("scaling_factor") + .Attr("score_function"), + FFI_CudaGraph_Traits); + +// ============================================================================ +// Fused Score for MoE Aux Loss - Forward +// ============================================================================ + +Error_Type FusedScoreForMoEAuxLossForwardFFI( + cudaStream_t stream, + Buffer_Type logits_buf, // [num_tokens, num_experts] + Result_Type scores_buf, // [num_tokens, num_experts] + Result_Type routing_map_buf, // [num_tokens, num_experts] + Result_Type intermediate_buf, // [num_tokens, num_experts] + int64_t topk, + int64_t score_function) { + auto dtype = convert_ffi_datatype_to_te_dtype(logits_buf.element_type()); + auto dims = logits_buf.dimensions(); + auto num_tokens = static_cast(dims[0]); + auto num_experts = static_cast(dims[1]); + + auto shape = std::vector{static_cast(num_tokens), + static_cast(num_experts)}; + + auto logits_tensor = + TensorWrapper(logits_buf.untyped_data(), shape, dtype); + auto scores_tensor = + TensorWrapper(scores_buf->untyped_data(), shape, dtype); + auto routing_map_tensor = + TensorWrapper(routing_map_buf->untyped_data(), shape, DType::kByte); + auto intermediate_tensor = + TensorWrapper(intermediate_buf->untyped_data(), shape, dtype); + + nvte_fused_score_for_moe_aux_loss_forward( + logits_tensor.data(), num_tokens, num_experts, static_cast(topk), + static_cast(score_function), scores_tensor.data(), routing_map_tensor.data(), + intermediate_tensor.data(), stream); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + FusedScoreForMoEAuxLossForwardHandler, FusedScoreForMoEAuxLossForwardFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // logits + .Ret() // scores + .Ret() // routing_map + .Ret() // intermediate_output + .Attr("topk") + .Attr("score_function"), + FFI_CudaGraph_Traits); + +// ============================================================================ +// Fused Score for MoE Aux Loss - Backward +// ============================================================================ + +Error_Type FusedScoreForMoEAuxLossBackwardFFI( + cudaStream_t stream, + Buffer_Type intermediate_buf, // [num_tokens, num_experts] + Buffer_Type grad_scores_buf, // [num_tokens, num_experts] + Result_Type grad_logits_buf, // [num_tokens, num_experts] + int64_t topk, + int64_t score_function) { + auto dtype = convert_ffi_datatype_to_te_dtype(intermediate_buf.element_type()); + auto dims = intermediate_buf.dimensions(); + auto num_tokens = static_cast(dims[0]); + auto num_experts = static_cast(dims[1]); + + auto shape = std::vector{static_cast(num_tokens), + static_cast(num_experts)}; + + auto intermediate_tensor = + TensorWrapper(intermediate_buf.untyped_data(), shape, dtype); + auto grad_scores_tensor = + TensorWrapper(grad_scores_buf.untyped_data(), shape, dtype); + auto grad_logits_tensor = + TensorWrapper(grad_logits_buf->untyped_data(), shape, dtype); + + nvte_fused_score_for_moe_aux_loss_backward( + intermediate_tensor.data(), grad_scores_tensor.data(), num_tokens, num_experts, + static_cast(topk), static_cast(score_function), grad_logits_tensor.data(), stream); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + FusedScoreForMoEAuxLossBackwardHandler, FusedScoreForMoEAuxLossBackwardFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // intermediate_output + .Arg() // grad_scores + .Ret() // grad_logits + .Attr("topk") + .Attr("score_function"), + FFI_CudaGraph_Traits); + +// ============================================================================ +// Fused MoE Aux Loss - Forward +// ============================================================================ + +Error_Type FusedMoEAuxLossForwardFFI( + cudaStream_t stream, + Buffer_Type probs_buf, // [num_rows, num_cols] + Buffer_Type tokens_per_expert_buf, // [num_experts] + Result_Type aux_loss_buf, // scalar + Result_Type const_buf, // scalar + int64_t total_num_tokens, + int64_t num_experts, + int64_t topk, + double coeff) { + auto dtype = convert_ffi_datatype_to_te_dtype(probs_buf.element_type()); + auto probs_dims = probs_buf.dimensions(); + auto num_rows = static_cast(probs_dims[0]); + auto num_cols = static_cast(probs_dims[1]); + + auto probs_shape = std::vector{static_cast(num_rows), + static_cast(num_cols)}; + auto tpe_dtype = convert_ffi_datatype_to_te_dtype(tokens_per_expert_buf.element_type()); + auto tpe_shape = std::vector{static_cast(num_experts)}; + auto scalar_shape = std::vector{1}; + + auto probs_tensor = TensorWrapper(probs_buf.untyped_data(), probs_shape, dtype); + auto tpe_tensor = + TensorWrapper(tokens_per_expert_buf.untyped_data(), tpe_shape, tpe_dtype); + auto aux_loss_tensor = TensorWrapper(aux_loss_buf->untyped_data(), scalar_shape, dtype); + auto const_buf_tensor = + TensorWrapper(const_buf->untyped_data(), scalar_shape, DType::kFloat32); + + nvte_fused_moe_aux_loss_forward( + probs_tensor.data(), tpe_tensor.data(), static_cast(total_num_tokens), + static_cast(num_experts), num_rows, num_cols, static_cast(topk), + static_cast(coeff), aux_loss_tensor.data(), const_buf_tensor.data(), stream); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + FusedMoEAuxLossForwardHandler, FusedMoEAuxLossForwardFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // probs + .Arg() // tokens_per_expert + .Ret() // aux_loss + .Ret() // Const_buf + .Attr("total_num_tokens") + .Attr("num_experts") + .Attr("topk") + .Attr("coeff"), + FFI_CudaGraph_Traits); + +// ============================================================================ +// Fused MoE Aux Loss - Backward +// ============================================================================ + +Error_Type FusedMoEAuxLossBackwardFFI( + cudaStream_t stream, + Buffer_Type const_buf_in, // scalar float32 + Buffer_Type tokens_per_expert_buf, // [num_experts] + Buffer_Type grad_aux_loss_buf, // scalar + Result_Type grad_probs_buf, // [num_rows, num_cols] + int64_t num_rows, + int64_t num_cols) { + auto grad_dtype = convert_ffi_datatype_to_te_dtype(grad_aux_loss_buf.element_type()); + auto tpe_dtype = convert_ffi_datatype_to_te_dtype(tokens_per_expert_buf.element_type()); + + auto scalar_shape = std::vector{1}; + auto tpe_dims = tokens_per_expert_buf.dimensions(); + auto tpe_shape = std::vector{static_cast(tpe_dims[0])}; + auto grad_probs_shape = std::vector{static_cast(num_rows), + static_cast(num_cols)}; + + auto const_buf_tensor = + TensorWrapper(const_buf_in.untyped_data(), scalar_shape, DType::kFloat32); + auto tpe_tensor = + TensorWrapper(tokens_per_expert_buf.untyped_data(), tpe_shape, tpe_dtype); + auto grad_aux_loss_tensor = + TensorWrapper(grad_aux_loss_buf.untyped_data(), scalar_shape, grad_dtype); + auto grad_probs_tensor = + TensorWrapper(grad_probs_buf->untyped_data(), grad_probs_shape, grad_dtype); + + nvte_fused_moe_aux_loss_backward( + const_buf_tensor.data(), tpe_tensor.data(), static_cast(num_rows), + static_cast(num_cols), grad_aux_loss_tensor.data(), grad_probs_tensor.data(), stream); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + FusedMoEAuxLossBackwardHandler, FusedMoEAuxLossBackwardFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // Const_buf + .Arg() // tokens_per_expert + .Arg() // grad_aux_loss + .Ret() // grad_probs + .Attr("num_rows") + .Attr("num_cols"), + FFI_CudaGraph_Traits); + +} // namespace jax +} // namespace transformer_engine diff --git a/transformer_engine/jax/router.py b/transformer_engine/jax/router.py new file mode 100644 index 0000000000..66c3cf47fd --- /dev/null +++ b/transformer_engine/jax/router.py @@ -0,0 +1,358 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Fused MoE Router API for JAX. + +This module provides high-level fused router operations for Mixture of Experts (MoE) +models with proper automatic differentiation support. These wrap the CUDA kernels in +transformer_engine/common/fused_router/. + +Functions: + fused_topk_with_score_function: + Fused score_function + top-k selection. Supports softmax/sigmoid, + grouped top-k, expert bias, and scaling factor. + + fused_compute_score_for_moe_aux_loss: + Compute clean scores and routing map for the auxiliary load-balancing loss. + + fused_moe_aux_loss: + Compute the MoE auxiliary load-balancing loss scalar. +""" + +from functools import partial +from typing import Optional, Tuple + +import jax +import jax.numpy as jnp + +from transformer_engine.jax.cpp_extensions.router import ( + fused_topk_with_score_function_fwd, + fused_topk_with_score_function_bwd, + fused_score_for_moe_aux_loss_fwd, + fused_score_for_moe_aux_loss_bwd, + fused_moe_aux_loss_fwd, + fused_moe_aux_loss_bwd, +) + +__all__ = [ + "fused_topk_with_score_function", + "fused_compute_score_for_moe_aux_loss", + "fused_moe_aux_loss", +] + + +# ============================================================================= +# Fused Top-K with Score Function +# ============================================================================= + + +def fused_topk_with_score_function( + logits: jnp.ndarray, + topk: int, + use_pre_softmax: bool = False, + num_groups: int = -1, + group_topk: int = -1, + scaling_factor: float = 1.0, + score_function: str = "softmax", + expert_bias: Optional[jnp.ndarray] = None, +) -> Tuple[jnp.ndarray, jnp.ndarray]: + """ + Fused top-k with score function router. + + Parameters + ---------- + logits : jnp.ndarray + Logits from the gating GEMM, shape [num_tokens, num_experts]. + topk : int + Number of top experts to select per token. + use_pre_softmax : bool + If True, apply softmax before top-k (only for softmax score function). Else, apply post top-k + num_groups : int + Number of groups for grouped top-k. -1 to disable. + group_topk : int + Top-k at group level. -1 to disable. + scaling_factor : float + Scaling factor applied to output probs. + score_function : str + Score function: "softmax" or "sigmoid". + expert_bias : Optional[jnp.ndarray] + Expert bias, shape [num_experts]. Only used with sigmoid. + + Returns + ------- + probs : jnp.ndarray + Sparse probability tensor, shape [num_tokens, num_experts]. + Non-zero only at selected expert positions. + routing_map : jnp.ndarray + Boolean mask, shape [num_tokens, num_experts]. + True at selected expert positions. + """ + if score_function not in ("softmax", "sigmoid"): + raise ValueError( + f"score_function must be 'softmax' or 'sigmoid', got '{score_function}'" + ) + + if expert_bias is not None and score_function != "sigmoid": + raise ValueError( + "expert_bias is only supported with score_function='sigmoid'. " + f"Got score_function='{score_function}'." + ) + + # Flatten to 2D if shape is [B, S, H] + original_shape = logits.shape + if logits.ndim > 2: + logits = logits.reshape(-1, original_shape[-1]) + + if expert_bias is None: + expert_bias = jnp.empty((0,), dtype=logits.dtype) + + probs, routing_map = _fused_topk_with_score_function( + logits, + expert_bias, + topk, + use_pre_softmax, + num_groups, + group_topk, + scaling_factor, + score_function, + ) + + # Restore shape if needed + if len(original_shape) > 2: + probs = probs.reshape(original_shape) + routing_map = routing_map.reshape(original_shape) + + return probs, routing_map + + +@partial(jax.custom_vjp, nondiff_argnums=(2, 3, 4, 5, 6, 7)) +def _fused_topk_with_score_function( + logits: jnp.ndarray, + expert_bias: jnp.ndarray, + topk: int, + use_pre_softmax: bool, + num_groups: int, + group_topk: int, + scaling_factor: float, + score_function: str, +) -> Tuple[jnp.ndarray, jnp.ndarray]: + (probs, routing_map), _ = _fused_topk_with_score_function_fwd( + logits, expert_bias, topk, use_pre_softmax, num_groups, group_topk, + scaling_factor, score_function, + ) + return probs, routing_map + + +def _fused_topk_with_score_function_fwd( + logits, expert_bias, topk, use_pre_softmax, num_groups, group_topk, + scaling_factor, score_function, +): + probs, routing_map, intermediate_output = fused_topk_with_score_function_fwd( + logits, topk, use_pre_softmax, num_groups, group_topk, + scaling_factor, score_function, expert_bias, + ) + residuals = (routing_map, intermediate_output) + return (probs, routing_map), residuals + + +def _fused_topk_with_score_function_bwd( + topk, use_pre_softmax, num_groups, group_topk, scaling_factor, score_function, + residuals, g, +): + routing_map, intermediate_output = residuals + grad_probs, _ = g # routing_map gradient is None (boolean) + + grad_logits = fused_topk_with_score_function_bwd( + routing_map, intermediate_output, grad_probs, + topk, use_pre_softmax, scaling_factor, score_function, + ) + # Return gradients for (logits, expert_bias) + return grad_logits, None + + +_fused_topk_with_score_function.defvjp( + _fused_topk_with_score_function_fwd, + _fused_topk_with_score_function_bwd, +) + + +# ============================================================================= +# Fused Score for MoE Aux Loss +# ============================================================================= + + +def fused_compute_score_for_moe_aux_loss( + logits: jnp.ndarray, + topk: int, + score_function: str = "softmax", +) -> Tuple[jnp.ndarray, jnp.ndarray]: + """ + Compute scores and routing map for MoE auxiliary loss. + + This uses clean softmax/sigmoid + plain top-k (no group constraints, + no expert bias, no scaling) to produce the scores and routing map + used for the load-balancing auxiliary loss. + + Parameters + ---------- + logits : jnp.ndarray + Logits from the gating GEMM, shape [num_tokens, num_experts]. + topk : int + Number of top experts to select. + score_function : str + Score function: "softmax" or "sigmoid". + + Returns + ------- + routing_map : jnp.ndarray + Boolean mask, shape [num_tokens, num_experts]. + scores : jnp.ndarray + Dense score tensor, shape [num_tokens, num_experts]. + """ + if score_function not in ("softmax", "sigmoid"): + raise ValueError( + f"score_function must be 'softmax' or 'sigmoid', got '{score_function}'" + ) + + original_shape = logits.shape + if logits.ndim > 2: + logits = logits.reshape(-1, original_shape[-1]) + + routing_map, scores = _fused_compute_score_for_moe_aux_loss(logits, topk, score_function) + + if len(original_shape) > 2: + routing_map = routing_map.reshape(original_shape) + scores = scores.reshape(original_shape) + + return routing_map, scores + + +@partial(jax.custom_vjp, nondiff_argnums=(1, 2)) +def _fused_compute_score_for_moe_aux_loss( + logits: jnp.ndarray, + topk: int, + score_function: str, +) -> Tuple[jnp.ndarray, jnp.ndarray]: + (routing_map, scores), _ = _fused_compute_score_for_moe_aux_loss_fwd( + logits, topk, score_function, + ) + return routing_map, scores + + +def _fused_compute_score_for_moe_aux_loss_fwd(logits, topk, score_function): + scores, routing_map, intermediate_output = fused_score_for_moe_aux_loss_fwd( + logits, topk, score_function, + ) + residuals = (intermediate_output,) + return (routing_map, scores), residuals + + +def _fused_compute_score_for_moe_aux_loss_bwd(topk, score_function, residuals, g): + (intermediate_output,) = residuals + _, grad_scores = g # routing_map gradient is None (boolean) + + grad_logits = fused_score_for_moe_aux_loss_bwd( + intermediate_output, grad_scores, topk, score_function, + ) + return (grad_logits,) + + +_fused_compute_score_for_moe_aux_loss.defvjp( + _fused_compute_score_for_moe_aux_loss_fwd, + _fused_compute_score_for_moe_aux_loss_bwd, +) + + +# ============================================================================= +# Fused MoE Aux Loss +# ============================================================================= + + +def fused_moe_aux_loss( + probs: jnp.ndarray, + tokens_per_expert: jnp.ndarray, + total_num_tokens: int, + num_experts: int, + topk: int, + coeff: float, +) -> jnp.ndarray: + """ + Compute the MoE auxiliary load-balancing loss. + + loss = (E * coeff / (k * T^2)) * sum_i(sum_t(probs[t,i]) * tokens_per_expert[i]) + + Parameters + ---------- + probs : jnp.ndarray + Probability/score tensor, shape [num_tokens, num_experts]. + tokens_per_expert : jnp.ndarray + Token counts per expert, shape [num_experts]. Integer tensor. + total_num_tokens : int + Total token count for normalization. + num_experts : int + Number of experts. + topk : int + Top-k value. + coeff : float + Loss coefficient. + + Returns + ------- + aux_loss : jnp.ndarray + Scalar loss value. + """ + return _fused_moe_aux_loss( + probs, tokens_per_expert, total_num_tokens, num_experts, topk, coeff, + ) + + +@partial(jax.custom_vjp, nondiff_argnums=(2, 3, 4, 5)) +def _fused_moe_aux_loss( + probs: jnp.ndarray, + tokens_per_expert: jnp.ndarray, + total_num_tokens: int, + num_experts: int, + topk: int, + coeff: float, +) -> jnp.ndarray: + (aux_loss,), _ = _fused_moe_aux_loss_fwd( + probs, tokens_per_expert, total_num_tokens, num_experts, topk, coeff, + ) + # Squeeze from shape (1,) to scalar + return aux_loss.squeeze() + + +def _fused_moe_aux_loss_fwd( + probs, tokens_per_expert, total_num_tokens, num_experts, topk, coeff, +): + num_rows = probs.shape[0] + num_cols = probs.shape[1] + aux_loss, const_buf = fused_moe_aux_loss_fwd( + probs, tokens_per_expert, total_num_tokens, num_experts, topk, coeff, + ) + residuals = (const_buf, tokens_per_expert, num_rows, num_cols) + return (aux_loss,), residuals + + +def _fused_moe_aux_loss_bwd( + total_num_tokens, num_experts, topk, coeff, residuals, g, +): + const_buf, tokens_per_expert, num_rows, num_cols = residuals + # g is a tuple matching the output of fwd; the squeeze means g is a scalar + (grad_aux_loss,) = g + # Ensure grad_aux_loss has shape (1,) for the C kernel + grad_aux_loss = grad_aux_loss.reshape(1) + + grad_probs = fused_moe_aux_loss_bwd( + const_buf, tokens_per_expert, grad_aux_loss, num_rows, num_cols, + ) + # Return gradients for (probs, tokens_per_expert) + # tokens_per_expert is integer, no gradient + return grad_probs, None + + +_fused_moe_aux_loss.defvjp( + _fused_moe_aux_loss_fwd, + _fused_moe_aux_loss_bwd, +) From 9c0e8846ca09b6ea3ad0acebd38353258143b896 Mon Sep 17 00:00:00 2001 From: tdophung Date: Thu, 26 Feb 2026 10:26:39 -0800 Subject: [PATCH 2/9] tests should pass Signed-off-by: tdophung --- tests/jax/test_distributed_router.py | 140 ++++++++++++------ tests/jax/test_fused_router.py | 21 +-- .../jax/cpp_extensions/router.py | 17 +-- transformer_engine/jax/router.py | 15 +- 4 files changed, 114 insertions(+), 79 deletions(-) diff --git a/tests/jax/test_distributed_router.py b/tests/jax/test_distributed_router.py index 1d1d060711..a0ee83873a 100644 --- a/tests/jax/test_distributed_router.py +++ b/tests/jax/test_distributed_router.py @@ -201,17 +201,57 @@ def test_distributed_topk_gspmd( mesh_resource, score_function, ): - self._impl_test( - device_count, - mesh_shape, - mesh_axes, - mesh_resource, - num_tokens=128, - num_experts=32, - topk=4, - score_function=score_function, - use_shardy=False, - ) + """GSPMD test using value_and_grad with explicit shardings. + + GSPMD (non-shardy) requires explicit in/out shardings on jax.jit + to correctly partition custom ops, matching the compare_ops pattern + used by other TE distributed tests (softmax, permutation). + """ + num_tokens, num_experts, topk = 128, 32, 4 + jax.config.update("jax_use_shardy_partitioner", False) + + logits = make_logits(num_tokens, num_experts, score_function) + + devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) + mesh = Mesh(devices, mesh_axes) + dp_axis = mesh_resource.dp_resource + sharded_pspec = PartitionSpec(dp_axis, None) + + with mesh: + logits_sharding = NamedSharding(mesh, sharded_pspec) + logits_sharded = jax.device_put(logits, logits_sharding) + + def target_loss(x): + p, _ = fused_topk_with_score_function( + x, topk=topk, score_function=score_function, + ) + return jnp.sum(p) + + def ref_loss(x): + p, _ = reference_topk_softmax_sigmoid( + x, topk=topk, score_function=score_function, + ) + return jnp.sum(p) + + target_vg = jax.jit( + jax.value_and_grad(target_loss), + in_shardings=(logits_sharding,), + out_shardings=(None, logits_sharding), + ) + ref_vg = jax.jit( + jax.value_and_grad(ref_loss), + in_shardings=(logits_sharding,), + out_shardings=(None, logits_sharding), + ) + target_fwd, target_grad = target_vg(logits_sharded) + ref_fwd, ref_grad = ref_vg(logits_sharded) + + assert_allclose(target_fwd, ref_fwd, dtype=jnp.float32) + assert_allclose( + jax.device_get(target_grad), + jax.device_get(ref_grad), + dtype=jnp.float32, + ) class TestDistributedScoreForAuxLoss: @@ -346,17 +386,52 @@ def test_distributed_score_for_aux_loss_gspmd( mesh_resource, score_function, ): - self._impl_test( - device_count, - mesh_shape, - mesh_axes, - mesh_resource, - num_tokens=128, - num_experts=32, - topk=4, - score_function=score_function, - use_shardy=False, - ) + """GSPMD test using value_and_grad with explicit shardings.""" + num_tokens, num_experts, topk = 128, 32, 4 + jax.config.update("jax_use_shardy_partitioner", False) + + logits = make_logits(num_tokens, num_experts, score_function) + + devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) + mesh = Mesh(devices, mesh_axes) + dp_axis = mesh_resource.dp_resource + sharded_pspec = PartitionSpec(dp_axis, None) + + with mesh: + logits_sharding = NamedSharding(mesh, sharded_pspec) + logits_sharded = jax.device_put(logits, logits_sharding) + + def target_loss(x): + _, s = fused_compute_score_for_moe_aux_loss( + x, topk=topk, score_function=score_function, + ) + return jnp.sum(s) + + def ref_loss(x): + _, s = reference_compute_scores_for_aux_loss( + x, topk=topk, score_function=score_function, + ) + return jnp.sum(s) + + target_vg = jax.jit( + jax.value_and_grad(target_loss), + in_shardings=(logits_sharding,), + out_shardings=(None, logits_sharding), + ) + ref_vg = jax.jit( + jax.value_and_grad(ref_loss), + in_shardings=(logits_sharding,), + out_shardings=(None, logits_sharding), + ) + target_fwd, target_grad = target_vg(logits_sharded) + ref_fwd, ref_grad = ref_vg(logits_sharded) + + assert_allclose(target_fwd, ref_fwd, dtype=jnp.float32) + assert_allclose( + jax.device_get(target_grad), + jax.device_get(ref_grad), + dtype=jnp.float32, + ) class TestDistributedMoEAuxLoss: @@ -454,7 +529,7 @@ def ref_loss_fn(p): @pytest_parametrize_wrapper( "num_tokens,num_experts,topk", AUX_LOSS_CASES, ) - @pytest.mark.parametrize("use_shardy", [True]) + @pytest.mark.parametrize("use_shardy", [True, False]) def test_distributed_aux_loss( self, device_count, @@ -476,22 +551,3 @@ def test_distributed_aux_loss( topk, use_shardy, ) - - @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) - def test_distributed_aux_loss_gspmd( - self, - device_count, - mesh_shape, - mesh_axes, - mesh_resource, - ): - self._impl_test( - device_count, - mesh_shape, - mesh_axes, - mesh_resource, - num_tokens=128, - num_experts=32, - topk=4, - use_shardy=False, - ) diff --git a/tests/jax/test_fused_router.py b/tests/jax/test_fused_router.py index 2a4aa4672f..b39b979a96 100644 --- a/tests/jax/test_fused_router.py +++ b/tests/jax/test_fused_router.py @@ -98,8 +98,7 @@ "L2": ALL_DTYPES, } -seed = 42 -key = jax.random.PRNGKey(seed) +SEED = 42 # ============================================================================= @@ -459,15 +458,14 @@ def loss_fused(logits_): "num_tokens,num_experts,topk", AUX_LOSS_CASES, ) def test_fused_moe_aux_loss(dtype, num_tokens, num_experts, topk): - global key - key, subkey1 = jax.random.split(key) + key = jax.random.PRNGKey(SEED) offset = jnp.arange(-num_tokens // 2, num_tokens // 2, dtype=dtype) * 1e-4 probs = jnp.arange(-num_experts // 2, num_experts // 2, dtype=dtype) * 1e-2 probs = probs[None, :].repeat(num_tokens, axis=0) + offset[:, None] probs = probs.reshape(num_tokens, num_experts) - tokens_per_expert = jax.random.randint(subkey1, (num_experts,), 1, 1000).astype(jnp.int32) + tokens_per_expert = jax.random.randint(key, (num_experts,), 1, 1000).astype(jnp.int32) coeff = 0.01 # Forward: reference (jitted) @@ -506,16 +504,3 @@ def loss_fused_fn(probs_): grad_fused = jax.jit(jax.grad(loss_fused_fn))(probs) assert jnp.allclose(grad_ref, grad_fused, atol=1e-5, rtol=1e-5), \ f"Grad mismatch: max diff = {jnp.abs(grad_ref - grad_fused).max()}" - - -if __name__ == "__main__": - test_topk_softmax( - dtype=jnp.float32, - num_tokens=128, - num_experts=32, - topk=4, - use_pre_softmax=False, - group_topk=None, - scaling_factor=None, - ) - print("All tests passed!") diff --git a/transformer_engine/jax/cpp_extensions/router.py b/transformer_engine/jax/cpp_extensions/router.py index fbf2a04285..38d456dc28 100644 --- a/transformer_engine/jax/cpp_extensions/router.py +++ b/transformer_engine/jax/cpp_extensions/router.py @@ -169,7 +169,7 @@ def infer_sharding_from_operands( "collective ops and hurt performance." ) out_sharding = NamedSharding(mesh, PartitionSpec(*logits_spec[:-1], None)) - return out_sharding, out_sharding, out_sharding + return [out_sharding, out_sharding, out_sharding] @staticmethod def partition( @@ -205,7 +205,7 @@ def partition( scaling_factor=scaling_factor, score_function=score_function, ) - return mesh, impl, (out_sharding, out_sharding, out_sharding), arg_shardings + return mesh, impl, [out_sharding, out_sharding, out_sharding], arg_shardings @staticmethod def shardy_sharding_rule(*args): @@ -429,7 +429,7 @@ def infer_sharding_from_operands(topk, score_function, mesh, arg_infos, result_i "Forcing XLA to not shard the expert dim." ) out_sharding = NamedSharding(mesh, PartitionSpec(*logits_spec[:-1], None)) - return out_sharding, out_sharding, out_sharding + return [out_sharding, out_sharding, out_sharding] @staticmethod def partition(topk, score_function, mesh, arg_infos, result_infos): @@ -446,7 +446,7 @@ def partition(topk, score_function, mesh, arg_infos, result_infos): impl = partial( FusedScoreForMoEAuxLossFwdPrimitive.impl, topk=topk, score_function=score_function ) - return mesh, impl, (out_sharding, out_sharding, out_sharding), arg_shardings + return mesh, impl, [out_sharding, out_sharding, out_sharding], arg_shardings @staticmethod def shardy_sharding_rule(*args): @@ -620,7 +620,7 @@ def infer_sharding_from_operands( ): del total_num_tokens, num_experts, topk, coeff, arg_infos, result_infos replicated = NamedSharding(mesh, PartitionSpec()) - return replicated, replicated + return [replicated, replicated] @staticmethod def partition( @@ -628,7 +628,6 @@ def partition( ): del result_infos, arg_infos replicated = NamedSharding(mesh, PartitionSpec()) - # Global reduction: all inputs must be replicated arg_shardings = (replicated, replicated) impl = partial( FusedMoEAuxLossFwdPrimitive.impl, @@ -637,12 +636,12 @@ def partition( topk=topk, coeff=coeff, ) - return mesh, impl, (replicated, replicated), arg_shardings + return mesh, impl, [replicated, replicated], arg_shardings @staticmethod def shardy_sharding_rule(*args): del args - return "num_tokens num_experts, num_experts -> , " + return "num_tokens num_experts, num_experts -> aux_loss_one, const_buf_one" register_primitive(FusedMoEAuxLossFwdPrimitive) @@ -725,7 +724,7 @@ def partition(num_rows, num_cols, mesh, arg_infos, result_infos): @staticmethod def shardy_sharding_rule(*args): del args - return ", , -> num_tokens num_experts" + return "const_buf_one, num_experts, grad_one -> num_tokens num_experts" register_primitive(FusedMoEAuxLossBwdPrimitive) diff --git a/transformer_engine/jax/router.py b/transformer_engine/jax/router.py index 66c3cf47fd..2813e29257 100644 --- a/transformer_engine/jax/router.py +++ b/transformer_engine/jax/router.py @@ -316,11 +316,10 @@ def _fused_moe_aux_loss( topk: int, coeff: float, ) -> jnp.ndarray: - (aux_loss,), _ = _fused_moe_aux_loss_fwd( + aux_loss, _ = _fused_moe_aux_loss_fwd( probs, tokens_per_expert, total_num_tokens, num_experts, topk, coeff, ) - # Squeeze from shape (1,) to scalar - return aux_loss.squeeze() + return aux_loss def _fused_moe_aux_loss_fwd( @@ -332,23 +331,19 @@ def _fused_moe_aux_loss_fwd( probs, tokens_per_expert, total_num_tokens, num_experts, topk, coeff, ) residuals = (const_buf, tokens_per_expert, num_rows, num_cols) - return (aux_loss,), residuals + return aux_loss.squeeze(), residuals def _fused_moe_aux_loss_bwd( total_num_tokens, num_experts, topk, coeff, residuals, g, ): const_buf, tokens_per_expert, num_rows, num_cols = residuals - # g is a tuple matching the output of fwd; the squeeze means g is a scalar - (grad_aux_loss,) = g - # Ensure grad_aux_loss has shape (1,) for the C kernel - grad_aux_loss = grad_aux_loss.reshape(1) + # g is a scalar matching the squeezed output of _fwd + grad_aux_loss = g.reshape(1) grad_probs = fused_moe_aux_loss_bwd( const_buf, tokens_per_expert, grad_aux_loss, num_rows, num_cols, ) - # Return gradients for (probs, tokens_per_expert) - # tokens_per_expert is integer, no gradient return grad_probs, None From 1fa388fedf5dc330d35f89cb7824d2298d41d039 Mon Sep 17 00:00:00 2001 From: tdophung Date: Thu, 26 Feb 2026 11:03:49 -0800 Subject: [PATCH 3/9] fix lint Signed-off-by: tdophung --- transformer_engine/jax/cpp_extensions/router.py | 3 +-- transformer_engine/jax/router.py | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/router.py b/transformer_engine/jax/cpp_extensions/router.py index 38d456dc28..e68a686ba9 100644 --- a/transformer_engine/jax/cpp_extensions/router.py +++ b/transformer_engine/jax/cpp_extensions/router.py @@ -5,7 +5,6 @@ import warnings from functools import partial -import jax import jax.numpy as jnp from jax import dtypes, ffi from jax.sharding import PartitionSpec, NamedSharding @@ -53,7 +52,7 @@ def abstract( score_function, ): """Abstract evaluation: describe output shapes and dtypes.""" - del topk, use_pre_softmax, num_groups, group_topk, scaling_factor, score_function + del expert_bias_aval, topk, use_pre_softmax, num_groups, group_topk, scaling_factor, score_function i_dtype = dtypes.canonicalize_dtype(logits_aval.dtype) i_shape = logits_aval.shape assert len(i_shape) == 2, f"logits must be 2D [num_tokens, num_experts], got {i_shape}" diff --git a/transformer_engine/jax/router.py b/transformer_engine/jax/router.py index 2813e29257..9d6a416cd1 100644 --- a/transformer_engine/jax/router.py +++ b/transformer_engine/jax/router.py @@ -157,7 +157,7 @@ def _fused_topk_with_score_function_fwd( def _fused_topk_with_score_function_bwd( - topk, use_pre_softmax, num_groups, group_topk, scaling_factor, score_function, + topk, use_pre_softmax, num_groups, group_topk, scaling_factor, score_function, # pylint: disable=unused-argument residuals, g, ): routing_map, intermediate_output = residuals @@ -335,7 +335,7 @@ def _fused_moe_aux_loss_fwd( def _fused_moe_aux_loss_bwd( - total_num_tokens, num_experts, topk, coeff, residuals, g, + total_num_tokens, num_experts, topk, coeff, residuals, g, # pylint: disable=unused-argument ): const_buf, tokens_per_expert, num_rows, num_cols = residuals # g is a scalar matching the squeezed output of _fwd From 5f1ea6e94fc7ef07010af4aac20c15b72a0f3972 Mon Sep 17 00:00:00 2001 From: tdophung Date: Fri, 27 Feb 2026 17:17:29 -0800 Subject: [PATCH 4/9] Address comments, minus the returning scalar request to reduce squeeze op Signed-off-by: tdophung --- tests/jax/test_distributed_router.py | 140 +----- .../jax/cpp_extensions/router.py | 441 ++---------------- transformer_engine/jax/csrc/extensions.h | 2 - .../jax/csrc/extensions/pybind.cpp | 4 - .../jax/csrc/extensions/router.cpp | 221 ++++----- transformer_engine/jax/router.py | 142 +++--- 6 files changed, 190 insertions(+), 760 deletions(-) diff --git a/tests/jax/test_distributed_router.py b/tests/jax/test_distributed_router.py index a0ee83873a..7dc6062158 100644 --- a/tests/jax/test_distributed_router.py +++ b/tests/jax/test_distributed_router.py @@ -21,8 +21,7 @@ - All inputs and outputs are replicated (partition function forces this) - We verify the op works correctly under a mesh context -These tests exercise: partition, infer_sharding_from_operands, batcher, -and shardy_sharding_rule from the router primitives. +These tests exercise: batcher and shardy_sharding_rule from the router primitives. """ import pytest @@ -86,9 +85,8 @@ def _impl_test( num_experts, topk, score_function, - use_shardy, ): - jax.config.update("jax_use_shardy_partitioner", use_shardy) + jax.config.update("jax_use_shardy_partitioner", True) logits = make_logits(num_tokens, num_experts, score_function) @@ -166,7 +164,6 @@ def ref_chunk_loss(x_chunk): "num_tokens,num_experts,topk", TOPK_CASES, ) @pytest.mark.parametrize("score_function", ["softmax", "sigmoid"]) - @pytest.mark.parametrize("use_shardy", [True]) def test_distributed_topk( self, device_count, @@ -177,7 +174,6 @@ def test_distributed_topk( num_experts, topk, score_function, - use_shardy, ): self._impl_test( device_count, @@ -188,70 +184,8 @@ def test_distributed_topk( num_experts, topk, score_function, - use_shardy, ) - @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) - @pytest.mark.parametrize("score_function", ["softmax", "sigmoid"]) - def test_distributed_topk_gspmd( - self, - device_count, - mesh_shape, - mesh_axes, - mesh_resource, - score_function, - ): - """GSPMD test using value_and_grad with explicit shardings. - - GSPMD (non-shardy) requires explicit in/out shardings on jax.jit - to correctly partition custom ops, matching the compare_ops pattern - used by other TE distributed tests (softmax, permutation). - """ - num_tokens, num_experts, topk = 128, 32, 4 - jax.config.update("jax_use_shardy_partitioner", False) - - logits = make_logits(num_tokens, num_experts, score_function) - - devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) - mesh = Mesh(devices, mesh_axes) - dp_axis = mesh_resource.dp_resource - sharded_pspec = PartitionSpec(dp_axis, None) - - with mesh: - logits_sharding = NamedSharding(mesh, sharded_pspec) - logits_sharded = jax.device_put(logits, logits_sharding) - - def target_loss(x): - p, _ = fused_topk_with_score_function( - x, topk=topk, score_function=score_function, - ) - return jnp.sum(p) - - def ref_loss(x): - p, _ = reference_topk_softmax_sigmoid( - x, topk=topk, score_function=score_function, - ) - return jnp.sum(p) - - target_vg = jax.jit( - jax.value_and_grad(target_loss), - in_shardings=(logits_sharding,), - out_shardings=(None, logits_sharding), - ) - ref_vg = jax.jit( - jax.value_and_grad(ref_loss), - in_shardings=(logits_sharding,), - out_shardings=(None, logits_sharding), - ) - target_fwd, target_grad = target_vg(logits_sharded) - ref_fwd, ref_grad = ref_vg(logits_sharded) - - assert_allclose(target_fwd, ref_fwd, dtype=jnp.float32) - assert_allclose( - jax.device_get(target_grad), - jax.device_get(ref_grad), - dtype=jnp.float32, - ) class TestDistributedScoreForAuxLoss: @@ -271,9 +205,8 @@ def _impl_test( num_experts, topk, score_function, - use_shardy, ): - jax.config.update("jax_use_shardy_partitioner", use_shardy) + jax.config.update("jax_use_shardy_partitioner", True) logits = make_logits(num_tokens, num_experts, score_function) @@ -351,7 +284,6 @@ def ref_chunk_loss(x_chunk): "num_tokens,num_experts,topk", TOPK_CASES, ) @pytest.mark.parametrize("score_function", ["softmax", "sigmoid"]) - @pytest.mark.parametrize("use_shardy", [True]) def test_distributed_score_for_aux_loss( self, device_count, @@ -362,7 +294,6 @@ def test_distributed_score_for_aux_loss( num_experts, topk, score_function, - use_shardy, ): self._impl_test( device_count, @@ -373,65 +304,8 @@ def test_distributed_score_for_aux_loss( num_experts, topk, score_function, - use_shardy, ) - @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) - @pytest.mark.parametrize("score_function", ["softmax", "sigmoid"]) - def test_distributed_score_for_aux_loss_gspmd( - self, - device_count, - mesh_shape, - mesh_axes, - mesh_resource, - score_function, - ): - """GSPMD test using value_and_grad with explicit shardings.""" - num_tokens, num_experts, topk = 128, 32, 4 - jax.config.update("jax_use_shardy_partitioner", False) - - logits = make_logits(num_tokens, num_experts, score_function) - - devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) - mesh = Mesh(devices, mesh_axes) - dp_axis = mesh_resource.dp_resource - sharded_pspec = PartitionSpec(dp_axis, None) - - with mesh: - logits_sharding = NamedSharding(mesh, sharded_pspec) - logits_sharded = jax.device_put(logits, logits_sharding) - - def target_loss(x): - _, s = fused_compute_score_for_moe_aux_loss( - x, topk=topk, score_function=score_function, - ) - return jnp.sum(s) - - def ref_loss(x): - _, s = reference_compute_scores_for_aux_loss( - x, topk=topk, score_function=score_function, - ) - return jnp.sum(s) - - target_vg = jax.jit( - jax.value_and_grad(target_loss), - in_shardings=(logits_sharding,), - out_shardings=(None, logits_sharding), - ) - ref_vg = jax.jit( - jax.value_and_grad(ref_loss), - in_shardings=(logits_sharding,), - out_shardings=(None, logits_sharding), - ) - target_fwd, target_grad = target_vg(logits_sharded) - ref_fwd, ref_grad = ref_vg(logits_sharded) - - assert_allclose(target_fwd, ref_fwd, dtype=jnp.float32) - assert_allclose( - jax.device_get(target_grad), - jax.device_get(ref_grad), - dtype=jnp.float32, - ) class TestDistributedMoEAuxLoss: @@ -452,12 +326,11 @@ def _impl_test( num_tokens, num_experts, topk, - use_shardy, ): - jax.config.update("jax_use_shardy_partitioner", use_shardy) + jax.config.update("jax_use_shardy_partitioner", True) key = jax.random.PRNGKey(42) - key, subkey1, subkey2 = jax.random.split(key, 3) + _, subkey1, _ = jax.random.split(key, 3) offset = jnp.arange(-num_tokens // 2, num_tokens // 2, dtype=jnp.float32) * 1e-4 probs = jnp.arange(-num_experts // 2, num_experts // 2, dtype=jnp.float32) * 1e-2 @@ -529,7 +402,6 @@ def ref_loss_fn(p): @pytest_parametrize_wrapper( "num_tokens,num_experts,topk", AUX_LOSS_CASES, ) - @pytest.mark.parametrize("use_shardy", [True, False]) def test_distributed_aux_loss( self, device_count, @@ -539,7 +411,6 @@ def test_distributed_aux_loss( num_tokens, num_experts, topk, - use_shardy, ): self._impl_test( device_count, @@ -549,5 +420,4 @@ def test_distributed_aux_loss( num_tokens, num_experts, topk, - use_shardy, ) diff --git a/transformer_engine/jax/cpp_extensions/router.py b/transformer_engine/jax/cpp_extensions/router.py index e68a686ba9..3ebfe3d9a1 100644 --- a/transformer_engine/jax/cpp_extensions/router.py +++ b/transformer_engine/jax/cpp_extensions/router.py @@ -2,26 +2,26 @@ # # See LICENSE for license information. """JAX/TE custom ops for fused MoE router""" -import warnings +from enum import IntEnum from functools import partial import jax.numpy as jnp from jax import dtypes, ffi -from jax.sharding import PartitionSpec, NamedSharding from .base import BasePrimitive, register_primitive -from .misc import get_padded_spec __all__ = [ + "ScoreFunction", "fused_topk_with_score_function_fwd", "fused_topk_with_score_function_bwd", - "fused_score_for_moe_aux_loss_fwd", - "fused_score_for_moe_aux_loss_bwd", "fused_moe_aux_loss_fwd", "fused_moe_aux_loss_bwd", ] -SCORE_FUNCTION_MAP = {"sigmoid": 0, "softmax": 1} + +class ScoreFunction(IntEnum): + SIGMOID = 0 + SOFTMAX = 1 # =========================================== ================================== @@ -32,11 +32,12 @@ class FusedTopkWithScoreFunctionFwdPrimitive(BasePrimitive): """ Fused Top-K with Score Function Forward Primitive. Computes score_function(logits) -> top-k -> probs, routing_map. + When compute_aux_scores=1, instead computes clean scores for aux loss. """ name = "te_fused_topk_with_score_function_forward_ffi" multiple_results = True - impl_static_args = (2, 3, 4, 5, 6, 7) # topk, use_pre_softmax, num_groups, group_topk, scaling_factor, score_function + impl_static_args = (2, 3, 4, 5, 6, 7, 8) # topk, use_pre_softmax, num_groups, group_topk, scaling_factor, score_function, compute_aux_scores inner_primitive = None outer_primitive = None @@ -50,12 +51,13 @@ def abstract( group_topk, scaling_factor, score_function, + compute_aux_scores, ): """Abstract evaluation: describe output shapes and dtypes.""" - del expert_bias_aval, topk, use_pre_softmax, num_groups, group_topk, scaling_factor, score_function + del expert_bias_aval, topk, use_pre_softmax, num_groups, group_topk + del scaling_factor, score_function, compute_aux_scores i_dtype = dtypes.canonicalize_dtype(logits_aval.dtype) i_shape = logits_aval.shape - assert len(i_shape) == 2, f"logits must be 2D [num_tokens, num_experts], got {i_shape}" probs_aval = logits_aval.update(shape=i_shape, dtype=i_dtype) routing_map_aval = logits_aval.update(shape=i_shape, dtype=jnp.bool_) intermediate_aval = logits_aval.update(shape=i_shape, dtype=i_dtype) @@ -73,6 +75,7 @@ def lowering( group_topk, scaling_factor, score_function, + compute_aux_scores, ): return ffi.ffi_lowering(FusedTopkWithScoreFunctionFwdPrimitive.name)( ctx, @@ -84,6 +87,7 @@ def lowering( group_topk=group_topk, scaling_factor=scaling_factor, score_function=score_function, + compute_aux_scores=compute_aux_scores, ) @staticmethod @@ -96,6 +100,7 @@ def impl( group_topk, scaling_factor, score_function, + compute_aux_scores, ): assert FusedTopkWithScoreFunctionFwdPrimitive.inner_primitive is not None return FusedTopkWithScoreFunctionFwdPrimitive.inner_primitive.bind( @@ -107,6 +112,7 @@ def impl( group_topk=group_topk, scaling_factor=scaling_factor, score_function=score_function, + compute_aux_scores=compute_aux_scores, ) @staticmethod @@ -120,6 +126,7 @@ def batcher( group_topk, scaling_factor, score_function, + compute_aux_scores, ): assert FusedTopkWithScoreFunctionFwdPrimitive.outer_primitive is not None logits, expert_bias = batched_args @@ -134,78 +141,11 @@ def batcher( group_topk=group_topk, scaling_factor=scaling_factor, score_function=score_function, + compute_aux_scores=compute_aux_scores, ), (logits_bdim, logits_bdim, logits_bdim), ) - @staticmethod - def infer_sharding_from_operands( - topk, - use_pre_softmax, - num_groups, - group_topk, - scaling_factor, - score_function, - mesh, - arg_infos, - result_infos, - ): - del ( - topk, - use_pre_softmax, - num_groups, - group_topk, - scaling_factor, - score_function, - result_infos, - ) - logits_spec = get_padded_spec(arg_infos[0]) - if logits_spec[-1] is not None: - warnings.warn( - f"Sharding the expert dimension is not supported in " - f"{FusedTopkWithScoreFunctionFwdPrimitive.name}! " - "Forcing XLA to not shard the expert dim, which might introduce extra " - "collective ops and hurt performance." - ) - out_sharding = NamedSharding(mesh, PartitionSpec(*logits_spec[:-1], None)) - return [out_sharding, out_sharding, out_sharding] - - @staticmethod - def partition( - topk, - use_pre_softmax, - num_groups, - group_topk, - scaling_factor, - score_function, - mesh, - arg_infos, - result_infos, - ): - del result_infos - logits_spec = get_padded_spec(arg_infos[0]) - if logits_spec[-1] is not None: - warnings.warn( - f"Sharding the expert dimension is not supported in " - f"{FusedTopkWithScoreFunctionFwdPrimitive.name}! " - "Forcing XLA to not shard the expert dim, which might introduce extra " - "collective ops and hurt performance." - ) - out_sharding = NamedSharding(mesh, PartitionSpec(*logits_spec[:-1], None)) - logits_sharding = out_sharding - bias_sharding = NamedSharding(mesh, PartitionSpec(None)) - arg_shardings = (logits_sharding, bias_sharding) - impl = partial( - FusedTopkWithScoreFunctionFwdPrimitive.impl, - topk=topk, - use_pre_softmax=use_pre_softmax, - num_groups=num_groups, - group_topk=group_topk, - scaling_factor=scaling_factor, - score_function=score_function, - ) - return mesh, impl, [out_sharding, out_sharding, out_sharding], arg_shardings - @staticmethod def shardy_sharding_rule(*args): del args @@ -223,11 +163,12 @@ def shardy_sharding_rule(*args): class FusedTopkWithScoreFunctionBwdPrimitive(BasePrimitive): """ Fused Top-K with Score Function Backward Primitive. + When compute_aux_scores=1, runs the score-for-aux-loss backward instead. """ name = "te_fused_topk_with_score_function_backward_ffi" multiple_results = False - impl_static_args = (3, 4, 5, 6) # topk, use_pre_softmax, scaling_factor, score_function + impl_static_args = (3, 4, 5, 6, 7) # topk, use_pre_softmax, scaling_factor, score_function, compute_aux_scores inner_primitive = None outer_primitive = None @@ -240,8 +181,10 @@ def abstract( use_pre_softmax, scaling_factor, score_function, + compute_aux_scores, ): - del topk, use_pre_softmax, scaling_factor, score_function, routing_map_aval + del topk, use_pre_softmax, scaling_factor, score_function + del compute_aux_scores, routing_map_aval return intermediate_aval.update( shape=intermediate_aval.shape, dtype=dtypes.canonicalize_dtype(grad_probs_aval.dtype), @@ -258,6 +201,7 @@ def lowering( use_pre_softmax, scaling_factor, score_function, + compute_aux_scores, ): return ffi.ffi_lowering(FusedTopkWithScoreFunctionBwdPrimitive.name)( ctx, @@ -268,6 +212,7 @@ def lowering( use_pre_softmax=use_pre_softmax, scaling_factor=scaling_factor, score_function=score_function, + compute_aux_scores=compute_aux_scores, ) @staticmethod @@ -279,6 +224,7 @@ def impl( use_pre_softmax, scaling_factor, score_function, + compute_aux_scores, ): assert FusedTopkWithScoreFunctionBwdPrimitive.inner_primitive is not None return FusedTopkWithScoreFunctionBwdPrimitive.inner_primitive.bind( @@ -289,6 +235,7 @@ def impl( use_pre_softmax=use_pre_softmax, scaling_factor=scaling_factor, score_function=score_function, + compute_aux_scores=compute_aux_scores, ) @staticmethod @@ -300,6 +247,7 @@ def batcher( use_pre_softmax, scaling_factor, score_function, + compute_aux_scores, ): assert FusedTopkWithScoreFunctionBwdPrimitive.outer_primitive is not None routing_map, intermediate, grad_probs = batched_args @@ -313,50 +261,11 @@ def batcher( use_pre_softmax=use_pre_softmax, scaling_factor=scaling_factor, score_function=score_function, + compute_aux_scores=compute_aux_scores, ), grad_probs_bdim, ) - @staticmethod - def infer_sharding_from_operands( - topk, use_pre_softmax, scaling_factor, score_function, mesh, arg_infos, result_infos - ): - del topk, use_pre_softmax, scaling_factor, score_function, result_infos - grad_spec = get_padded_spec(arg_infos[2]) - if grad_spec[-1] is not None: - warnings.warn( - f"Sharding the expert dimension is not supported in " - f"{FusedTopkWithScoreFunctionBwdPrimitive.name}! " - "Forcing XLA to not shard the expert dim." - ) - return NamedSharding(mesh, PartitionSpec(*grad_spec[:-1], None)) - - @staticmethod - def partition( - topk, use_pre_softmax, scaling_factor, score_function, mesh, arg_infos, result_infos - ): - del result_infos - grad_spec = get_padded_spec(arg_infos[2]) - if grad_spec[-1] is not None: - warnings.warn( - f"Sharding the expert dimension is not supported in " - f"{FusedTopkWithScoreFunctionBwdPrimitive.name}! " - "Forcing XLA to not shard the expert dim." - ) - out_sharding = NamedSharding(mesh, PartitionSpec(*grad_spec[:-1], None)) - arg_shardings = tuple( - NamedSharding(mesh, PartitionSpec(*get_padded_spec(a)[:-1], None)) - for a in arg_infos - ) - impl = partial( - FusedTopkWithScoreFunctionBwdPrimitive.impl, - topk=topk, - use_pre_softmax=use_pre_softmax, - scaling_factor=scaling_factor, - score_function=score_function, - ) - return mesh, impl, out_sharding, arg_shardings - @staticmethod def shardy_sharding_rule(*args): del args @@ -366,186 +275,6 @@ def shardy_sharding_rule(*args): register_primitive(FusedTopkWithScoreFunctionBwdPrimitive) -# ============================================================================= -# Fused Score for MoE Aux Loss - Forward -# ============================================================================= - - -class FusedScoreForMoEAuxLossFwdPrimitive(BasePrimitive): - """ - Fused Score for MoE Aux Loss Forward Primitive. - """ - - name = "te_fused_score_for_moe_aux_loss_forward_ffi" - multiple_results = True - impl_static_args = (1, 2) # topk, score_function - inner_primitive = None - outer_primitive = None - - @staticmethod - def abstract(logits_aval, topk, score_function): - del topk, score_function - i_dtype = dtypes.canonicalize_dtype(logits_aval.dtype) - i_shape = logits_aval.shape - scores_aval = logits_aval.update(shape=i_shape, dtype=i_dtype) - routing_map_aval = logits_aval.update(shape=i_shape, dtype=jnp.bool_) - intermediate_aval = logits_aval.update(shape=i_shape, dtype=i_dtype) - return scores_aval, routing_map_aval, intermediate_aval - - @staticmethod - def lowering(ctx, logits, *, topk, score_function): - return ffi.ffi_lowering(FusedScoreForMoEAuxLossFwdPrimitive.name)( - ctx, logits, topk=topk, score_function=score_function - ) - - @staticmethod - def impl(logits, topk, score_function): - assert FusedScoreForMoEAuxLossFwdPrimitive.inner_primitive is not None - return FusedScoreForMoEAuxLossFwdPrimitive.inner_primitive.bind( - logits, topk=topk, score_function=score_function - ) - - @staticmethod - def batcher(batched_args, batch_dims, *, topk, score_function): - assert FusedScoreForMoEAuxLossFwdPrimitive.outer_primitive is not None - (logits,) = batched_args - (logits_bdim,) = batch_dims - return ( - FusedScoreForMoEAuxLossFwdPrimitive.outer_primitive.bind( - logits, topk=topk, score_function=score_function - ), - (logits_bdim, logits_bdim, logits_bdim), - ) - - @staticmethod - def infer_sharding_from_operands(topk, score_function, mesh, arg_infos, result_infos): - del topk, score_function, result_infos - logits_spec = get_padded_spec(arg_infos[0]) - if logits_spec[-1] is not None: - warnings.warn( - f"Sharding the expert dimension is not supported in " - f"{FusedScoreForMoEAuxLossFwdPrimitive.name}! " - "Forcing XLA to not shard the expert dim." - ) - out_sharding = NamedSharding(mesh, PartitionSpec(*logits_spec[:-1], None)) - return [out_sharding, out_sharding, out_sharding] - - @staticmethod - def partition(topk, score_function, mesh, arg_infos, result_infos): - del result_infos - logits_spec = get_padded_spec(arg_infos[0]) - if logits_spec[-1] is not None: - warnings.warn( - f"Sharding the expert dimension is not supported in " - f"{FusedScoreForMoEAuxLossFwdPrimitive.name}! " - "Forcing XLA to not shard the expert dim." - ) - out_sharding = NamedSharding(mesh, PartitionSpec(*logits_spec[:-1], None)) - arg_shardings = (out_sharding,) - impl = partial( - FusedScoreForMoEAuxLossFwdPrimitive.impl, topk=topk, score_function=score_function - ) - return mesh, impl, [out_sharding, out_sharding, out_sharding], arg_shardings - - @staticmethod - def shardy_sharding_rule(*args): - del args - return "num_tokens num_experts -> num_tokens num_experts, num_tokens num_experts, num_tokens num_experts" - - -register_primitive(FusedScoreForMoEAuxLossFwdPrimitive) - - -# ============================================================================= -# Fused Score for MoE Aux Loss - Backward -# ============================================================================= - - -class FusedScoreForMoEAuxLossBwdPrimitive(BasePrimitive): - """ - Fused Score for MoE Aux Loss Backward Primitive. - """ - - name = "te_fused_score_for_moe_aux_loss_backward_ffi" - multiple_results = False - impl_static_args = (2, 3) # topk, score_function - inner_primitive = None - outer_primitive = None - - @staticmethod - def abstract(intermediate_aval, grad_scores_aval, topk, score_function): - del topk, score_function, intermediate_aval - return grad_scores_aval.update( - shape=grad_scores_aval.shape, - dtype=dtypes.canonicalize_dtype(grad_scores_aval.dtype), - ) - - @staticmethod - def lowering(ctx, intermediate, grad_scores, *, topk, score_function): - return ffi.ffi_lowering(FusedScoreForMoEAuxLossBwdPrimitive.name)( - ctx, intermediate, grad_scores, topk=topk, score_function=score_function - ) - - @staticmethod - def impl(intermediate, grad_scores, topk, score_function): - assert FusedScoreForMoEAuxLossBwdPrimitive.inner_primitive is not None - return FusedScoreForMoEAuxLossBwdPrimitive.inner_primitive.bind( - intermediate, grad_scores, topk=topk, score_function=score_function - ) - - @staticmethod - def batcher(batched_args, batch_dims, *, topk, score_function): - assert FusedScoreForMoEAuxLossBwdPrimitive.outer_primitive is not None - intermediate, grad_scores = batched_args - _, grad_scores_bdim = batch_dims - return ( - FusedScoreForMoEAuxLossBwdPrimitive.outer_primitive.bind( - intermediate, grad_scores, topk=topk, score_function=score_function - ), - grad_scores_bdim, - ) - - @staticmethod - def infer_sharding_from_operands(topk, score_function, mesh, arg_infos, result_infos): - del topk, score_function, result_infos - spec = get_padded_spec(arg_infos[1]) - if spec[-1] is not None: - warnings.warn( - f"Sharding the expert dimension is not supported in " - f"{FusedScoreForMoEAuxLossBwdPrimitive.name}! " - "Forcing XLA to not shard the expert dim." - ) - return NamedSharding(mesh, PartitionSpec(*spec[:-1], None)) - - @staticmethod - def partition(topk, score_function, mesh, arg_infos, result_infos): - del result_infos - spec = get_padded_spec(arg_infos[1]) - if spec[-1] is not None: - warnings.warn( - f"Sharding the expert dimension is not supported in " - f"{FusedScoreForMoEAuxLossBwdPrimitive.name}! " - "Forcing XLA to not shard the expert dim." - ) - out_sharding = NamedSharding(mesh, PartitionSpec(*spec[:-1], None)) - arg_shardings = tuple( - NamedSharding(mesh, PartitionSpec(*get_padded_spec(a)[:-1], None)) - for a in arg_infos - ) - impl = partial( - FusedScoreForMoEAuxLossBwdPrimitive.impl, topk=topk, score_function=score_function - ) - return mesh, impl, out_sharding, arg_shardings - - @staticmethod - def shardy_sharding_rule(*args): - del args - return "num_tokens num_experts, num_tokens num_experts -> num_tokens num_experts" - - -register_primitive(FusedScoreForMoEAuxLossBwdPrimitive) - - # ============================================================================= # Fused MoE Aux Loss - Forward # ============================================================================= @@ -613,30 +342,6 @@ def batcher( (probs_bdim, probs_bdim), ) - @staticmethod - def infer_sharding_from_operands( - total_num_tokens, num_experts, topk, coeff, mesh, arg_infos, result_infos - ): - del total_num_tokens, num_experts, topk, coeff, arg_infos, result_infos - replicated = NamedSharding(mesh, PartitionSpec()) - return [replicated, replicated] - - @staticmethod - def partition( - total_num_tokens, num_experts, topk, coeff, mesh, arg_infos, result_infos - ): - del result_infos, arg_infos - replicated = NamedSharding(mesh, PartitionSpec()) - arg_shardings = (replicated, replicated) - impl = partial( - FusedMoEAuxLossFwdPrimitive.impl, - total_num_tokens=total_num_tokens, - num_experts=num_experts, - topk=topk, - coeff=coeff, - ) - return mesh, impl, [replicated, replicated], arg_shardings - @staticmethod def shardy_sharding_rule(*args): del args @@ -701,25 +406,6 @@ def batcher(batched_args, batch_dims, *, num_rows, num_cols): grad_bdim, ) - @staticmethod - def infer_sharding_from_operands(num_rows, num_cols, mesh, arg_infos, result_infos): - del num_rows, num_cols, result_infos, arg_infos - # Output is [num_rows, num_cols]; cannot infer token sharding from - # scalar/1D inputs, so replicate by default. - return NamedSharding(mesh, PartitionSpec(None, None)) - - @staticmethod - def partition(num_rows, num_cols, mesh, arg_infos, result_infos): - del result_infos, arg_infos - # All inputs are scalars or 1D vectors — replicate them. - # Output is [num_rows, num_cols] — replicate (no token sharding info - # available from scalar inputs). - replicated = NamedSharding(mesh, PartitionSpec()) - out_sharding = NamedSharding(mesh, PartitionSpec(None, None)) - arg_shardings = (replicated, replicated, replicated) - impl = partial(FusedMoEAuxLossBwdPrimitive.impl, num_rows=num_rows, num_cols=num_cols) - return mesh, impl, out_sharding, arg_shardings - @staticmethod def shardy_sharding_rule(*args): del args @@ -741,12 +427,17 @@ def fused_topk_with_score_function_fwd( num_groups: int, group_topk: int, scaling_factor: float, - score_function: str, + score_function, expert_bias: jnp.ndarray, + compute_aux_scores: bool = False, ): """ Fused top-k with score function forward pass. + When compute_aux_scores=True, runs the clean score-for-aux-loss kernel + instead of the full top-k kernel (expert_bias, use_pre_softmax, num_groups, + group_topk, and scaling_factor are ignored). + Parameters ---------- logits : jnp.ndarray @@ -756,21 +447,22 @@ def fused_topk_with_score_function_fwd( use_pre_softmax : bool If True, apply softmax before top-k. num_groups : int - Number of groups for grouped top-k (-1 to disable). + Number of groups for grouped top-k (1 to disable). group_topk : int - Top-k at group level (-1 to disable). + Top-k at group level (1 to disable). scaling_factor : float Scaling factor for output probs. - score_function : str - "softmax" or "sigmoid". + score_function : ScoreFunction + ScoreFunction.SOFTMAX or ScoreFunction.SIGMOID. expert_bias : jnp.ndarray Expert bias (only used with sigmoid). Pass empty array if unused. + compute_aux_scores : bool + If True, compute clean scores for aux loss instead of full top-k. Returns ------- - probs, routing_map, intermediate_output + probs_or_scores, routing_map, intermediate_output """ - score_fn_int = SCORE_FUNCTION_MAP[score_function] return FusedTopkWithScoreFunctionFwdPrimitive.outer_primitive.bind( logits, expert_bias, @@ -779,7 +471,8 @@ def fused_topk_with_score_function_fwd( num_groups=int(num_groups), group_topk=int(group_topk), scaling_factor=float(scaling_factor), - score_function=int(score_fn_int), + score_function=int(score_function), + compute_aux_scores=int(compute_aux_scores), ) @@ -790,12 +483,15 @@ def fused_topk_with_score_function_bwd( topk: int, use_pre_softmax: bool, scaling_factor: float, - score_function: str, + score_function, + compute_aux_scores: bool = False, ): """ Fused top-k with score function backward pass. + + When compute_aux_scores=True, routing_map is ignored and the + score-for-aux-loss backward kernel is used instead. """ - score_fn_int = SCORE_FUNCTION_MAP[score_function] return FusedTopkWithScoreFunctionBwdPrimitive.outer_primitive.bind( routing_map, intermediate_output, @@ -803,45 +499,8 @@ def fused_topk_with_score_function_bwd( topk=int(topk), use_pre_softmax=int(use_pre_softmax), scaling_factor=float(scaling_factor), - score_function=int(score_fn_int), - ) - - -def fused_score_for_moe_aux_loss_fwd( - logits: jnp.ndarray, - topk: int, - score_function: str, -): - """ - Fused compute scores for MoE aux loss forward pass. - - Returns - ------- - scores, routing_map, intermediate_output - """ - score_fn_int = SCORE_FUNCTION_MAP[score_function] - return FusedScoreForMoEAuxLossFwdPrimitive.outer_primitive.bind( - logits, - topk=int(topk), - score_function=int(score_fn_int), - ) - - -def fused_score_for_moe_aux_loss_bwd( - intermediate_output: jnp.ndarray, - grad_scores: jnp.ndarray, - topk: int, - score_function: str, -): - """ - Fused compute scores for MoE aux loss backward pass. - """ - score_fn_int = SCORE_FUNCTION_MAP[score_function] - return FusedScoreForMoEAuxLossBwdPrimitive.outer_primitive.bind( - intermediate_output, - grad_scores, - topk=int(topk), - score_function=int(score_fn_int), + score_function=int(score_function), + compute_aux_scores=int(compute_aux_scores), ) diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 575bba8a44..e10aae9f83 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -152,8 +152,6 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(CublasHandleInitHandler); // Router XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedTopkWithScoreFunctionForwardHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedTopkWithScoreFunctionBackwardHandler); -XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedScoreForMoEAuxLossForwardHandler); -XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedScoreForMoEAuxLossBackwardHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedMoEAuxLossForwardHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedMoEAuxLossBackwardHandler); diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 7120b89ece..c7200eab29 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -86,10 +86,6 @@ pybind11::dict Registrations() { EncapsulateFFI(FusedTopkWithScoreFunctionForwardHandler); dict["te_fused_topk_with_score_function_backward_ffi"] = EncapsulateFFI(FusedTopkWithScoreFunctionBackwardHandler); - dict["te_fused_score_for_moe_aux_loss_forward_ffi"] = - EncapsulateFFI(FusedScoreForMoEAuxLossForwardHandler); - dict["te_fused_score_for_moe_aux_loss_backward_ffi"] = - EncapsulateFFI(FusedScoreForMoEAuxLossBackwardHandler); dict["te_fused_moe_aux_loss_forward_ffi"] = EncapsulateFFI(FusedMoEAuxLossForwardHandler); dict["te_fused_moe_aux_loss_backward_ffi"] = diff --git a/transformer_engine/jax/csrc/extensions/router.cpp b/transformer_engine/jax/csrc/extensions/router.cpp index fb75745ae6..ac3699ccab 100644 --- a/transformer_engine/jax/csrc/extensions/router.cpp +++ b/transformer_engine/jax/csrc/extensions/router.cpp @@ -12,6 +12,22 @@ namespace transformer_engine { namespace jax { +enum class ScoreFunction : int64_t { + kSigmoid = 0, + kSoftmax = 1, +}; + +// Compute num_tokens as the product of all dimensions except the last (num_experts). +// Supports arbitrary-rank inputs (e.g., [B, S, E] or [num_tokens, E]). +template +static int compute_num_tokens(const Dims &dims) { + int num_tokens = 1; + for (size_t i = 0; i + 1 < dims.size(); ++i) { + num_tokens *= static_cast(dims[i]); + } + return num_tokens; +} + // ============================================================================ // Fused Top-K with Score Function - Forward // ============================================================================ @@ -20,7 +36,7 @@ Error_Type FusedTopkWithScoreFunctionForwardFFI( cudaStream_t stream, Buffer_Type logits_buf, // [num_tokens, num_experts] Buffer_Type expert_bias_buf, // [num_experts] or empty - Result_Type probs_buf, // [num_tokens, num_experts] + Result_Type probs_buf, // [num_tokens, num_experts] (or scores when compute_aux_scores) Result_Type routing_map_buf, // [num_tokens, num_experts] Result_Type intermediate_buf, // [num_tokens, num_experts] int64_t topk, @@ -28,11 +44,12 @@ Error_Type FusedTopkWithScoreFunctionForwardFFI( int64_t num_groups, int64_t group_topk, double scaling_factor, - int64_t score_function) { + int64_t score_function, + int64_t compute_aux_scores) { auto dtype = convert_ffi_datatype_to_te_dtype(logits_buf.element_type()); auto dims = logits_buf.dimensions(); - auto num_tokens = static_cast(dims[0]); - auto num_experts = static_cast(dims[1]); + auto num_tokens = compute_num_tokens(dims); + auto num_experts = static_cast(dims[dims.size() - 1]); auto *logits = logits_buf.untyped_data(); auto *expert_bias = expert_bias_buf.untyped_data(); @@ -40,28 +57,34 @@ Error_Type FusedTopkWithScoreFunctionForwardFFI( auto *routing_map = routing_map_buf->untyped_data(); auto *intermediate = intermediate_buf->untyped_data(); - auto logits_shape = std::vector{static_cast(num_tokens), - static_cast(num_experts)}; - auto logits_tensor = TensorWrapper(logits, logits_shape, dtype); - auto probs_tensor = TensorWrapper(probs, logits_shape, dtype); - auto routing_map_tensor = TensorWrapper(routing_map, logits_shape, DType::kByte); - auto intermediate_tensor = TensorWrapper(intermediate, logits_shape, dtype); - - // Expert bias: may be empty (dims will be 0) - auto bias_dims = expert_bias_buf.dimensions(); - auto expert_bias_tensor = - (bias_dims.size() > 0 && bias_dims[0] > 0) - ? TensorWrapper(expert_bias, - std::vector{static_cast(bias_dims[0])}, - convert_ffi_datatype_to_te_dtype(expert_bias_buf.element_type())) - : TensorWrapper(); - - nvte_fused_topk_with_score_function_forward( - logits_tensor.data(), num_tokens, num_experts, static_cast(topk), - static_cast(use_pre_softmax), static_cast(num_groups), - static_cast(group_topk), static_cast(scaling_factor), - static_cast(score_function), expert_bias_tensor.data(), probs_tensor.data(), - routing_map_tensor.data(), intermediate_tensor.data(), stream); + auto flat_shape = std::vector{static_cast(num_tokens), + static_cast(num_experts)}; + auto logits_tensor = TensorWrapper(logits, flat_shape, dtype); + auto probs_tensor = TensorWrapper(probs, flat_shape, dtype); + auto routing_map_tensor = TensorWrapper(routing_map, flat_shape, DType::kByte); + auto intermediate_tensor = TensorWrapper(intermediate, flat_shape, dtype); + + if (compute_aux_scores) { + nvte_fused_score_for_moe_aux_loss_forward( + logits_tensor.data(), num_tokens, num_experts, static_cast(topk), + static_cast(score_function), probs_tensor.data(), routing_map_tensor.data(), + intermediate_tensor.data(), stream); + } else { + auto bias_dims = expert_bias_buf.dimensions(); + auto expert_bias_tensor = + (bias_dims.size() > 0 && bias_dims[0] > 0) + ? TensorWrapper(expert_bias, + std::vector{static_cast(bias_dims[0])}, + convert_ffi_datatype_to_te_dtype(expert_bias_buf.element_type())) + : TensorWrapper(); + + nvte_fused_topk_with_score_function_forward( + logits_tensor.data(), num_tokens, num_experts, static_cast(topk), + static_cast(use_pre_softmax), static_cast(num_groups), + static_cast(group_topk), static_cast(scaling_factor), + static_cast(score_function), expert_bias_tensor.data(), probs_tensor.data(), + routing_map_tensor.data(), intermediate_tensor.data(), stream); + } return ffi_with_cuda_error_check(); } @@ -72,7 +95,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( .Ctx() // stream .Arg() // logits .Arg() // expert_bias - .Ret() // probs + .Ret() // probs (or scores) .Ret() // routing_map .Ret() // intermediate_output .Attr("topk") @@ -80,7 +103,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( .Attr("num_groups") .Attr("group_topk") .Attr("scaling_factor") - .Attr("score_function"), + .Attr("score_function") + .Attr("compute_aux_scores"), FFI_CudaGraph_Traits); // ============================================================================ @@ -89,36 +113,45 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( Error_Type FusedTopkWithScoreFunctionBackwardFFI( cudaStream_t stream, - Buffer_Type routing_map_buf, // [num_tokens, num_experts] + Buffer_Type routing_map_buf, // [num_tokens, num_experts] (unused when compute_aux_scores) Buffer_Type intermediate_buf, // [num_tokens, num_experts] - Buffer_Type grad_probs_buf, // [num_tokens, num_experts] + Buffer_Type grad_probs_buf, // [num_tokens, num_experts] (grad_scores when compute_aux_scores) Result_Type grad_logits_buf, // [num_tokens, num_experts] int64_t topk, int64_t use_pre_softmax, double scaling_factor, - int64_t score_function) { + int64_t score_function, + int64_t compute_aux_scores) { auto dtype = convert_ffi_datatype_to_te_dtype(intermediate_buf.element_type()); auto dims = intermediate_buf.dimensions(); - auto num_tokens = static_cast(dims[0]); - auto num_experts = static_cast(dims[1]); + auto num_tokens = compute_num_tokens(dims); + auto num_experts = static_cast(dims[dims.size() - 1]); - auto shape = std::vector{static_cast(num_tokens), - static_cast(num_experts)}; + auto flat_shape = std::vector{static_cast(num_tokens), + static_cast(num_experts)}; - auto routing_map_tensor = - TensorWrapper(routing_map_buf.untyped_data(), shape, DType::kByte); auto intermediate_tensor = - TensorWrapper(intermediate_buf.untyped_data(), shape, dtype); + TensorWrapper(intermediate_buf.untyped_data(), flat_shape, dtype); auto grad_probs_tensor = - TensorWrapper(grad_probs_buf.untyped_data(), shape, dtype); + TensorWrapper(grad_probs_buf.untyped_data(), flat_shape, dtype); auto grad_logits_tensor = - TensorWrapper(grad_logits_buf->untyped_data(), shape, dtype); - - nvte_fused_topk_with_score_function_backward( - routing_map_tensor.data(), intermediate_tensor.data(), grad_probs_tensor.data(), num_tokens, - num_experts, static_cast(topk), static_cast(use_pre_softmax), - static_cast(scaling_factor), static_cast(score_function), - grad_logits_tensor.data(), stream); + TensorWrapper(grad_logits_buf->untyped_data(), flat_shape, dtype); + + if (compute_aux_scores) { + nvte_fused_score_for_moe_aux_loss_backward( + intermediate_tensor.data(), grad_probs_tensor.data(), num_tokens, num_experts, + static_cast(topk), static_cast(score_function), grad_logits_tensor.data(), + stream); + } else { + auto routing_map_tensor = + TensorWrapper(routing_map_buf.untyped_data(), flat_shape, DType::kByte); + + nvte_fused_topk_with_score_function_backward( + routing_map_tensor.data(), intermediate_tensor.data(), grad_probs_tensor.data(), + num_tokens, num_experts, static_cast(topk), static_cast(use_pre_softmax), + static_cast(scaling_factor), static_cast(score_function), + grad_logits_tensor.data(), stream); + } return ffi_with_cuda_error_check(); } @@ -134,100 +167,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( .Attr("topk") .Attr("use_pre_softmax") .Attr("scaling_factor") - .Attr("score_function"), - FFI_CudaGraph_Traits); - -// ============================================================================ -// Fused Score for MoE Aux Loss - Forward -// ============================================================================ - -Error_Type FusedScoreForMoEAuxLossForwardFFI( - cudaStream_t stream, - Buffer_Type logits_buf, // [num_tokens, num_experts] - Result_Type scores_buf, // [num_tokens, num_experts] - Result_Type routing_map_buf, // [num_tokens, num_experts] - Result_Type intermediate_buf, // [num_tokens, num_experts] - int64_t topk, - int64_t score_function) { - auto dtype = convert_ffi_datatype_to_te_dtype(logits_buf.element_type()); - auto dims = logits_buf.dimensions(); - auto num_tokens = static_cast(dims[0]); - auto num_experts = static_cast(dims[1]); - - auto shape = std::vector{static_cast(num_tokens), - static_cast(num_experts)}; - - auto logits_tensor = - TensorWrapper(logits_buf.untyped_data(), shape, dtype); - auto scores_tensor = - TensorWrapper(scores_buf->untyped_data(), shape, dtype); - auto routing_map_tensor = - TensorWrapper(routing_map_buf->untyped_data(), shape, DType::kByte); - auto intermediate_tensor = - TensorWrapper(intermediate_buf->untyped_data(), shape, dtype); - - nvte_fused_score_for_moe_aux_loss_forward( - logits_tensor.data(), num_tokens, num_experts, static_cast(topk), - static_cast(score_function), scores_tensor.data(), routing_map_tensor.data(), - intermediate_tensor.data(), stream); - - return ffi_with_cuda_error_check(); -} - -XLA_FFI_DEFINE_HANDLER_SYMBOL( - FusedScoreForMoEAuxLossForwardHandler, FusedScoreForMoEAuxLossForwardFFI, - FFI::Bind() - .Ctx() // stream - .Arg() // logits - .Ret() // scores - .Ret() // routing_map - .Ret() // intermediate_output - .Attr("topk") - .Attr("score_function"), - FFI_CudaGraph_Traits); - -// ============================================================================ -// Fused Score for MoE Aux Loss - Backward -// ============================================================================ - -Error_Type FusedScoreForMoEAuxLossBackwardFFI( - cudaStream_t stream, - Buffer_Type intermediate_buf, // [num_tokens, num_experts] - Buffer_Type grad_scores_buf, // [num_tokens, num_experts] - Result_Type grad_logits_buf, // [num_tokens, num_experts] - int64_t topk, - int64_t score_function) { - auto dtype = convert_ffi_datatype_to_te_dtype(intermediate_buf.element_type()); - auto dims = intermediate_buf.dimensions(); - auto num_tokens = static_cast(dims[0]); - auto num_experts = static_cast(dims[1]); - - auto shape = std::vector{static_cast(num_tokens), - static_cast(num_experts)}; - - auto intermediate_tensor = - TensorWrapper(intermediate_buf.untyped_data(), shape, dtype); - auto grad_scores_tensor = - TensorWrapper(grad_scores_buf.untyped_data(), shape, dtype); - auto grad_logits_tensor = - TensorWrapper(grad_logits_buf->untyped_data(), shape, dtype); - - nvte_fused_score_for_moe_aux_loss_backward( - intermediate_tensor.data(), grad_scores_tensor.data(), num_tokens, num_experts, - static_cast(topk), static_cast(score_function), grad_logits_tensor.data(), stream); - - return ffi_with_cuda_error_check(); -} - -XLA_FFI_DEFINE_HANDLER_SYMBOL( - FusedScoreForMoEAuxLossBackwardHandler, FusedScoreForMoEAuxLossBackwardFFI, - FFI::Bind() - .Ctx() // stream - .Arg() // intermediate_output - .Arg() // grad_scores - .Ret() // grad_logits - .Attr("topk") - .Attr("score_function"), + .Attr("score_function") + .Attr("compute_aux_scores"), FFI_CudaGraph_Traits); // ============================================================================ diff --git a/transformer_engine/jax/router.py b/transformer_engine/jax/router.py index 9d6a416cd1..cf3a4a4097 100644 --- a/transformer_engine/jax/router.py +++ b/transformer_engine/jax/router.py @@ -21,27 +21,40 @@ """ from functools import partial -from typing import Optional, Tuple +from typing import Optional, Tuple, Union import jax import jax.numpy as jnp from transformer_engine.jax.cpp_extensions.router import ( + ScoreFunction, fused_topk_with_score_function_fwd, fused_topk_with_score_function_bwd, - fused_score_for_moe_aux_loss_fwd, - fused_score_for_moe_aux_loss_bwd, fused_moe_aux_loss_fwd, fused_moe_aux_loss_bwd, ) __all__ = [ + "ScoreFunction", "fused_topk_with_score_function", "fused_compute_score_for_moe_aux_loss", "fused_moe_aux_loss", ] +def _validate_score_function(score_function: Union[str, ScoreFunction]) -> ScoreFunction: + """Validate and convert score_function to a ScoreFunction enum.""" + if isinstance(score_function, ScoreFunction): + return score_function + try: + return ScoreFunction[score_function.upper()] + except (KeyError, AttributeError): + raise ValueError( + f"score_function must be 'softmax', 'sigmoid', or a ScoreFunction enum, " + f"got {score_function!r}" + ) from None + + # ============================================================================= # Fused Top-K with Score Function # ============================================================================= @@ -51,10 +64,10 @@ def fused_topk_with_score_function( logits: jnp.ndarray, topk: int, use_pre_softmax: bool = False, - num_groups: int = -1, - group_topk: int = -1, + num_groups: int = 1, + group_topk: int = 1, scaling_factor: float = 1.0, - score_function: str = "softmax", + score_function: Union[str, ScoreFunction] = ScoreFunction.SOFTMAX, expert_bias: Optional[jnp.ndarray] = None, ) -> Tuple[jnp.ndarray, jnp.ndarray]: """ @@ -69,13 +82,13 @@ def fused_topk_with_score_function( use_pre_softmax : bool If True, apply softmax before top-k (only for softmax score function). Else, apply post top-k num_groups : int - Number of groups for grouped top-k. -1 to disable. + Number of groups for grouped top-k. 1 means no grouping. group_topk : int - Top-k at group level. -1 to disable. + Top-k at group level. 1 means no group-level selection. scaling_factor : float Scaling factor applied to output probs. - score_function : str - Score function: "softmax" or "sigmoid". + score_function : Union[str, ScoreFunction] + Score function: "softmax" / "sigmoid" or ScoreFunction.SOFTMAX / ScoreFunction.SIGMOID. expert_bias : Optional[jnp.ndarray] Expert bias, shape [num_experts]. Only used with sigmoid. @@ -88,22 +101,14 @@ def fused_topk_with_score_function( Boolean mask, shape [num_tokens, num_experts]. True at selected expert positions. """ - if score_function not in ("softmax", "sigmoid"): - raise ValueError( - f"score_function must be 'softmax' or 'sigmoid', got '{score_function}'" - ) + score_function = _validate_score_function(score_function) - if expert_bias is not None and score_function != "sigmoid": + if expert_bias is not None and score_function != ScoreFunction.SIGMOID: raise ValueError( "expert_bias is only supported with score_function='sigmoid'. " - f"Got score_function='{score_function}'." + f"Got score_function='{score_function.name}'." ) - # Flatten to 2D if shape is [B, S, H] - original_shape = logits.shape - if logits.ndim > 2: - logits = logits.reshape(-1, original_shape[-1]) - if expert_bias is None: expert_bias = jnp.empty((0,), dtype=logits.dtype) @@ -116,17 +121,13 @@ def fused_topk_with_score_function( group_topk, scaling_factor, score_function, + False, # compute_aux_scores ) - # Restore shape if needed - if len(original_shape) > 2: - probs = probs.reshape(original_shape) - routing_map = routing_map.reshape(original_shape) - return probs, routing_map -@partial(jax.custom_vjp, nondiff_argnums=(2, 3, 4, 5, 6, 7)) +@partial(jax.custom_vjp, nondiff_argnums=(2, 3, 4, 5, 6, 7, 8)) def _fused_topk_with_score_function( logits: jnp.ndarray, expert_bias: jnp.ndarray, @@ -135,22 +136,23 @@ def _fused_topk_with_score_function( num_groups: int, group_topk: int, scaling_factor: float, - score_function: str, + score_function: ScoreFunction, + compute_aux_scores: bool, ) -> Tuple[jnp.ndarray, jnp.ndarray]: (probs, routing_map), _ = _fused_topk_with_score_function_fwd( logits, expert_bias, topk, use_pre_softmax, num_groups, group_topk, - scaling_factor, score_function, + scaling_factor, score_function, compute_aux_scores, ) return probs, routing_map def _fused_topk_with_score_function_fwd( logits, expert_bias, topk, use_pre_softmax, num_groups, group_topk, - scaling_factor, score_function, + scaling_factor, score_function, compute_aux_scores, ): probs, routing_map, intermediate_output = fused_topk_with_score_function_fwd( logits, topk, use_pre_softmax, num_groups, group_topk, - scaling_factor, score_function, expert_bias, + scaling_factor, score_function, expert_bias, compute_aux_scores, ) residuals = (routing_map, intermediate_output) return (probs, routing_map), residuals @@ -158,7 +160,7 @@ def _fused_topk_with_score_function_fwd( def _fused_topk_with_score_function_bwd( topk, use_pre_softmax, num_groups, group_topk, scaling_factor, score_function, # pylint: disable=unused-argument - residuals, g, + compute_aux_scores, residuals, g, ): routing_map, intermediate_output = residuals grad_probs, _ = g # routing_map gradient is None (boolean) @@ -166,8 +168,9 @@ def _fused_topk_with_score_function_bwd( grad_logits = fused_topk_with_score_function_bwd( routing_map, intermediate_output, grad_probs, topk, use_pre_softmax, scaling_factor, score_function, + compute_aux_scores, ) - # Return gradients for (logits, expert_bias) + # expert_bias gradient is None: bias is not differentiated through this kernel return grad_logits, None @@ -185,7 +188,7 @@ def _fused_topk_with_score_function_bwd( def fused_compute_score_for_moe_aux_loss( logits: jnp.ndarray, topk: int, - score_function: str = "softmax", + score_function: Union[str, ScoreFunction] = ScoreFunction.SOFTMAX, ) -> Tuple[jnp.ndarray, jnp.ndarray]: """ Compute scores and routing map for MoE auxiliary loss. @@ -194,14 +197,17 @@ def fused_compute_score_for_moe_aux_loss( no expert bias, no scaling) to produce the scores and routing map used for the load-balancing auxiliary loss. + Internally delegates to the same primitive as fused_topk_with_score_function + with compute_aux_scores=True, selecting the score-for-aux-loss CUDA kernel. + Parameters ---------- logits : jnp.ndarray Logits from the gating GEMM, shape [num_tokens, num_experts]. topk : int Number of top experts to select. - score_function : str - Score function: "softmax" or "sigmoid". + score_function : Union[str, ScoreFunction] + Score function: "softmax" / "sigmoid" or ScoreFunction.SOFTMAX / ScoreFunction.SIGMOID. Returns ------- @@ -210,60 +216,22 @@ def fused_compute_score_for_moe_aux_loss( scores : jnp.ndarray Dense score tensor, shape [num_tokens, num_experts]. """ - if score_function not in ("softmax", "sigmoid"): - raise ValueError( - f"score_function must be 'softmax' or 'sigmoid', got '{score_function}'" - ) - - original_shape = logits.shape - if logits.ndim > 2: - logits = logits.reshape(-1, original_shape[-1]) - - routing_map, scores = _fused_compute_score_for_moe_aux_loss(logits, topk, score_function) - - if len(original_shape) > 2: - routing_map = routing_map.reshape(original_shape) - scores = scores.reshape(original_shape) - - return routing_map, scores - + score_function = _validate_score_function(score_function) -@partial(jax.custom_vjp, nondiff_argnums=(1, 2)) -def _fused_compute_score_for_moe_aux_loss( - logits: jnp.ndarray, - topk: int, - score_function: str, -) -> Tuple[jnp.ndarray, jnp.ndarray]: - (routing_map, scores), _ = _fused_compute_score_for_moe_aux_loss_fwd( - logits, topk, score_function, + scores, routing_map = _fused_topk_with_score_function( + logits, + jnp.empty((0,), dtype=logits.dtype), + topk, + False, # use_pre_softmax (unused for aux scores) + 1, # num_groups (unused for aux scores) + 1, # group_topk (unused for aux scores) + 1.0, # scaling_factor (unused for aux scores) + score_function, + True, # compute_aux_scores ) return routing_map, scores -def _fused_compute_score_for_moe_aux_loss_fwd(logits, topk, score_function): - scores, routing_map, intermediate_output = fused_score_for_moe_aux_loss_fwd( - logits, topk, score_function, - ) - residuals = (intermediate_output,) - return (routing_map, scores), residuals - - -def _fused_compute_score_for_moe_aux_loss_bwd(topk, score_function, residuals, g): - (intermediate_output,) = residuals - _, grad_scores = g # routing_map gradient is None (boolean) - - grad_logits = fused_score_for_moe_aux_loss_bwd( - intermediate_output, grad_scores, topk, score_function, - ) - return (grad_logits,) - - -_fused_compute_score_for_moe_aux_loss.defvjp( - _fused_compute_score_for_moe_aux_loss_fwd, - _fused_compute_score_for_moe_aux_loss_bwd, -) - - # ============================================================================= # Fused MoE Aux Loss # ============================================================================= @@ -325,12 +293,10 @@ def _fused_moe_aux_loss( def _fused_moe_aux_loss_fwd( probs, tokens_per_expert, total_num_tokens, num_experts, topk, coeff, ): - num_rows = probs.shape[0] - num_cols = probs.shape[1] aux_loss, const_buf = fused_moe_aux_loss_fwd( probs, tokens_per_expert, total_num_tokens, num_experts, topk, coeff, ) - residuals = (const_buf, tokens_per_expert, num_rows, num_cols) + residuals = (const_buf, tokens_per_expert, probs.shape[0], probs.shape[1]) return aux_loss.squeeze(), residuals From db5a1bcf444255f4bf599e6f4e11938e19d598c9 Mon Sep 17 00:00:00 2001 From: tdophung Date: Fri, 27 Feb 2026 17:42:02 -0800 Subject: [PATCH 5/9] add notImplemented infer_sharding_from_operands and partition back in to make basePrimitive class happy Signed-off-by: tdophung --- .../jax/cpp_extensions/router.py | 32 +++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/transformer_engine/jax/cpp_extensions/router.py b/transformer_engine/jax/cpp_extensions/router.py index 3ebfe3d9a1..6d1b2fee38 100644 --- a/transformer_engine/jax/cpp_extensions/router.py +++ b/transformer_engine/jax/cpp_extensions/router.py @@ -146,6 +146,14 @@ def batcher( (logits_bdim, logits_bdim, logits_bdim), ) + @staticmethod + def infer_sharding_from_operands(*args, **kwargs): + raise NotImplementedError("Use shardy sharding rules instead of GSPMD custom partitioning") + + @staticmethod + def partition(*args, **kwargs): + raise NotImplementedError("Use shardy sharding rules instead of GSPMD custom partitioning") + @staticmethod def shardy_sharding_rule(*args): del args @@ -266,6 +274,14 @@ def batcher( grad_probs_bdim, ) + @staticmethod + def infer_sharding_from_operands(*args, **kwargs): + raise NotImplementedError("Use shardy sharding rules instead of GSPMD custom partitioning") + + @staticmethod + def partition(*args, **kwargs): + raise NotImplementedError("Use shardy sharding rules instead of GSPMD custom partitioning") + @staticmethod def shardy_sharding_rule(*args): del args @@ -342,6 +358,14 @@ def batcher( (probs_bdim, probs_bdim), ) + @staticmethod + def infer_sharding_from_operands(*args, **kwargs): + raise NotImplementedError("Use shardy sharding rules instead of GSPMD custom partitioning") + + @staticmethod + def partition(*args, **kwargs): + raise NotImplementedError("Use shardy sharding rules instead of GSPMD custom partitioning") + @staticmethod def shardy_sharding_rule(*args): del args @@ -406,6 +430,14 @@ def batcher(batched_args, batch_dims, *, num_rows, num_cols): grad_bdim, ) + @staticmethod + def infer_sharding_from_operands(*args, **kwargs): + raise NotImplementedError("Use shardy sharding rules instead of GSPMD custom partitioning") + + @staticmethod + def partition(*args, **kwargs): + raise NotImplementedError("Use shardy sharding rules instead of GSPMD custom partitioning") + @staticmethod def shardy_sharding_rule(*args): del args From 8a87cafd36b4013ab3ea3dd92d47499d5309620c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 28 Feb 2026 02:01:59 +0000 Subject: [PATCH 6/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/jax/test_distributed_router.py | 119 ++++++--- tests/jax/test_fused_router.py | 212 +++++++++------- .../transformer_engine/transformer_engine.h | 8 +- .../jax/cpp_extensions/router.py | 33 ++- .../jax/csrc/extensions/pybind.cpp | 6 +- .../jax/csrc/extensions/router.cpp | 229 ++++++++---------- transformer_engine/jax/router.py | 104 ++++++-- 7 files changed, 418 insertions(+), 293 deletions(-) diff --git a/tests/jax/test_distributed_router.py b/tests/jax/test_distributed_router.py index 7dc6062158..7ccc842be5 100644 --- a/tests/jax/test_distributed_router.py +++ b/tests/jax/test_distributed_router.py @@ -106,17 +106,21 @@ def _impl_test( @jax.jit def target_fwd(x): return fused_topk_with_score_function( - x, topk=topk, score_function=score_function, + x, + topk=topk, + score_function=score_function, ) target_probs, target_routing_map = target_fwd(logits_sharded) - logits_shards = jnp.reshape( - logits, (num_dp_devices, local_num_tokens, num_experts) + logits_shards = jnp.reshape(logits, (num_dp_devices, local_num_tokens, num_experts)) + ref_fwd_fn = jax.jit( + lambda x: reference_topk_softmax_sigmoid( + x, + topk=topk, + score_function=score_function, + ) ) - ref_fwd_fn = jax.jit(lambda x: reference_topk_softmax_sigmoid( - x, topk=topk, score_function=score_function, - )) ref_probs_list = [] ref_routing_list = [] for i in range(num_dp_devices): @@ -128,22 +132,29 @@ def target_fwd(x): ref_routing = jnp.concatenate(ref_routing_list, axis=0) assert_allclose( - jax.device_get(target_probs), ref_probs, dtype=jnp.float32, + jax.device_get(target_probs), + ref_probs, + dtype=jnp.float32, ) assert jnp.array_equal( - jax.device_get(target_routing_map), ref_routing, + jax.device_get(target_routing_map), + ref_routing, ), "Routing map mismatch in distributed fused_topk" # === Backward === def target_loss(x): p, _ = fused_topk_with_score_function( - x, topk=topk, score_function=score_function, + x, + topk=topk, + score_function=score_function, ) return jnp.sum(p) def ref_chunk_loss(x_chunk): p, _ = reference_topk_softmax_sigmoid( - x_chunk, topk=topk, score_function=score_function, + x_chunk, + topk=topk, + score_function=score_function, ) return jnp.sum(p) @@ -156,12 +167,15 @@ def ref_chunk_loss(x_chunk): ref_grad = jnp.concatenate(ref_grads, axis=0) assert_allclose( - jax.device_get(target_grad), ref_grad, dtype=jnp.float32, + jax.device_get(target_grad), + ref_grad, + dtype=jnp.float32, ) @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) @pytest_parametrize_wrapper( - "num_tokens,num_experts,topk", TOPK_CASES, + "num_tokens,num_experts,topk", + TOPK_CASES, ) @pytest.mark.parametrize("score_function", ["softmax", "sigmoid"]) def test_distributed_topk( @@ -187,7 +201,6 @@ def test_distributed_topk( ) - class TestDistributedScoreForAuxLoss: """Test distributed execution of fused_compute_score_for_moe_aux_loss. @@ -226,17 +239,21 @@ def _impl_test( @jax.jit def target_fwd(x): return fused_compute_score_for_moe_aux_loss( - x, topk=topk, score_function=score_function, + x, + topk=topk, + score_function=score_function, ) target_routing_map, target_scores = target_fwd(logits_sharded) - logits_shards = jnp.reshape( - logits, (num_dp_devices, local_num_tokens, num_experts) + logits_shards = jnp.reshape(logits, (num_dp_devices, local_num_tokens, num_experts)) + ref_fwd_fn = jax.jit( + lambda x: reference_compute_scores_for_aux_loss( + x, + topk=topk, + score_function=score_function, + ) ) - ref_fwd_fn = jax.jit(lambda x: reference_compute_scores_for_aux_loss( - x, topk=topk, score_function=score_function, - )) ref_routing_list = [] ref_scores_list = [] for i in range(num_dp_devices): @@ -248,22 +265,29 @@ def target_fwd(x): ref_scores = jnp.concatenate(ref_scores_list, axis=0) assert_allclose( - jax.device_get(target_scores), ref_scores, dtype=jnp.float32, + jax.device_get(target_scores), + ref_scores, + dtype=jnp.float32, ) assert jnp.array_equal( - jax.device_get(target_routing_map), ref_routing, + jax.device_get(target_routing_map), + ref_routing, ), "Routing map mismatch in distributed score_for_aux_loss" # === Backward === def target_loss(x): _, s = fused_compute_score_for_moe_aux_loss( - x, topk=topk, score_function=score_function, + x, + topk=topk, + score_function=score_function, ) return jnp.sum(s) def ref_chunk_loss(x_chunk): _, s = reference_compute_scores_for_aux_loss( - x_chunk, topk=topk, score_function=score_function, + x_chunk, + topk=topk, + score_function=score_function, ) return jnp.sum(s) @@ -276,12 +300,15 @@ def ref_chunk_loss(x_chunk): ref_grad = jnp.concatenate(ref_grads, axis=0) assert_allclose( - jax.device_get(target_grad), ref_grad, dtype=jnp.float32, + jax.device_get(target_grad), + ref_grad, + dtype=jnp.float32, ) @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) @pytest_parametrize_wrapper( - "num_tokens,num_experts,topk", TOPK_CASES, + "num_tokens,num_experts,topk", + TOPK_CASES, ) @pytest.mark.parametrize("score_function", ["softmax", "sigmoid"]) def test_distributed_score_for_aux_loss( @@ -307,7 +334,6 @@ def test_distributed_score_for_aux_loss( ) - class TestDistributedMoEAuxLoss: """Test distributed execution of fused_moe_aux_loss. @@ -336,9 +362,7 @@ def _impl_test( probs = jnp.arange(-num_experts // 2, num_experts // 2, dtype=jnp.float32) * 1e-2 probs = probs[None, :].repeat(num_tokens, axis=0) + offset[:, None] - tokens_per_expert = jax.random.randint( - subkey1, (num_experts,), 1, 1000 - ).astype(jnp.int32) + tokens_per_expert = jax.random.randint(subkey1, (num_experts,), 1, 1000).astype(jnp.int32) coeff = 0.01 devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) @@ -358,7 +382,8 @@ def _impl_test( @jax.jit def target_fwd(p, tpe): return fused_moe_aux_loss( - p, tpe, + p, + tpe, total_num_tokens=num_tokens, num_experts=num_experts, topk=topk, @@ -367,19 +392,29 @@ def target_fwd(p, tpe): target_loss = target_fwd(probs_dev, tpe_dev) - ref_fwd_fn = jax.jit(lambda p: reference_aux_loss( - p, tokens_per_expert, num_tokens, topk, num_experts, coeff, - )) + ref_fwd_fn = jax.jit( + lambda p: reference_aux_loss( + p, + tokens_per_expert, + num_tokens, + topk, + num_experts, + coeff, + ) + ) ref_loss = ref_fwd_fn(probs) assert_allclose( - jax.device_get(target_loss), ref_loss, dtype=jnp.float32, + jax.device_get(target_loss), + ref_loss, + dtype=jnp.float32, ) # === Backward === def target_loss_fn(p): return fused_moe_aux_loss( - p, tokens_per_expert, + p, + tokens_per_expert, total_num_tokens=num_tokens, num_experts=num_experts, topk=topk, @@ -388,19 +423,27 @@ def target_loss_fn(p): def ref_loss_fn(p): return reference_aux_loss( - p, tokens_per_expert, num_tokens, topk, num_experts, coeff, + p, + tokens_per_expert, + num_tokens, + topk, + num_experts, + coeff, ) target_grad = jax.jit(jax.grad(target_loss_fn))(probs_dev) ref_grad = jax.jit(jax.grad(ref_loss_fn))(probs) assert_allclose( - jax.device_get(target_grad), ref_grad, dtype=jnp.float32, + jax.device_get(target_grad), + ref_grad, + dtype=jnp.float32, ) @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) @pytest_parametrize_wrapper( - "num_tokens,num_experts,topk", AUX_LOSS_CASES, + "num_tokens,num_experts,topk", + AUX_LOSS_CASES, ) def test_distributed_aux_loss( self, diff --git a/tests/jax/test_fused_router.py b/tests/jax/test_fused_router.py index b39b979a96..f270fa3b80 100644 --- a/tests/jax/test_fused_router.py +++ b/tests/jax/test_fused_router.py @@ -123,12 +123,12 @@ def reference_group_limited_topk( "reference_group_limited_topk requires valid num_groups > 0. " "For plain top-k, use jax.lax.top_k directly." ) - assert group_topk is not None and group_topk > 0, ( - "reference_group_limited_topk requires valid group_topk > 0." - ) - assert num_experts % num_groups == 0, ( - f"num_experts ({num_experts}) must be divisible by num_groups ({num_groups})" - ) + assert ( + group_topk is not None and group_topk > 0 + ), "reference_group_limited_topk requires valid group_topk > 0." + assert ( + num_experts % num_groups == 0 + ), f"num_experts ({num_experts}) must be divisible by num_groups ({num_groups})" group_size = num_experts // num_groups experts_per_group = topk // group_topk @@ -138,14 +138,11 @@ def reference_group_limited_topk( .sum(axis=-1) ) group_idx = jax.lax.top_k(group_scores, k=group_topk)[1] - group_mask = jnp.zeros_like(group_scores).at[ - jnp.arange(num_tokens)[:, None], group_idx - ].set(1) + group_mask = jnp.zeros_like(group_scores).at[jnp.arange(num_tokens)[:, None], group_idx].set(1) - score_mask = ( - group_mask[:, :, None] - * jnp.ones((num_tokens, num_groups, group_size)) - ).reshape(num_tokens, -1) + score_mask = (group_mask[:, :, None] * jnp.ones((num_tokens, num_groups, group_size))).reshape( + num_tokens, -1 + ) masked_scores = jnp.where(score_mask.astype(bool), scores, -jnp.inf) probs, top_indices = jax.lax.top_k(masked_scores, k=topk) @@ -200,19 +197,19 @@ def compute_topk(scores, topk, num_groups=None, group_topk=None): if scaling_factor: probs = probs * scaling_factor - topk_masked_gates = jnp.zeros_like(logits).at[ - jnp.arange(num_tokens)[:, None], top_indices - ].set(probs) - topk_map = jnp.zeros_like(logits, dtype=jnp.bool_).at[ - jnp.arange(num_tokens)[:, None], top_indices - ].set(True) + topk_masked_gates = ( + jnp.zeros_like(logits).at[jnp.arange(num_tokens)[:, None], top_indices].set(probs) + ) + topk_map = ( + jnp.zeros_like(logits, dtype=jnp.bool_) + .at[jnp.arange(num_tokens)[:, None], top_indices] + .set(True) + ) return topk_masked_gates, topk_map -def reference_compute_scores_for_aux_loss( - logits: jnp.ndarray, topk: int, score_function: str -): +def reference_compute_scores_for_aux_loss(logits: jnp.ndarray, topk: int, score_function: str): """Reference implementation for computing routing scores for aux loss.""" if score_function == "softmax": scores = jax.nn.softmax(logits.astype(jnp.float32), axis=-1) @@ -224,9 +221,11 @@ def reference_compute_scores_for_aux_loss( _, top_indices = jax.lax.top_k(scores, k=topk) num_tokens = logits.shape[0] - routing_map = jnp.zeros_like(logits, dtype=jnp.bool_).at[ - jnp.arange(num_tokens)[:, None], top_indices - ].set(True) + routing_map = ( + jnp.zeros_like(logits, dtype=jnp.bool_) + .at[jnp.arange(num_tokens)[:, None], top_indices] + .set(True) + ) return routing_map, scores @@ -297,62 +296,78 @@ def run_topk_comparison( expert_bias = None # Forward: reference (jitted) - ref_fwd_fn = jax.jit(partial( - reference_topk_softmax_sigmoid, - topk=topk, - use_pre_softmax=use_pre_softmax, - num_groups=num_groups, - group_topk=group_topk, - scaling_factor=scaling_factor, - score_function=score_function, - expert_bias=expert_bias, - )) + ref_fwd_fn = jax.jit( + partial( + reference_topk_softmax_sigmoid, + topk=topk, + use_pre_softmax=use_pre_softmax, + num_groups=num_groups, + group_topk=group_topk, + scaling_factor=scaling_factor, + score_function=score_function, + expert_bias=expert_bias, + ) + ) probs_ref, routing_map_ref = ref_fwd_fn(logits) # Forward: fused (jitted) - fused_fwd_fn = jax.jit(partial( - fused_topk_with_score_function, - topk=topk, - use_pre_softmax=use_pre_softmax, - num_groups=num_groups if num_groups else -1, - group_topk=group_topk if group_topk else -1, - scaling_factor=scaling_factor if scaling_factor else 1.0, - score_function=score_function, - expert_bias=expert_bias, - )) + fused_fwd_fn = jax.jit( + partial( + fused_topk_with_score_function, + topk=topk, + use_pre_softmax=use_pre_softmax, + num_groups=num_groups if num_groups else -1, + group_topk=group_topk if group_topk else -1, + scaling_factor=scaling_factor if scaling_factor else 1.0, + score_function=score_function, + expert_bias=expert_bias, + ) + ) probs_fused, routing_map_fused = fused_fwd_fn(logits) - assert jnp.allclose(probs_ref, probs_fused, atol=1e-5, rtol=1e-5), \ - f"Probs mismatch: max diff = {jnp.abs(probs_ref - probs_fused).max()}" + assert jnp.allclose( + probs_ref, probs_fused, atol=1e-5, rtol=1e-5 + ), f"Probs mismatch: max diff = {jnp.abs(probs_ref - probs_fused).max()}" assert jnp.array_equal(routing_map_ref, routing_map_fused), "Routing map mismatch" # Backward: reference (jitted) def loss_ref(logits_): p, _ = reference_topk_softmax_sigmoid( - logits_, topk, use_pre_softmax, num_groups, group_topk, - scaling_factor, score_function, expert_bias, + logits_, + topk, + use_pre_softmax, + num_groups, + group_topk, + scaling_factor, + score_function, + expert_bias, ) return p.sum() def loss_fused(logits_): p, _ = fused_topk_with_score_function( - logits_, topk, use_pre_softmax, + logits_, + topk, + use_pre_softmax, num_groups if num_groups else -1, group_topk if group_topk else -1, scaling_factor if scaling_factor else 1.0, - score_function, expert_bias, + score_function, + expert_bias, ) return p.sum() grad_ref = jax.jit(jax.grad(loss_ref))(logits) grad_fused = jax.jit(jax.grad(loss_fused))(logits) - assert jnp.allclose(grad_ref, grad_fused, atol=1e-5, rtol=1e-5), \ - f"Grad mismatch: max diff = {jnp.abs(grad_ref - grad_fused).max()}" + assert jnp.allclose( + grad_ref, grad_fused, atol=1e-5, rtol=1e-5 + ), f"Grad mismatch: max diff = {jnp.abs(grad_ref - grad_fused).max()}" @pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper( - "num_tokens,num_experts,topk", TOPK_CASES, + "num_tokens,num_experts,topk", + TOPK_CASES, ) @pytest_parametrize_wrapper("group_topk", GROUP_TOPK_OPTIONS) @pytest_parametrize_wrapper("scaling_factor", SCALING_FACTOR_OPTIONS) @@ -377,7 +392,8 @@ def test_topk_sigmoid( @pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper( - "num_tokens,num_experts,topk", TOPK_CASES, + "num_tokens,num_experts,topk", + TOPK_CASES, ) @pytest_parametrize_wrapper("use_pre_softmax", USE_PRE_SOFTMAX_OPTIONS) @pytest_parametrize_wrapper("group_topk", GROUP_TOPK_OPTIONS) @@ -407,30 +423,36 @@ def test_topk_softmax( @pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper( - "num_tokens,num_experts,topk", SCORE_AUX_LOSS_CASES, + "num_tokens,num_experts,topk", + SCORE_AUX_LOSS_CASES, ) @pytest_parametrize_wrapper("score_function", SCORE_FUNCTIONS) def test_fused_scores_for_aux_loss(dtype, num_tokens, num_experts, topk, score_function): logits = make_logits(num_tokens, num_experts, score_function, dtype) # Forward: reference (jitted) - ref_fwd_fn = jax.jit(partial( - reference_compute_scores_for_aux_loss, - topk=topk, - score_function=score_function, - )) + ref_fwd_fn = jax.jit( + partial( + reference_compute_scores_for_aux_loss, + topk=topk, + score_function=score_function, + ) + ) routing_map_ref, scores_ref = ref_fwd_fn(logits) # Forward: fused (jitted) - fused_fwd_fn = jax.jit(partial( - fused_compute_score_for_moe_aux_loss, - topk=topk, - score_function=score_function, - )) + fused_fwd_fn = jax.jit( + partial( + fused_compute_score_for_moe_aux_loss, + topk=topk, + score_function=score_function, + ) + ) routing_map_fused, scores_fused = fused_fwd_fn(logits) - assert jnp.allclose(scores_ref, scores_fused, atol=1e-5, rtol=1e-5), \ - f"Scores mismatch: max diff = {jnp.abs(scores_ref - scores_fused).max()}" + assert jnp.allclose( + scores_ref, scores_fused, atol=1e-5, rtol=1e-5 + ), f"Scores mismatch: max diff = {jnp.abs(scores_ref - scores_fused).max()}" assert jnp.array_equal(routing_map_ref, routing_map_fused), "Routing map mismatch" # Backward (jitted) @@ -444,8 +466,9 @@ def loss_fused(logits_): grad_ref = jax.jit(jax.grad(loss_ref))(logits) grad_fused = jax.jit(jax.grad(loss_fused))(logits) - assert jnp.allclose(grad_ref, grad_fused, atol=1e-5, rtol=1e-5), \ - f"Grad mismatch: max diff = {jnp.abs(grad_ref - grad_fused).max()}" + assert jnp.allclose( + grad_ref, grad_fused, atol=1e-5, rtol=1e-5 + ), f"Grad mismatch: max diff = {jnp.abs(grad_ref - grad_fused).max()}" # ============================================================================= @@ -455,7 +478,8 @@ def loss_fused(logits_): @pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper( - "num_tokens,num_experts,topk", AUX_LOSS_CASES, + "num_tokens,num_experts,topk", + AUX_LOSS_CASES, ) def test_fused_moe_aux_loss(dtype, num_tokens, num_experts, topk): key = jax.random.PRNGKey(SEED) @@ -469,29 +493,34 @@ def test_fused_moe_aux_loss(dtype, num_tokens, num_experts, topk): coeff = 0.01 # Forward: reference (jitted) - ref_fwd_fn = jax.jit(partial( - reference_aux_loss, - tokens_per_expert=tokens_per_expert, - total_num_tokens=num_tokens, - topk=topk, - num_experts=num_experts, - moe_aux_loss_coeff=coeff, - )) + ref_fwd_fn = jax.jit( + partial( + reference_aux_loss, + tokens_per_expert=tokens_per_expert, + total_num_tokens=num_tokens, + topk=topk, + num_experts=num_experts, + moe_aux_loss_coeff=coeff, + ) + ) aux_loss_ref = ref_fwd_fn(probs) # Forward: fused (jitted) - fused_fwd_fn = jax.jit(partial( - fused_moe_aux_loss, - tokens_per_expert=tokens_per_expert, - total_num_tokens=num_tokens, - num_experts=num_experts, - topk=topk, - coeff=coeff, - )) + fused_fwd_fn = jax.jit( + partial( + fused_moe_aux_loss, + tokens_per_expert=tokens_per_expert, + total_num_tokens=num_tokens, + num_experts=num_experts, + topk=topk, + coeff=coeff, + ) + ) aux_loss_fused = fused_fwd_fn(probs) - assert jnp.allclose(aux_loss_ref, aux_loss_fused, atol=1e-5, rtol=1e-5), \ - f"Aux loss mismatch: ref={aux_loss_ref}, fused={aux_loss_fused}" + assert jnp.allclose( + aux_loss_ref, aux_loss_fused, atol=1e-5, rtol=1e-5 + ), f"Aux loss mismatch: ref={aux_loss_ref}, fused={aux_loss_fused}" # Backward (jitted) def loss_ref_fn(probs_): @@ -502,5 +531,6 @@ def loss_fused_fn(probs_): grad_ref = jax.jit(jax.grad(loss_ref_fn))(probs) grad_fused = jax.jit(jax.grad(loss_fused_fn))(probs) - assert jnp.allclose(grad_ref, grad_fused, atol=1e-5, rtol=1e-5), \ - f"Grad mismatch: max diff = {jnp.abs(grad_ref - grad_fused).max()}" + assert jnp.allclose( + grad_ref, grad_fused, atol=1e-5, rtol=1e-5 + ), f"Grad mismatch: max diff = {jnp.abs(grad_ref - grad_fused).max()}" diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index ee7380c89d..e316f8be8c 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -615,12 +615,8 @@ class TensorWrapper { * \param[in] scale_inv_dptr Pointer to the inverse of scale value. * \param[in] scaling_mode Tensor data format. */ - TensorWrapper(void *dptr, - const NVTEShape &shape, - const DType dtype, - float *amax_dptr = nullptr, - float *scale_dptr = nullptr, - float *scale_inv_dptr = nullptr, + TensorWrapper(void *dptr, const NVTEShape &shape, const DType dtype, float *amax_dptr = nullptr, + float *scale_dptr = nullptr, float *scale_inv_dptr = nullptr, NVTEShape scale_inv_shape = defaultShape, const NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING) { tensor_ = nvte_create_tensor(scaling_mode); diff --git a/transformer_engine/jax/cpp_extensions/router.py b/transformer_engine/jax/cpp_extensions/router.py index 6d1b2fee38..bd6ebf9d0d 100644 --- a/transformer_engine/jax/cpp_extensions/router.py +++ b/transformer_engine/jax/cpp_extensions/router.py @@ -28,6 +28,7 @@ class ScoreFunction(IntEnum): # Fused Top-K with Score Function - Forward # ============================================================================= + class FusedTopkWithScoreFunctionFwdPrimitive(BasePrimitive): """ Fused Top-K with Score Function Forward Primitive. @@ -37,7 +38,15 @@ class FusedTopkWithScoreFunctionFwdPrimitive(BasePrimitive): name = "te_fused_topk_with_score_function_forward_ffi" multiple_results = True - impl_static_args = (2, 3, 4, 5, 6, 7, 8) # topk, use_pre_softmax, num_groups, group_topk, scaling_factor, score_function, compute_aux_scores + impl_static_args = ( + 2, + 3, + 4, + 5, + 6, + 7, + 8, + ) # topk, use_pre_softmax, num_groups, group_topk, scaling_factor, score_function, compute_aux_scores inner_primitive = None outer_primitive = None @@ -157,7 +166,10 @@ def partition(*args, **kwargs): @staticmethod def shardy_sharding_rule(*args): del args - return "num_tokens num_experts, num_experts -> num_tokens num_experts, num_tokens num_experts, num_tokens num_experts" + return ( + "num_tokens num_experts, num_experts -> num_tokens num_experts, num_tokens num_experts," + " num_tokens num_experts" + ) register_primitive(FusedTopkWithScoreFunctionFwdPrimitive) @@ -176,7 +188,13 @@ class FusedTopkWithScoreFunctionBwdPrimitive(BasePrimitive): name = "te_fused_topk_with_score_function_backward_ffi" multiple_results = False - impl_static_args = (3, 4, 5, 6, 7) # topk, use_pre_softmax, scaling_factor, score_function, compute_aux_scores + impl_static_args = ( + 3, + 4, + 5, + 6, + 7, + ) # topk, use_pre_softmax, scaling_factor, score_function, compute_aux_scores inner_primitive = None outer_primitive = None @@ -285,7 +303,10 @@ def partition(*args, **kwargs): @staticmethod def shardy_sharding_rule(*args): del args - return "num_tokens num_experts, num_tokens num_experts, num_tokens num_experts -> num_tokens num_experts" + return ( + "num_tokens num_experts, num_tokens num_experts, num_tokens num_experts -> num_tokens" + " num_experts" + ) register_primitive(FusedTopkWithScoreFunctionBwdPrimitive) @@ -340,9 +361,7 @@ def impl(probs, tokens_per_expert, total_num_tokens, num_experts, topk, coeff): ) @staticmethod - def batcher( - batched_args, batch_dims, *, total_num_tokens, num_experts, topk, coeff - ): + def batcher(batched_args, batch_dims, *, total_num_tokens, num_experts, topk, coeff): assert FusedMoEAuxLossFwdPrimitive.outer_primitive is not None probs, tokens_per_expert = batched_args probs_bdim, _ = batch_dims diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index cbb840c705..06c3a2348f 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -86,10 +86,8 @@ pybind11::dict Registrations() { EncapsulateFFI(FusedTopkWithScoreFunctionForwardHandler); dict["te_fused_topk_with_score_function_backward_ffi"] = EncapsulateFFI(FusedTopkWithScoreFunctionBackwardHandler); - dict["te_fused_moe_aux_loss_forward_ffi"] = - EncapsulateFFI(FusedMoEAuxLossForwardHandler); - dict["te_fused_moe_aux_loss_backward_ffi"] = - EncapsulateFFI(FusedMoEAuxLossBackwardHandler); + dict["te_fused_moe_aux_loss_forward_ffi"] = EncapsulateFFI(FusedMoEAuxLossForwardHandler); + dict["te_fused_moe_aux_loss_backward_ffi"] = EncapsulateFFI(FusedMoEAuxLossBackwardHandler); dict["te_inspect_ffi"] = pybind11::dict(pybind11::arg("execute") = EncapsulateFFI(InspectHandler)); diff --git a/transformer_engine/jax/csrc/extensions/router.cpp b/transformer_engine/jax/csrc/extensions/router.cpp index ac3699ccab..ce811473c2 100644 --- a/transformer_engine/jax/csrc/extensions/router.cpp +++ b/transformer_engine/jax/csrc/extensions/router.cpp @@ -34,18 +34,13 @@ static int compute_num_tokens(const Dims &dims) { Error_Type FusedTopkWithScoreFunctionForwardFFI( cudaStream_t stream, - Buffer_Type logits_buf, // [num_tokens, num_experts] - Buffer_Type expert_bias_buf, // [num_experts] or empty - Result_Type probs_buf, // [num_tokens, num_experts] (or scores when compute_aux_scores) - Result_Type routing_map_buf, // [num_tokens, num_experts] - Result_Type intermediate_buf, // [num_tokens, num_experts] - int64_t topk, - int64_t use_pre_softmax, - int64_t num_groups, - int64_t group_topk, - double scaling_factor, - int64_t score_function, - int64_t compute_aux_scores) { + Buffer_Type logits_buf, // [num_tokens, num_experts] + Buffer_Type expert_bias_buf, // [num_experts] or empty + Result_Type probs_buf, // [num_tokens, num_experts] (or scores when compute_aux_scores) + Result_Type routing_map_buf, // [num_tokens, num_experts] + Result_Type intermediate_buf, // [num_tokens, num_experts] + int64_t topk, int64_t use_pre_softmax, int64_t num_groups, int64_t group_topk, + double scaling_factor, int64_t score_function, int64_t compute_aux_scores) { auto dtype = convert_ffi_datatype_to_te_dtype(logits_buf.element_type()); auto dims = logits_buf.dimensions(); auto num_tokens = compute_num_tokens(dims); @@ -57,8 +52,8 @@ Error_Type FusedTopkWithScoreFunctionForwardFFI( auto *routing_map = routing_map_buf->untyped_data(); auto *intermediate = intermediate_buf->untyped_data(); - auto flat_shape = std::vector{static_cast(num_tokens), - static_cast(num_experts)}; + auto flat_shape = + std::vector{static_cast(num_tokens), static_cast(num_experts)}; auto logits_tensor = TensorWrapper(logits, flat_shape, dtype); auto probs_tensor = TensorWrapper(probs, flat_shape, dtype); auto routing_map_tensor = TensorWrapper(routing_map, flat_shape, DType::kByte); @@ -73,8 +68,7 @@ Error_Type FusedTopkWithScoreFunctionForwardFFI( auto bias_dims = expert_bias_buf.dimensions(); auto expert_bias_tensor = (bias_dims.size() > 0 && bias_dims[0] > 0) - ? TensorWrapper(expert_bias, - std::vector{static_cast(bias_dims[0])}, + ? TensorWrapper(expert_bias, std::vector{static_cast(bias_dims[0])}, convert_ffi_datatype_to_te_dtype(expert_bias_buf.element_type())) : TensorWrapper(); @@ -89,23 +83,23 @@ Error_Type FusedTopkWithScoreFunctionForwardFFI( return ffi_with_cuda_error_check(); } -XLA_FFI_DEFINE_HANDLER_SYMBOL( - FusedTopkWithScoreFunctionForwardHandler, FusedTopkWithScoreFunctionForwardFFI, - FFI::Bind() - .Ctx() // stream - .Arg() // logits - .Arg() // expert_bias - .Ret() // probs (or scores) - .Ret() // routing_map - .Ret() // intermediate_output - .Attr("topk") - .Attr("use_pre_softmax") - .Attr("num_groups") - .Attr("group_topk") - .Attr("scaling_factor") - .Attr("score_function") - .Attr("compute_aux_scores"), - FFI_CudaGraph_Traits); +XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedTopkWithScoreFunctionForwardHandler, + FusedTopkWithScoreFunctionForwardFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // logits + .Arg() // expert_bias + .Ret() // probs (or scores) + .Ret() // routing_map + .Ret() // intermediate_output + .Attr("topk") + .Attr("use_pre_softmax") + .Attr("num_groups") + .Attr("group_topk") + .Attr("scaling_factor") + .Attr("score_function") + .Attr("compute_aux_scores"), + FFI_CudaGraph_Traits); // ============================================================================ // Fused Top-K with Score Function - Backward @@ -113,42 +107,36 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( Error_Type FusedTopkWithScoreFunctionBackwardFFI( cudaStream_t stream, - Buffer_Type routing_map_buf, // [num_tokens, num_experts] (unused when compute_aux_scores) - Buffer_Type intermediate_buf, // [num_tokens, num_experts] - Buffer_Type grad_probs_buf, // [num_tokens, num_experts] (grad_scores when compute_aux_scores) - Result_Type grad_logits_buf, // [num_tokens, num_experts] - int64_t topk, - int64_t use_pre_softmax, - double scaling_factor, - int64_t score_function, + Buffer_Type routing_map_buf, // [num_tokens, num_experts] (unused when compute_aux_scores) + Buffer_Type intermediate_buf, // [num_tokens, num_experts] + Buffer_Type grad_probs_buf, // [num_tokens, num_experts] (grad_scores when compute_aux_scores) + Result_Type grad_logits_buf, // [num_tokens, num_experts] + int64_t topk, int64_t use_pre_softmax, double scaling_factor, int64_t score_function, int64_t compute_aux_scores) { auto dtype = convert_ffi_datatype_to_te_dtype(intermediate_buf.element_type()); auto dims = intermediate_buf.dimensions(); auto num_tokens = compute_num_tokens(dims); auto num_experts = static_cast(dims[dims.size() - 1]); - auto flat_shape = std::vector{static_cast(num_tokens), - static_cast(num_experts)}; + auto flat_shape = + std::vector{static_cast(num_tokens), static_cast(num_experts)}; - auto intermediate_tensor = - TensorWrapper(intermediate_buf.untyped_data(), flat_shape, dtype); - auto grad_probs_tensor = - TensorWrapper(grad_probs_buf.untyped_data(), flat_shape, dtype); - auto grad_logits_tensor = - TensorWrapper(grad_logits_buf->untyped_data(), flat_shape, dtype); + auto intermediate_tensor = TensorWrapper(intermediate_buf.untyped_data(), flat_shape, dtype); + auto grad_probs_tensor = TensorWrapper(grad_probs_buf.untyped_data(), flat_shape, dtype); + auto grad_logits_tensor = TensorWrapper(grad_logits_buf->untyped_data(), flat_shape, dtype); if (compute_aux_scores) { - nvte_fused_score_for_moe_aux_loss_backward( - intermediate_tensor.data(), grad_probs_tensor.data(), num_tokens, num_experts, - static_cast(topk), static_cast(score_function), grad_logits_tensor.data(), - stream); + nvte_fused_score_for_moe_aux_loss_backward(intermediate_tensor.data(), grad_probs_tensor.data(), + num_tokens, num_experts, static_cast(topk), + static_cast(score_function), + grad_logits_tensor.data(), stream); } else { auto routing_map_tensor = TensorWrapper(routing_map_buf.untyped_data(), flat_shape, DType::kByte); nvte_fused_topk_with_score_function_backward( - routing_map_tensor.data(), intermediate_tensor.data(), grad_probs_tensor.data(), - num_tokens, num_experts, static_cast(topk), static_cast(use_pre_softmax), + routing_map_tensor.data(), intermediate_tensor.data(), grad_probs_tensor.data(), num_tokens, + num_experts, static_cast(topk), static_cast(use_pre_softmax), static_cast(scaling_factor), static_cast(score_function), grad_logits_tensor.data(), stream); } @@ -156,52 +144,47 @@ Error_Type FusedTopkWithScoreFunctionBackwardFFI( return ffi_with_cuda_error_check(); } -XLA_FFI_DEFINE_HANDLER_SYMBOL( - FusedTopkWithScoreFunctionBackwardHandler, FusedTopkWithScoreFunctionBackwardFFI, - FFI::Bind() - .Ctx() // stream - .Arg() // routing_map - .Arg() // intermediate_output - .Arg() // grad_probs - .Ret() // grad_logits - .Attr("topk") - .Attr("use_pre_softmax") - .Attr("scaling_factor") - .Attr("score_function") - .Attr("compute_aux_scores"), - FFI_CudaGraph_Traits); +XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedTopkWithScoreFunctionBackwardHandler, + FusedTopkWithScoreFunctionBackwardFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // routing_map + .Arg() // intermediate_output + .Arg() // grad_probs + .Ret() // grad_logits + .Attr("topk") + .Attr("use_pre_softmax") + .Attr("scaling_factor") + .Attr("score_function") + .Attr("compute_aux_scores"), + FFI_CudaGraph_Traits); // ============================================================================ // Fused MoE Aux Loss - Forward // ============================================================================ -Error_Type FusedMoEAuxLossForwardFFI( - cudaStream_t stream, - Buffer_Type probs_buf, // [num_rows, num_cols] - Buffer_Type tokens_per_expert_buf, // [num_experts] - Result_Type aux_loss_buf, // scalar - Result_Type const_buf, // scalar - int64_t total_num_tokens, - int64_t num_experts, - int64_t topk, - double coeff) { +Error_Type FusedMoEAuxLossForwardFFI(cudaStream_t stream, + Buffer_Type probs_buf, // [num_rows, num_cols] + Buffer_Type tokens_per_expert_buf, // [num_experts] + Result_Type aux_loss_buf, // scalar + Result_Type const_buf, // scalar + int64_t total_num_tokens, int64_t num_experts, int64_t topk, + double coeff) { auto dtype = convert_ffi_datatype_to_te_dtype(probs_buf.element_type()); auto probs_dims = probs_buf.dimensions(); auto num_rows = static_cast(probs_dims[0]); auto num_cols = static_cast(probs_dims[1]); - auto probs_shape = std::vector{static_cast(num_rows), - static_cast(num_cols)}; + auto probs_shape = + std::vector{static_cast(num_rows), static_cast(num_cols)}; auto tpe_dtype = convert_ffi_datatype_to_te_dtype(tokens_per_expert_buf.element_type()); auto tpe_shape = std::vector{static_cast(num_experts)}; auto scalar_shape = std::vector{1}; auto probs_tensor = TensorWrapper(probs_buf.untyped_data(), probs_shape, dtype); - auto tpe_tensor = - TensorWrapper(tokens_per_expert_buf.untyped_data(), tpe_shape, tpe_dtype); + auto tpe_tensor = TensorWrapper(tokens_per_expert_buf.untyped_data(), tpe_shape, tpe_dtype); auto aux_loss_tensor = TensorWrapper(aux_loss_buf->untyped_data(), scalar_shape, dtype); - auto const_buf_tensor = - TensorWrapper(const_buf->untyped_data(), scalar_shape, DType::kFloat32); + auto const_buf_tensor = TensorWrapper(const_buf->untyped_data(), scalar_shape, DType::kFloat32); nvte_fused_moe_aux_loss_forward( probs_tensor.data(), tpe_tensor.data(), static_cast(total_num_tokens), @@ -211,68 +194,62 @@ Error_Type FusedMoEAuxLossForwardFFI( return ffi_with_cuda_error_check(); } -XLA_FFI_DEFINE_HANDLER_SYMBOL( - FusedMoEAuxLossForwardHandler, FusedMoEAuxLossForwardFFI, - FFI::Bind() - .Ctx() // stream - .Arg() // probs - .Arg() // tokens_per_expert - .Ret() // aux_loss - .Ret() // Const_buf - .Attr("total_num_tokens") - .Attr("num_experts") - .Attr("topk") - .Attr("coeff"), - FFI_CudaGraph_Traits); +XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedMoEAuxLossForwardHandler, FusedMoEAuxLossForwardFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // probs + .Arg() // tokens_per_expert + .Ret() // aux_loss + .Ret() // Const_buf + .Attr("total_num_tokens") + .Attr("num_experts") + .Attr("topk") + .Attr("coeff"), + FFI_CudaGraph_Traits); // ============================================================================ // Fused MoE Aux Loss - Backward // ============================================================================ -Error_Type FusedMoEAuxLossBackwardFFI( - cudaStream_t stream, - Buffer_Type const_buf_in, // scalar float32 - Buffer_Type tokens_per_expert_buf, // [num_experts] - Buffer_Type grad_aux_loss_buf, // scalar - Result_Type grad_probs_buf, // [num_rows, num_cols] - int64_t num_rows, - int64_t num_cols) { +Error_Type FusedMoEAuxLossBackwardFFI(cudaStream_t stream, + Buffer_Type const_buf_in, // scalar float32 + Buffer_Type tokens_per_expert_buf, // [num_experts] + Buffer_Type grad_aux_loss_buf, // scalar + Result_Type grad_probs_buf, // [num_rows, num_cols] + int64_t num_rows, int64_t num_cols) { auto grad_dtype = convert_ffi_datatype_to_te_dtype(grad_aux_loss_buf.element_type()); auto tpe_dtype = convert_ffi_datatype_to_te_dtype(tokens_per_expert_buf.element_type()); auto scalar_shape = std::vector{1}; auto tpe_dims = tokens_per_expert_buf.dimensions(); auto tpe_shape = std::vector{static_cast(tpe_dims[0])}; - auto grad_probs_shape = std::vector{static_cast(num_rows), - static_cast(num_cols)}; + auto grad_probs_shape = + std::vector{static_cast(num_rows), static_cast(num_cols)}; - auto const_buf_tensor = - TensorWrapper(const_buf_in.untyped_data(), scalar_shape, DType::kFloat32); - auto tpe_tensor = - TensorWrapper(tokens_per_expert_buf.untyped_data(), tpe_shape, tpe_dtype); + auto const_buf_tensor = TensorWrapper(const_buf_in.untyped_data(), scalar_shape, DType::kFloat32); + auto tpe_tensor = TensorWrapper(tokens_per_expert_buf.untyped_data(), tpe_shape, tpe_dtype); auto grad_aux_loss_tensor = TensorWrapper(grad_aux_loss_buf.untyped_data(), scalar_shape, grad_dtype); auto grad_probs_tensor = TensorWrapper(grad_probs_buf->untyped_data(), grad_probs_shape, grad_dtype); - nvte_fused_moe_aux_loss_backward( - const_buf_tensor.data(), tpe_tensor.data(), static_cast(num_rows), - static_cast(num_cols), grad_aux_loss_tensor.data(), grad_probs_tensor.data(), stream); + nvte_fused_moe_aux_loss_backward(const_buf_tensor.data(), tpe_tensor.data(), + static_cast(num_rows), static_cast(num_cols), + grad_aux_loss_tensor.data(), grad_probs_tensor.data(), stream); return ffi_with_cuda_error_check(); } -XLA_FFI_DEFINE_HANDLER_SYMBOL( - FusedMoEAuxLossBackwardHandler, FusedMoEAuxLossBackwardFFI, - FFI::Bind() - .Ctx() // stream - .Arg() // Const_buf - .Arg() // tokens_per_expert - .Arg() // grad_aux_loss - .Ret() // grad_probs - .Attr("num_rows") - .Attr("num_cols"), - FFI_CudaGraph_Traits); +XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedMoEAuxLossBackwardHandler, FusedMoEAuxLossBackwardFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // Const_buf + .Arg() // tokens_per_expert + .Arg() // grad_aux_loss + .Ret() // grad_probs + .Attr("num_rows") + .Attr("num_cols"), + FFI_CudaGraph_Traits); } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/router.py b/transformer_engine/jax/router.py index cf3a4a4097..73cdb3cfcf 100644 --- a/transformer_engine/jax/router.py +++ b/transformer_engine/jax/router.py @@ -50,7 +50,7 @@ def _validate_score_function(score_function: Union[str, ScoreFunction]) -> Score return ScoreFunction[score_function.upper()] except (KeyError, AttributeError): raise ValueError( - f"score_function must be 'softmax', 'sigmoid', or a ScoreFunction enum, " + "score_function must be 'softmax', 'sigmoid', or a ScoreFunction enum, " f"got {score_function!r}" ) from None @@ -140,34 +140,67 @@ def _fused_topk_with_score_function( compute_aux_scores: bool, ) -> Tuple[jnp.ndarray, jnp.ndarray]: (probs, routing_map), _ = _fused_topk_with_score_function_fwd( - logits, expert_bias, topk, use_pre_softmax, num_groups, group_topk, - scaling_factor, score_function, compute_aux_scores, + logits, + expert_bias, + topk, + use_pre_softmax, + num_groups, + group_topk, + scaling_factor, + score_function, + compute_aux_scores, ) return probs, routing_map def _fused_topk_with_score_function_fwd( - logits, expert_bias, topk, use_pre_softmax, num_groups, group_topk, - scaling_factor, score_function, compute_aux_scores, + logits, + expert_bias, + topk, + use_pre_softmax, + num_groups, + group_topk, + scaling_factor, + score_function, + compute_aux_scores, ): probs, routing_map, intermediate_output = fused_topk_with_score_function_fwd( - logits, topk, use_pre_softmax, num_groups, group_topk, - scaling_factor, score_function, expert_bias, compute_aux_scores, + logits, + topk, + use_pre_softmax, + num_groups, + group_topk, + scaling_factor, + score_function, + expert_bias, + compute_aux_scores, ) residuals = (routing_map, intermediate_output) return (probs, routing_map), residuals def _fused_topk_with_score_function_bwd( - topk, use_pre_softmax, num_groups, group_topk, scaling_factor, score_function, # pylint: disable=unused-argument - compute_aux_scores, residuals, g, + topk, + use_pre_softmax, + num_groups, + group_topk, + scaling_factor, + score_function, # pylint: disable=unused-argument + compute_aux_scores, + residuals, + g, ): routing_map, intermediate_output = residuals grad_probs, _ = g # routing_map gradient is None (boolean) grad_logits = fused_topk_with_score_function_bwd( - routing_map, intermediate_output, grad_probs, - topk, use_pre_softmax, scaling_factor, score_function, + routing_map, + intermediate_output, + grad_probs, + topk, + use_pre_softmax, + scaling_factor, + score_function, compute_aux_scores, ) # expert_bias gradient is None: bias is not differentiated through this kernel @@ -223,11 +256,11 @@ def fused_compute_score_for_moe_aux_loss( jnp.empty((0,), dtype=logits.dtype), topk, False, # use_pre_softmax (unused for aux scores) - 1, # num_groups (unused for aux scores) - 1, # group_topk (unused for aux scores) - 1.0, # scaling_factor (unused for aux scores) + 1, # num_groups (unused for aux scores) + 1, # group_topk (unused for aux scores) + 1.0, # scaling_factor (unused for aux scores) score_function, - True, # compute_aux_scores + True, # compute_aux_scores ) return routing_map, scores @@ -271,7 +304,12 @@ def fused_moe_aux_loss( Scalar loss value. """ return _fused_moe_aux_loss( - probs, tokens_per_expert, total_num_tokens, num_experts, topk, coeff, + probs, + tokens_per_expert, + total_num_tokens, + num_experts, + topk, + coeff, ) @@ -285,30 +323,54 @@ def _fused_moe_aux_loss( coeff: float, ) -> jnp.ndarray: aux_loss, _ = _fused_moe_aux_loss_fwd( - probs, tokens_per_expert, total_num_tokens, num_experts, topk, coeff, + probs, + tokens_per_expert, + total_num_tokens, + num_experts, + topk, + coeff, ) return aux_loss def _fused_moe_aux_loss_fwd( - probs, tokens_per_expert, total_num_tokens, num_experts, topk, coeff, + probs, + tokens_per_expert, + total_num_tokens, + num_experts, + topk, + coeff, ): aux_loss, const_buf = fused_moe_aux_loss_fwd( - probs, tokens_per_expert, total_num_tokens, num_experts, topk, coeff, + probs, + tokens_per_expert, + total_num_tokens, + num_experts, + topk, + coeff, ) residuals = (const_buf, tokens_per_expert, probs.shape[0], probs.shape[1]) return aux_loss.squeeze(), residuals def _fused_moe_aux_loss_bwd( - total_num_tokens, num_experts, topk, coeff, residuals, g, # pylint: disable=unused-argument + total_num_tokens, + num_experts, + topk, + coeff, + residuals, + g, # pylint: disable=unused-argument ): const_buf, tokens_per_expert, num_rows, num_cols = residuals # g is a scalar matching the squeezed output of _fwd grad_aux_loss = g.reshape(1) grad_probs = fused_moe_aux_loss_bwd( - const_buf, tokens_per_expert, grad_aux_loss, num_rows, num_cols, + const_buf, + tokens_per_expert, + grad_aux_loss, + num_rows, + num_cols, ) return grad_probs, None From f3557cd23d63c9ee288299a5000617756b97cf49 Mon Sep 17 00:00:00 2001 From: tdophung Date: Fri, 27 Feb 2026 18:17:25 -0800 Subject: [PATCH 7/9] properly merge the jax top level APIs of score_for_moe_aux_loss with topk_and_score Signed-off-by: tdophung --- tests/jax/test_distributed_router.py | 21 ++- tests/jax/test_fused_router.py | 20 +-- .../jax/cpp_extensions/router.py | 122 +++++++++++++++--- transformer_engine/jax/router.py | 117 +++++++---------- 4 files changed, 170 insertions(+), 110 deletions(-) diff --git a/tests/jax/test_distributed_router.py b/tests/jax/test_distributed_router.py index 7ccc842be5..92932d85f1 100644 --- a/tests/jax/test_distributed_router.py +++ b/tests/jax/test_distributed_router.py @@ -10,7 +10,7 @@ sharded execution on the token dimension should produce identical results to processing each shard independently with the reference implementation. -For fused_topk_with_score_function and fused_compute_score_for_moe_aux_loss: +For fused_topk_with_score_function (including compute_aux_scores mode): - Input logits [num_tokens, num_experts] are sharded on num_tokens (DP axis) - Expert dimension is replicated - Each GPU processes its local tokens independently @@ -36,7 +36,6 @@ from transformer_engine.jax.router import ( fused_topk_with_score_function, - fused_compute_score_for_moe_aux_loss, fused_moe_aux_loss, ) @@ -202,7 +201,7 @@ def test_distributed_topk( class TestDistributedScoreForAuxLoss: - """Test distributed execution of fused_compute_score_for_moe_aux_loss. + """Test distributed execution of fused_topk_with_score_function with compute_aux_scores=True. Same sharding strategy as fused_topk: shard on token dim, replicate experts. Each GPU independently computes scores and routing map for its local tokens. @@ -238,13 +237,12 @@ def _impl_test( # === Forward === @jax.jit def target_fwd(x): - return fused_compute_score_for_moe_aux_loss( - x, - topk=topk, - score_function=score_function, + return fused_topk_with_score_function( + x, topk=topk, score_function=score_function, + compute_aux_scores=True, ) - target_routing_map, target_scores = target_fwd(logits_sharded) + target_scores, target_routing_map = target_fwd(logits_sharded) logits_shards = jnp.reshape(logits, (num_dp_devices, local_num_tokens, num_experts)) ref_fwd_fn = jax.jit( @@ -276,10 +274,9 @@ def target_fwd(x): # === Backward === def target_loss(x): - _, s = fused_compute_score_for_moe_aux_loss( - x, - topk=topk, - score_function=score_function, + s, _ = fused_topk_with_score_function( + x, topk=topk, score_function=score_function, + compute_aux_scores=True, ) return jnp.sum(s) diff --git a/tests/jax/test_fused_router.py b/tests/jax/test_fused_router.py index f270fa3b80..72913b0d90 100644 --- a/tests/jax/test_fused_router.py +++ b/tests/jax/test_fused_router.py @@ -15,7 +15,6 @@ from transformer_engine.jax.router import ( fused_topk_with_score_function, - fused_compute_score_for_moe_aux_loss, fused_moe_aux_loss, ) @@ -441,14 +440,13 @@ def test_fused_scores_for_aux_loss(dtype, num_tokens, num_experts, topk, score_f routing_map_ref, scores_ref = ref_fwd_fn(logits) # Forward: fused (jitted) - fused_fwd_fn = jax.jit( - partial( - fused_compute_score_for_moe_aux_loss, - topk=topk, - score_function=score_function, - ) - ) - routing_map_fused, scores_fused = fused_fwd_fn(logits) + fused_fwd_fn = jax.jit(partial( + fused_topk_with_score_function, + topk=topk, + score_function=score_function, + compute_aux_scores=True, + )) + scores_fused, routing_map_fused = fused_fwd_fn(logits) assert jnp.allclose( scores_ref, scores_fused, atol=1e-5, rtol=1e-5 @@ -461,7 +459,9 @@ def loss_ref(logits_): return s.sum() def loss_fused(logits_): - _, s = fused_compute_score_for_moe_aux_loss(logits_, topk, score_function) + s, _ = fused_topk_with_score_function( + logits_, topk, score_function=score_function, compute_aux_scores=True, + ) return s.sum() grad_ref = jax.jit(jax.grad(loss_ref))(logits) diff --git a/transformer_engine/jax/cpp_extensions/router.py b/transformer_engine/jax/cpp_extensions/router.py index bd6ebf9d0d..e886809309 100644 --- a/transformer_engine/jax/cpp_extensions/router.py +++ b/transformer_engine/jax/cpp_extensions/router.py @@ -7,8 +7,10 @@ import jax.numpy as jnp from jax import dtypes, ffi +from jax.sharding import NamedSharding, PartitionSpec from .base import BasePrimitive, register_primitive +from ..sharding import get_padded_spec __all__ = [ "ScoreFunction", @@ -156,12 +158,41 @@ def batcher( ) @staticmethod - def infer_sharding_from_operands(*args, **kwargs): - raise NotImplementedError("Use shardy sharding rules instead of GSPMD custom partitioning") + def infer_sharding_from_operands( + topk, use_pre_softmax, num_groups, group_topk, scaling_factor, + score_function, compute_aux_scores, + mesh, arg_infos, result_infos, + ): + del topk, use_pre_softmax, num_groups, group_topk, scaling_factor + del score_function, compute_aux_scores, result_infos + logits_spec = get_padded_spec(arg_infos[0]) + out_sharding = NamedSharding(mesh, PartitionSpec(*logits_spec)) + routing_sharding = NamedSharding(mesh, PartitionSpec(*logits_spec)) + intermediate_sharding = NamedSharding(mesh, PartitionSpec(*logits_spec)) + return [out_sharding, routing_sharding, intermediate_sharding] @staticmethod - def partition(*args, **kwargs): - raise NotImplementedError("Use shardy sharding rules instead of GSPMD custom partitioning") + def partition( + topk, use_pre_softmax, num_groups, group_topk, scaling_factor, + score_function, compute_aux_scores, + mesh, arg_infos, result_infos, + ): + del result_infos + logits_spec = get_padded_spec(arg_infos[0]) + out_sharding = NamedSharding(mesh, PartitionSpec(*logits_spec)) + routing_sharding = NamedSharding(mesh, PartitionSpec(*logits_spec)) + intermediate_sharding = NamedSharding(mesh, PartitionSpec(*logits_spec)) + out_shardings = [out_sharding, routing_sharding, intermediate_sharding] + arg_shardings = (arg_infos[0].sharding, arg_infos[1].sharding) + + def sharded_impl(logits, expert_bias): + return FusedTopkWithScoreFunctionFwdPrimitive.impl( + logits, expert_bias, + topk, use_pre_softmax, num_groups, group_topk, + scaling_factor, score_function, compute_aux_scores, + ) + + return mesh, sharded_impl, out_shardings, arg_shardings @staticmethod def shardy_sharding_rule(*args): @@ -293,12 +324,33 @@ def batcher( ) @staticmethod - def infer_sharding_from_operands(*args, **kwargs): - raise NotImplementedError("Use shardy sharding rules instead of GSPMD custom partitioning") + def infer_sharding_from_operands( + topk, use_pre_softmax, scaling_factor, score_function, compute_aux_scores, + mesh, arg_infos, result_infos, + ): + del topk, use_pre_softmax, scaling_factor, score_function + del compute_aux_scores, result_infos + grad_spec = get_padded_spec(arg_infos[2]) + return NamedSharding(mesh, PartitionSpec(*grad_spec)) @staticmethod - def partition(*args, **kwargs): - raise NotImplementedError("Use shardy sharding rules instead of GSPMD custom partitioning") + def partition( + topk, use_pre_softmax, scaling_factor, score_function, compute_aux_scores, + mesh, arg_infos, result_infos, + ): + del result_infos + grad_spec = get_padded_spec(arg_infos[2]) + out_sharding = NamedSharding(mesh, PartitionSpec(*grad_spec)) + arg_shardings = (arg_infos[0].sharding, arg_infos[1].sharding, arg_infos[2].sharding) + + def sharded_impl(routing_map, intermediate, grad_probs): + return FusedTopkWithScoreFunctionBwdPrimitive.impl( + routing_map, intermediate, grad_probs, + topk, use_pre_softmax, scaling_factor, + score_function, compute_aux_scores, + ) + + return mesh, sharded_impl, out_sharding, arg_shardings @staticmethod def shardy_sharding_rule(*args): @@ -378,12 +430,31 @@ def batcher(batched_args, batch_dims, *, total_num_tokens, num_experts, topk, co ) @staticmethod - def infer_sharding_from_operands(*args, **kwargs): - raise NotImplementedError("Use shardy sharding rules instead of GSPMD custom partitioning") + def infer_sharding_from_operands( + total_num_tokens, num_experts, topk, coeff, + mesh, arg_infos, result_infos, + ): + del total_num_tokens, num_experts, topk, coeff, arg_infos, result_infos + scalar_sharding = NamedSharding(mesh, PartitionSpec(None)) + return [scalar_sharding, scalar_sharding] @staticmethod - def partition(*args, **kwargs): - raise NotImplementedError("Use shardy sharding rules instead of GSPMD custom partitioning") + def partition( + total_num_tokens, num_experts, topk, coeff, + mesh, arg_infos, result_infos, + ): + del result_infos + scalar_sharding = NamedSharding(mesh, PartitionSpec(None)) + out_shardings = [scalar_sharding, scalar_sharding] + arg_shardings = (arg_infos[0].sharding, arg_infos[1].sharding) + + def sharded_impl(probs, tokens_per_expert): + return FusedMoEAuxLossFwdPrimitive.impl( + probs, tokens_per_expert, + total_num_tokens, num_experts, topk, coeff, + ) + + return mesh, sharded_impl, out_shardings, arg_shardings @staticmethod def shardy_sharding_rule(*args): @@ -450,12 +521,31 @@ def batcher(batched_args, batch_dims, *, num_rows, num_cols): ) @staticmethod - def infer_sharding_from_operands(*args, **kwargs): - raise NotImplementedError("Use shardy sharding rules instead of GSPMD custom partitioning") + def infer_sharding_from_operands( + num_rows, num_cols, + mesh, arg_infos, result_infos, + ): + del num_rows, num_cols, arg_infos, result_infos + return NamedSharding(mesh, PartitionSpec(None, None)) @staticmethod - def partition(*args, **kwargs): - raise NotImplementedError("Use shardy sharding rules instead of GSPMD custom partitioning") + def partition( + num_rows, num_cols, + mesh, arg_infos, result_infos, + ): + del result_infos + out_sharding = NamedSharding(mesh, PartitionSpec(None, None)) + arg_shardings = ( + arg_infos[0].sharding, arg_infos[1].sharding, arg_infos[2].sharding, + ) + + def sharded_impl(const_buf, tokens_per_expert, grad_aux_loss): + return FusedMoEAuxLossBwdPrimitive.impl( + const_buf, tokens_per_expert, grad_aux_loss, + num_rows, num_cols, + ) + + return mesh, sharded_impl, out_sharding, arg_shardings @staticmethod def shardy_sharding_rule(*args): diff --git a/transformer_engine/jax/router.py b/transformer_engine/jax/router.py index 73cdb3cfcf..6c35ca5125 100644 --- a/transformer_engine/jax/router.py +++ b/transformer_engine/jax/router.py @@ -11,10 +11,9 @@ Functions: fused_topk_with_score_function: Fused score_function + top-k selection. Supports softmax/sigmoid, - grouped top-k, expert bias, and scaling factor. - - fused_compute_score_for_moe_aux_loss: - Compute clean scores and routing map for the auxiliary load-balancing loss. + grouped top-k, expert bias, and scaling factor. When compute_aux_scores=True, + switches to the clean score-for-aux-loss kernel (no bias/groups/scaling, + dense output). fused_moe_aux_loss: Compute the MoE auxiliary load-balancing loss scalar. @@ -37,7 +36,6 @@ __all__ = [ "ScoreFunction", "fused_topk_with_score_function", - "fused_compute_score_for_moe_aux_loss", "fused_moe_aux_loss", ] @@ -69,10 +67,21 @@ def fused_topk_with_score_function( scaling_factor: float = 1.0, score_function: Union[str, ScoreFunction] = ScoreFunction.SOFTMAX, expert_bias: Optional[jnp.ndarray] = None, + compute_aux_scores: bool = False, ) -> Tuple[jnp.ndarray, jnp.ndarray]: """ Fused top-k with score function router. + When compute_aux_scores=False (default), runs the main routing kernel: + score_function(logits) -> [optional bias] -> top-k -> [optional post-softmax] -> scale. + Returns sparse probs (only top-k positions nonzero) and routing_map. + + When compute_aux_scores=True, runs the score-for-aux-loss kernel instead: + score_function(logits) -> top-k (clean, no bias/groups/scaling). + Returns dense scores (all expert positions) and routing_map. + The expert_bias, use_pre_softmax, num_groups, group_topk, and scaling_factor + parameters are ignored in this mode. + Parameters ---------- logits : jnp.ndarray @@ -80,39 +89,55 @@ def fused_topk_with_score_function( topk : int Number of top experts to select per token. use_pre_softmax : bool - If True, apply softmax before top-k (only for softmax score function). Else, apply post top-k + If True, apply softmax before top-k (only for softmax score function). Else, apply post top-k. + Ignored when compute_aux_scores=True. num_groups : int Number of groups for grouped top-k. 1 means no grouping. + Ignored when compute_aux_scores=True. group_topk : int Top-k at group level. 1 means no group-level selection. + Ignored when compute_aux_scores=True. scaling_factor : float Scaling factor applied to output probs. + Ignored when compute_aux_scores=True. score_function : Union[str, ScoreFunction] Score function: "softmax" / "sigmoid" or ScoreFunction.SOFTMAX / ScoreFunction.SIGMOID. expert_bias : Optional[jnp.ndarray] Expert bias, shape [num_experts]. Only used with sigmoid. + Ignored when compute_aux_scores=True. + compute_aux_scores : bool + If True, use the clean score-for-aux-loss kernel. Returns dense scores + over all experts instead of sparse probs. Returns ------- - probs : jnp.ndarray - Sparse probability tensor, shape [num_tokens, num_experts]. - Non-zero only at selected expert positions. + probs_or_scores : jnp.ndarray + When compute_aux_scores=False: Sparse probability tensor, shape [num_tokens, num_experts]. + Non-zero only at selected expert positions. + When compute_aux_scores=True: Dense score tensor, shape [num_tokens, num_experts]. + All expert positions contain scores. routing_map : jnp.ndarray Boolean mask, shape [num_tokens, num_experts]. True at selected expert positions. """ score_function = _validate_score_function(score_function) - if expert_bias is not None and score_function != ScoreFunction.SIGMOID: - raise ValueError( - "expert_bias is only supported with score_function='sigmoid'. " - f"Got score_function='{score_function.name}'." - ) - - if expert_bias is None: + if compute_aux_scores: expert_bias = jnp.empty((0,), dtype=logits.dtype) - - probs, routing_map = _fused_topk_with_score_function( + use_pre_softmax = False + num_groups = 1 + group_topk = 1 + scaling_factor = 1.0 + else: + if expert_bias is not None and score_function != ScoreFunction.SIGMOID: + raise ValueError( + "expert_bias is only supported with score_function='sigmoid'. " + f"Got score_function='{score_function.name}'." + ) + if expert_bias is None: + expert_bias = jnp.empty((0,), dtype=logits.dtype) + + probs_or_scores, routing_map = _fused_topk_with_score_function( logits, expert_bias, topk, @@ -121,10 +146,10 @@ def fused_topk_with_score_function( group_topk, scaling_factor, score_function, - False, # compute_aux_scores + compute_aux_scores, ) - return probs, routing_map + return probs_or_scores, routing_map @partial(jax.custom_vjp, nondiff_argnums=(2, 3, 4, 5, 6, 7, 8)) @@ -213,58 +238,6 @@ def _fused_topk_with_score_function_bwd( ) -# ============================================================================= -# Fused Score for MoE Aux Loss -# ============================================================================= - - -def fused_compute_score_for_moe_aux_loss( - logits: jnp.ndarray, - topk: int, - score_function: Union[str, ScoreFunction] = ScoreFunction.SOFTMAX, -) -> Tuple[jnp.ndarray, jnp.ndarray]: - """ - Compute scores and routing map for MoE auxiliary loss. - - This uses clean softmax/sigmoid + plain top-k (no group constraints, - no expert bias, no scaling) to produce the scores and routing map - used for the load-balancing auxiliary loss. - - Internally delegates to the same primitive as fused_topk_with_score_function - with compute_aux_scores=True, selecting the score-for-aux-loss CUDA kernel. - - Parameters - ---------- - logits : jnp.ndarray - Logits from the gating GEMM, shape [num_tokens, num_experts]. - topk : int - Number of top experts to select. - score_function : Union[str, ScoreFunction] - Score function: "softmax" / "sigmoid" or ScoreFunction.SOFTMAX / ScoreFunction.SIGMOID. - - Returns - ------- - routing_map : jnp.ndarray - Boolean mask, shape [num_tokens, num_experts]. - scores : jnp.ndarray - Dense score tensor, shape [num_tokens, num_experts]. - """ - score_function = _validate_score_function(score_function) - - scores, routing_map = _fused_topk_with_score_function( - logits, - jnp.empty((0,), dtype=logits.dtype), - topk, - False, # use_pre_softmax (unused for aux scores) - 1, # num_groups (unused for aux scores) - 1, # group_topk (unused for aux scores) - 1.0, # scaling_factor (unused for aux scores) - score_function, - True, # compute_aux_scores - ) - return routing_map, scores - - # ============================================================================= # Fused MoE Aux Loss # ============================================================================= From fd9afbe024bbc7e62849e5676d11202e9d0bdfd5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 28 Feb 2026 02:20:51 +0000 Subject: [PATCH 8/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/jax/test_distributed_router.py | 8 +- tests/jax/test_fused_router.py | 19 ++- .../jax/cpp_extensions/router.py | 120 +++++++++++++----- 3 files changed, 109 insertions(+), 38 deletions(-) diff --git a/tests/jax/test_distributed_router.py b/tests/jax/test_distributed_router.py index 92932d85f1..fb13008723 100644 --- a/tests/jax/test_distributed_router.py +++ b/tests/jax/test_distributed_router.py @@ -238,7 +238,9 @@ def _impl_test( @jax.jit def target_fwd(x): return fused_topk_with_score_function( - x, topk=topk, score_function=score_function, + x, + topk=topk, + score_function=score_function, compute_aux_scores=True, ) @@ -275,7 +277,9 @@ def target_fwd(x): # === Backward === def target_loss(x): s, _ = fused_topk_with_score_function( - x, topk=topk, score_function=score_function, + x, + topk=topk, + score_function=score_function, compute_aux_scores=True, ) return jnp.sum(s) diff --git a/tests/jax/test_fused_router.py b/tests/jax/test_fused_router.py index 72913b0d90..058f9c9060 100644 --- a/tests/jax/test_fused_router.py +++ b/tests/jax/test_fused_router.py @@ -440,12 +440,14 @@ def test_fused_scores_for_aux_loss(dtype, num_tokens, num_experts, topk, score_f routing_map_ref, scores_ref = ref_fwd_fn(logits) # Forward: fused (jitted) - fused_fwd_fn = jax.jit(partial( - fused_topk_with_score_function, - topk=topk, - score_function=score_function, - compute_aux_scores=True, - )) + fused_fwd_fn = jax.jit( + partial( + fused_topk_with_score_function, + topk=topk, + score_function=score_function, + compute_aux_scores=True, + ) + ) scores_fused, routing_map_fused = fused_fwd_fn(logits) assert jnp.allclose( @@ -460,7 +462,10 @@ def loss_ref(logits_): def loss_fused(logits_): s, _ = fused_topk_with_score_function( - logits_, topk, score_function=score_function, compute_aux_scores=True, + logits_, + topk, + score_function=score_function, + compute_aux_scores=True, ) return s.sum() diff --git a/transformer_engine/jax/cpp_extensions/router.py b/transformer_engine/jax/cpp_extensions/router.py index e886809309..8bd3911196 100644 --- a/transformer_engine/jax/cpp_extensions/router.py +++ b/transformer_engine/jax/cpp_extensions/router.py @@ -159,9 +159,16 @@ def batcher( @staticmethod def infer_sharding_from_operands( - topk, use_pre_softmax, num_groups, group_topk, scaling_factor, - score_function, compute_aux_scores, - mesh, arg_infos, result_infos, + topk, + use_pre_softmax, + num_groups, + group_topk, + scaling_factor, + score_function, + compute_aux_scores, + mesh, + arg_infos, + result_infos, ): del topk, use_pre_softmax, num_groups, group_topk, scaling_factor del score_function, compute_aux_scores, result_infos @@ -173,9 +180,16 @@ def infer_sharding_from_operands( @staticmethod def partition( - topk, use_pre_softmax, num_groups, group_topk, scaling_factor, - score_function, compute_aux_scores, - mesh, arg_infos, result_infos, + topk, + use_pre_softmax, + num_groups, + group_topk, + scaling_factor, + score_function, + compute_aux_scores, + mesh, + arg_infos, + result_infos, ): del result_infos logits_spec = get_padded_spec(arg_infos[0]) @@ -187,9 +201,15 @@ def partition( def sharded_impl(logits, expert_bias): return FusedTopkWithScoreFunctionFwdPrimitive.impl( - logits, expert_bias, - topk, use_pre_softmax, num_groups, group_topk, - scaling_factor, score_function, compute_aux_scores, + logits, + expert_bias, + topk, + use_pre_softmax, + num_groups, + group_topk, + scaling_factor, + score_function, + compute_aux_scores, ) return mesh, sharded_impl, out_shardings, arg_shardings @@ -325,8 +345,14 @@ def batcher( @staticmethod def infer_sharding_from_operands( - topk, use_pre_softmax, scaling_factor, score_function, compute_aux_scores, - mesh, arg_infos, result_infos, + topk, + use_pre_softmax, + scaling_factor, + score_function, + compute_aux_scores, + mesh, + arg_infos, + result_infos, ): del topk, use_pre_softmax, scaling_factor, score_function del compute_aux_scores, result_infos @@ -335,8 +361,14 @@ def infer_sharding_from_operands( @staticmethod def partition( - topk, use_pre_softmax, scaling_factor, score_function, compute_aux_scores, - mesh, arg_infos, result_infos, + topk, + use_pre_softmax, + scaling_factor, + score_function, + compute_aux_scores, + mesh, + arg_infos, + result_infos, ): del result_infos grad_spec = get_padded_spec(arg_infos[2]) @@ -345,9 +377,14 @@ def partition( def sharded_impl(routing_map, intermediate, grad_probs): return FusedTopkWithScoreFunctionBwdPrimitive.impl( - routing_map, intermediate, grad_probs, - topk, use_pre_softmax, scaling_factor, - score_function, compute_aux_scores, + routing_map, + intermediate, + grad_probs, + topk, + use_pre_softmax, + scaling_factor, + score_function, + compute_aux_scores, ) return mesh, sharded_impl, out_sharding, arg_shardings @@ -431,8 +468,13 @@ def batcher(batched_args, batch_dims, *, total_num_tokens, num_experts, topk, co @staticmethod def infer_sharding_from_operands( - total_num_tokens, num_experts, topk, coeff, - mesh, arg_infos, result_infos, + total_num_tokens, + num_experts, + topk, + coeff, + mesh, + arg_infos, + result_infos, ): del total_num_tokens, num_experts, topk, coeff, arg_infos, result_infos scalar_sharding = NamedSharding(mesh, PartitionSpec(None)) @@ -440,8 +482,13 @@ def infer_sharding_from_operands( @staticmethod def partition( - total_num_tokens, num_experts, topk, coeff, - mesh, arg_infos, result_infos, + total_num_tokens, + num_experts, + topk, + coeff, + mesh, + arg_infos, + result_infos, ): del result_infos scalar_sharding = NamedSharding(mesh, PartitionSpec(None)) @@ -450,8 +497,12 @@ def partition( def sharded_impl(probs, tokens_per_expert): return FusedMoEAuxLossFwdPrimitive.impl( - probs, tokens_per_expert, - total_num_tokens, num_experts, topk, coeff, + probs, + tokens_per_expert, + total_num_tokens, + num_experts, + topk, + coeff, ) return mesh, sharded_impl, out_shardings, arg_shardings @@ -522,27 +573,38 @@ def batcher(batched_args, batch_dims, *, num_rows, num_cols): @staticmethod def infer_sharding_from_operands( - num_rows, num_cols, - mesh, arg_infos, result_infos, + num_rows, + num_cols, + mesh, + arg_infos, + result_infos, ): del num_rows, num_cols, arg_infos, result_infos return NamedSharding(mesh, PartitionSpec(None, None)) @staticmethod def partition( - num_rows, num_cols, - mesh, arg_infos, result_infos, + num_rows, + num_cols, + mesh, + arg_infos, + result_infos, ): del result_infos out_sharding = NamedSharding(mesh, PartitionSpec(None, None)) arg_shardings = ( - arg_infos[0].sharding, arg_infos[1].sharding, arg_infos[2].sharding, + arg_infos[0].sharding, + arg_infos[1].sharding, + arg_infos[2].sharding, ) def sharded_impl(const_buf, tokens_per_expert, grad_aux_loss): return FusedMoEAuxLossBwdPrimitive.impl( - const_buf, tokens_per_expert, grad_aux_loss, - num_rows, num_cols, + const_buf, + tokens_per_expert, + grad_aux_loss, + num_rows, + num_cols, ) return mesh, sharded_impl, out_sharding, arg_shardings From 3fdeeef253e5496ceb6ca6de6053d9f619b0c2ba Mon Sep 17 00:00:00 2001 From: tdophung Date: Fri, 27 Feb 2026 18:42:46 -0800 Subject: [PATCH 9/9] fix lint + import issues Signed-off-by: tdophung --- transformer_engine/jax/cpp_extensions/router.py | 5 +++-- transformer_engine/jax/router.py | 16 ++++++++-------- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/router.py b/transformer_engine/jax/cpp_extensions/router.py index 8bd3911196..ad25802841 100644 --- a/transformer_engine/jax/cpp_extensions/router.py +++ b/transformer_engine/jax/cpp_extensions/router.py @@ -3,14 +3,13 @@ # See LICENSE for license information. """JAX/TE custom ops for fused MoE router""" from enum import IntEnum -from functools import partial import jax.numpy as jnp from jax import dtypes, ffi from jax.sharding import NamedSharding, PartitionSpec from .base import BasePrimitive, register_primitive -from ..sharding import get_padded_spec +from .misc import get_padded_spec __all__ = [ "ScoreFunction", @@ -22,6 +21,8 @@ class ScoreFunction(IntEnum): + """Score function enum for fused MoE router kernels.""" + SIGMOID = 0 SOFTMAX = 1 diff --git a/transformer_engine/jax/router.py b/transformer_engine/jax/router.py index 6c35ca5125..c95287a5fb 100644 --- a/transformer_engine/jax/router.py +++ b/transformer_engine/jax/router.py @@ -207,10 +207,10 @@ def _fused_topk_with_score_function_fwd( def _fused_topk_with_score_function_bwd( topk, use_pre_softmax, - num_groups, - group_topk, + num_groups, # pylint: disable=unused-argument + group_topk, # pylint: disable=unused-argument scaling_factor, - score_function, # pylint: disable=unused-argument + score_function, compute_aux_scores, residuals, g, @@ -327,12 +327,12 @@ def _fused_moe_aux_loss_fwd( def _fused_moe_aux_loss_bwd( - total_num_tokens, - num_experts, - topk, - coeff, + total_num_tokens, # pylint: disable=unused-argument + num_experts, # pylint: disable=unused-argument + topk, # pylint: disable=unused-argument + coeff, # pylint: disable=unused-argument residuals, - g, # pylint: disable=unused-argument + g, ): const_buf, tokens_per_expert, num_rows, num_cols = residuals # g is a scalar matching the squeezed output of _fwd