Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions examples/jax/collective_gemm/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,6 @@ def _initialize_distributed(args):
)

_distributed_initialized = True
jax.clear_caches()
jax.config.update(
"jax_use_shardy_partitioner", False
) # CollectiveGEMM does not work with Shardy yet

assert jax.local_device_count() == 1, (
f"[{args.process_id}|{args.num_devices_per_process}] Expected 1 GPU per process, found"
Expand Down
5 changes: 1 addition & 4 deletions examples/jax/collective_gemm/test_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,6 @@ def _jitted_cgemm(x, weight, bias, contracting_dims, collective_op, output_shard
def run_gemm_tests(args, mesh=None):
"""Execute GEMM tests."""
print(args)
# Collective GEMM requires Shardy partitioner to be disabled
jax.config.update("jax_use_shardy_partitioner", False)

# Initialize distributed with provided arguments
_initialize_distributed(args)
Expand Down Expand Up @@ -137,8 +135,7 @@ def run_gemm_tests(args, mesh=None):
bias_sharded,
contracting_dims=((2,), (0,)),
collective_op=collective_op,
# CollectiveGEMM output should have a correct sharding without applying sharding constraint
output_sharding=None,
output_sharding=output_sharding,
)
assert (
ref_output.sharding == output.sharding
Expand Down
2 changes: 0 additions & 2 deletions examples/jax/collective_gemm/test_layernorm_mlp_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,6 @@ def _value_and_grad_layernorm_mlp(
def run_layernorm_mlp_grad_tests(args, mesh=None):
"""Execute Dense Gradient tests."""
print(args)
# Collective GEMM requires Shardy partitioner to be disabled
jax.config.update("jax_use_shardy_partitioner", False)

# Initialize distributed with provided arguments
_initialize_distributed(args)
Expand Down
13 changes: 10 additions & 3 deletions transformer_engine/jax/cpp_extensions/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1172,9 +1172,16 @@ def shardy_sharding_rule(
del mesh, result_types, transpose_batch_sequence, sequence_dim, is_outer

if not collective_op.is_none:
raise NotImplementedError(
"CollectiveGEMM with Shardy propagation is not supported yet! Please turn off"
" Shardy by exporting env var JAX_USE_SHARDY_PARTITIONER=false"
warnings.warn(
"CollectiveGEMM with Shardy propagation may produce an incorrect sharding pattern"
" for the output.\n To resolve this, apply a sharding constraint on the output"
" using one of the following options:\n"
" - TE `dense` vjp: set `output_axes`.\n"
" - TE `layernorm_mlp` vjp: set `dot_2_input_axes`.\n"
" - TE `transformer_engine.jax.cpp_extensions.gemm`: apply"
" `jax.lax.with_sharding_constraint` on the output.\n"
" - TE via MaxText: no action needed.",
UserWarning,
)

prefix = "Gemm_"
Expand Down
7 changes: 7 additions & 0 deletions transformer_engine/jax/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,13 @@ def dense(
if transpose_batch_sequence:
warnings.warn("transpose_batch_sequence is not well tested, use with caution!")

if collective_op_set != tex.noop_collective_op_set and not output_axes:
warnings.warn(
"Collective GEMM with Shardy propagation may produce an incorrect sharding pattern"
" for the output. Set `output_axes` to apply the correct sharding constraint.",
UserWarning,
)
Comment on lines +97 to +102
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Warning fires even when Shardy is not active

This check runs unconditionally at Python-call time, so users who explicitly disable Shardy (e.g. via JAX_USE_SHARDY_PARTITIONER=false) or who rely on GSPMD propagation will also see this UserWarning whenever they use collective GEMM without setting output_axes. The warning text specifically says "Shardy propagation", which makes it confusing in a non-Shardy context.

Consider gating the check on whether Shardy is active to keep the signal actionable:

if (
    collective_op_set != tex.noop_collective_op_set
    and not output_axes
    and jax.config.jax_use_shardy_partitioner
):
    warnings.warn(...)

The same pattern applies to the analogous guard added in layernorm_mlp.py.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shardy propagation is the default in JAX, and GSPMD will be deprecated this month. So this check is not needed.


if quantizer_set == noop_quantizer_set:
input_dtype = x.dtype
kernel = kernel.astype(input_dtype)
Expand Down
8 changes: 8 additions & 0 deletions transformer_engine/jax/layernorm_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from typing import List, Tuple, Sequence, Union, Callable
from functools import partial
import warnings

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -275,6 +276,13 @@ def _layernorm_mlp_fwd_rule(
assert not collective_op_set_1.forward.is_reduce_scatter
assert not collective_op_set_2.forward.is_all_gather

if collective_op_set_1 != tex.noop_collective_op_set and not dot_2_input_axes:
warnings.warn(
"Collective GEMM with Shardy propagation may produce an incorrect sharding pattern"
" for the output. Set `dot_2_input_axes` to apply the correct sharding constraint.",
UserWarning,
)
Comment on lines +279 to +284
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Warning condition misses collective_op_set_2

The guard only checks collective_op_set_1, so users who configure a collective op only on the second GEMM (e.g. a reduce-scatter-only row-parallel layer, where collective_op_set_1 remains noop_collective_op_set) will never see this warning, even though dot_2_input_axes is equally needed to constrain the input to dot_2 (lines 341–345 and 363) and the gradient in _layernorm_mlp_bwd_rule (line 504).

The condition should be broadened to cover either collective op set being active:

Suggested change
if collective_op_set_1 != tex.noop_collective_op_set and not dot_2_input_axes:
warnings.warn(
"Collective GEMM with Shardy propagation may produce an incorrect sharding pattern"
" for the output. Set `dot_2_input_axes` to apply the correct sharding constraint.",
UserWarning,
)
if (collective_op_set_1 != tex.noop_collective_op_set or collective_op_set_2 != tex.noop_collective_op_set) and not dot_2_input_axes:
warnings.warn(
"Collective GEMM with Shardy propagation may produce an incorrect sharding pattern"
" for the output. Set `dot_2_input_axes` to apply the correct sharding constraint.",
UserWarning,
)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to check for thecollective_op_set_2 here as the dot2 output sharding (for RS) will work correctly without additional sharding constraint.


# x should be in shape of (batch..., hidden)
# Kernel_1 should be in shape of (hidden_in, activation_len, intermediate)
# Kernel_2 should be in shape of (intermediate, hidden_in)
Expand Down
Loading