Skip to content
4 changes: 0 additions & 4 deletions examples/jax/collective_gemm/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,6 @@ def _initialize_distributed(args):
)

_distributed_initialized = True
jax.clear_caches()
jax.config.update(
"jax_use_shardy_partitioner", False
) # CollectiveGEMM does not work with Shardy yet

assert jax.local_device_count() == 1, (
f"[{args.process_id}|{args.num_devices_per_process}] Expected 1 GPU per process, found"
Expand Down
5 changes: 1 addition & 4 deletions examples/jax/collective_gemm/test_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,6 @@ def _jitted_cgemm(x, weight, bias, contracting_dims, collective_op, output_shard
def run_gemm_tests(args, mesh=None):
"""Execute GEMM tests."""
print(args)
# Collective GEMM requires Shardy partitioner to be disabled
jax.config.update("jax_use_shardy_partitioner", False)

# Initialize distributed with provided arguments
_initialize_distributed(args)
Expand Down Expand Up @@ -137,8 +135,7 @@ def run_gemm_tests(args, mesh=None):
bias_sharded,
contracting_dims=((2,), (0,)),
collective_op=collective_op,
# CollectiveGEMM output should have a correct sharding without applying sharding constraint
output_sharding=None,
output_sharding=output_sharding,
)
assert (
ref_output.sharding == output.sharding
Expand Down
2 changes: 0 additions & 2 deletions examples/jax/collective_gemm/test_layernorm_mlp_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,6 @@ def _value_and_grad_layernorm_mlp(
def run_layernorm_mlp_grad_tests(args, mesh=None):
"""Execute Dense Gradient tests."""
print(args)
# Collective GEMM requires Shardy partitioner to be disabled
jax.config.update("jax_use_shardy_partitioner", False)

# Initialize distributed with provided arguments
_initialize_distributed(args)
Expand Down
4 changes: 0 additions & 4 deletions examples/jax/encoder/run_test_multiprocessing_encoder.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,6 @@ TEST_CASES=(
"test_te_current_scaling_fp8"
"test_te_mxfp8"
"test_te_nvfp4"
"test_te_bf16_shardy"
"test_te_delayed_scaling_fp8_shardy"
"test_te_current_scaling_fp8_shardy"
"test_te_nvfp4_shardy"
)

: ${TE_PATH:=/opt/transformerengine}
Expand Down
68 changes: 0 additions & 68 deletions examples/jax/encoder/test_model_parallel_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,6 @@ def check_fp8(state, var_collect, inputs, masks, labels):
def train_and_evaluate(args):
"""Execute model training and evaluation loop."""
print(args)
jax.config.update("jax_use_shardy_partitioner", args.enable_shardy)

train_ds, test_ds, num_embed = get_datasets(args.max_seq_len)

Expand Down Expand Up @@ -474,9 +473,6 @@ def encoder_parser(args):
parser.add_argument(
"--enable-sp", action="store_true", default=False, help="Enable sequence parallelism."
)
parser.add_argument(
"--enable-shardy", action="store_true", default=False, help="Enable Shardy (experimental)."
)

return parser.parse_args(args)

Expand Down Expand Up @@ -559,70 +555,6 @@ def test_te_nvfp4_with_sp(self):
actual = train_and_evaluate(self.args)
assert actual[0] < 0.40 and actual[1] > 0.82

@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16_shardy(self):
"""Test Transformer Engine with BF16"""
self.args.enable_shardy = True
actual = train_and_evaluate(self.args)
assert actual[0] < 0.36 and actual[1] > 0.84

@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8_shardy(self):
"""Test Transformer Engine with DelayedScaling FP8"""
self.args.enable_shardy = True
self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.362 and actual[1] > 0.84

@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8_with_sp_shardy(self):
"""Test Transformer Engine with DelayedScaling FP8 + SP"""
self.args.enable_shardy = True
self.args.enable_sp = True
self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.362 and actual[1] > 0.84

@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8_shardy(self):
"""Test Transformer Engine with MXFP8"""
self.args.enable_shardy = True
self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.36 and actual[1] > 0.84

@unittest.skipIf(not is_nvfp4_supported, nvfp4_reason)
def test_te_nvfp4_shardy(self):
"""Test Transformer Engine with NVFP4"""
self.args.enable_shardy = True
self.args.use_fp8 = True
self.args.fp8_recipe = "NVFP4BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.40 and actual[1] > 0.82

@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8_with_sp_shardy(self):
"""Test Transformer Engine with MXFP8 + SP"""
self.args.enable_shardy = True
self.args.enable_sp = True
self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.36 and actual[1] > 0.84

@unittest.skipIf(not is_nvfp4_supported, nvfp4_reason)
def test_te_nvfp4_with_sp_shardy(self):
"""Test Transformer Engine with NVFP4"""
self.args.enable_shardy = True
self.args.enable_sp = True
self.args.use_fp8 = True
self.args.fp8_recipe = "NVFP4BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.40 and actual[1] > 0.82


if __name__ == "__main__":
train_and_evaluate(encoder_parser(None))
47 changes: 0 additions & 47 deletions examples/jax/encoder/test_multigpu_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,6 @@ def replace_params(x):
def train_and_evaluate(args):
"""Execute model training and evaluation loop."""
print(args)
jax.config.update("jax_use_shardy_partitioner", args.enable_shardy)
train_ds, test_ds, num_embed = get_datasets(args.max_seq_len)

num_gpu = jax.local_device_count()
Expand Down Expand Up @@ -438,9 +437,6 @@ def encoder_parser(args):
default="DelayedScaling",
help="Use FP8 recipe (default: DelayedScaling)",
)
parser.add_argument(
"--enable-shardy", action="store_true", default=False, help="Enable Shardy (experimental)."
)

return parser.parse_args(args)

Expand Down Expand Up @@ -494,49 +490,6 @@ def test_te_nvfp4(self):
actual = train_and_evaluate(self.args)
assert actual[0] < 0.52 and actual[1] > 0.74

@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16_shardy(self):
"""Test Transformer Engine with BF16"""
self.args.enable_shardy = True
actual = train_and_evaluate(self.args)
assert actual[0] < 0.51 and actual[1] > 0.75

@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8_shardy(self):
"""Test Transformer Engine with DelayedScaling FP8"""
self.args.enable_shardy = True
self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.51 and actual[1] > 0.75

@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_current_scaling_fp8_shardy(self):
"""Test Transformer Engine with CurrentScaling FP8"""
self.args.enable_shardy = True
self.args.use_fp8 = True
self.args.fp8_recipe = "Float8CurrentScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.51 and actual[1] > 0.749

@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8_shardy(self):
"""Test Transformer Engine with MXFP8"""
self.args.enable_shardy = True
self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.51 and actual[1] > 0.75

@unittest.skipIf(not is_nvfp4_supported, nvfp4_reason)
def test_te_nvfp4_shardy(self):
"""Test Transformer Engine with NVFP4"""
self.args.enable_shardy = True
self.args.use_fp8 = True
self.args.fp8_recipe = "NVFP4BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.52 and actual[1] > 0.74


if __name__ == "__main__":
train_and_evaluate(encoder_parser(None))
45 changes: 1 addition & 44 deletions examples/jax/encoder/test_multiprocessing_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,6 @@ def replace_params(x):
def train_and_evaluate(args):
"""Execute model training and evaluation loop."""
print(args)
jax.config.update("jax_use_shardy_partitioner", args.enable_shardy)
if args.process_id == 0:
nltk.download("punkt_tab")

Expand Down Expand Up @@ -605,9 +604,6 @@ def encoder_parser(args):
default=0,
help="the ID number of the current process (default: 0)",
)
parser.add_argument(
"--enable-shardy", action="store_true", default=False, help="Enable Shardy (experimental)."
)

return parser.parse_args(args)

Expand All @@ -616,7 +612,7 @@ def encoder_parser(args):
class TestEncoder(unittest.TestCase):
"""Encoder unittests"""

def exec(self, use_fp8, fp8_recipe, *, enable_shardy=False):
def exec(self, use_fp8, fp8_recipe):
"""Run 5 epochs for testing"""
args = encoder_parser(["--epochs", "5"])

Expand All @@ -632,7 +628,6 @@ def exec(self, use_fp8, fp8_recipe, *, enable_shardy=False):
args.num_process = num_gpu
args.process_id = self.process_id
args.fp8_recipe = fp8_recipe
args.enable_shardy = enable_shardy

return train_and_evaluate(args)

Expand Down Expand Up @@ -674,44 +669,6 @@ def test_te_nvfp4(self):
result = self.exec(True, "NVFP4BlockScaling")
assert result[0] < 0.451 and result[1] > 0.787

@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16_shardy(self):
"""Test Transformer Engine with BF16"""
result = self.exec(False, None, enable_shardy=True)
assert result[0] < 0.43 and result[1] > 0.80

@unittest.skipIf(
not is_fp8_supported(), "Device compute capability 9.0+ is required for DelayedScaling FP8"
)
def test_te_delayed_scaling_fp8_shardy(self):
"""Test Transformer Engine with DelayedScaling FP8"""
result = self.exec(True, "DelayedScaling", enable_shardy=True)
assert result[0] < 0.43 and result[1] > 0.80

@unittest.skipIf(
not is_fp8_supported(), "Device compute capability 9.0+ is required for CurrentScaling FP8"
)
def test_te_current_scaling_fp8_shardy(self):
"""Test Transformer Engine with CurrentScaling FP8"""
result = self.exec(True, "Float8CurrentScaling", enable_shardy=True)
assert result[0] < 0.432 and result[1] > 0.80

@unittest.skipIf(
not is_mxfp8_supported(), "Device compute capability 10.0+ is required for MXFP8"
)
def test_te_mxfp8_shardy(self):
"""Test Transformer Engine with MXFP8"""
result = self.exec(True, "MXFP8BlockScaling", enable_shardy=True)
assert result[0] < 0.43 and result[1] > 0.80

@unittest.skipIf(
not is_nvfp4_supported(), "Device compute capability 10.0+ is required for NVFP4"
)
def test_te_nvfp4_shardy(self):
"""Test Transformer Engine with NVFP4"""
result = self.exec(True, "NVFP4BlockScaling", enable_shardy=True)
assert result[0] < 0.451 and result[1] > 0.787


if __name__ == "__main__":
train_and_evaluate(encoder_parser(None))
Loading
Loading