diff --git a/examples/jax/collective_gemm/common.py b/examples/jax/collective_gemm/common.py index 0d812da057..2965896d07 100644 --- a/examples/jax/collective_gemm/common.py +++ b/examples/jax/collective_gemm/common.py @@ -131,10 +131,6 @@ def _initialize_distributed(args): ) _distributed_initialized = True - jax.clear_caches() - jax.config.update( - "jax_use_shardy_partitioner", False - ) # CollectiveGEMM does not work with Shardy yet assert jax.local_device_count() == 1, ( f"[{args.process_id}|{args.num_devices_per_process}] Expected 1 GPU per process, found" diff --git a/examples/jax/collective_gemm/test_gemm.py b/examples/jax/collective_gemm/test_gemm.py index d2994723bb..ea119713e3 100644 --- a/examples/jax/collective_gemm/test_gemm.py +++ b/examples/jax/collective_gemm/test_gemm.py @@ -88,8 +88,6 @@ def _jitted_cgemm(x, weight, bias, contracting_dims, collective_op, output_shard def run_gemm_tests(args, mesh=None): """Execute GEMM tests.""" print(args) - # Collective GEMM requires Shardy partitioner to be disabled - jax.config.update("jax_use_shardy_partitioner", False) # Initialize distributed with provided arguments _initialize_distributed(args) @@ -137,8 +135,7 @@ def run_gemm_tests(args, mesh=None): bias_sharded, contracting_dims=((2,), (0,)), collective_op=collective_op, - # CollectiveGEMM output should have a correct sharding without applying sharding constraint - output_sharding=None, + output_sharding=output_sharding, ) assert ( ref_output.sharding == output.sharding diff --git a/examples/jax/collective_gemm/test_layernorm_mlp_grad.py b/examples/jax/collective_gemm/test_layernorm_mlp_grad.py index 61c960a7aa..84cb011da1 100644 --- a/examples/jax/collective_gemm/test_layernorm_mlp_grad.py +++ b/examples/jax/collective_gemm/test_layernorm_mlp_grad.py @@ -119,8 +119,6 @@ def _value_and_grad_layernorm_mlp( def run_layernorm_mlp_grad_tests(args, mesh=None): """Execute Dense Gradient tests.""" print(args) - # Collective GEMM requires Shardy partitioner to be disabled - jax.config.update("jax_use_shardy_partitioner", False) # Initialize distributed with provided arguments _initialize_distributed(args) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index a34cb030bf..fbaafdf6d8 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -1172,9 +1172,16 @@ def shardy_sharding_rule( del mesh, result_types, transpose_batch_sequence, sequence_dim, is_outer if not collective_op.is_none: - raise NotImplementedError( - "CollectiveGEMM with Shardy propagation is not supported yet! Please turn off" - " Shardy by exporting env var JAX_USE_SHARDY_PARTITIONER=false" + warnings.warn( + "CollectiveGEMM with Shardy propagation may produce an incorrect sharding pattern" + " for the output.\n To resolve this, apply a sharding constraint on the output" + " using one of the following options:\n" + " - TE `dense` vjp: set `output_axes`.\n" + " - TE `layernorm_mlp` vjp: set `dot_2_input_axes`.\n" + " - TE `transformer_engine.jax.cpp_extensions.gemm`: apply" + " `jax.lax.with_sharding_constraint` on the output.\n" + " - TE via MaxText: no action needed.", + UserWarning, ) prefix = "Gemm_" diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index 23d91f7db0..268995281c 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -94,6 +94,13 @@ def dense( if transpose_batch_sequence: warnings.warn("transpose_batch_sequence is not well tested, use with caution!") + if collective_op_set != tex.noop_collective_op_set and not output_axes: + warnings.warn( + "Collective GEMM with Shardy propagation may produce an incorrect sharding pattern" + " for the output. Set `output_axes` to apply the correct sharding constraint.", + UserWarning, + ) + if quantizer_set == noop_quantizer_set: input_dtype = x.dtype kernel = kernel.astype(input_dtype) diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index a8de32830b..c90d018aee 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -15,6 +15,7 @@ from typing import List, Tuple, Sequence, Union, Callable from functools import partial +import warnings import jax import jax.numpy as jnp @@ -275,6 +276,13 @@ def _layernorm_mlp_fwd_rule( assert not collective_op_set_1.forward.is_reduce_scatter assert not collective_op_set_2.forward.is_all_gather + if collective_op_set_1 != tex.noop_collective_op_set and not dot_2_input_axes: + warnings.warn( + "Collective GEMM with Shardy propagation may produce an incorrect sharding pattern" + " for the output. Set `dot_2_input_axes` to apply the correct sharding constraint.", + UserWarning, + ) + # x should be in shape of (batch..., hidden) # Kernel_1 should be in shape of (hidden_in, activation_len, intermediate) # Kernel_2 should be in shape of (intermediate, hidden_in)