Skip to content

Commit 4fc12e2

Browse files
committed
created test for pinning first and last block on device
1 parent c8656ed commit 4fc12e2

File tree

1 file changed

+84
-0
lines changed

1 file changed

+84
-0
lines changed

tests/hooks/test_group_offloading.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,3 +362,87 @@ def apply_layer_output_tracker_hook(model: DummyModelWithLayerNorm):
362362
self.assertLess(
363363
cumulated_absmax, 1e-5, f"Output differences for {name} exceeded threshold: {cumulated_absmax:.5f}"
364364
)
365+
366+
def test_block_level_pin_first_last_groups_stay_on_device(self):
367+
if torch.device(torch_device).type not in ["cuda", "xpu"]:
368+
return
369+
370+
def first_param_device(mod):
371+
p = next(mod.parameters(), None) # recurse=True by default
372+
self.assertIsNotNone(p, f"No parameters found for module {mod}")
373+
return p.device
374+
375+
def assert_all_modules_device(mods, expected_type: str, msg: str = ""):
376+
bad = []
377+
for i, m in enumerate(mods):
378+
dev_type = first_param_device(m).type
379+
if dev_type != expected_type:
380+
bad.append((i, m.__class__.__name__, dev_type))
381+
self.assertFalse(
382+
bad,
383+
(msg + "\n" if msg else "")
384+
+ f"Expected all modules on {expected_type}, but found mismatches: {bad}",
385+
)
386+
387+
def get_param_modules_from_exec_order(model):
388+
root_registry = HookRegistry.check_if_exists_or_initialize(model)
389+
390+
lazy_hook = root_registry.get_hook("lazy_prefetch_group_offloading")
391+
self.assertIsNotNone(lazy_hook, "lazy_prefetch_group_offloading hook was not registered")
392+
393+
with torch.no_grad():
394+
#record execution order with first forward
395+
model(self.input)
396+
397+
mods = [m for _, m in lazy_hook.execution_order]
398+
param_mods = [m for m in mods if next(m.parameters(), None) is not None]
399+
self.assertGreaterEqual(
400+
len(param_mods), 2, f"Expected >=2 param-bearing modules in execution_order, got {len(param_mods)}"
401+
)
402+
403+
first = param_mods[0]
404+
last = param_mods[-1]
405+
middle = param_mods[1:-1] # <- ALL middle layers
406+
return first, middle, last
407+
408+
accel_type = torch.device(torch_device).type
409+
410+
# -------------------------
411+
# No pin: everything on CPU
412+
# -------------------------
413+
model_no_pin = self.get_model()
414+
model_no_pin.enable_group_offload(
415+
torch_device,
416+
offload_type="block_level",
417+
num_blocks_per_group=1,
418+
use_stream=True,
419+
)
420+
model_no_pin.eval()
421+
first, middle, last = get_param_modules_from_exec_order(model_no_pin)
422+
423+
self.assertEqual(first_param_device(first).type, "cpu")
424+
self.assertEqual(first_param_device(last).type, "cpu")
425+
assert_all_modules_device(middle, "cpu", msg="No-pin: expected ALL middle layers on CPU")
426+
427+
model_pin = self.get_model()
428+
model_pin.enable_group_offload(
429+
torch_device,
430+
offload_type="block_level",
431+
num_blocks_per_group=1,
432+
use_stream=True,
433+
pin_first_last=True,
434+
)
435+
model_pin.eval()
436+
first, middle, last = get_param_modules_from_exec_order(model_pin)
437+
438+
self.assertEqual(first_param_device(first).type, accel_type)
439+
self.assertEqual(first_param_device(last).type, accel_type)
440+
assert_all_modules_device(middle, "cpu", msg="Pin: expected ALL middle layers on CPU")
441+
442+
# Should still hold after another invocation
443+
with torch.no_grad():
444+
model_pin(self.input)
445+
446+
self.assertEqual(first_param_device(first).type, accel_type)
447+
self.assertEqual(first_param_device(last).type, accel_type)
448+
assert_all_modules_device(middle, "cpu", msg="Pin (2nd forward): expected ALL middle layers on CPU")

0 commit comments

Comments
 (0)