From e0e4e2c5fc0bb574d0c8d9bb7ba1d933fc1b6bc0 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Thu, 26 Feb 2026 16:17:29 -0800 Subject: [PATCH 1/5] replace Shardy error with warning Signed-off-by: Phuong Nguyen --- transformer_engine/jax/cpp_extensions/gemm.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 71f133bfc4..ba75bcd9c2 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -1164,9 +1164,16 @@ def shardy_sharding_rule( del mesh, result_types, transpose_batch_sequence, sequence_dim, is_outer if not collective_op.is_none: - raise NotImplementedError( - "CollectiveGEMM with Shardy propagation is not supported yet! Please turn off" - " Shardy by exporting env var JAX_USE_SHARDY_PARTITIONER=false" + warnings.warn( + "CollectiveGEMM with Shardy propagation may produce an incorrect sharding pattern" + " for the output.\n To resolve this, apply a sharding constraint on the output" + " using one of the following options:\n" + " - TE `dense` vjp: set `output_axes`.\n" + " - TE `layernorm_mlp` vjp: set `dot_2_input_axes`.\n" + " - TE `transformer_engine.jax.cpp_extensions.gemm`: apply" + " `jax.lax.with_sharding_constraint` on the output.\n" + " - TE via MaxText: no action needed.", + UserWarning, ) prefix = "Gemm_" From e4b02d5d93ff63ea5beaf81516e5651eb4de11cf Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Thu, 26 Feb 2026 16:22:35 -0800 Subject: [PATCH 2/5] cleanup tests Signed-off-by: Phuong Nguyen --- examples/jax/collective_gemm/common.py | 4 ---- examples/jax/collective_gemm/test_gemm.py | 5 +---- examples/jax/collective_gemm/test_layernorm_mlp_grad.py | 2 -- 3 files changed, 1 insertion(+), 10 deletions(-) diff --git a/examples/jax/collective_gemm/common.py b/examples/jax/collective_gemm/common.py index 0d812da057..2965896d07 100644 --- a/examples/jax/collective_gemm/common.py +++ b/examples/jax/collective_gemm/common.py @@ -131,10 +131,6 @@ def _initialize_distributed(args): ) _distributed_initialized = True - jax.clear_caches() - jax.config.update( - "jax_use_shardy_partitioner", False - ) # CollectiveGEMM does not work with Shardy yet assert jax.local_device_count() == 1, ( f"[{args.process_id}|{args.num_devices_per_process}] Expected 1 GPU per process, found" diff --git a/examples/jax/collective_gemm/test_gemm.py b/examples/jax/collective_gemm/test_gemm.py index d2994723bb..ea119713e3 100644 --- a/examples/jax/collective_gemm/test_gemm.py +++ b/examples/jax/collective_gemm/test_gemm.py @@ -88,8 +88,6 @@ def _jitted_cgemm(x, weight, bias, contracting_dims, collective_op, output_shard def run_gemm_tests(args, mesh=None): """Execute GEMM tests.""" print(args) - # Collective GEMM requires Shardy partitioner to be disabled - jax.config.update("jax_use_shardy_partitioner", False) # Initialize distributed with provided arguments _initialize_distributed(args) @@ -137,8 +135,7 @@ def run_gemm_tests(args, mesh=None): bias_sharded, contracting_dims=((2,), (0,)), collective_op=collective_op, - # CollectiveGEMM output should have a correct sharding without applying sharding constraint - output_sharding=None, + output_sharding=output_sharding, ) assert ( ref_output.sharding == output.sharding diff --git a/examples/jax/collective_gemm/test_layernorm_mlp_grad.py b/examples/jax/collective_gemm/test_layernorm_mlp_grad.py index 61c960a7aa..84cb011da1 100644 --- a/examples/jax/collective_gemm/test_layernorm_mlp_grad.py +++ b/examples/jax/collective_gemm/test_layernorm_mlp_grad.py @@ -119,8 +119,6 @@ def _value_and_grad_layernorm_mlp( def run_layernorm_mlp_grad_tests(args, mesh=None): """Execute Dense Gradient tests.""" print(args) - # Collective GEMM requires Shardy partitioner to be disabled - jax.config.update("jax_use_shardy_partitioner", False) # Initialize distributed with provided arguments _initialize_distributed(args) From cd008b56692b2161082fb5fe3f9d26a23f1b0a19 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Thu, 26 Feb 2026 16:55:44 -0800 Subject: [PATCH 3/5] added warnings to dense vjp and layernorm_mlp vjp Signed-off-by: Phuong Nguyen --- transformer_engine/jax/dense.py | 7 +++++++ transformer_engine/jax/layernorm_mlp.py | 7 +++++++ 2 files changed, 14 insertions(+) diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index 23d91f7db0..268995281c 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -94,6 +94,13 @@ def dense( if transpose_batch_sequence: warnings.warn("transpose_batch_sequence is not well tested, use with caution!") + if collective_op_set != tex.noop_collective_op_set and not output_axes: + warnings.warn( + "Collective GEMM with Shardy propagation may produce an incorrect sharding pattern" + " for the output. Set `output_axes` to apply the correct sharding constraint.", + UserWarning, + ) + if quantizer_set == noop_quantizer_set: input_dtype = x.dtype kernel = kernel.astype(input_dtype) diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index a8de32830b..6d636e3462 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -275,6 +275,13 @@ def _layernorm_mlp_fwd_rule( assert not collective_op_set_1.forward.is_reduce_scatter assert not collective_op_set_2.forward.is_all_gather + if collective_op_set_1 != tex.noop_collective_op_set and not dot_2_input_axes: + warnings.warn( + "Collective GEMM with Shardy propagation may produce an incorrect sharding pattern" + " for the output. Set `dot_2_input_axes` to apply the correct sharding constraint.", + UserWarning, + ) + # x should be in shape of (batch..., hidden) # Kernel_1 should be in shape of (hidden_in, activation_len, intermediate) # Kernel_2 should be in shape of (intermediate, hidden_in) From 36ebb2dd12060f87e65c59e45c1f64a875761ddd Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Fri, 27 Feb 2026 08:51:01 -0800 Subject: [PATCH 4/5] fix lint Signed-off-by: Phuong Nguyen --- transformer_engine/jax/layernorm_mlp.py | 1 + 1 file changed, 1 insertion(+) diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index 6d636e3462..e79a9cb19d 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -29,6 +29,7 @@ noop_quantizer_set, TensorUsage, ) +import warnings def layernorm_mlp( From 0fae7a44039653507706242050885b588aa1fc34 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Fri, 27 Feb 2026 08:54:56 -0800 Subject: [PATCH 5/5] fix lint Signed-off-by: Phuong Nguyen --- transformer_engine/jax/layernorm_mlp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index e79a9cb19d..c90d018aee 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -15,6 +15,7 @@ from typing import List, Tuple, Sequence, Union, Callable from functools import partial +import warnings import jax import jax.numpy as jnp @@ -29,7 +30,6 @@ noop_quantizer_set, TensorUsage, ) -import warnings def layernorm_mlp(