Skip to content

Commit 0103b53

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 16cbc74 commit 0103b53

2 files changed

Lines changed: 38 additions & 57 deletions

File tree

tests/pytorch/distributed/run_fsdp2_fused_adam.py

Lines changed: 37 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,7 @@
3333
def get_recipe_from_string(recipe_name, fp8_format=Format.HYBRID):
3434
"""Convert recipe name to a recipe object."""
3535
if recipe_name == "delayed_scaling":
36-
return DelayedScaling(
37-
fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max"
38-
)
36+
return DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max")
3937
elif recipe_name == "current_scaling":
4038
return Float8CurrentScaling(fp8_format=fp8_format)
4139
elif recipe_name == "mx_fp8_block_scaling":
@@ -146,9 +144,7 @@ def test_fused_adam_fp8_master_weights(recipe=None):
146144
master_weight_dtype=torch.float32,
147145
)
148146

149-
x = torch.randn(
150-
SEQ_LEN, BATCH_PER_RANK, HIDDEN_SIZE, dtype=torch.bfloat16, device=device
151-
)
147+
x = torch.randn(SEQ_LEN, BATCH_PER_RANK, HIDDEN_SIZE, dtype=torch.bfloat16, device=device)
152148
target = torch.randn_like(x)
153149

154150
for step in range(NUM_STEPS):
@@ -162,16 +158,16 @@ def test_fused_adam_fp8_master_weights(recipe=None):
162158
# Verify optimizer states
163159
for param in model.parameters():
164160
state = optimizer.state[param]
165-
assert state["exp_avg"].dtype == torch.float32, (
166-
f"exp_avg dtype {state['exp_avg'].dtype}, expected float32"
167-
)
168-
assert state["exp_avg_sq"].dtype == torch.float32, (
169-
f"exp_avg_sq dtype {state['exp_avg_sq'].dtype}, expected float32"
170-
)
161+
assert (
162+
state["exp_avg"].dtype == torch.float32
163+
), f"exp_avg dtype {state['exp_avg'].dtype}, expected float32"
164+
assert (
165+
state["exp_avg_sq"].dtype == torch.float32
166+
), f"exp_avg_sq dtype {state['exp_avg_sq'].dtype}, expected float32"
171167
if "master_param" in state:
172-
assert state["master_param"].dtype == torch.float32, (
173-
f"master_param dtype {state['master_param'].dtype}, expected float32"
174-
)
168+
assert (
169+
state["master_param"].dtype == torch.float32
170+
), f"master_param dtype {state['master_param'].dtype}, expected float32"
175171

176172
# Verify FP8 params preserved
177173
qt_count = sum(
@@ -201,9 +197,7 @@ def test_fused_adam_bf16(recipe=None):
201197
master_weight_dtype=torch.float32,
202198
)
203199

204-
x = torch.randn(
205-
SEQ_LEN, BATCH_PER_RANK, HIDDEN_SIZE, dtype=torch.bfloat16, device=device
206-
)
200+
x = torch.randn(SEQ_LEN, BATCH_PER_RANK, HIDDEN_SIZE, dtype=torch.bfloat16, device=device)
207201
target = torch.randn_like(x)
208202

209203
losses = []
@@ -244,9 +238,7 @@ def test_fused_adam_fp8_no_master(recipe=None):
244238
master_weights=False,
245239
)
246240

247-
x = torch.randn(
248-
SEQ_LEN, BATCH_PER_RANK, HIDDEN_SIZE, dtype=torch.bfloat16, device=device
249-
)
241+
x = torch.randn(SEQ_LEN, BATCH_PER_RANK, HIDDEN_SIZE, dtype=torch.bfloat16, device=device)
250242
target = torch.randn_like(x)
251243

252244
for step in range(NUM_STEPS):
@@ -291,9 +283,7 @@ def test_fused_adam_bf16_store_param_remainders(recipe=None):
291283
store_param_remainders=True,
292284
)
293285

294-
x = torch.randn(
295-
SEQ_LEN, BATCH_PER_RANK, HIDDEN_SIZE, dtype=torch.bfloat16, device=device
296-
)
286+
x = torch.randn(SEQ_LEN, BATCH_PER_RANK, HIDDEN_SIZE, dtype=torch.bfloat16, device=device)
297287
target = torch.randn_like(x)
298288

299289
losses = []
@@ -308,24 +298,24 @@ def test_fused_adam_bf16_store_param_remainders(recipe=None):
308298

309299
# Verify model params are bf16 (required for store_param_remainders)
310300
for name, param in model.named_parameters():
311-
assert param.dtype == torch.bfloat16, (
312-
f"{name}: param dtype {param.dtype}, expected bfloat16"
313-
)
301+
assert (
302+
param.dtype == torch.bfloat16
303+
), f"{name}: param dtype {param.dtype}, expected bfloat16"
314304

315305
# Verify optimizer states
316306
for name, param in model.named_parameters():
317307
state = optimizer.state[param]
318-
assert state["exp_avg"].dtype == torch.float32, (
319-
f"{name}: exp_avg dtype {state['exp_avg'].dtype}, expected float32"
320-
)
321-
assert state["exp_avg_sq"].dtype == torch.float32, (
322-
f"{name}: exp_avg_sq dtype {state['exp_avg_sq'].dtype}, expected float32"
323-
)
308+
assert (
309+
state["exp_avg"].dtype == torch.float32
310+
), f"{name}: exp_avg dtype {state['exp_avg'].dtype}, expected float32"
311+
assert (
312+
state["exp_avg_sq"].dtype == torch.float32
313+
), f"{name}: exp_avg_sq dtype {state['exp_avg_sq'].dtype}, expected float32"
324314
# store_param_remainders stores master_param as int16 remainder bits
325315
if "master_param" in state:
326-
assert state["master_param"].dtype == torch.int16, (
327-
f"{name}: master_param dtype {state['master_param'].dtype}, expected int16"
328-
)
316+
assert (
317+
state["master_param"].dtype == torch.int16
318+
), f"{name}: master_param dtype {state['master_param'].dtype}, expected int16"
329319

330320
# Verify loss decreased (basic sanity)
331321
assert losses[-1] < losses[0], f"Loss did not decrease: {losses}"
@@ -351,9 +341,7 @@ def test_fuse_wgrad_accumulation(recipe=None):
351341

352342
# Allocate main_grad buffers on the DTensor params
353343
for param in model.parameters():
354-
param.main_grad = torch.zeros(
355-
param.shape, dtype=torch.float32, device=param.device
356-
)
344+
param.main_grad = torch.zeros(param.shape, dtype=torch.float32, device=param.device)
357345

358346
model = _shard_model(model, world_size)
359347

@@ -365,9 +353,7 @@ def test_fuse_wgrad_accumulation(recipe=None):
365353
use_decoupled_grad=True,
366354
)
367355

368-
x = torch.randn(
369-
SEQ_LEN, BATCH_PER_RANK, HIDDEN_SIZE, dtype=torch.bfloat16, device=device
370-
)
356+
x = torch.randn(SEQ_LEN, BATCH_PER_RANK, HIDDEN_SIZE, dtype=torch.bfloat16, device=device)
371357
target = torch.randn_like(x)
372358

373359
# This is currently failing during backward because the local Float8Tensor
@@ -409,9 +395,7 @@ def test_dcp_save_load(recipe=None):
409395
master_weight_dtype=torch.float32,
410396
)
411397

412-
x = torch.randn(
413-
SEQ_LEN, BATCH_PER_RANK, HIDDEN_SIZE, dtype=torch.bfloat16, device=device
414-
)
398+
x = torch.randn(SEQ_LEN, BATCH_PER_RANK, HIDDEN_SIZE, dtype=torch.bfloat16, device=device)
415399
target = torch.randn_like(x)
416400

417401
# Train a few steps to populate optimizer state.
@@ -434,9 +418,7 @@ def test_dcp_save_load(recipe=None):
434418
# the saved and loaded state_dict. It also means we need to load the state_dict back with
435419
# `strict=False` to avoid an error on missing entries.
436420
model_state = model.state_dict()
437-
model_state = {
438-
k: v for k, v in model_state.items() if not k.endswith("_extra_state")
439-
}
421+
model_state = {k: v for k, v in model_state.items() if not k.endswith("_extra_state")}
440422
else:
441423
model_state = model.state_dict()
442424

@@ -479,9 +461,9 @@ def test_dcp_save_load(recipe=None):
479461

480462
# Loss after loading should be comparable to loss before save
481463
# (not a massive spike indicating corrupted state).
482-
assert loss_after_load < loss_before_save * 2.0, (
483-
f"Loss spiked after checkpoint load: {loss_after_load} vs {loss_before_save}"
484-
)
464+
assert (
465+
loss_after_load < loss_before_save * 2.0
466+
), f"Loss spiked after checkpoint load: {loss_after_load} vs {loss_before_save}"
485467

486468
# Clean up checkpoint.
487469
import shutil
@@ -521,9 +503,7 @@ def test_safetensors_fp32_export(recipe=None):
521503
master_weight_dtype=torch.float32,
522504
)
523505

524-
x = torch.randn(
525-
SEQ_LEN, BATCH_PER_RANK, HIDDEN_SIZE, dtype=torch.bfloat16, device=device
526-
)
506+
x = torch.randn(SEQ_LEN, BATCH_PER_RANK, HIDDEN_SIZE, dtype=torch.bfloat16, device=device)
527507
target = torch.randn_like(x)
528508

529509
# Train a few steps.
@@ -560,9 +540,9 @@ def test_safetensors_fp32_export(recipe=None):
560540
save_file(fp32_state, save_path)
561541
loaded = load_file(save_path)
562542

563-
assert len(loaded) == len(fp32_state), (
564-
f"Loaded {len(loaded)} tensors, expected {len(fp32_state)}"
565-
)
543+
assert len(loaded) == len(
544+
fp32_state
545+
), f"Loaded {len(loaded)} tensors, expected {len(fp32_state)}"
566546
for k, v in loaded.items():
567547
assert v.dtype == torch.float32, f"{k}: expected float32, got {v.dtype}"
568548

tests/pytorch/distributed/test_torch_fsdp2.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def _run_fused_adam_test(test_name, recipe="delayed_scaling"):
7575

7676
result = subprocess.run(test_cmd, env=os.environ, check=True)
7777

78+
7879
@pytest.mark.skipif(NUM_PROCS < 2, reason="Requires 2+ GPUs")
7980
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
8081
@pytest.mark.parametrize("recipe", ("delayed_scaling", "current_scaling", "mx_fp8_block_scaling"))

0 commit comments

Comments
 (0)