@@ -368,7 +368,7 @@ def test_block_level_pin_first_last_groups_stay_on_device(self):
368368 return
369369
370370 def first_param_device (mod ):
371- p = next (mod .parameters (), None ) # recurse=True by default
371+ p = next (mod .parameters (), None )
372372 self .assertIsNotNone (p , f"No parameters found for module { mod } " )
373373 return p .device
374374
@@ -390,8 +390,8 @@ def get_param_modules_from_exec_order(model):
390390 lazy_hook = root_registry .get_hook ("lazy_prefetch_group_offloading" )
391391 self .assertIsNotNone (lazy_hook , "lazy_prefetch_group_offloading hook was not registered" )
392392
393+ #record execution order with first forward
393394 with torch .no_grad ():
394- #record execution order with first forward
395395 model (self .input )
396396
397397 mods = [m for _ , m in lazy_hook .execution_order ]
@@ -402,14 +402,11 @@ def get_param_modules_from_exec_order(model):
402402
403403 first = param_mods [0 ]
404404 last = param_mods [- 1 ]
405- middle = param_mods [1 :- 1 ] # <- ALL middle layers
406- return first , middle , last
405+ middle_layers = param_mods [1 :- 1 ]
406+ return first , middle_layers , last
407407
408408 accel_type = torch .device (torch_device ).type
409409
410- # -------------------------
411- # No pin: everything on CPU
412- # -------------------------
413410 model_no_pin = self .get_model ()
414411 model_no_pin .enable_group_offload (
415412 torch_device ,
0 commit comments