Skip to content

Commit 9ccc0c3

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 35d4d0c commit 9ccc0c3

3 files changed

Lines changed: 7 additions & 11 deletions

File tree

tests/pytorch/distributed/run_fsdp2_fused_adam.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -122,14 +122,9 @@ def test_fused_adam_fp8_master_weights(recipe=None):
122122
model = _build_model(fp8_init=True, recipe=recipe)
123123

124124
# Verify FP8 params created
125-
qt_count = sum(
126-
1
127-
for _, p in model.named_parameters()
128-
if isinstance(p, QuantizedTensor)
129-
)
125+
qt_count = sum(1 for _, p in model.named_parameters() if isinstance(p, QuantizedTensor))
130126
assert qt_count > 0, "No QuantizedTensor local tensors before training"
131127

132-
133128
model = _shard_model(model, world_size)
134129

135130
# Verify params are DTensors
@@ -144,7 +139,6 @@ def test_fused_adam_fp8_master_weights(recipe=None):
144139
)
145140
assert qt_count > 0, "No QuantizedTensor local tensors after sharding"
146141

147-
148142
optimizer = te.optimizers.FusedAdam(
149143
model.parameters(),
150144
lr=1e-3,

tests/pytorch/distributed/run_fsdp2_model.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,11 @@ def _train(args):
328328
target = torch.randn(out_shape, device=device)
329329

330330
# NVFP4BlockScaling requires bfloat16 inputs in both the forward and backward passes.
331-
with torch.autocast(device_type="cuda", dtype=torch.bfloat16) if args.recipe == "NVFP4BlockScaling" else nullcontext():
331+
with (
332+
torch.autocast(device_type="cuda", dtype=torch.bfloat16)
333+
if args.recipe == "NVFP4BlockScaling"
334+
else nullcontext()
335+
):
332336
with te.autocast(enabled=True, recipe=fp8_recipe):
333337
output = model(input_data)
334338
loss = F.mse_loss(output, target)

tests/pytorch/distributed/test_torch_fsdp2.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,7 @@ def _run_test(fp_init, sharding_dims, recipe, layer_type):
7373
def test_distributed(fp8_init, sharding_dims, fp_recipe, layer_type):
7474

7575
if fp_recipe in ("Float8BlockScaling", "NVFP4BlockScaling") and fp8_init:
76-
pytest.xfail(
77-
f"{fp_recipe} + fp8_init: test_fp8_fsdp2_allgather is currently failing."
78-
)
76+
pytest.xfail(f"{fp_recipe} + fp8_init: test_fp8_fsdp2_allgather is currently failing.")
7977

8078
_run_test(fp8_init, sharding_dims, fp_recipe, layer_type)
8179

0 commit comments

Comments
 (0)