@@ -59,6 +59,7 @@ class GroupOffloadingConfig:
5959 num_blocks_per_group : Optional [int ] = None
6060 offload_to_disk_path : Optional [str ] = None
6161 stream : Optional [Union [torch .cuda .Stream , torch .Stream ]] = None
62+ block_modules : Optional [List [str ]] = None
6263
6364
6465class ModuleGroup :
@@ -77,7 +78,7 @@ def __init__(
7778 low_cpu_mem_usage : bool = False ,
7879 onload_self : bool = True ,
7980 offload_to_disk_path : Optional [str ] = None ,
80- group_id : Optional [int ] = None ,
81+ group_id : Optional [Union [ int , str ] ] = None ,
8182 ) -> None :
8283 self .modules = modules
8384 self .offload_device = offload_device
@@ -453,6 +454,7 @@ def apply_group_offloading(
453454 record_stream : bool = False ,
454455 low_cpu_mem_usage : bool = False ,
455456 offload_to_disk_path : Optional [str ] = None ,
457+ block_modules : Optional [List [str ]] = None ,
456458) -> None :
457459 r"""
458460 Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and
@@ -510,6 +512,9 @@ def apply_group_offloading(
510512 If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
511513 option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
512514 the CPU memory is a bottleneck but may counteract the benefits of using streams.
515+ block_modules (`List[str]`, *optional*):
516+ List of module names that should be treated as blocks for offloading. If provided, only these modules
517+ will be considered for block-level offloading. If not provided, the default block detection logic will be used.
513518
514519 Example:
515520 ```python
@@ -561,6 +566,7 @@ def apply_group_offloading(
561566 record_stream = record_stream ,
562567 low_cpu_mem_usage = low_cpu_mem_usage ,
563568 offload_to_disk_path = offload_to_disk_path ,
569+ block_modules = block_modules ,
564570 )
565571 _apply_group_offloading (module , config )
566572
@@ -576,28 +582,123 @@ def _apply_group_offloading(module: torch.nn.Module, config: GroupOffloadingConf
576582
577583def _apply_group_offloading_block_level (module : torch .nn .Module , config : GroupOffloadingConfig ) -> None :
578584 r"""
579- This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to
580- the "leaf_level" offloading, which is more fine-grained, this offloading is done at the top-level blocks.
581- """
585+ This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks, and explicitly
586+ defined block modules. In comparison to the "leaf_level" offloading, which is more fine-grained, this offloading
587+ is done at the top-level blocks and modules specified in block_modules.
582588
589+ When block_modules is provided, only those modules will be treated as blocks for offloading. For each specified
590+ module, we either offload the entire submodule or recursively apply block offloading to it.
591+ """
583592 if config .stream is not None and config .num_blocks_per_group != 1 :
584593 logger .warning (
585594 f"Using streams is only supported for num_blocks_per_group=1. Got { config .num_blocks_per_group = } . Setting it to 1."
586595 )
587596 config .num_blocks_per_group = 1
588597
589- # Create module groups for ModuleList and Sequential blocks
598+ block_modules = set (config .block_modules ) if config .block_modules is not None else set ()
599+
600+ # Create module groups for ModuleList and Sequential blocks, and explicitly defined block modules
590601 modules_with_group_offloading = set ()
591602 unmatched_modules = []
592603 matched_module_groups = []
604+
593605 for name , submodule in module .named_children ():
594- if not isinstance (submodule , (torch .nn .ModuleList , torch .nn .Sequential )):
606+ # Check if this is an explicitly defined block module
607+ if name in block_modules :
608+ # Apply block offloading to the specified submodule
609+ _apply_block_offloading_to_submodule (
610+ submodule , name , config , modules_with_group_offloading , matched_module_groups
611+ )
612+ elif isinstance (submodule , (torch .nn .ModuleList , torch .nn .Sequential )):
613+ # Handle ModuleList and Sequential blocks as before
614+ for i in range (0 , len (submodule ), config .num_blocks_per_group ):
615+ current_modules = list (submodule [i : i + config .num_blocks_per_group ])
616+ if len (current_modules ) == 0 :
617+ continue
618+
619+ group_id = f"{ name } _{ i } _{ i + len (current_modules ) - 1 } "
620+ group = ModuleGroup (
621+ modules = current_modules ,
622+ offload_device = config .offload_device ,
623+ onload_device = config .onload_device ,
624+ offload_to_disk_path = config .offload_to_disk_path ,
625+ offload_leader = current_modules [- 1 ],
626+ onload_leader = current_modules [0 ],
627+ non_blocking = config .non_blocking ,
628+ stream = config .stream ,
629+ record_stream = config .record_stream ,
630+ low_cpu_mem_usage = config .low_cpu_mem_usage ,
631+ onload_self = True ,
632+ group_id = group_id ,
633+ )
634+ matched_module_groups .append (group )
635+ for j in range (i , i + len (current_modules )):
636+ modules_with_group_offloading .add (f"{ name } .{ j } " )
637+ else :
638+ # This is an unmatched module
595639 unmatched_modules .append ((name , submodule ))
596- modules_with_group_offloading .add (name )
597- continue
598640
641+ # Apply group offloading hooks to the module groups
642+ for i , group in enumerate (matched_module_groups ):
643+ for group_module in group .modules :
644+ _apply_group_offloading_hook (group_module , group , config = config )
645+
646+ # Parameters and Buffers of the top-level module need to be offloaded/onloaded separately
647+ # when the forward pass of this module is called. This is because the top-level module is not
648+ # part of any group (as doing so would lead to no VRAM savings).
649+ parameters = _gather_parameters_with_no_group_offloading_parent (module , modules_with_group_offloading )
650+ buffers = _gather_buffers_with_no_group_offloading_parent (module , modules_with_group_offloading )
651+ parameters = [param for _ , param in parameters ]
652+ buffers = [buffer for _ , buffer in buffers ]
653+
654+ # Create a group for the remaining unmatched submodules of the top-level
655+ # module so that they are on the correct device when the forward pass is called.
656+ unmatched_modules = [unmatched_module for _ , unmatched_module in unmatched_modules ]
657+ if len (unmatched_modules ) > 0 or len (parameters ) > 0 or len (buffers ) > 0 :
658+ unmatched_group = ModuleGroup (
659+ modules = unmatched_modules ,
660+ offload_device = config .offload_device ,
661+ onload_device = config .onload_device ,
662+ offload_to_disk_path = config .offload_to_disk_path ,
663+ offload_leader = module ,
664+ onload_leader = module ,
665+ parameters = parameters ,
666+ buffers = buffers ,
667+ non_blocking = False ,
668+ stream = None ,
669+ record_stream = False ,
670+ onload_self = True ,
671+ group_id = f"{ module .__class__ .__name__ } _unmatched_group" ,
672+ )
673+ if config .stream is None :
674+ _apply_group_offloading_hook (module , unmatched_group , config = config )
675+ else :
676+ _apply_lazy_group_offloading_hook (module , unmatched_group , config = config )
677+
678+
679+ def _apply_block_offloading_to_submodule (
680+ submodule : torch .nn .Module ,
681+ name : str ,
682+ config : GroupOffloadingConfig ,
683+ modules_with_group_offloading : Set [str ],
684+ matched_module_groups : List [ModuleGroup ],
685+ ) -> None :
686+ r"""
687+ Apply block offloading to a explicitly defined submodule. This function either:
688+ 1. Offloads the entire submodule as a single group ( SIMPLE APPROACH)
689+ 2. Recursively applies block offloading to the submodule
690+
691+ For now, we use the simple approach - offload the entire submodule as a single group.
692+ """
693+ # Simple approach: offload the entire submodule as a single group
694+ # Since AEs are typically small, this is usually okay
695+ if isinstance (submodule , (torch .nn .ModuleList , torch .nn .Sequential )):
696+ # If it's a ModuleList or Sequential, apply the normal block-level logic
599697 for i in range (0 , len (submodule ), config .num_blocks_per_group ):
600- current_modules = submodule [i : i + config .num_blocks_per_group ]
698+ current_modules = list (submodule [i : i + config .num_blocks_per_group ])
699+ if len (current_modules ) == 0 :
700+ continue
701+
601702 group_id = f"{ name } _{ i } _{ i + len (current_modules ) - 1 } "
602703 group = ModuleGroup (
603704 modules = current_modules ,
@@ -616,42 +717,24 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
616717 matched_module_groups .append (group )
617718 for j in range (i , i + len (current_modules )):
618719 modules_with_group_offloading .add (f"{ name } .{ j } " )
619-
620- # Apply group offloading hooks to the module groups
621- for i , group in enumerate (matched_module_groups ):
622- for group_module in group .modules :
623- _apply_group_offloading_hook (group_module , group , config = config )
624-
625- # Parameters and Buffers of the top-level module need to be offloaded/onloaded separately
626- # when the forward pass of this module is called. This is because the top-level module is not
627- # part of any group (as doing so would lead to no VRAM savings).
628- parameters = _gather_parameters_with_no_group_offloading_parent (module , modules_with_group_offloading )
629- buffers = _gather_buffers_with_no_group_offloading_parent (module , modules_with_group_offloading )
630- parameters = [param for _ , param in parameters ]
631- buffers = [buffer for _ , buffer in buffers ]
632-
633- # Create a group for the unmatched submodules of the top-level module so that they are on the correct
634- # device when the forward pass is called.
635- unmatched_modules = [unmatched_module for _ , unmatched_module in unmatched_modules ]
636- unmatched_group = ModuleGroup (
637- modules = unmatched_modules ,
638- offload_device = config .offload_device ,
639- onload_device = config .onload_device ,
640- offload_to_disk_path = config .offload_to_disk_path ,
641- offload_leader = module ,
642- onload_leader = module ,
643- parameters = parameters ,
644- buffers = buffers ,
645- non_blocking = False ,
646- stream = None ,
647- record_stream = False ,
648- onload_self = True ,
649- group_id = f"{ module .__class__ .__name__ } _unmatched_group" ,
650- )
651- if config .stream is None :
652- _apply_group_offloading_hook (module , unmatched_group , config = config )
653720 else :
654- _apply_lazy_group_offloading_hook (module , unmatched_group , config = config )
721+ # For other modules, treat the entire submodule as a single group
722+ group = ModuleGroup (
723+ modules = [submodule ],
724+ offload_device = config .offload_device ,
725+ onload_device = config .onload_device ,
726+ offload_to_disk_path = config .offload_to_disk_path ,
727+ offload_leader = submodule ,
728+ onload_leader = submodule ,
729+ non_blocking = config .non_blocking ,
730+ stream = config .stream ,
731+ record_stream = config .record_stream ,
732+ low_cpu_mem_usage = config .low_cpu_mem_usage ,
733+ onload_self = True ,
734+ group_id = name ,
735+ )
736+ matched_module_groups .append (group )
737+ modules_with_group_offloading .add (name )
655738
656739
657740def _apply_group_offloading_leaf_level (module : torch .nn .Module , config : GroupOffloadingConfig ) -> None :
0 commit comments