Skip to content

Commit 3fc1e4b

Browse files
authored
[JAX] Fix for TE GEMM - Always AllGather RHS non-contracting dims with FSDP axis (NVIDIA#2075)
* fix fsdp Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com> --------- Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
1 parent 0e3e270 commit 3fc1e4b

1 file changed

Lines changed: 10 additions & 7 deletions

File tree

  • transformer_engine/jax/cpp_extensions

transformer_engine/jax/cpp_extensions/gemm.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -476,21 +476,24 @@ def _parse_operand_output_specs(
476476
lhs_cspecs = tuple(s if s == reduce_spec else None for s in lhs_cspecs)
477477
rhs_cspecs = tuple(s if s == reduce_spec else None for s in rhs_cspecs)
478478

479-
# Non-batched non-contracting dims of RHS needs to be unsharded (i.e. FSDP)
480-
# Check if spec is not the batch-dim is not needed as rhs_non_cspecs never includes batch-dim
481-
# rhs_specs only includes batch-dim in the Wgrad GEMM, but there batch-dim belongs to rhs_cspecs
479+
# Non-contracting dims of RHS always needs to be gathered, i.e. for TP + activation_hidden
480+
# No batch-dim check needed as `rhs_non_cspecs` never contains batch-dim.
481+
# In `rhs_specs`, the batch dim appears only in Wgrad GEMM under `rhs_cspecs`.
482482
rhs_non_cspecs = tuple(
483483
None if spec in lhs_non_cspecs else spec for spec in rhs_non_cspecs
484484
)
485+
485486
else:
486487
# Otherwise, require contracting dims of both operands to be unsharded
487488
lhs_cspecs = (None,) * len(lhs_cspecs)
488489
rhs_cspecs = (None,) * len(rhs_cspecs)
489490

490-
# Non-batched non-contracting dims of LHS to be unsharded, i.e gather SP dim
491-
# The spec for batch_dim in lhs_non_cspecs won't ever appear in the rhs_non_cspecs as
492-
# rhs_non_cspecs never has batch-dim. Hence, spec for batch_dim of lhs_non_cspecs won't be
493-
# overwrite
491+
# Non-contracting dims of RHS always needs to be gathered along the FSDP axis
492+
rhs_non_cspecs = tuple(
493+
None if spec is not None and "fsdp" in spec else spec for spec in rhs_non_cspecs
494+
)
495+
496+
# Non-contracting dims of LHS to be gathered along the SP axis.
494497
# Minor note: This causes MaxText TP (= Megatron TP + activation_hidden sharding) gathering x for
495498
# dW1 = x^T * dY1 which is unexpected. This is a known issue and no solution has found yet.
496499
lhs_non_cspecs = tuple(None if spec in rhs_non_cspecs else spec for spec in lhs_non_cspecs)

0 commit comments

Comments
 (0)