diff --git a/tests/jax/test_distributed_router.py b/tests/jax/test_distributed_router.py new file mode 100644 index 0000000000..fb13008723 --- /dev/null +++ b/tests/jax/test_distributed_router.py @@ -0,0 +1,467 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Tests for distributed/sharded execution of fused MoE router primitives. + +Testing Strategy: +================= +Router operations process each token independently (1 warp per token), so +sharded execution on the token dimension should produce identical results +to processing each shard independently with the reference implementation. + +For fused_topk_with_score_function (including compute_aux_scores mode): +- Input logits [num_tokens, num_experts] are sharded on num_tokens (DP axis) +- Expert dimension is replicated +- Each GPU processes its local tokens independently +- We verify sharded output matches per-shard reference, concatenated + +For fused_moe_aux_loss: +- This is a global reduction to a scalar +- All inputs and outputs are replicated (partition function forces this) +- We verify the op works correctly under a mesh context + +These tests exercise: batcher and shardy_sharding_rule from the router primitives. +""" + +import pytest + +import jax +import jax.numpy as jnp +import numpy as np +from jax.sharding import Mesh, NamedSharding, PartitionSpec + +from distributed_test_base import generate_configs +from utils import assert_allclose, pytest_parametrize_wrapper + +from transformer_engine.jax.router import ( + fused_topk_with_score_function, + fused_moe_aux_loss, +) + +from test_fused_router import ( + reference_topk_softmax_sigmoid, + reference_compute_scores_for_aux_loss, + reference_aux_loss, + make_logits, +) + +# (num_tokens, num_experts, topk) +ALL_TOPK_CASES = [ + (128, 32, 4), + (2048, 128, 8), +] +TOPK_CASES = { + "L0": ALL_TOPK_CASES[0:1], + "L2": ALL_TOPK_CASES, +} + +ALL_AUX_LOSS_CASES = [ + (128, 32, 4), + (2048, 128, 4), +] +AUX_LOSS_CASES = { + "L0": ALL_AUX_LOSS_CASES[0:1], + "L2": ALL_AUX_LOSS_CASES, +} + + +class TestDistributedFusedTopk: + """Test distributed execution of fused_topk_with_score_function. + + Shards logits on the token dimension. Each GPU independently runs the + fused kernel on its local tokens. We compare against the reference + implementation run per-shard and concatenated. + """ + + def _impl_test( + self, + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + num_tokens, + num_experts, + topk, + score_function, + ): + jax.config.update("jax_use_shardy_partitioner", True) + + logits = make_logits(num_tokens, num_experts, score_function) + + devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) + mesh = Mesh(devices, mesh_axes) + + dp_axis = mesh_resource.dp_resource + sharded_pspec = PartitionSpec(dp_axis, None) + num_dp_devices = mesh.shape[dp_axis] if dp_axis else 1 + local_num_tokens = num_tokens // num_dp_devices + + with mesh: + logits_sharding = NamedSharding(mesh, sharded_pspec) + logits_sharded = jax.device_put(logits, logits_sharding) + + # === Forward === + @jax.jit + def target_fwd(x): + return fused_topk_with_score_function( + x, + topk=topk, + score_function=score_function, + ) + + target_probs, target_routing_map = target_fwd(logits_sharded) + + logits_shards = jnp.reshape(logits, (num_dp_devices, local_num_tokens, num_experts)) + ref_fwd_fn = jax.jit( + lambda x: reference_topk_softmax_sigmoid( + x, + topk=topk, + score_function=score_function, + ) + ) + ref_probs_list = [] + ref_routing_list = [] + for i in range(num_dp_devices): + p, rm = ref_fwd_fn(logits_shards[i]) + ref_probs_list.append(p) + ref_routing_list.append(rm) + + ref_probs = jnp.concatenate(ref_probs_list, axis=0) + ref_routing = jnp.concatenate(ref_routing_list, axis=0) + + assert_allclose( + jax.device_get(target_probs), + ref_probs, + dtype=jnp.float32, + ) + assert jnp.array_equal( + jax.device_get(target_routing_map), + ref_routing, + ), "Routing map mismatch in distributed fused_topk" + + # === Backward === + def target_loss(x): + p, _ = fused_topk_with_score_function( + x, + topk=topk, + score_function=score_function, + ) + return jnp.sum(p) + + def ref_chunk_loss(x_chunk): + p, _ = reference_topk_softmax_sigmoid( + x_chunk, + topk=topk, + score_function=score_function, + ) + return jnp.sum(p) + + target_grad = jax.jit(jax.grad(target_loss))(logits_sharded) + + ref_grads = [] + ref_chunk_grad_fn = jax.jit(jax.grad(ref_chunk_loss)) + for i in range(num_dp_devices): + ref_grads.append(ref_chunk_grad_fn(logits_shards[i])) + ref_grad = jnp.concatenate(ref_grads, axis=0) + + assert_allclose( + jax.device_get(target_grad), + ref_grad, + dtype=jnp.float32, + ) + + @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) + @pytest_parametrize_wrapper( + "num_tokens,num_experts,topk", + TOPK_CASES, + ) + @pytest.mark.parametrize("score_function", ["softmax", "sigmoid"]) + def test_distributed_topk( + self, + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + num_tokens, + num_experts, + topk, + score_function, + ): + self._impl_test( + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + num_tokens, + num_experts, + topk, + score_function, + ) + + +class TestDistributedScoreForAuxLoss: + """Test distributed execution of fused_topk_with_score_function with compute_aux_scores=True. + + Same sharding strategy as fused_topk: shard on token dim, replicate experts. + Each GPU independently computes scores and routing map for its local tokens. + """ + + def _impl_test( + self, + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + num_tokens, + num_experts, + topk, + score_function, + ): + jax.config.update("jax_use_shardy_partitioner", True) + + logits = make_logits(num_tokens, num_experts, score_function) + + devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) + mesh = Mesh(devices, mesh_axes) + + dp_axis = mesh_resource.dp_resource + sharded_pspec = PartitionSpec(dp_axis, None) + num_dp_devices = mesh.shape[dp_axis] if dp_axis else 1 + local_num_tokens = num_tokens // num_dp_devices + + with mesh: + logits_sharding = NamedSharding(mesh, sharded_pspec) + logits_sharded = jax.device_put(logits, logits_sharding) + + # === Forward === + @jax.jit + def target_fwd(x): + return fused_topk_with_score_function( + x, + topk=topk, + score_function=score_function, + compute_aux_scores=True, + ) + + target_scores, target_routing_map = target_fwd(logits_sharded) + + logits_shards = jnp.reshape(logits, (num_dp_devices, local_num_tokens, num_experts)) + ref_fwd_fn = jax.jit( + lambda x: reference_compute_scores_for_aux_loss( + x, + topk=topk, + score_function=score_function, + ) + ) + ref_routing_list = [] + ref_scores_list = [] + for i in range(num_dp_devices): + rm, s = ref_fwd_fn(logits_shards[i]) + ref_routing_list.append(rm) + ref_scores_list.append(s) + + ref_routing = jnp.concatenate(ref_routing_list, axis=0) + ref_scores = jnp.concatenate(ref_scores_list, axis=0) + + assert_allclose( + jax.device_get(target_scores), + ref_scores, + dtype=jnp.float32, + ) + assert jnp.array_equal( + jax.device_get(target_routing_map), + ref_routing, + ), "Routing map mismatch in distributed score_for_aux_loss" + + # === Backward === + def target_loss(x): + s, _ = fused_topk_with_score_function( + x, + topk=topk, + score_function=score_function, + compute_aux_scores=True, + ) + return jnp.sum(s) + + def ref_chunk_loss(x_chunk): + _, s = reference_compute_scores_for_aux_loss( + x_chunk, + topk=topk, + score_function=score_function, + ) + return jnp.sum(s) + + target_grad = jax.jit(jax.grad(target_loss))(logits_sharded) + + ref_grads = [] + ref_chunk_grad_fn = jax.jit(jax.grad(ref_chunk_loss)) + for i in range(num_dp_devices): + ref_grads.append(ref_chunk_grad_fn(logits_shards[i])) + ref_grad = jnp.concatenate(ref_grads, axis=0) + + assert_allclose( + jax.device_get(target_grad), + ref_grad, + dtype=jnp.float32, + ) + + @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) + @pytest_parametrize_wrapper( + "num_tokens,num_experts,topk", + TOPK_CASES, + ) + @pytest.mark.parametrize("score_function", ["softmax", "sigmoid"]) + def test_distributed_score_for_aux_loss( + self, + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + num_tokens, + num_experts, + topk, + score_function, + ): + self._impl_test( + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + num_tokens, + num_experts, + topk, + score_function, + ) + + +class TestDistributedMoEAuxLoss: + """Test distributed execution of fused_moe_aux_loss. + + Aux loss is a global reduction to a scalar. The partition function forces + all inputs to be replicated. We verify the op produces correct results + under a mesh context with replicated sharding, testing both forward + (scalar loss) and backward (gradient w.r.t. probs). + """ + + def _impl_test( + self, + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + num_tokens, + num_experts, + topk, + ): + jax.config.update("jax_use_shardy_partitioner", True) + + key = jax.random.PRNGKey(42) + _, subkey1, _ = jax.random.split(key, 3) + + offset = jnp.arange(-num_tokens // 2, num_tokens // 2, dtype=jnp.float32) * 1e-4 + probs = jnp.arange(-num_experts // 2, num_experts // 2, dtype=jnp.float32) * 1e-2 + probs = probs[None, :].repeat(num_tokens, axis=0) + offset[:, None] + + tokens_per_expert = jax.random.randint(subkey1, (num_experts,), 1, 1000).astype(jnp.int32) + coeff = 0.01 + + devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) + mesh = Mesh(devices, mesh_axes) + + replicated_2d_pspec = PartitionSpec(None, None) + replicated_1d_pspec = PartitionSpec(None) + + with mesh: + probs_sharding = NamedSharding(mesh, replicated_2d_pspec) + tpe_sharding = NamedSharding(mesh, replicated_1d_pspec) + + probs_dev = jax.device_put(probs, probs_sharding) + tpe_dev = jax.device_put(tokens_per_expert, tpe_sharding) + + # === Forward === + @jax.jit + def target_fwd(p, tpe): + return fused_moe_aux_loss( + p, + tpe, + total_num_tokens=num_tokens, + num_experts=num_experts, + topk=topk, + coeff=coeff, + ) + + target_loss = target_fwd(probs_dev, tpe_dev) + + ref_fwd_fn = jax.jit( + lambda p: reference_aux_loss( + p, + tokens_per_expert, + num_tokens, + topk, + num_experts, + coeff, + ) + ) + ref_loss = ref_fwd_fn(probs) + + assert_allclose( + jax.device_get(target_loss), + ref_loss, + dtype=jnp.float32, + ) + + # === Backward === + def target_loss_fn(p): + return fused_moe_aux_loss( + p, + tokens_per_expert, + total_num_tokens=num_tokens, + num_experts=num_experts, + topk=topk, + coeff=coeff, + ) + + def ref_loss_fn(p): + return reference_aux_loss( + p, + tokens_per_expert, + num_tokens, + topk, + num_experts, + coeff, + ) + + target_grad = jax.jit(jax.grad(target_loss_fn))(probs_dev) + ref_grad = jax.jit(jax.grad(ref_loss_fn))(probs) + + assert_allclose( + jax.device_get(target_grad), + ref_grad, + dtype=jnp.float32, + ) + + @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) + @pytest_parametrize_wrapper( + "num_tokens,num_experts,topk", + AUX_LOSS_CASES, + ) + def test_distributed_aux_loss( + self, + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + num_tokens, + num_experts, + topk, + ): + self._impl_test( + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + num_tokens, + num_experts, + topk, + ) diff --git a/tests/jax/test_fused_router.py b/tests/jax/test_fused_router.py new file mode 100644 index 0000000000..058f9c9060 --- /dev/null +++ b/tests/jax/test_fused_router.py @@ -0,0 +1,541 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Tests for fused MoE router CUDA kernels (JAX wrappers).""" + +from functools import partial +from typing import Optional + +import jax +import jax.numpy as jnp +import pytest + +from utils import pytest_parametrize_wrapper + +from transformer_engine.jax.router import ( + fused_topk_with_score_function, + fused_moe_aux_loss, +) + +# ============================================================================= +# Test case definitions (L0 = fast smoke, L2 = comprehensive) +# ============================================================================= + +# (num_tokens, num_experts, topk) +ALL_TOPK_CASES = [ + (128, 32, 4), + (2048, 32, 4), + (2048, 128, 8), + (7168, 128, 4), + (7168, 32, 8), +] +TOPK_CASES = { + "L0": ALL_TOPK_CASES[0:2], + "L2": ALL_TOPK_CASES, +} + +ALL_GROUP_TOPK_OPTIONS = [None, 4] +GROUP_TOPK_OPTIONS = { + "L0": [None], + "L2": ALL_GROUP_TOPK_OPTIONS, +} + +ALL_SCALING_FACTOR_OPTIONS = [None, 1.2] +SCALING_FACTOR_OPTIONS = { + "L0": [None], + "L2": ALL_SCALING_FACTOR_OPTIONS, +} + +ALL_ENABLE_BIAS_OPTIONS = [True, False] +ENABLE_BIAS_OPTIONS = { + "L0": [False], + "L2": ALL_ENABLE_BIAS_OPTIONS, +} + +ALL_USE_PRE_SOFTMAX_OPTIONS = [True, False] +USE_PRE_SOFTMAX_OPTIONS = { + "L0": [False], + "L2": ALL_USE_PRE_SOFTMAX_OPTIONS, +} + +# (num_tokens, num_experts, topk) +ALL_SCORE_AUX_LOSS_CASES = [ + (128, 32, 4), + (2048, 128, 4), + (2048, 256, 8), + (7168, 128, 8), + (7168, 32, 4), +] +SCORE_AUX_LOSS_CASES = { + "L0": ALL_SCORE_AUX_LOSS_CASES[0:2], + "L2": ALL_SCORE_AUX_LOSS_CASES, +} + +ALL_SCORE_FUNCTIONS = ["softmax", "sigmoid"] +SCORE_FUNCTIONS = { + "L0": ["softmax"], + "L2": ALL_SCORE_FUNCTIONS, +} + +# (num_tokens, num_experts, topk) +ALL_AUX_LOSS_CASES = [ + (128, 32, 4), + (2048, 128, 4), + (2048, 256, 4), + (7168, 128, 4), + (7168, 32, 4), +] +AUX_LOSS_CASES = { + "L0": ALL_AUX_LOSS_CASES[0:2], + "L2": ALL_AUX_LOSS_CASES, +} + +ALL_DTYPES = [jnp.float32] +DTYPES = { + "L0": [jnp.float32], + "L2": ALL_DTYPES, +} + +SEED = 42 + + +# ============================================================================= +# Reference Implementations +# ============================================================================= + + +def reference_group_limited_topk( + scores: jnp.ndarray, + topk: int, + num_tokens: int, + num_experts: int, + num_groups: int, + group_topk: int, +): + """Reference implementation for grouped top-k. + + Only valid when num_groups and group_topk are both positive integers. + For plain top-k without grouping, use jax.lax.top_k directly. + """ + assert num_groups is not None and num_groups > 0, ( + "reference_group_limited_topk requires valid num_groups > 0. " + "For plain top-k, use jax.lax.top_k directly." + ) + assert ( + group_topk is not None and group_topk > 0 + ), "reference_group_limited_topk requires valid group_topk > 0." + assert ( + num_experts % num_groups == 0 + ), f"num_experts ({num_experts}) must be divisible by num_groups ({num_groups})" + group_size = num_experts // num_groups + experts_per_group = topk // group_topk + + group_scores = ( + scores.reshape(num_tokens, num_groups, group_size) + .sort(axis=-1)[..., -experts_per_group:] + .sum(axis=-1) + ) + group_idx = jax.lax.top_k(group_scores, k=group_topk)[1] + group_mask = jnp.zeros_like(group_scores).at[jnp.arange(num_tokens)[:, None], group_idx].set(1) + + score_mask = (group_mask[:, :, None] * jnp.ones((num_tokens, num_groups, group_size))).reshape( + num_tokens, -1 + ) + + masked_scores = jnp.where(score_mask.astype(bool), scores, -jnp.inf) + probs, top_indices = jax.lax.top_k(masked_scores, k=topk) + return probs, top_indices + + +def reference_topk_softmax_sigmoid( + logits: jnp.ndarray, + topk: int, + use_pre_softmax: bool = False, + num_groups: Optional[int] = None, + group_topk: Optional[int] = None, + scaling_factor: Optional[float] = None, + score_function: str = "softmax", + expert_bias: Optional[jnp.ndarray] = None, +): + """Reference implementation for topk + softmax/sigmoid.""" + num_tokens, num_experts = logits.shape + + def compute_topk(scores, topk, num_groups=None, group_topk=None): + if group_topk: + return reference_group_limited_topk( + scores=scores, + topk=topk, + num_tokens=num_tokens, + num_experts=num_experts, + num_groups=num_groups, + group_topk=group_topk, + ) + else: + return jax.lax.top_k(scores, k=topk) + + if score_function == "softmax": + if use_pre_softmax: + scores = jax.nn.softmax(logits.astype(jnp.float32), axis=-1).astype(logits.dtype) + probs, top_indices = compute_topk(scores, topk, num_groups, group_topk) + else: + scores, top_indices = compute_topk(logits, topk, num_groups, group_topk) + probs = jax.nn.softmax(scores.astype(jnp.float32), axis=-1).astype(logits.dtype) + elif score_function == "sigmoid": + scores = jax.nn.sigmoid(logits.astype(jnp.float32)).astype(logits.dtype) + if expert_bias is not None: + scores_for_routing = scores + expert_bias + _, top_indices = compute_topk(scores_for_routing, topk, num_groups, group_topk) + scores = jnp.take_along_axis(scores, top_indices, axis=1).astype(logits.dtype) + else: + scores, top_indices = compute_topk(scores, topk, num_groups, group_topk) + probs = scores / (scores.sum(axis=-1, keepdims=True) + 1e-20) if topk > 1 else scores + else: + raise ValueError(f"Invalid score_function: {score_function}") + + if scaling_factor: + probs = probs * scaling_factor + + topk_masked_gates = ( + jnp.zeros_like(logits).at[jnp.arange(num_tokens)[:, None], top_indices].set(probs) + ) + topk_map = ( + jnp.zeros_like(logits, dtype=jnp.bool_) + .at[jnp.arange(num_tokens)[:, None], top_indices] + .set(True) + ) + + return topk_masked_gates, topk_map + + +def reference_compute_scores_for_aux_loss(logits: jnp.ndarray, topk: int, score_function: str): + """Reference implementation for computing routing scores for aux loss.""" + if score_function == "softmax": + scores = jax.nn.softmax(logits.astype(jnp.float32), axis=-1) + elif score_function == "sigmoid": + scores = jax.nn.sigmoid(logits.astype(jnp.float32)) + scores = scores / (scores.sum(axis=-1, keepdims=True) + 1e-20) if topk > 1 else scores + else: + raise ValueError(f"Invalid score_function: {score_function}") + + _, top_indices = jax.lax.top_k(scores, k=topk) + num_tokens = logits.shape[0] + routing_map = ( + jnp.zeros_like(logits, dtype=jnp.bool_) + .at[jnp.arange(num_tokens)[:, None], top_indices] + .set(True) + ) + return routing_map, scores + + +def reference_aux_loss( + probs: jnp.ndarray, + tokens_per_expert: jnp.ndarray, + total_num_tokens: int, + topk: int, + num_experts: int, + moe_aux_loss_coeff: float, +): + """Reference implementation for MoE auxiliary loss.""" + aggregated_probs_per_expert = probs.sum(axis=0) + aux_loss = jnp.sum(aggregated_probs_per_expert * tokens_per_expert) * ( + num_experts * moe_aux_loss_coeff / (topk * total_num_tokens * total_num_tokens) + ) + return aux_loss + + +# ============================================================================= +# Helper: logits generation +# ============================================================================= + + +def make_logits(num_tokens, num_experts, score_function, dtype=jnp.float32): + """Create deterministic logits for testing.""" + if score_function == "sigmoid": + offset = jnp.arange(-num_tokens // 2, num_tokens // 2, dtype=dtype) * 1e-4 + logits = jnp.arange(-num_experts // 2, num_experts // 2, dtype=dtype) * 1e-2 + logits = logits[None, :].repeat(num_tokens, axis=0) + offset[:, None] + else: + logits = ( + jnp.arange( + -num_tokens * num_experts // 2, + num_tokens * num_experts // 2, + dtype=dtype, + ) + * 1e-4 + ) + logits = logits.reshape(num_tokens, num_experts) + return logits + + +# ============================================================================= +# Test: Fused Top-K with Score Function +# ============================================================================= + + +def run_topk_comparison( + dtype, + num_tokens, + num_experts, + topk, + use_pre_softmax, + num_groups, + group_topk, + scaling_factor, + score_function, + enable_bias, +): + """Compare fused vs reference top-k implementation, both jitted.""" + logits = make_logits(num_tokens, num_experts, score_function, dtype) + + if enable_bias and score_function == "sigmoid": + expert_bias = jnp.arange(num_experts, dtype=jnp.float32) * 0.1 + expert_bias = jnp.flip(expert_bias) + else: + expert_bias = None + + # Forward: reference (jitted) + ref_fwd_fn = jax.jit( + partial( + reference_topk_softmax_sigmoid, + topk=topk, + use_pre_softmax=use_pre_softmax, + num_groups=num_groups, + group_topk=group_topk, + scaling_factor=scaling_factor, + score_function=score_function, + expert_bias=expert_bias, + ) + ) + probs_ref, routing_map_ref = ref_fwd_fn(logits) + + # Forward: fused (jitted) + fused_fwd_fn = jax.jit( + partial( + fused_topk_with_score_function, + topk=topk, + use_pre_softmax=use_pre_softmax, + num_groups=num_groups if num_groups else -1, + group_topk=group_topk if group_topk else -1, + scaling_factor=scaling_factor if scaling_factor else 1.0, + score_function=score_function, + expert_bias=expert_bias, + ) + ) + probs_fused, routing_map_fused = fused_fwd_fn(logits) + + assert jnp.allclose( + probs_ref, probs_fused, atol=1e-5, rtol=1e-5 + ), f"Probs mismatch: max diff = {jnp.abs(probs_ref - probs_fused).max()}" + assert jnp.array_equal(routing_map_ref, routing_map_fused), "Routing map mismatch" + + # Backward: reference (jitted) + def loss_ref(logits_): + p, _ = reference_topk_softmax_sigmoid( + logits_, + topk, + use_pre_softmax, + num_groups, + group_topk, + scaling_factor, + score_function, + expert_bias, + ) + return p.sum() + + def loss_fused(logits_): + p, _ = fused_topk_with_score_function( + logits_, + topk, + use_pre_softmax, + num_groups if num_groups else -1, + group_topk if group_topk else -1, + scaling_factor if scaling_factor else 1.0, + score_function, + expert_bias, + ) + return p.sum() + + grad_ref = jax.jit(jax.grad(loss_ref))(logits) + grad_fused = jax.jit(jax.grad(loss_fused))(logits) + assert jnp.allclose( + grad_ref, grad_fused, atol=1e-5, rtol=1e-5 + ), f"Grad mismatch: max diff = {jnp.abs(grad_ref - grad_fused).max()}" + + +@pytest_parametrize_wrapper("dtype", DTYPES) +@pytest_parametrize_wrapper( + "num_tokens,num_experts,topk", + TOPK_CASES, +) +@pytest_parametrize_wrapper("group_topk", GROUP_TOPK_OPTIONS) +@pytest_parametrize_wrapper("scaling_factor", SCALING_FACTOR_OPTIONS) +@pytest_parametrize_wrapper("enable_bias", ENABLE_BIAS_OPTIONS) +def test_topk_sigmoid( + dtype, num_tokens, num_experts, topk, group_topk, scaling_factor, enable_bias +): + num_groups = 8 if group_topk else None + run_topk_comparison( + dtype=dtype, + num_tokens=num_tokens, + num_experts=num_experts, + topk=topk, + use_pre_softmax=False, + num_groups=num_groups, + group_topk=group_topk, + scaling_factor=scaling_factor, + score_function="sigmoid", + enable_bias=enable_bias, + ) + + +@pytest_parametrize_wrapper("dtype", DTYPES) +@pytest_parametrize_wrapper( + "num_tokens,num_experts,topk", + TOPK_CASES, +) +@pytest_parametrize_wrapper("use_pre_softmax", USE_PRE_SOFTMAX_OPTIONS) +@pytest_parametrize_wrapper("group_topk", GROUP_TOPK_OPTIONS) +@pytest_parametrize_wrapper("scaling_factor", SCALING_FACTOR_OPTIONS) +def test_topk_softmax( + dtype, num_tokens, num_experts, topk, use_pre_softmax, group_topk, scaling_factor +): + num_groups = 8 if group_topk else None + run_topk_comparison( + dtype=dtype, + num_tokens=num_tokens, + num_experts=num_experts, + topk=topk, + use_pre_softmax=use_pre_softmax, + num_groups=num_groups, + group_topk=group_topk, + scaling_factor=scaling_factor, + score_function="softmax", + enable_bias=False, + ) + + +# ============================================================================= +# Test: Fused Score for MoE Aux Loss +# ============================================================================= + + +@pytest_parametrize_wrapper("dtype", DTYPES) +@pytest_parametrize_wrapper( + "num_tokens,num_experts,topk", + SCORE_AUX_LOSS_CASES, +) +@pytest_parametrize_wrapper("score_function", SCORE_FUNCTIONS) +def test_fused_scores_for_aux_loss(dtype, num_tokens, num_experts, topk, score_function): + logits = make_logits(num_tokens, num_experts, score_function, dtype) + + # Forward: reference (jitted) + ref_fwd_fn = jax.jit( + partial( + reference_compute_scores_for_aux_loss, + topk=topk, + score_function=score_function, + ) + ) + routing_map_ref, scores_ref = ref_fwd_fn(logits) + + # Forward: fused (jitted) + fused_fwd_fn = jax.jit( + partial( + fused_topk_with_score_function, + topk=topk, + score_function=score_function, + compute_aux_scores=True, + ) + ) + scores_fused, routing_map_fused = fused_fwd_fn(logits) + + assert jnp.allclose( + scores_ref, scores_fused, atol=1e-5, rtol=1e-5 + ), f"Scores mismatch: max diff = {jnp.abs(scores_ref - scores_fused).max()}" + assert jnp.array_equal(routing_map_ref, routing_map_fused), "Routing map mismatch" + + # Backward (jitted) + def loss_ref(logits_): + _, s = reference_compute_scores_for_aux_loss(logits_, topk, score_function) + return s.sum() + + def loss_fused(logits_): + s, _ = fused_topk_with_score_function( + logits_, + topk, + score_function=score_function, + compute_aux_scores=True, + ) + return s.sum() + + grad_ref = jax.jit(jax.grad(loss_ref))(logits) + grad_fused = jax.jit(jax.grad(loss_fused))(logits) + assert jnp.allclose( + grad_ref, grad_fused, atol=1e-5, rtol=1e-5 + ), f"Grad mismatch: max diff = {jnp.abs(grad_ref - grad_fused).max()}" + + +# ============================================================================= +# Test: Fused MoE Aux Loss +# ============================================================================= + + +@pytest_parametrize_wrapper("dtype", DTYPES) +@pytest_parametrize_wrapper( + "num_tokens,num_experts,topk", + AUX_LOSS_CASES, +) +def test_fused_moe_aux_loss(dtype, num_tokens, num_experts, topk): + key = jax.random.PRNGKey(SEED) + + offset = jnp.arange(-num_tokens // 2, num_tokens // 2, dtype=dtype) * 1e-4 + probs = jnp.arange(-num_experts // 2, num_experts // 2, dtype=dtype) * 1e-2 + probs = probs[None, :].repeat(num_tokens, axis=0) + offset[:, None] + probs = probs.reshape(num_tokens, num_experts) + + tokens_per_expert = jax.random.randint(key, (num_experts,), 1, 1000).astype(jnp.int32) + coeff = 0.01 + + # Forward: reference (jitted) + ref_fwd_fn = jax.jit( + partial( + reference_aux_loss, + tokens_per_expert=tokens_per_expert, + total_num_tokens=num_tokens, + topk=topk, + num_experts=num_experts, + moe_aux_loss_coeff=coeff, + ) + ) + aux_loss_ref = ref_fwd_fn(probs) + + # Forward: fused (jitted) + fused_fwd_fn = jax.jit( + partial( + fused_moe_aux_loss, + tokens_per_expert=tokens_per_expert, + total_num_tokens=num_tokens, + num_experts=num_experts, + topk=topk, + coeff=coeff, + ) + ) + aux_loss_fused = fused_fwd_fn(probs) + + assert jnp.allclose( + aux_loss_ref, aux_loss_fused, atol=1e-5, rtol=1e-5 + ), f"Aux loss mismatch: ref={aux_loss_ref}, fused={aux_loss_fused}" + + # Backward (jitted) + def loss_ref_fn(probs_): + return reference_aux_loss(probs_, tokens_per_expert, num_tokens, topk, num_experts, coeff) + + def loss_fused_fn(probs_): + return fused_moe_aux_loss(probs_, tokens_per_expert, num_tokens, num_experts, topk, coeff) + + grad_ref = jax.jit(jax.grad(loss_ref_fn))(probs) + grad_fused = jax.jit(jax.grad(loss_fused_fn))(probs) + assert jnp.allclose( + grad_ref, grad_fused, atol=1e-5, rtol=1e-5 + ), f"Grad mismatch: max diff = {jnp.abs(grad_ref - grad_fused).max()}" diff --git a/transformer_engine/jax/cpp_extensions/__init__.py b/transformer_engine/jax/cpp_extensions/__init__.py index 6a2f9b7378..d203fcea9d 100644 --- a/transformer_engine/jax/cpp_extensions/__init__.py +++ b/transformer_engine/jax/cpp_extensions/__init__.py @@ -9,3 +9,4 @@ from .quantization import * from .softmax import * from .gemm import * +from .router import * diff --git a/transformer_engine/jax/cpp_extensions/router.py b/transformer_engine/jax/cpp_extensions/router.py new file mode 100644 index 0000000000..ad25802841 --- /dev/null +++ b/transformer_engine/jax/cpp_extensions/router.py @@ -0,0 +1,752 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""JAX/TE custom ops for fused MoE router""" +from enum import IntEnum + +import jax.numpy as jnp +from jax import dtypes, ffi +from jax.sharding import NamedSharding, PartitionSpec + +from .base import BasePrimitive, register_primitive +from .misc import get_padded_spec + +__all__ = [ + "ScoreFunction", + "fused_topk_with_score_function_fwd", + "fused_topk_with_score_function_bwd", + "fused_moe_aux_loss_fwd", + "fused_moe_aux_loss_bwd", +] + + +class ScoreFunction(IntEnum): + """Score function enum for fused MoE router kernels.""" + + SIGMOID = 0 + SOFTMAX = 1 + + +# =========================================== ================================== +# Fused Top-K with Score Function - Forward +# ============================================================================= + + +class FusedTopkWithScoreFunctionFwdPrimitive(BasePrimitive): + """ + Fused Top-K with Score Function Forward Primitive. + Computes score_function(logits) -> top-k -> probs, routing_map. + When compute_aux_scores=1, instead computes clean scores for aux loss. + """ + + name = "te_fused_topk_with_score_function_forward_ffi" + multiple_results = True + impl_static_args = ( + 2, + 3, + 4, + 5, + 6, + 7, + 8, + ) # topk, use_pre_softmax, num_groups, group_topk, scaling_factor, score_function, compute_aux_scores + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract( + logits_aval, + expert_bias_aval, + topk, + use_pre_softmax, + num_groups, + group_topk, + scaling_factor, + score_function, + compute_aux_scores, + ): + """Abstract evaluation: describe output shapes and dtypes.""" + del expert_bias_aval, topk, use_pre_softmax, num_groups, group_topk + del scaling_factor, score_function, compute_aux_scores + i_dtype = dtypes.canonicalize_dtype(logits_aval.dtype) + i_shape = logits_aval.shape + probs_aval = logits_aval.update(shape=i_shape, dtype=i_dtype) + routing_map_aval = logits_aval.update(shape=i_shape, dtype=jnp.bool_) + intermediate_aval = logits_aval.update(shape=i_shape, dtype=i_dtype) + return probs_aval, routing_map_aval, intermediate_aval + + @staticmethod + def lowering( + ctx, + logits, + expert_bias, + *, + topk, + use_pre_softmax, + num_groups, + group_topk, + scaling_factor, + score_function, + compute_aux_scores, + ): + return ffi.ffi_lowering(FusedTopkWithScoreFunctionFwdPrimitive.name)( + ctx, + logits, + expert_bias, + topk=topk, + use_pre_softmax=use_pre_softmax, + num_groups=num_groups, + group_topk=group_topk, + scaling_factor=scaling_factor, + score_function=score_function, + compute_aux_scores=compute_aux_scores, + ) + + @staticmethod + def impl( + logits, + expert_bias, + topk, + use_pre_softmax, + num_groups, + group_topk, + scaling_factor, + score_function, + compute_aux_scores, + ): + assert FusedTopkWithScoreFunctionFwdPrimitive.inner_primitive is not None + return FusedTopkWithScoreFunctionFwdPrimitive.inner_primitive.bind( + logits, + expert_bias, + topk=topk, + use_pre_softmax=use_pre_softmax, + num_groups=num_groups, + group_topk=group_topk, + scaling_factor=scaling_factor, + score_function=score_function, + compute_aux_scores=compute_aux_scores, + ) + + @staticmethod + def batcher( + batched_args, + batch_dims, + *, + topk, + use_pre_softmax, + num_groups, + group_topk, + scaling_factor, + score_function, + compute_aux_scores, + ): + assert FusedTopkWithScoreFunctionFwdPrimitive.outer_primitive is not None + logits, expert_bias = batched_args + logits_bdim, _ = batch_dims + return ( + FusedTopkWithScoreFunctionFwdPrimitive.outer_primitive.bind( + logits, + expert_bias, + topk=topk, + use_pre_softmax=use_pre_softmax, + num_groups=num_groups, + group_topk=group_topk, + scaling_factor=scaling_factor, + score_function=score_function, + compute_aux_scores=compute_aux_scores, + ), + (logits_bdim, logits_bdim, logits_bdim), + ) + + @staticmethod + def infer_sharding_from_operands( + topk, + use_pre_softmax, + num_groups, + group_topk, + scaling_factor, + score_function, + compute_aux_scores, + mesh, + arg_infos, + result_infos, + ): + del topk, use_pre_softmax, num_groups, group_topk, scaling_factor + del score_function, compute_aux_scores, result_infos + logits_spec = get_padded_spec(arg_infos[0]) + out_sharding = NamedSharding(mesh, PartitionSpec(*logits_spec)) + routing_sharding = NamedSharding(mesh, PartitionSpec(*logits_spec)) + intermediate_sharding = NamedSharding(mesh, PartitionSpec(*logits_spec)) + return [out_sharding, routing_sharding, intermediate_sharding] + + @staticmethod + def partition( + topk, + use_pre_softmax, + num_groups, + group_topk, + scaling_factor, + score_function, + compute_aux_scores, + mesh, + arg_infos, + result_infos, + ): + del result_infos + logits_spec = get_padded_spec(arg_infos[0]) + out_sharding = NamedSharding(mesh, PartitionSpec(*logits_spec)) + routing_sharding = NamedSharding(mesh, PartitionSpec(*logits_spec)) + intermediate_sharding = NamedSharding(mesh, PartitionSpec(*logits_spec)) + out_shardings = [out_sharding, routing_sharding, intermediate_sharding] + arg_shardings = (arg_infos[0].sharding, arg_infos[1].sharding) + + def sharded_impl(logits, expert_bias): + return FusedTopkWithScoreFunctionFwdPrimitive.impl( + logits, + expert_bias, + topk, + use_pre_softmax, + num_groups, + group_topk, + scaling_factor, + score_function, + compute_aux_scores, + ) + + return mesh, sharded_impl, out_shardings, arg_shardings + + @staticmethod + def shardy_sharding_rule(*args): + del args + return ( + "num_tokens num_experts, num_experts -> num_tokens num_experts, num_tokens num_experts," + " num_tokens num_experts" + ) + + +register_primitive(FusedTopkWithScoreFunctionFwdPrimitive) + + +# ============================================================================= +# Fused Top-K with Score Function - Backward +# ============================================================================= + + +class FusedTopkWithScoreFunctionBwdPrimitive(BasePrimitive): + """ + Fused Top-K with Score Function Backward Primitive. + When compute_aux_scores=1, runs the score-for-aux-loss backward instead. + """ + + name = "te_fused_topk_with_score_function_backward_ffi" + multiple_results = False + impl_static_args = ( + 3, + 4, + 5, + 6, + 7, + ) # topk, use_pre_softmax, scaling_factor, score_function, compute_aux_scores + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract( + routing_map_aval, + intermediate_aval, + grad_probs_aval, + topk, + use_pre_softmax, + scaling_factor, + score_function, + compute_aux_scores, + ): + del topk, use_pre_softmax, scaling_factor, score_function + del compute_aux_scores, routing_map_aval + return intermediate_aval.update( + shape=intermediate_aval.shape, + dtype=dtypes.canonicalize_dtype(grad_probs_aval.dtype), + ) + + @staticmethod + def lowering( + ctx, + routing_map, + intermediate, + grad_probs, + *, + topk, + use_pre_softmax, + scaling_factor, + score_function, + compute_aux_scores, + ): + return ffi.ffi_lowering(FusedTopkWithScoreFunctionBwdPrimitive.name)( + ctx, + routing_map, + intermediate, + grad_probs, + topk=topk, + use_pre_softmax=use_pre_softmax, + scaling_factor=scaling_factor, + score_function=score_function, + compute_aux_scores=compute_aux_scores, + ) + + @staticmethod + def impl( + routing_map, + intermediate, + grad_probs, + topk, + use_pre_softmax, + scaling_factor, + score_function, + compute_aux_scores, + ): + assert FusedTopkWithScoreFunctionBwdPrimitive.inner_primitive is not None + return FusedTopkWithScoreFunctionBwdPrimitive.inner_primitive.bind( + routing_map, + intermediate, + grad_probs, + topk=topk, + use_pre_softmax=use_pre_softmax, + scaling_factor=scaling_factor, + score_function=score_function, + compute_aux_scores=compute_aux_scores, + ) + + @staticmethod + def batcher( + batched_args, + batch_dims, + *, + topk, + use_pre_softmax, + scaling_factor, + score_function, + compute_aux_scores, + ): + assert FusedTopkWithScoreFunctionBwdPrimitive.outer_primitive is not None + routing_map, intermediate, grad_probs = batched_args + _, _, grad_probs_bdim = batch_dims + return ( + FusedTopkWithScoreFunctionBwdPrimitive.outer_primitive.bind( + routing_map, + intermediate, + grad_probs, + topk=topk, + use_pre_softmax=use_pre_softmax, + scaling_factor=scaling_factor, + score_function=score_function, + compute_aux_scores=compute_aux_scores, + ), + grad_probs_bdim, + ) + + @staticmethod + def infer_sharding_from_operands( + topk, + use_pre_softmax, + scaling_factor, + score_function, + compute_aux_scores, + mesh, + arg_infos, + result_infos, + ): + del topk, use_pre_softmax, scaling_factor, score_function + del compute_aux_scores, result_infos + grad_spec = get_padded_spec(arg_infos[2]) + return NamedSharding(mesh, PartitionSpec(*grad_spec)) + + @staticmethod + def partition( + topk, + use_pre_softmax, + scaling_factor, + score_function, + compute_aux_scores, + mesh, + arg_infos, + result_infos, + ): + del result_infos + grad_spec = get_padded_spec(arg_infos[2]) + out_sharding = NamedSharding(mesh, PartitionSpec(*grad_spec)) + arg_shardings = (arg_infos[0].sharding, arg_infos[1].sharding, arg_infos[2].sharding) + + def sharded_impl(routing_map, intermediate, grad_probs): + return FusedTopkWithScoreFunctionBwdPrimitive.impl( + routing_map, + intermediate, + grad_probs, + topk, + use_pre_softmax, + scaling_factor, + score_function, + compute_aux_scores, + ) + + return mesh, sharded_impl, out_sharding, arg_shardings + + @staticmethod + def shardy_sharding_rule(*args): + del args + return ( + "num_tokens num_experts, num_tokens num_experts, num_tokens num_experts -> num_tokens" + " num_experts" + ) + + +register_primitive(FusedTopkWithScoreFunctionBwdPrimitive) + + +# ============================================================================= +# Fused MoE Aux Loss - Forward +# ============================================================================= + + +class FusedMoEAuxLossFwdPrimitive(BasePrimitive): + """ + Fused MoE Aux Loss Forward Primitive. + """ + + name = "te_fused_moe_aux_loss_forward_ffi" + multiple_results = True + impl_static_args = (2, 3, 4, 5) # total_num_tokens, num_experts, topk, coeff + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract(probs_aval, tokens_per_expert_aval, total_num_tokens, num_experts, topk, coeff): + del total_num_tokens, num_experts, topk, coeff, tokens_per_expert_aval + i_dtype = dtypes.canonicalize_dtype(probs_aval.dtype) + aux_loss_aval = probs_aval.update(shape=(1,), dtype=i_dtype) + const_buf_aval = probs_aval.update(shape=(1,), dtype=jnp.float32) + return aux_loss_aval, const_buf_aval + + @staticmethod + def lowering(ctx, probs, tokens_per_expert, *, total_num_tokens, num_experts, topk, coeff): + return ffi.ffi_lowering(FusedMoEAuxLossFwdPrimitive.name)( + ctx, + probs, + tokens_per_expert, + total_num_tokens=total_num_tokens, + num_experts=num_experts, + topk=topk, + coeff=coeff, + ) + + @staticmethod + def impl(probs, tokens_per_expert, total_num_tokens, num_experts, topk, coeff): + assert FusedMoEAuxLossFwdPrimitive.inner_primitive is not None + return FusedMoEAuxLossFwdPrimitive.inner_primitive.bind( + probs, + tokens_per_expert, + total_num_tokens=total_num_tokens, + num_experts=num_experts, + topk=topk, + coeff=coeff, + ) + + @staticmethod + def batcher(batched_args, batch_dims, *, total_num_tokens, num_experts, topk, coeff): + assert FusedMoEAuxLossFwdPrimitive.outer_primitive is not None + probs, tokens_per_expert = batched_args + probs_bdim, _ = batch_dims + return ( + FusedMoEAuxLossFwdPrimitive.outer_primitive.bind( + probs, + tokens_per_expert, + total_num_tokens=total_num_tokens, + num_experts=num_experts, + topk=topk, + coeff=coeff, + ), + (probs_bdim, probs_bdim), + ) + + @staticmethod + def infer_sharding_from_operands( + total_num_tokens, + num_experts, + topk, + coeff, + mesh, + arg_infos, + result_infos, + ): + del total_num_tokens, num_experts, topk, coeff, arg_infos, result_infos + scalar_sharding = NamedSharding(mesh, PartitionSpec(None)) + return [scalar_sharding, scalar_sharding] + + @staticmethod + def partition( + total_num_tokens, + num_experts, + topk, + coeff, + mesh, + arg_infos, + result_infos, + ): + del result_infos + scalar_sharding = NamedSharding(mesh, PartitionSpec(None)) + out_shardings = [scalar_sharding, scalar_sharding] + arg_shardings = (arg_infos[0].sharding, arg_infos[1].sharding) + + def sharded_impl(probs, tokens_per_expert): + return FusedMoEAuxLossFwdPrimitive.impl( + probs, + tokens_per_expert, + total_num_tokens, + num_experts, + topk, + coeff, + ) + + return mesh, sharded_impl, out_shardings, arg_shardings + + @staticmethod + def shardy_sharding_rule(*args): + del args + return "num_tokens num_experts, num_experts -> aux_loss_one, const_buf_one" + + +register_primitive(FusedMoEAuxLossFwdPrimitive) + + +# ============================================================================= +# Fused MoE Aux Loss - Backward +# ============================================================================= + + +class FusedMoEAuxLossBwdPrimitive(BasePrimitive): + """ + Fused MoE Aux Loss Backward Primitive. + """ + + name = "te_fused_moe_aux_loss_backward_ffi" + multiple_results = False + impl_static_args = (3, 4) # num_rows, num_cols + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract(const_buf_aval, tokens_per_expert_aval, grad_aux_loss_aval, num_rows, num_cols): + del const_buf_aval, tokens_per_expert_aval + out_dtype = dtypes.canonicalize_dtype(grad_aux_loss_aval.dtype) + return grad_aux_loss_aval.update( + shape=(num_rows, num_cols), + dtype=out_dtype, + ) + + @staticmethod + def lowering(ctx, const_buf, tokens_per_expert, grad_aux_loss, *, num_rows, num_cols): + return ffi.ffi_lowering(FusedMoEAuxLossBwdPrimitive.name)( + ctx, + const_buf, + tokens_per_expert, + grad_aux_loss, + num_rows=num_rows, + num_cols=num_cols, + ) + + @staticmethod + def impl(const_buf, tokens_per_expert, grad_aux_loss, num_rows, num_cols): + assert FusedMoEAuxLossBwdPrimitive.inner_primitive is not None + return FusedMoEAuxLossBwdPrimitive.inner_primitive.bind( + const_buf, tokens_per_expert, grad_aux_loss, num_rows=num_rows, num_cols=num_cols + ) + + @staticmethod + def batcher(batched_args, batch_dims, *, num_rows, num_cols): + assert FusedMoEAuxLossBwdPrimitive.outer_primitive is not None + const_buf, tokens_per_expert, grad_aux_loss = batched_args + _, _, grad_bdim = batch_dims + return ( + FusedMoEAuxLossBwdPrimitive.outer_primitive.bind( + const_buf, tokens_per_expert, grad_aux_loss, num_rows=num_rows, num_cols=num_cols + ), + grad_bdim, + ) + + @staticmethod + def infer_sharding_from_operands( + num_rows, + num_cols, + mesh, + arg_infos, + result_infos, + ): + del num_rows, num_cols, arg_infos, result_infos + return NamedSharding(mesh, PartitionSpec(None, None)) + + @staticmethod + def partition( + num_rows, + num_cols, + mesh, + arg_infos, + result_infos, + ): + del result_infos + out_sharding = NamedSharding(mesh, PartitionSpec(None, None)) + arg_shardings = ( + arg_infos[0].sharding, + arg_infos[1].sharding, + arg_infos[2].sharding, + ) + + def sharded_impl(const_buf, tokens_per_expert, grad_aux_loss): + return FusedMoEAuxLossBwdPrimitive.impl( + const_buf, + tokens_per_expert, + grad_aux_loss, + num_rows, + num_cols, + ) + + return mesh, sharded_impl, out_sharding, arg_shardings + + @staticmethod + def shardy_sharding_rule(*args): + del args + return "const_buf_one, num_experts, grad_one -> num_tokens num_experts" + + +register_primitive(FusedMoEAuxLossBwdPrimitive) + + +# ============================================================================= +# Public API functions +# ============================================================================= + + +def fused_topk_with_score_function_fwd( + logits: jnp.ndarray, + topk: int, + use_pre_softmax: bool, + num_groups: int, + group_topk: int, + scaling_factor: float, + score_function, + expert_bias: jnp.ndarray, + compute_aux_scores: bool = False, +): + """ + Fused top-k with score function forward pass. + + When compute_aux_scores=True, runs the clean score-for-aux-loss kernel + instead of the full top-k kernel (expert_bias, use_pre_softmax, num_groups, + group_topk, and scaling_factor are ignored). + + Parameters + ---------- + logits : jnp.ndarray + [num_tokens, num_experts] logits from gating GEMM. + topk : int + Number of top experts to select. + use_pre_softmax : bool + If True, apply softmax before top-k. + num_groups : int + Number of groups for grouped top-k (1 to disable). + group_topk : int + Top-k at group level (1 to disable). + scaling_factor : float + Scaling factor for output probs. + score_function : ScoreFunction + ScoreFunction.SOFTMAX or ScoreFunction.SIGMOID. + expert_bias : jnp.ndarray + Expert bias (only used with sigmoid). Pass empty array if unused. + compute_aux_scores : bool + If True, compute clean scores for aux loss instead of full top-k. + + Returns + ------- + probs_or_scores, routing_map, intermediate_output + """ + return FusedTopkWithScoreFunctionFwdPrimitive.outer_primitive.bind( + logits, + expert_bias, + topk=int(topk), + use_pre_softmax=int(use_pre_softmax), + num_groups=int(num_groups), + group_topk=int(group_topk), + scaling_factor=float(scaling_factor), + score_function=int(score_function), + compute_aux_scores=int(compute_aux_scores), + ) + + +def fused_topk_with_score_function_bwd( + routing_map: jnp.ndarray, + intermediate_output: jnp.ndarray, + grad_probs: jnp.ndarray, + topk: int, + use_pre_softmax: bool, + scaling_factor: float, + score_function, + compute_aux_scores: bool = False, +): + """ + Fused top-k with score function backward pass. + + When compute_aux_scores=True, routing_map is ignored and the + score-for-aux-loss backward kernel is used instead. + """ + return FusedTopkWithScoreFunctionBwdPrimitive.outer_primitive.bind( + routing_map, + intermediate_output, + grad_probs, + topk=int(topk), + use_pre_softmax=int(use_pre_softmax), + scaling_factor=float(scaling_factor), + score_function=int(score_function), + compute_aux_scores=int(compute_aux_scores), + ) + + +def fused_moe_aux_loss_fwd( + probs: jnp.ndarray, + tokens_per_expert: jnp.ndarray, + total_num_tokens: int, + num_experts: int, + topk: int, + coeff: float, +): + """ + Fused MoE aux loss forward pass. + + Returns + ------- + aux_loss, const_buf + """ + return FusedMoEAuxLossFwdPrimitive.outer_primitive.bind( + probs, + tokens_per_expert, + total_num_tokens=int(total_num_tokens), + num_experts=int(num_experts), + topk=int(topk), + coeff=float(coeff), + ) + + +def fused_moe_aux_loss_bwd( + const_buf: jnp.ndarray, + tokens_per_expert: jnp.ndarray, + grad_aux_loss: jnp.ndarray, + num_rows: int, + num_cols: int, +): + """ + Fused MoE aux loss backward pass. + """ + return FusedMoEAuxLossBwdPrimitive.outer_primitive.bind( + const_buf, + tokens_per_expert, + grad_aux_loss, + num_rows=int(num_rows), + num_cols=int(num_cols), + ) diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 1c0bc52b88..0c459fd9e7 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -152,6 +152,12 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(CudnnHandleInitHandler); // CuBLAS helpers XLA_FFI_DECLARE_HANDLER_SYMBOL(CublasHandleInitHandler); +// Router +XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedTopkWithScoreFunctionForwardHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedTopkWithScoreFunctionBackwardHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedMoEAuxLossForwardHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedMoEAuxLossBackwardHandler); + } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 71de897d9b..06c3a2348f 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -81,6 +81,13 @@ pybind11::dict Registrations() { pybind11::arg("initialize") = EncapsulateFFI(RHTAmaxCalculationInitializeHandler), pybind11::arg("execute") = EncapsulateFFI(RHTAmaxCalculationHandler)); + // Router + dict["te_fused_topk_with_score_function_forward_ffi"] = + EncapsulateFFI(FusedTopkWithScoreFunctionForwardHandler); + dict["te_fused_topk_with_score_function_backward_ffi"] = + EncapsulateFFI(FusedTopkWithScoreFunctionBackwardHandler); + dict["te_fused_moe_aux_loss_forward_ffi"] = EncapsulateFFI(FusedMoEAuxLossForwardHandler); + dict["te_fused_moe_aux_loss_backward_ffi"] = EncapsulateFFI(FusedMoEAuxLossBackwardHandler); dict["te_inspect_ffi"] = pybind11::dict(pybind11::arg("execute") = EncapsulateFFI(InspectHandler)); diff --git a/transformer_engine/jax/csrc/extensions/router.cpp b/transformer_engine/jax/csrc/extensions/router.cpp new file mode 100644 index 0000000000..ce811473c2 --- /dev/null +++ b/transformer_engine/jax/csrc/extensions/router.cpp @@ -0,0 +1,255 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include + +#include "../extensions.h" +#include "xla/ffi/api/c_api.h" + +namespace transformer_engine { +namespace jax { + +enum class ScoreFunction : int64_t { + kSigmoid = 0, + kSoftmax = 1, +}; + +// Compute num_tokens as the product of all dimensions except the last (num_experts). +// Supports arbitrary-rank inputs (e.g., [B, S, E] or [num_tokens, E]). +template +static int compute_num_tokens(const Dims &dims) { + int num_tokens = 1; + for (size_t i = 0; i + 1 < dims.size(); ++i) { + num_tokens *= static_cast(dims[i]); + } + return num_tokens; +} + +// ============================================================================ +// Fused Top-K with Score Function - Forward +// ============================================================================ + +Error_Type FusedTopkWithScoreFunctionForwardFFI( + cudaStream_t stream, + Buffer_Type logits_buf, // [num_tokens, num_experts] + Buffer_Type expert_bias_buf, // [num_experts] or empty + Result_Type probs_buf, // [num_tokens, num_experts] (or scores when compute_aux_scores) + Result_Type routing_map_buf, // [num_tokens, num_experts] + Result_Type intermediate_buf, // [num_tokens, num_experts] + int64_t topk, int64_t use_pre_softmax, int64_t num_groups, int64_t group_topk, + double scaling_factor, int64_t score_function, int64_t compute_aux_scores) { + auto dtype = convert_ffi_datatype_to_te_dtype(logits_buf.element_type()); + auto dims = logits_buf.dimensions(); + auto num_tokens = compute_num_tokens(dims); + auto num_experts = static_cast(dims[dims.size() - 1]); + + auto *logits = logits_buf.untyped_data(); + auto *expert_bias = expert_bias_buf.untyped_data(); + auto *probs = probs_buf->untyped_data(); + auto *routing_map = routing_map_buf->untyped_data(); + auto *intermediate = intermediate_buf->untyped_data(); + + auto flat_shape = + std::vector{static_cast(num_tokens), static_cast(num_experts)}; + auto logits_tensor = TensorWrapper(logits, flat_shape, dtype); + auto probs_tensor = TensorWrapper(probs, flat_shape, dtype); + auto routing_map_tensor = TensorWrapper(routing_map, flat_shape, DType::kByte); + auto intermediate_tensor = TensorWrapper(intermediate, flat_shape, dtype); + + if (compute_aux_scores) { + nvte_fused_score_for_moe_aux_loss_forward( + logits_tensor.data(), num_tokens, num_experts, static_cast(topk), + static_cast(score_function), probs_tensor.data(), routing_map_tensor.data(), + intermediate_tensor.data(), stream); + } else { + auto bias_dims = expert_bias_buf.dimensions(); + auto expert_bias_tensor = + (bias_dims.size() > 0 && bias_dims[0] > 0) + ? TensorWrapper(expert_bias, std::vector{static_cast(bias_dims[0])}, + convert_ffi_datatype_to_te_dtype(expert_bias_buf.element_type())) + : TensorWrapper(); + + nvte_fused_topk_with_score_function_forward( + logits_tensor.data(), num_tokens, num_experts, static_cast(topk), + static_cast(use_pre_softmax), static_cast(num_groups), + static_cast(group_topk), static_cast(scaling_factor), + static_cast(score_function), expert_bias_tensor.data(), probs_tensor.data(), + routing_map_tensor.data(), intermediate_tensor.data(), stream); + } + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedTopkWithScoreFunctionForwardHandler, + FusedTopkWithScoreFunctionForwardFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // logits + .Arg() // expert_bias + .Ret() // probs (or scores) + .Ret() // routing_map + .Ret() // intermediate_output + .Attr("topk") + .Attr("use_pre_softmax") + .Attr("num_groups") + .Attr("group_topk") + .Attr("scaling_factor") + .Attr("score_function") + .Attr("compute_aux_scores"), + FFI_CudaGraph_Traits); + +// ============================================================================ +// Fused Top-K with Score Function - Backward +// ============================================================================ + +Error_Type FusedTopkWithScoreFunctionBackwardFFI( + cudaStream_t stream, + Buffer_Type routing_map_buf, // [num_tokens, num_experts] (unused when compute_aux_scores) + Buffer_Type intermediate_buf, // [num_tokens, num_experts] + Buffer_Type grad_probs_buf, // [num_tokens, num_experts] (grad_scores when compute_aux_scores) + Result_Type grad_logits_buf, // [num_tokens, num_experts] + int64_t topk, int64_t use_pre_softmax, double scaling_factor, int64_t score_function, + int64_t compute_aux_scores) { + auto dtype = convert_ffi_datatype_to_te_dtype(intermediate_buf.element_type()); + auto dims = intermediate_buf.dimensions(); + auto num_tokens = compute_num_tokens(dims); + auto num_experts = static_cast(dims[dims.size() - 1]); + + auto flat_shape = + std::vector{static_cast(num_tokens), static_cast(num_experts)}; + + auto intermediate_tensor = TensorWrapper(intermediate_buf.untyped_data(), flat_shape, dtype); + auto grad_probs_tensor = TensorWrapper(grad_probs_buf.untyped_data(), flat_shape, dtype); + auto grad_logits_tensor = TensorWrapper(grad_logits_buf->untyped_data(), flat_shape, dtype); + + if (compute_aux_scores) { + nvte_fused_score_for_moe_aux_loss_backward(intermediate_tensor.data(), grad_probs_tensor.data(), + num_tokens, num_experts, static_cast(topk), + static_cast(score_function), + grad_logits_tensor.data(), stream); + } else { + auto routing_map_tensor = + TensorWrapper(routing_map_buf.untyped_data(), flat_shape, DType::kByte); + + nvte_fused_topk_with_score_function_backward( + routing_map_tensor.data(), intermediate_tensor.data(), grad_probs_tensor.data(), num_tokens, + num_experts, static_cast(topk), static_cast(use_pre_softmax), + static_cast(scaling_factor), static_cast(score_function), + grad_logits_tensor.data(), stream); + } + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedTopkWithScoreFunctionBackwardHandler, + FusedTopkWithScoreFunctionBackwardFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // routing_map + .Arg() // intermediate_output + .Arg() // grad_probs + .Ret() // grad_logits + .Attr("topk") + .Attr("use_pre_softmax") + .Attr("scaling_factor") + .Attr("score_function") + .Attr("compute_aux_scores"), + FFI_CudaGraph_Traits); + +// ============================================================================ +// Fused MoE Aux Loss - Forward +// ============================================================================ + +Error_Type FusedMoEAuxLossForwardFFI(cudaStream_t stream, + Buffer_Type probs_buf, // [num_rows, num_cols] + Buffer_Type tokens_per_expert_buf, // [num_experts] + Result_Type aux_loss_buf, // scalar + Result_Type const_buf, // scalar + int64_t total_num_tokens, int64_t num_experts, int64_t topk, + double coeff) { + auto dtype = convert_ffi_datatype_to_te_dtype(probs_buf.element_type()); + auto probs_dims = probs_buf.dimensions(); + auto num_rows = static_cast(probs_dims[0]); + auto num_cols = static_cast(probs_dims[1]); + + auto probs_shape = + std::vector{static_cast(num_rows), static_cast(num_cols)}; + auto tpe_dtype = convert_ffi_datatype_to_te_dtype(tokens_per_expert_buf.element_type()); + auto tpe_shape = std::vector{static_cast(num_experts)}; + auto scalar_shape = std::vector{1}; + + auto probs_tensor = TensorWrapper(probs_buf.untyped_data(), probs_shape, dtype); + auto tpe_tensor = TensorWrapper(tokens_per_expert_buf.untyped_data(), tpe_shape, tpe_dtype); + auto aux_loss_tensor = TensorWrapper(aux_loss_buf->untyped_data(), scalar_shape, dtype); + auto const_buf_tensor = TensorWrapper(const_buf->untyped_data(), scalar_shape, DType::kFloat32); + + nvte_fused_moe_aux_loss_forward( + probs_tensor.data(), tpe_tensor.data(), static_cast(total_num_tokens), + static_cast(num_experts), num_rows, num_cols, static_cast(topk), + static_cast(coeff), aux_loss_tensor.data(), const_buf_tensor.data(), stream); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedMoEAuxLossForwardHandler, FusedMoEAuxLossForwardFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // probs + .Arg() // tokens_per_expert + .Ret() // aux_loss + .Ret() // Const_buf + .Attr("total_num_tokens") + .Attr("num_experts") + .Attr("topk") + .Attr("coeff"), + FFI_CudaGraph_Traits); + +// ============================================================================ +// Fused MoE Aux Loss - Backward +// ============================================================================ + +Error_Type FusedMoEAuxLossBackwardFFI(cudaStream_t stream, + Buffer_Type const_buf_in, // scalar float32 + Buffer_Type tokens_per_expert_buf, // [num_experts] + Buffer_Type grad_aux_loss_buf, // scalar + Result_Type grad_probs_buf, // [num_rows, num_cols] + int64_t num_rows, int64_t num_cols) { + auto grad_dtype = convert_ffi_datatype_to_te_dtype(grad_aux_loss_buf.element_type()); + auto tpe_dtype = convert_ffi_datatype_to_te_dtype(tokens_per_expert_buf.element_type()); + + auto scalar_shape = std::vector{1}; + auto tpe_dims = tokens_per_expert_buf.dimensions(); + auto tpe_shape = std::vector{static_cast(tpe_dims[0])}; + auto grad_probs_shape = + std::vector{static_cast(num_rows), static_cast(num_cols)}; + + auto const_buf_tensor = TensorWrapper(const_buf_in.untyped_data(), scalar_shape, DType::kFloat32); + auto tpe_tensor = TensorWrapper(tokens_per_expert_buf.untyped_data(), tpe_shape, tpe_dtype); + auto grad_aux_loss_tensor = + TensorWrapper(grad_aux_loss_buf.untyped_data(), scalar_shape, grad_dtype); + auto grad_probs_tensor = + TensorWrapper(grad_probs_buf->untyped_data(), grad_probs_shape, grad_dtype); + + nvte_fused_moe_aux_loss_backward(const_buf_tensor.data(), tpe_tensor.data(), + static_cast(num_rows), static_cast(num_cols), + grad_aux_loss_tensor.data(), grad_probs_tensor.data(), stream); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedMoEAuxLossBackwardHandler, FusedMoEAuxLossBackwardFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // Const_buf + .Arg() // tokens_per_expert + .Arg() // grad_aux_loss + .Ret() // grad_probs + .Attr("num_rows") + .Attr("num_cols"), + FFI_CudaGraph_Traits); + +} // namespace jax +} // namespace transformer_engine diff --git a/transformer_engine/jax/router.py b/transformer_engine/jax/router.py new file mode 100644 index 0000000000..c95287a5fb --- /dev/null +++ b/transformer_engine/jax/router.py @@ -0,0 +1,354 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Fused MoE Router API for JAX. + +This module provides high-level fused router operations for Mixture of Experts (MoE) +models with proper automatic differentiation support. These wrap the CUDA kernels in +transformer_engine/common/fused_router/. + +Functions: + fused_topk_with_score_function: + Fused score_function + top-k selection. Supports softmax/sigmoid, + grouped top-k, expert bias, and scaling factor. When compute_aux_scores=True, + switches to the clean score-for-aux-loss kernel (no bias/groups/scaling, + dense output). + + fused_moe_aux_loss: + Compute the MoE auxiliary load-balancing loss scalar. +""" + +from functools import partial +from typing import Optional, Tuple, Union + +import jax +import jax.numpy as jnp + +from transformer_engine.jax.cpp_extensions.router import ( + ScoreFunction, + fused_topk_with_score_function_fwd, + fused_topk_with_score_function_bwd, + fused_moe_aux_loss_fwd, + fused_moe_aux_loss_bwd, +) + +__all__ = [ + "ScoreFunction", + "fused_topk_with_score_function", + "fused_moe_aux_loss", +] + + +def _validate_score_function(score_function: Union[str, ScoreFunction]) -> ScoreFunction: + """Validate and convert score_function to a ScoreFunction enum.""" + if isinstance(score_function, ScoreFunction): + return score_function + try: + return ScoreFunction[score_function.upper()] + except (KeyError, AttributeError): + raise ValueError( + "score_function must be 'softmax', 'sigmoid', or a ScoreFunction enum, " + f"got {score_function!r}" + ) from None + + +# ============================================================================= +# Fused Top-K with Score Function +# ============================================================================= + + +def fused_topk_with_score_function( + logits: jnp.ndarray, + topk: int, + use_pre_softmax: bool = False, + num_groups: int = 1, + group_topk: int = 1, + scaling_factor: float = 1.0, + score_function: Union[str, ScoreFunction] = ScoreFunction.SOFTMAX, + expert_bias: Optional[jnp.ndarray] = None, + compute_aux_scores: bool = False, +) -> Tuple[jnp.ndarray, jnp.ndarray]: + """ + Fused top-k with score function router. + + When compute_aux_scores=False (default), runs the main routing kernel: + score_function(logits) -> [optional bias] -> top-k -> [optional post-softmax] -> scale. + Returns sparse probs (only top-k positions nonzero) and routing_map. + + When compute_aux_scores=True, runs the score-for-aux-loss kernel instead: + score_function(logits) -> top-k (clean, no bias/groups/scaling). + Returns dense scores (all expert positions) and routing_map. + The expert_bias, use_pre_softmax, num_groups, group_topk, and scaling_factor + parameters are ignored in this mode. + + Parameters + ---------- + logits : jnp.ndarray + Logits from the gating GEMM, shape [num_tokens, num_experts]. + topk : int + Number of top experts to select per token. + use_pre_softmax : bool + If True, apply softmax before top-k (only for softmax score function). Else, apply post top-k. + Ignored when compute_aux_scores=True. + num_groups : int + Number of groups for grouped top-k. 1 means no grouping. + Ignored when compute_aux_scores=True. + group_topk : int + Top-k at group level. 1 means no group-level selection. + Ignored when compute_aux_scores=True. + scaling_factor : float + Scaling factor applied to output probs. + Ignored when compute_aux_scores=True. + score_function : Union[str, ScoreFunction] + Score function: "softmax" / "sigmoid" or ScoreFunction.SOFTMAX / ScoreFunction.SIGMOID. + expert_bias : Optional[jnp.ndarray] + Expert bias, shape [num_experts]. Only used with sigmoid. + Ignored when compute_aux_scores=True. + compute_aux_scores : bool + If True, use the clean score-for-aux-loss kernel. Returns dense scores + over all experts instead of sparse probs. + + Returns + ------- + probs_or_scores : jnp.ndarray + When compute_aux_scores=False: Sparse probability tensor, shape [num_tokens, num_experts]. + Non-zero only at selected expert positions. + When compute_aux_scores=True: Dense score tensor, shape [num_tokens, num_experts]. + All expert positions contain scores. + routing_map : jnp.ndarray + Boolean mask, shape [num_tokens, num_experts]. + True at selected expert positions. + """ + score_function = _validate_score_function(score_function) + + if compute_aux_scores: + expert_bias = jnp.empty((0,), dtype=logits.dtype) + use_pre_softmax = False + num_groups = 1 + group_topk = 1 + scaling_factor = 1.0 + else: + if expert_bias is not None and score_function != ScoreFunction.SIGMOID: + raise ValueError( + "expert_bias is only supported with score_function='sigmoid'. " + f"Got score_function='{score_function.name}'." + ) + if expert_bias is None: + expert_bias = jnp.empty((0,), dtype=logits.dtype) + + probs_or_scores, routing_map = _fused_topk_with_score_function( + logits, + expert_bias, + topk, + use_pre_softmax, + num_groups, + group_topk, + scaling_factor, + score_function, + compute_aux_scores, + ) + + return probs_or_scores, routing_map + + +@partial(jax.custom_vjp, nondiff_argnums=(2, 3, 4, 5, 6, 7, 8)) +def _fused_topk_with_score_function( + logits: jnp.ndarray, + expert_bias: jnp.ndarray, + topk: int, + use_pre_softmax: bool, + num_groups: int, + group_topk: int, + scaling_factor: float, + score_function: ScoreFunction, + compute_aux_scores: bool, +) -> Tuple[jnp.ndarray, jnp.ndarray]: + (probs, routing_map), _ = _fused_topk_with_score_function_fwd( + logits, + expert_bias, + topk, + use_pre_softmax, + num_groups, + group_topk, + scaling_factor, + score_function, + compute_aux_scores, + ) + return probs, routing_map + + +def _fused_topk_with_score_function_fwd( + logits, + expert_bias, + topk, + use_pre_softmax, + num_groups, + group_topk, + scaling_factor, + score_function, + compute_aux_scores, +): + probs, routing_map, intermediate_output = fused_topk_with_score_function_fwd( + logits, + topk, + use_pre_softmax, + num_groups, + group_topk, + scaling_factor, + score_function, + expert_bias, + compute_aux_scores, + ) + residuals = (routing_map, intermediate_output) + return (probs, routing_map), residuals + + +def _fused_topk_with_score_function_bwd( + topk, + use_pre_softmax, + num_groups, # pylint: disable=unused-argument + group_topk, # pylint: disable=unused-argument + scaling_factor, + score_function, + compute_aux_scores, + residuals, + g, +): + routing_map, intermediate_output = residuals + grad_probs, _ = g # routing_map gradient is None (boolean) + + grad_logits = fused_topk_with_score_function_bwd( + routing_map, + intermediate_output, + grad_probs, + topk, + use_pre_softmax, + scaling_factor, + score_function, + compute_aux_scores, + ) + # expert_bias gradient is None: bias is not differentiated through this kernel + return grad_logits, None + + +_fused_topk_with_score_function.defvjp( + _fused_topk_with_score_function_fwd, + _fused_topk_with_score_function_bwd, +) + + +# ============================================================================= +# Fused MoE Aux Loss +# ============================================================================= + + +def fused_moe_aux_loss( + probs: jnp.ndarray, + tokens_per_expert: jnp.ndarray, + total_num_tokens: int, + num_experts: int, + topk: int, + coeff: float, +) -> jnp.ndarray: + """ + Compute the MoE auxiliary load-balancing loss. + + loss = (E * coeff / (k * T^2)) * sum_i(sum_t(probs[t,i]) * tokens_per_expert[i]) + + Parameters + ---------- + probs : jnp.ndarray + Probability/score tensor, shape [num_tokens, num_experts]. + tokens_per_expert : jnp.ndarray + Token counts per expert, shape [num_experts]. Integer tensor. + total_num_tokens : int + Total token count for normalization. + num_experts : int + Number of experts. + topk : int + Top-k value. + coeff : float + Loss coefficient. + + Returns + ------- + aux_loss : jnp.ndarray + Scalar loss value. + """ + return _fused_moe_aux_loss( + probs, + tokens_per_expert, + total_num_tokens, + num_experts, + topk, + coeff, + ) + + +@partial(jax.custom_vjp, nondiff_argnums=(2, 3, 4, 5)) +def _fused_moe_aux_loss( + probs: jnp.ndarray, + tokens_per_expert: jnp.ndarray, + total_num_tokens: int, + num_experts: int, + topk: int, + coeff: float, +) -> jnp.ndarray: + aux_loss, _ = _fused_moe_aux_loss_fwd( + probs, + tokens_per_expert, + total_num_tokens, + num_experts, + topk, + coeff, + ) + return aux_loss + + +def _fused_moe_aux_loss_fwd( + probs, + tokens_per_expert, + total_num_tokens, + num_experts, + topk, + coeff, +): + aux_loss, const_buf = fused_moe_aux_loss_fwd( + probs, + tokens_per_expert, + total_num_tokens, + num_experts, + topk, + coeff, + ) + residuals = (const_buf, tokens_per_expert, probs.shape[0], probs.shape[1]) + return aux_loss.squeeze(), residuals + + +def _fused_moe_aux_loss_bwd( + total_num_tokens, # pylint: disable=unused-argument + num_experts, # pylint: disable=unused-argument + topk, # pylint: disable=unused-argument + coeff, # pylint: disable=unused-argument + residuals, + g, +): + const_buf, tokens_per_expert, num_rows, num_cols = residuals + # g is a scalar matching the squeezed output of _fwd + grad_aux_loss = g.reshape(1) + + grad_probs = fused_moe_aux_loss_bwd( + const_buf, + tokens_per_expert, + grad_aux_loss, + num_rows, + num_cols, + ) + return grad_probs, None + + +_fused_moe_aux_loss.defvjp( + _fused_moe_aux_loss_fwd, + _fused_moe_aux_loss_bwd, +)