@@ -476,21 +476,24 @@ def _parse_operand_output_specs(
476476 lhs_cspecs = tuple (s if s == reduce_spec else None for s in lhs_cspecs )
477477 rhs_cspecs = tuple (s if s == reduce_spec else None for s in rhs_cspecs )
478478
479- # Non-batched non- contracting dims of RHS needs to be unsharded ( i.e. FSDP)
480- # Check if spec is not the batch-dim is not needed as rhs_non_cspecs never includes batch-dim
481- # rhs_specs only includes batch- dim in the Wgrad GEMM, but there batch-dim belongs to rhs_cspecs
479+ # Non-contracting dims of RHS always needs to be gathered, i.e. for TP + activation_hidden
480+ # No batch-dim check needed as ` rhs_non_cspecs` never contains batch-dim.
481+ # In ` rhs_specs`, the batch dim appears only in Wgrad GEMM under ` rhs_cspecs`.
482482 rhs_non_cspecs = tuple (
483483 None if spec in lhs_non_cspecs else spec for spec in rhs_non_cspecs
484484 )
485+
485486 else :
486487 # Otherwise, require contracting dims of both operands to be unsharded
487488 lhs_cspecs = (None ,) * len (lhs_cspecs )
488489 rhs_cspecs = (None ,) * len (rhs_cspecs )
489490
490- # Non-batched non-contracting dims of LHS to be unsharded, i.e gather SP dim
491- # The spec for batch_dim in lhs_non_cspecs won't ever appear in the rhs_non_cspecs as
492- # rhs_non_cspecs never has batch-dim. Hence, spec for batch_dim of lhs_non_cspecs won't be
493- # overwrite
491+ # Non-contracting dims of RHS always needs to be gathered along the FSDP axis
492+ rhs_non_cspecs = tuple (
493+ None if spec is not None and "fsdp" in spec else spec for spec in rhs_non_cspecs
494+ )
495+
496+ # Non-contracting dims of LHS to be gathered along the SP axis.
494497 # Minor note: This causes MaxText TP (= Megatron TP + activation_hidden sharding) gathering x for
495498 # dW1 = x^T * dY1 which is unexpected. This is a known issue and no solution has found yet.
496499 lhs_non_cspecs = tuple (None if spec in rhs_non_cspecs else spec for spec in lhs_non_cspecs )
0 commit comments