@@ -362,3 +362,87 @@ def apply_layer_output_tracker_hook(model: DummyModelWithLayerNorm):
362362 self .assertLess (
363363 cumulated_absmax , 1e-5 , f"Output differences for { name } exceeded threshold: { cumulated_absmax :.5f} "
364364 )
365+
366+ def test_block_level_pin_first_last_groups_stay_on_device (self ):
367+ if torch .device (torch_device ).type not in ["cuda" , "xpu" ]:
368+ return
369+
370+ def first_param_device (mod ):
371+ p = next (mod .parameters (), None ) # recurse=True by default
372+ self .assertIsNotNone (p , f"No parameters found for module { mod } " )
373+ return p .device
374+
375+ def assert_all_modules_device (mods , expected_type : str , msg : str = "" ):
376+ bad = []
377+ for i , m in enumerate (mods ):
378+ dev_type = first_param_device (m ).type
379+ if dev_type != expected_type :
380+ bad .append ((i , m .__class__ .__name__ , dev_type ))
381+ self .assertFalse (
382+ bad ,
383+ (msg + "\n " if msg else "" )
384+ + f"Expected all modules on { expected_type } , but found mismatches: { bad } " ,
385+ )
386+
387+ def get_param_modules_from_exec_order (model ):
388+ root_registry = HookRegistry .check_if_exists_or_initialize (model )
389+
390+ lazy_hook = root_registry .get_hook ("lazy_prefetch_group_offloading" )
391+ self .assertIsNotNone (lazy_hook , "lazy_prefetch_group_offloading hook was not registered" )
392+
393+ with torch .no_grad ():
394+ #record execution order with first forward
395+ model (self .input )
396+
397+ mods = [m for _ , m in lazy_hook .execution_order ]
398+ param_mods = [m for m in mods if next (m .parameters (), None ) is not None ]
399+ self .assertGreaterEqual (
400+ len (param_mods ), 2 , f"Expected >=2 param-bearing modules in execution_order, got { len (param_mods )} "
401+ )
402+
403+ first = param_mods [0 ]
404+ last = param_mods [- 1 ]
405+ middle = param_mods [1 :- 1 ] # <- ALL middle layers
406+ return first , middle , last
407+
408+ accel_type = torch .device (torch_device ).type
409+
410+ # -------------------------
411+ # No pin: everything on CPU
412+ # -------------------------
413+ model_no_pin = self .get_model ()
414+ model_no_pin .enable_group_offload (
415+ torch_device ,
416+ offload_type = "block_level" ,
417+ num_blocks_per_group = 1 ,
418+ use_stream = True ,
419+ )
420+ model_no_pin .eval ()
421+ first , middle , last = get_param_modules_from_exec_order (model_no_pin )
422+
423+ self .assertEqual (first_param_device (first ).type , "cpu" )
424+ self .assertEqual (first_param_device (last ).type , "cpu" )
425+ assert_all_modules_device (middle , "cpu" , msg = "No-pin: expected ALL middle layers on CPU" )
426+
427+ model_pin = self .get_model ()
428+ model_pin .enable_group_offload (
429+ torch_device ,
430+ offload_type = "block_level" ,
431+ num_blocks_per_group = 1 ,
432+ use_stream = True ,
433+ pin_first_last = True ,
434+ )
435+ model_pin .eval ()
436+ first , middle , last = get_param_modules_from_exec_order (model_pin )
437+
438+ self .assertEqual (first_param_device (first ).type , accel_type )
439+ self .assertEqual (first_param_device (last ).type , accel_type )
440+ assert_all_modules_device (middle , "cpu" , msg = "Pin: expected ALL middle layers on CPU" )
441+
442+ # Should still hold after another invocation
443+ with torch .no_grad ():
444+ model_pin (self .input )
445+
446+ self .assertEqual (first_param_device (first ).type , accel_type )
447+ self .assertEqual (first_param_device (last ).type , accel_type )
448+ assert_all_modules_device (middle , "cpu" , msg = "Pin (2nd forward): expected ALL middle layers on CPU" )
0 commit comments