Skip to content

Commit 93e6d31

Browse files
committed
fix comments in tests for cleaner code
1 parent 4fc12e2 commit 93e6d31

File tree

1 file changed

+4
-7
lines changed

1 file changed

+4
-7
lines changed

tests/hooks/test_group_offloading.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ def test_block_level_pin_first_last_groups_stay_on_device(self):
368368
return
369369

370370
def first_param_device(mod):
371-
p = next(mod.parameters(), None) # recurse=True by default
371+
p = next(mod.parameters(), None)
372372
self.assertIsNotNone(p, f"No parameters found for module {mod}")
373373
return p.device
374374

@@ -390,8 +390,8 @@ def get_param_modules_from_exec_order(model):
390390
lazy_hook = root_registry.get_hook("lazy_prefetch_group_offloading")
391391
self.assertIsNotNone(lazy_hook, "lazy_prefetch_group_offloading hook was not registered")
392392

393+
#record execution order with first forward
393394
with torch.no_grad():
394-
#record execution order with first forward
395395
model(self.input)
396396

397397
mods = [m for _, m in lazy_hook.execution_order]
@@ -402,14 +402,11 @@ def get_param_modules_from_exec_order(model):
402402

403403
first = param_mods[0]
404404
last = param_mods[-1]
405-
middle = param_mods[1:-1] # <- ALL middle layers
406-
return first, middle, last
405+
middle_layers = param_mods[1:-1]
406+
return first, middle_layers, last
407407

408408
accel_type = torch.device(torch_device).type
409409

410-
# -------------------------
411-
# No pin: everything on CPU
412-
# -------------------------
413410
model_no_pin = self.get_model()
414411
model_no_pin.enable_group_offload(
415412
torch_device,

0 commit comments

Comments
 (0)