-
Notifications
You must be signed in to change notification settings - Fork 652
[JAX] CGEMM with Shardy #2714
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[JAX] CGEMM with Shardy #2714
Changes from all commits
e0e4e2c
e4b02d5
cd008b5
fb23ba0
36ebb2d
0fae7a4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -15,6 +15,7 @@ | |||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| from typing import List, Tuple, Sequence, Union, Callable | ||||||||||||||||||||||||||
| from functools import partial | ||||||||||||||||||||||||||
| import warnings | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| import jax | ||||||||||||||||||||||||||
| import jax.numpy as jnp | ||||||||||||||||||||||||||
|
|
@@ -275,6 +276,13 @@ def _layernorm_mlp_fwd_rule( | |||||||||||||||||||||||||
| assert not collective_op_set_1.forward.is_reduce_scatter | ||||||||||||||||||||||||||
| assert not collective_op_set_2.forward.is_all_gather | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| if collective_op_set_1 != tex.noop_collective_op_set and not dot_2_input_axes: | ||||||||||||||||||||||||||
| warnings.warn( | ||||||||||||||||||||||||||
| "Collective GEMM with Shardy propagation may produce an incorrect sharding pattern" | ||||||||||||||||||||||||||
| " for the output. Set `dot_2_input_axes` to apply the correct sharding constraint.", | ||||||||||||||||||||||||||
| UserWarning, | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
|
Comment on lines
+279
to
+284
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Warning condition misses The guard only checks The condition should be broadened to cover either collective op set being active:
Suggested change
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We don't need to check for the |
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # x should be in shape of (batch..., hidden) | ||||||||||||||||||||||||||
| # Kernel_1 should be in shape of (hidden_in, activation_len, intermediate) | ||||||||||||||||||||||||||
| # Kernel_2 should be in shape of (intermediate, hidden_in) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Warning fires even when Shardy is not active
This check runs unconditionally at Python-call time, so users who explicitly disable Shardy (e.g. via
JAX_USE_SHARDY_PARTITIONER=false) or who rely on GSPMD propagation will also see thisUserWarningwhenever they use collective GEMM without settingoutput_axes. The warning text specifically says "Shardy propagation", which makes it confusing in a non-Shardy context.Consider gating the check on whether Shardy is active to keep the signal actionable:
The same pattern applies to the analogous guard added in
layernorm_mlp.py.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shardy propagation is the default in JAX, and GSPMD will be deprecated this month. So this check is not needed.