3333def get_recipe_from_string (recipe_name , fp8_format = Format .HYBRID ):
3434 """Convert recipe name to a recipe object."""
3535 if recipe_name == "delayed_scaling" :
36- return DelayedScaling (
37- fp8_format = fp8_format , amax_history_len = 16 , amax_compute_algo = "max"
38- )
36+ return DelayedScaling (fp8_format = fp8_format , amax_history_len = 16 , amax_compute_algo = "max" )
3937 elif recipe_name == "current_scaling" :
4038 return Float8CurrentScaling (fp8_format = fp8_format )
4139 elif recipe_name == "mx_fp8_block_scaling" :
@@ -146,9 +144,7 @@ def test_fused_adam_fp8_master_weights(recipe=None):
146144 master_weight_dtype = torch .float32 ,
147145 )
148146
149- x = torch .randn (
150- SEQ_LEN , BATCH_PER_RANK , HIDDEN_SIZE , dtype = torch .bfloat16 , device = device
151- )
147+ x = torch .randn (SEQ_LEN , BATCH_PER_RANK , HIDDEN_SIZE , dtype = torch .bfloat16 , device = device )
152148 target = torch .randn_like (x )
153149
154150 for step in range (NUM_STEPS ):
@@ -162,16 +158,16 @@ def test_fused_adam_fp8_master_weights(recipe=None):
162158 # Verify optimizer states
163159 for param in model .parameters ():
164160 state = optimizer .state [param ]
165- assert state [ "exp_avg" ]. dtype == torch . float32 , (
166- f"exp_avg dtype { state [' exp_avg' ].dtype } , expected float32"
167- )
168- assert state [ "exp_avg_sq" ]. dtype == torch . float32 , (
169- f"exp_avg_sq dtype { state [' exp_avg_sq' ].dtype } , expected float32"
170- )
161+ assert (
162+ state [" exp_avg" ].dtype == torch . float32
163+ ), f"exp_avg dtype { state [ 'exp_avg' ]. dtype } , expected float32"
164+ assert (
165+ state [" exp_avg_sq" ].dtype == torch . float32
166+ ), f"exp_avg_sq dtype { state [ 'exp_avg_sq' ]. dtype } , expected float32"
171167 if "master_param" in state :
172- assert state [ "master_param" ]. dtype == torch . float32 , (
173- f"master_param dtype { state [' master_param' ].dtype } , expected float32"
174- )
168+ assert (
169+ state [" master_param" ].dtype == torch . float32
170+ ), f"master_param dtype { state [ 'master_param' ]. dtype } , expected float32"
175171
176172 # Verify FP8 params preserved
177173 qt_count = sum (
@@ -201,9 +197,7 @@ def test_fused_adam_bf16(recipe=None):
201197 master_weight_dtype = torch .float32 ,
202198 )
203199
204- x = torch .randn (
205- SEQ_LEN , BATCH_PER_RANK , HIDDEN_SIZE , dtype = torch .bfloat16 , device = device
206- )
200+ x = torch .randn (SEQ_LEN , BATCH_PER_RANK , HIDDEN_SIZE , dtype = torch .bfloat16 , device = device )
207201 target = torch .randn_like (x )
208202
209203 losses = []
@@ -244,9 +238,7 @@ def test_fused_adam_fp8_no_master(recipe=None):
244238 master_weights = False ,
245239 )
246240
247- x = torch .randn (
248- SEQ_LEN , BATCH_PER_RANK , HIDDEN_SIZE , dtype = torch .bfloat16 , device = device
249- )
241+ x = torch .randn (SEQ_LEN , BATCH_PER_RANK , HIDDEN_SIZE , dtype = torch .bfloat16 , device = device )
250242 target = torch .randn_like (x )
251243
252244 for step in range (NUM_STEPS ):
@@ -291,9 +283,7 @@ def test_fused_adam_bf16_store_param_remainders(recipe=None):
291283 store_param_remainders = True ,
292284 )
293285
294- x = torch .randn (
295- SEQ_LEN , BATCH_PER_RANK , HIDDEN_SIZE , dtype = torch .bfloat16 , device = device
296- )
286+ x = torch .randn (SEQ_LEN , BATCH_PER_RANK , HIDDEN_SIZE , dtype = torch .bfloat16 , device = device )
297287 target = torch .randn_like (x )
298288
299289 losses = []
@@ -308,24 +298,24 @@ def test_fused_adam_bf16_store_param_remainders(recipe=None):
308298
309299 # Verify model params are bf16 (required for store_param_remainders)
310300 for name , param in model .named_parameters ():
311- assert param . dtype == torch . bfloat16 , (
312- f" { name } : param dtype { param .dtype } , expected bfloat16"
313- )
301+ assert (
302+ param .dtype == torch . bfloat16
303+ ), f" { name } : param dtype { param . dtype } , expected bfloat16"
314304
315305 # Verify optimizer states
316306 for name , param in model .named_parameters ():
317307 state = optimizer .state [param ]
318- assert state [ "exp_avg" ]. dtype == torch . float32 , (
319- f" { name } : exp_avg dtype { state [' exp_avg' ].dtype } , expected float32"
320- )
321- assert state [ "exp_avg_sq" ]. dtype == torch . float32 , (
322- f" { name } : exp_avg_sq dtype { state [' exp_avg_sq' ].dtype } , expected float32"
323- )
308+ assert (
309+ state [" exp_avg" ].dtype == torch . float32
310+ ), f" { name } : exp_avg dtype { state [ 'exp_avg' ]. dtype } , expected float32"
311+ assert (
312+ state [" exp_avg_sq" ].dtype == torch . float32
313+ ), f" { name } : exp_avg_sq dtype { state [ 'exp_avg_sq' ]. dtype } , expected float32"
324314 # store_param_remainders stores master_param as int16 remainder bits
325315 if "master_param" in state :
326- assert state [ "master_param" ]. dtype == torch . int16 , (
327- f" { name } : master_param dtype { state [' master_param' ].dtype } , expected int16"
328- )
316+ assert (
317+ state [" master_param" ].dtype == torch . int16
318+ ), f" { name } : master_param dtype { state [ 'master_param' ]. dtype } , expected int16"
329319
330320 # Verify loss decreased (basic sanity)
331321 assert losses [- 1 ] < losses [0 ], f"Loss did not decrease: { losses } "
@@ -351,9 +341,7 @@ def test_fuse_wgrad_accumulation(recipe=None):
351341
352342 # Allocate main_grad buffers on the DTensor params
353343 for param in model .parameters ():
354- param .main_grad = torch .zeros (
355- param .shape , dtype = torch .float32 , device = param .device
356- )
344+ param .main_grad = torch .zeros (param .shape , dtype = torch .float32 , device = param .device )
357345
358346 model = _shard_model (model , world_size )
359347
@@ -365,9 +353,7 @@ def test_fuse_wgrad_accumulation(recipe=None):
365353 use_decoupled_grad = True ,
366354 )
367355
368- x = torch .randn (
369- SEQ_LEN , BATCH_PER_RANK , HIDDEN_SIZE , dtype = torch .bfloat16 , device = device
370- )
356+ x = torch .randn (SEQ_LEN , BATCH_PER_RANK , HIDDEN_SIZE , dtype = torch .bfloat16 , device = device )
371357 target = torch .randn_like (x )
372358
373359 # This is currently failing during backward because the local Float8Tensor
@@ -409,9 +395,7 @@ def test_dcp_save_load(recipe=None):
409395 master_weight_dtype = torch .float32 ,
410396 )
411397
412- x = torch .randn (
413- SEQ_LEN , BATCH_PER_RANK , HIDDEN_SIZE , dtype = torch .bfloat16 , device = device
414- )
398+ x = torch .randn (SEQ_LEN , BATCH_PER_RANK , HIDDEN_SIZE , dtype = torch .bfloat16 , device = device )
415399 target = torch .randn_like (x )
416400
417401 # Train a few steps to populate optimizer state.
@@ -434,9 +418,7 @@ def test_dcp_save_load(recipe=None):
434418 # the saved and loaded state_dict. It also means we need to load the state_dict back with
435419 # `strict=False` to avoid an error on missing entries.
436420 model_state = model .state_dict ()
437- model_state = {
438- k : v for k , v in model_state .items () if not k .endswith ("_extra_state" )
439- }
421+ model_state = {k : v for k , v in model_state .items () if not k .endswith ("_extra_state" )}
440422 else :
441423 model_state = model .state_dict ()
442424
@@ -479,9 +461,9 @@ def test_dcp_save_load(recipe=None):
479461
480462 # Loss after loading should be comparable to loss before save
481463 # (not a massive spike indicating corrupted state).
482- assert loss_after_load < loss_before_save * 2.0 , (
483- f"Loss spiked after checkpoint load: { loss_after_load } vs { loss_before_save } "
484- )
464+ assert (
465+ loss_after_load < loss_before_save * 2.0
466+ ), f"Loss spiked after checkpoint load: { loss_after_load } vs { loss_before_save } "
485467
486468 # Clean up checkpoint.
487469 import shutil
@@ -521,9 +503,7 @@ def test_safetensors_fp32_export(recipe=None):
521503 master_weight_dtype = torch .float32 ,
522504 )
523505
524- x = torch .randn (
525- SEQ_LEN , BATCH_PER_RANK , HIDDEN_SIZE , dtype = torch .bfloat16 , device = device
526- )
506+ x = torch .randn (SEQ_LEN , BATCH_PER_RANK , HIDDEN_SIZE , dtype = torch .bfloat16 , device = device )
527507 target = torch .randn_like (x )
528508
529509 # Train a few steps.
@@ -560,9 +540,9 @@ def test_safetensors_fp32_export(recipe=None):
560540 save_file (fp32_state , save_path )
561541 loaded = load_file (save_path )
562542
563- assert len (loaded ) == len (fp32_state ), (
564- f"Loaded { len ( loaded ) } tensors, expected { len ( fp32_state ) } "
565- )
543+ assert len (loaded ) == len (
544+ fp32_state
545+ ), f"Loaded { len ( loaded ) } tensors, expected { len ( fp32_state ) } "
566546 for k , v in loaded .items ():
567547 assert v .dtype == torch .float32 , f"{ k } : expected float32, got { v .dtype } "
568548
0 commit comments