Skip to content

Commit 28ca15c

Browse files
phu0ngngclaude
andcommitted
[JAX] Deprecate GSPMD: remove infer_sharding_from_operands and GSPMD tests
GSPMD sharding propagation is being deprecated in favour of Shardy, which is now the default JAX partitioner. This commit removes all GSPMD-related code paths and tests: - Drop the infer_sharding_from_operands abstract method from BasePrimitive and remove it from def_partition() registration - Remove all infer_sharding_from_operands implementations across cpp_extensions: activation, amax, attention, gemm, normalization, quantization, and softmax primitives - Remove stale "Keep in sync with infer_sharding_from_operands" comments from FusedAttn shardy_sharding_rule methods - Drop all use_shardy=False (GSPMD) distributed test paths and the jax.config.update("jax_use_shardy_partitioner", ...) config calls - Consolidate paired GSPMD/Shardy test functions into single tests and strip _shardy suffixes from test names Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com> Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 000b1b1 commit 28ca15c

12 files changed

Lines changed: 0 additions & 827 deletions

tests/jax/test_distributed_fused_attn.py

Lines changed: 0 additions & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,7 @@ def impl_test_self_attn(
6868
attn_mask_type,
6969
dtype,
7070
softmax_type,
71-
use_shardy,
7271
):
73-
jax.config.update("jax_use_shardy_partitioner", use_shardy)
7472
dropout_prob = 0.0
7573
is_training = True
7674
batch, seqlen, num_head, hidden = data_shape
@@ -178,48 +176,6 @@ def test_self_attn(
178176
attn_mask_type,
179177
dtype,
180178
softmax_type,
181-
use_shardy=False,
182-
)
183-
184-
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
185-
@pytest.mark.parametrize(
186-
"attn_bias_type, bias_shape",
187-
[
188-
pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
189-
pytest.param(AttnBiasType.PRE_SCALE_BIAS, BiasShape._1HSS, id="PRE_SCALE_BIAS-1HSS"),
190-
],
191-
)
192-
@pytest.mark.parametrize(
193-
"softmax_type",
194-
[
195-
pytest.param(AttnSoftmaxType.VANILLA_SOFTMAX, id="VANILLA_SOFTMAX"),
196-
pytest.param(AttnSoftmaxType.OFF_BY_ONE_SOFTMAX, id="OFF_BY_ONE_SOFTMAX"),
197-
pytest.param(AttnSoftmaxType.LEARNABLE_SOFTMAX, id="LEARNABLE_SOFTMAX"),
198-
],
199-
)
200-
def test_self_attn_shardy(
201-
self,
202-
device_count,
203-
mesh_shape,
204-
mesh_axes,
205-
mesh_resource,
206-
attn_bias_type,
207-
bias_shape,
208-
softmax_type,
209-
):
210-
data_shape = (32, 512, 12, 64)
211-
self.impl_test_self_attn(
212-
device_count,
213-
mesh_shape,
214-
mesh_axes,
215-
mesh_resource,
216-
data_shape,
217-
attn_bias_type,
218-
bias_shape,
219-
AttnMaskType.PADDING_MASK,
220-
jnp.bfloat16,
221-
softmax_type,
222-
use_shardy=True,
223179
)
224180

225181

@@ -348,7 +304,6 @@ def impl_test_context_parallel_attn(
348304
qkv_layout,
349305
load_balanced,
350306
cp_strategy,
351-
use_shardy,
352307
use_scan_ring=False,
353308
window_size=None,
354309
stripe_size=None,
@@ -366,8 +321,6 @@ def impl_test_context_parallel_attn(
366321
os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] = "1"
367322
else:
368323
os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] = "0"
369-
370-
jax.config.update("jax_use_shardy_partitioner", use_shardy)
371324
attn_bias_type = AttnBiasType.NO_BIAS
372325
bias_shape = None
373326
dropout_prob = 0.0
@@ -452,49 +405,6 @@ def check_has_backend_for_mask(mask_type):
452405
runner.test_backward()
453406
del os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"]
454407

455-
@pytest_parametrize_wrapper(
456-
"device_count,mesh_shape,mesh_axes,mesh_resource",
457-
generate_context_parallel_configs_for_attn(),
458-
)
459-
@pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES)
460-
@pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")])
461-
@pytest.mark.parametrize(
462-
"qkv_layout, attn_mask_type",
463-
DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS,
464-
)
465-
def test_context_parallel_allgather_attn_shardy(
466-
self,
467-
device_count,
468-
mesh_shape,
469-
mesh_axes,
470-
mesh_resource,
471-
data_shape,
472-
attn_mask_type,
473-
dtype,
474-
qkv_layout,
475-
):
476-
if qkv_layout.is_thd():
477-
pytest.skip("Only BSHD layout is supported for CP + AG + Dual chunk attention")
478-
kv_groups = 8
479-
self.impl_test_context_parallel_attn(
480-
device_count,
481-
mesh_shape,
482-
mesh_axes,
483-
mesh_resource,
484-
data_shape,
485-
kv_groups,
486-
attn_mask_type,
487-
dtype,
488-
qkv_layout,
489-
load_balanced=True,
490-
cp_strategy=CPStrategy.ALL_GATHER,
491-
use_shardy=True,
492-
)
493-
494-
@pytest_parametrize_wrapper(
495-
"device_count,mesh_shape,mesh_axes,mesh_resource",
496-
generate_context_parallel_configs_for_attn(),
497-
)
498408
@pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES[:1])
499409
@pytest.mark.parametrize("kv_groups", [1, 8])
500410
@pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")])
@@ -551,7 +461,6 @@ def test_context_parallel_allgather_striped_attn(
551461
qkv_layout,
552462
load_balanced,
553463
CPStrategy.ALL_GATHER,
554-
use_shardy=False,
555464
window_size=window_size,
556465
stripe_size=stripe_size,
557466
num_segments_per_seq=num_segments_per_seq,
@@ -599,7 +508,6 @@ def test_context_parallel_allgather_attn(
599508
qkv_layout,
600509
load_balanced,
601510
CPStrategy.ALL_GATHER,
602-
use_shardy=False,
603511
)
604512

605513
@pytest_parametrize_wrapper(
@@ -664,53 +572,11 @@ def test_context_parallel_ring_attn(
664572
qkv_layout,
665573
load_balanced,
666574
CPStrategy.RING,
667-
use_shardy=False,
668575
use_scan_ring=use_scan,
669576
window_size=window_size,
670577
stripe_size=stripe_size,
671578
)
672579

673-
@pytest_parametrize_wrapper(
674-
"device_count,mesh_shape,mesh_axes,mesh_resource",
675-
generate_context_parallel_configs_for_attn(),
676-
)
677-
@pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES[:1])
678-
@pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")])
679-
@pytest.mark.parametrize(
680-
"qkv_layout, attn_mask_type",
681-
DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS,
682-
)
683-
def test_context_parallel_ring_attn_shardy(
684-
self,
685-
device_count,
686-
mesh_shape,
687-
mesh_axes,
688-
mesh_resource,
689-
data_shape,
690-
attn_mask_type,
691-
dtype,
692-
qkv_layout,
693-
):
694-
kv_groups = 8
695-
# Set the stripe size to 1 (ring attention only support stripe_size=1)
696-
stripe_size = 1 if qkv_layout.is_thd() else None
697-
self.impl_test_context_parallel_attn(
698-
device_count,
699-
mesh_shape,
700-
mesh_axes,
701-
mesh_resource,
702-
data_shape,
703-
kv_groups,
704-
attn_mask_type,
705-
dtype,
706-
qkv_layout,
707-
load_balanced=True,
708-
cp_strategy=CPStrategy.RING,
709-
use_shardy=False,
710-
use_scan_ring=True,
711-
stripe_size=stripe_size,
712-
)
713-
714580

715581
REORDER_CAUSAL_LOAD_BALANCING_DATA_SHAPES = {
716582
"L0": [[]],

tests/jax/test_distributed_layernorm.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,6 @@ def generate_collectives_count_ref(
8787
@pytest_parametrize_wrapper("zero_centered_gamma", [False, True])
8888
@pytest_parametrize_wrapper("shard_weights", [False, True])
8989
@pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
90-
@pytest_parametrize_wrapper("use_shardy", [False, True])
9190
def test_layernorm(
9291
self,
9392
device_count,
@@ -99,9 +98,7 @@ def test_layernorm(
9998
zero_centered_gamma,
10099
shard_weights,
101100
fp8_recipe,
102-
use_shardy,
103101
):
104-
jax.config.update("jax_use_shardy_partitioner", use_shardy)
105102
epsilon = 1e-6
106103
ln_type = "layernorm"
107104
q_dtype = jnp.float8_e4m3fn
@@ -178,7 +175,6 @@ def ref_func(x, gamma, beta):
178175
@pytest_parametrize_wrapper("dtype", DTYPES)
179176
@pytest_parametrize_wrapper("shard_weights", [False, True])
180177
@pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
181-
@pytest_parametrize_wrapper("use_shardy", [False, True])
182178
def test_rmsnorm(
183179
self,
184180
device_count,
@@ -189,9 +185,7 @@ def test_rmsnorm(
189185
dtype,
190186
shard_weights,
191187
fp8_recipe,
192-
use_shardy,
193188
):
194-
jax.config.update("jax_use_shardy_partitioner", use_shardy)
195189
epsilon = 1e-6
196190
ln_type = "rmsnorm"
197191
q_dtype = jnp.float8_e4m3fn

tests/jax/test_distributed_layernorm_mlp.py

Lines changed: 0 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -192,10 +192,8 @@ def _test_layernorm_mlp_grad(
192192
input_shape,
193193
dtype,
194194
quantization_recipe,
195-
use_shardy,
196195
with_jax_gemm,
197196
):
198-
jax.config.update("jax_use_shardy_partitioner", use_shardy)
199197
device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config
200198
layernorm_type = "rmsnorm"
201199

@@ -313,36 +311,6 @@ def test_layernorm_mlp_grad(
313311
dtype,
314312
quantization_recipe,
315313
with_jax_gemm,
316-
):
317-
if dtype == jnp.float16 and quantization_recipe is not None and quantization_recipe.nvfp4():
318-
pytest.skip("NVFP4 GEMM + Float16 output is unsupported!")
319-
self._test_layernorm_mlp_grad(
320-
mesh_config,
321-
activation_type,
322-
use_bias,
323-
input_shape,
324-
dtype,
325-
quantization_recipe,
326-
use_shardy=False,
327-
with_jax_gemm=with_jax_gemm,
328-
)
329-
330-
@pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs())
331-
@pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
332-
@pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
333-
@pytest_parametrize_wrapper("dtype", DTYPES)
334-
@pytest_parametrize_wrapper("use_bias", [True, False])
335-
@pytest_parametrize_wrapper("quantization_recipe", [None] + SUPPORTED_RECIPES)
336-
@pytest_parametrize_wrapper("with_jax_gemm", [False, True])
337-
def test_layernorm_mlp_grad_shardy(
338-
self,
339-
mesh_config,
340-
activation_type,
341-
use_bias,
342-
input_shape,
343-
dtype,
344-
quantization_recipe,
345-
with_jax_gemm,
346314
):
347315
if dtype == jnp.float16 and quantization_recipe is not None and quantization_recipe.nvfp4():
348316
pytest.skip("NVFP4 GEMM + Float16 output is unsupported!")
@@ -353,7 +321,6 @@ def test_layernorm_mlp_grad_shardy(
353321
input_shape,
354322
dtype,
355323
quantization_recipe=quantization_recipe,
356-
use_shardy=True,
357324
with_jax_gemm=with_jax_gemm,
358325
)
359326

@@ -366,10 +333,8 @@ def _test_layernorm_mlp(
366333
dtype,
367334
use_fp8,
368335
quantization_recipe,
369-
use_shardy,
370336
with_jax_gemm,
371337
):
372-
jax.config.update("jax_use_shardy_partitioner", use_shardy)
373338
batch, seqlen, hidden_in = input_shape
374339
layernorm_type = "rmsnorm"
375340

@@ -481,7 +446,6 @@ def test_layernorm_mlp_layer(
481446
dtype,
482447
use_fp8=False,
483448
quantization_recipe=None,
484-
use_shardy=False,
485449
with_jax_gemm=with_jax_gemm,
486450
)
487451

@@ -512,58 +476,5 @@ def test_layernorm_mlp_layer_fp8(
512476
dtype,
513477
use_fp8=True,
514478
quantization_recipe=quantization_recipe,
515-
use_shardy=False,
516-
with_jax_gemm=with_jax_gemm,
517-
)
518-
519-
@pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
520-
@pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs())
521-
@pytest_parametrize_wrapper("activation_type", [("gelu",), ("silu", "linear")])
522-
@pytest_parametrize_wrapper("dtype", DTYPES)
523-
@pytest_parametrize_wrapper("use_bias", [True, False])
524-
@pytest_parametrize_wrapper("with_jax_gemm", [False, True])
525-
def test_layernorm_mlp_layer_shardy(
526-
self, mesh_config, activation_type, use_bias, input_shape, dtype, with_jax_gemm
527-
):
528-
self._test_layernorm_mlp(
529-
mesh_config,
530-
activation_type,
531-
use_bias,
532-
input_shape,
533-
dtype,
534-
use_fp8=False,
535-
quantization_recipe=None,
536-
use_shardy=True,
537-
with_jax_gemm=with_jax_gemm,
538-
)
539-
540-
@pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs())
541-
@pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
542-
@pytest_parametrize_wrapper("use_bias", [True, False])
543-
@pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
544-
@pytest_parametrize_wrapper("dtype", DTYPES)
545-
@pytest_parametrize_wrapper("quantization_recipe", SUPPORTED_RECIPES)
546-
@pytest_parametrize_wrapper("with_jax_gemm", [False, True])
547-
def test_layernorm_mlp_layer_fp8_shardy(
548-
self,
549-
mesh_config,
550-
activation_type,
551-
use_bias,
552-
input_shape,
553-
dtype,
554-
quantization_recipe,
555-
with_jax_gemm,
556-
):
557-
if dtype == jnp.float16 and quantization_recipe is not None and quantization_recipe.nvfp4():
558-
pytest.skip("NVFP4 GEMM + Float16 output is unsupported!")
559-
self._test_layernorm_mlp(
560-
mesh_config,
561-
activation_type,
562-
use_bias,
563-
input_shape,
564-
dtype,
565-
use_fp8=True,
566-
quantization_recipe=quantization_recipe,
567-
use_shardy=True,
568479
with_jax_gemm=with_jax_gemm,
569480
)

0 commit comments

Comments
 (0)