Skip to content

Commit 3455019

Browse files
committed
Support explicit block modules in group offloading
1 parent 93e6d31 commit 3455019

File tree

4 files changed

+131
-44
lines changed

4 files changed

+131
-44
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 127 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ class GroupOffloadingConfig:
5959
num_blocks_per_group: Optional[int] = None
6060
offload_to_disk_path: Optional[str] = None
6161
stream: Optional[Union[torch.cuda.Stream, torch.Stream]] = None
62+
block_modules: Optional[List[str]] = None
6263

6364

6465
class ModuleGroup:
@@ -77,7 +78,7 @@ def __init__(
7778
low_cpu_mem_usage: bool = False,
7879
onload_self: bool = True,
7980
offload_to_disk_path: Optional[str] = None,
80-
group_id: Optional[int] = None,
81+
group_id: Optional[Union[int, str]] = None,
8182
) -> None:
8283
self.modules = modules
8384
self.offload_device = offload_device
@@ -453,6 +454,7 @@ def apply_group_offloading(
453454
record_stream: bool = False,
454455
low_cpu_mem_usage: bool = False,
455456
offload_to_disk_path: Optional[str] = None,
457+
block_modules: Optional[List[str]] = None,
456458
) -> None:
457459
r"""
458460
Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and
@@ -510,6 +512,9 @@ def apply_group_offloading(
510512
If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
511513
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
512514
the CPU memory is a bottleneck but may counteract the benefits of using streams.
515+
block_modules (`List[str]`, *optional*):
516+
List of module names that should be treated as blocks for offloading. If provided, only these modules
517+
will be considered for block-level offloading. If not provided, the default block detection logic will be used.
513518
514519
Example:
515520
```python
@@ -561,6 +566,7 @@ def apply_group_offloading(
561566
record_stream=record_stream,
562567
low_cpu_mem_usage=low_cpu_mem_usage,
563568
offload_to_disk_path=offload_to_disk_path,
569+
block_modules=block_modules,
564570
)
565571
_apply_group_offloading(module, config)
566572

@@ -576,28 +582,123 @@ def _apply_group_offloading(module: torch.nn.Module, config: GroupOffloadingConf
576582

577583
def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
578584
r"""
579-
This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to
580-
the "leaf_level" offloading, which is more fine-grained, this offloading is done at the top-level blocks.
581-
"""
585+
This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks, and explicitly
586+
defined block modules. In comparison to the "leaf_level" offloading, which is more fine-grained, this offloading
587+
is done at the top-level blocks and modules specified in block_modules.
582588
589+
When block_modules is provided, only those modules will be treated as blocks for offloading. For each specified
590+
module, we either offload the entire submodule or recursively apply block offloading to it.
591+
"""
583592
if config.stream is not None and config.num_blocks_per_group != 1:
584593
logger.warning(
585594
f"Using streams is only supported for num_blocks_per_group=1. Got {config.num_blocks_per_group=}. Setting it to 1."
586595
)
587596
config.num_blocks_per_group = 1
588597

589-
# Create module groups for ModuleList and Sequential blocks
598+
block_modules = set(config.block_modules) if config.block_modules is not None else set()
599+
600+
# Create module groups for ModuleList and Sequential blocks, and explicitly defined block modules
590601
modules_with_group_offloading = set()
591602
unmatched_modules = []
592603
matched_module_groups = []
604+
593605
for name, submodule in module.named_children():
594-
if not isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
606+
# Check if this is an explicitly defined block module
607+
if name in block_modules:
608+
# Apply block offloading to the specified submodule
609+
_apply_block_offloading_to_submodule(
610+
submodule, name, config, modules_with_group_offloading, matched_module_groups
611+
)
612+
elif isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
613+
# Handle ModuleList and Sequential blocks as before
614+
for i in range(0, len(submodule), config.num_blocks_per_group):
615+
current_modules = list(submodule[i : i + config.num_blocks_per_group])
616+
if len(current_modules) == 0:
617+
continue
618+
619+
group_id = f"{name}_{i}_{i + len(current_modules) - 1}"
620+
group = ModuleGroup(
621+
modules=current_modules,
622+
offload_device=config.offload_device,
623+
onload_device=config.onload_device,
624+
offload_to_disk_path=config.offload_to_disk_path,
625+
offload_leader=current_modules[-1],
626+
onload_leader=current_modules[0],
627+
non_blocking=config.non_blocking,
628+
stream=config.stream,
629+
record_stream=config.record_stream,
630+
low_cpu_mem_usage=config.low_cpu_mem_usage,
631+
onload_self=True,
632+
group_id=group_id,
633+
)
634+
matched_module_groups.append(group)
635+
for j in range(i, i + len(current_modules)):
636+
modules_with_group_offloading.add(f"{name}.{j}")
637+
else:
638+
# This is an unmatched module
595639
unmatched_modules.append((name, submodule))
596-
modules_with_group_offloading.add(name)
597-
continue
598640

641+
# Apply group offloading hooks to the module groups
642+
for i, group in enumerate(matched_module_groups):
643+
for group_module in group.modules:
644+
_apply_group_offloading_hook(group_module, group, config=config)
645+
646+
# Parameters and Buffers of the top-level module need to be offloaded/onloaded separately
647+
# when the forward pass of this module is called. This is because the top-level module is not
648+
# part of any group (as doing so would lead to no VRAM savings).
649+
parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading)
650+
buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading)
651+
parameters = [param for _, param in parameters]
652+
buffers = [buffer for _, buffer in buffers]
653+
654+
# Create a group for the remaining unmatched submodules of the top-level
655+
# module so that they are on the correct device when the forward pass is called.
656+
unmatched_modules = [unmatched_module for _, unmatched_module in unmatched_modules]
657+
if len(unmatched_modules) > 0 or len(parameters) > 0 or len(buffers) > 0:
658+
unmatched_group = ModuleGroup(
659+
modules=unmatched_modules,
660+
offload_device=config.offload_device,
661+
onload_device=config.onload_device,
662+
offload_to_disk_path=config.offload_to_disk_path,
663+
offload_leader=module,
664+
onload_leader=module,
665+
parameters=parameters,
666+
buffers=buffers,
667+
non_blocking=False,
668+
stream=None,
669+
record_stream=False,
670+
onload_self=True,
671+
group_id=f"{module.__class__.__name__}_unmatched_group",
672+
)
673+
if config.stream is None:
674+
_apply_group_offloading_hook(module, unmatched_group, config=config)
675+
else:
676+
_apply_lazy_group_offloading_hook(module, unmatched_group, config=config)
677+
678+
679+
def _apply_block_offloading_to_submodule(
680+
submodule: torch.nn.Module,
681+
name: str,
682+
config: GroupOffloadingConfig,
683+
modules_with_group_offloading: Set[str],
684+
matched_module_groups: List[ModuleGroup],
685+
) -> None:
686+
r"""
687+
Apply block offloading to a explicitly defined submodule. This function either:
688+
1. Offloads the entire submodule as a single group ( SIMPLE APPROACH)
689+
2. Recursively applies block offloading to the submodule
690+
691+
For now, we use the simple approach - offload the entire submodule as a single group.
692+
"""
693+
# Simple approach: offload the entire submodule as a single group
694+
# Since AEs are typically small, this is usually okay
695+
if isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
696+
# If it's a ModuleList or Sequential, apply the normal block-level logic
599697
for i in range(0, len(submodule), config.num_blocks_per_group):
600-
current_modules = submodule[i : i + config.num_blocks_per_group]
698+
current_modules = list(submodule[i : i + config.num_blocks_per_group])
699+
if len(current_modules) == 0:
700+
continue
701+
601702
group_id = f"{name}_{i}_{i + len(current_modules) - 1}"
602703
group = ModuleGroup(
603704
modules=current_modules,
@@ -616,42 +717,24 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
616717
matched_module_groups.append(group)
617718
for j in range(i, i + len(current_modules)):
618719
modules_with_group_offloading.add(f"{name}.{j}")
619-
620-
# Apply group offloading hooks to the module groups
621-
for i, group in enumerate(matched_module_groups):
622-
for group_module in group.modules:
623-
_apply_group_offloading_hook(group_module, group, config=config)
624-
625-
# Parameters and Buffers of the top-level module need to be offloaded/onloaded separately
626-
# when the forward pass of this module is called. This is because the top-level module is not
627-
# part of any group (as doing so would lead to no VRAM savings).
628-
parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading)
629-
buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading)
630-
parameters = [param for _, param in parameters]
631-
buffers = [buffer for _, buffer in buffers]
632-
633-
# Create a group for the unmatched submodules of the top-level module so that they are on the correct
634-
# device when the forward pass is called.
635-
unmatched_modules = [unmatched_module for _, unmatched_module in unmatched_modules]
636-
unmatched_group = ModuleGroup(
637-
modules=unmatched_modules,
638-
offload_device=config.offload_device,
639-
onload_device=config.onload_device,
640-
offload_to_disk_path=config.offload_to_disk_path,
641-
offload_leader=module,
642-
onload_leader=module,
643-
parameters=parameters,
644-
buffers=buffers,
645-
non_blocking=False,
646-
stream=None,
647-
record_stream=False,
648-
onload_self=True,
649-
group_id=f"{module.__class__.__name__}_unmatched_group",
650-
)
651-
if config.stream is None:
652-
_apply_group_offloading_hook(module, unmatched_group, config=config)
653720
else:
654-
_apply_lazy_group_offloading_hook(module, unmatched_group, config=config)
721+
# For other modules, treat the entire submodule as a single group
722+
group = ModuleGroup(
723+
modules=[submodule],
724+
offload_device=config.offload_device,
725+
onload_device=config.onload_device,
726+
offload_to_disk_path=config.offload_to_disk_path,
727+
offload_leader=submodule,
728+
onload_leader=submodule,
729+
non_blocking=config.non_blocking,
730+
stream=config.stream,
731+
record_stream=config.record_stream,
732+
low_cpu_mem_usage=config.low_cpu_mem_usage,
733+
onload_self=True,
734+
group_id=name,
735+
)
736+
matched_module_groups.append(group)
737+
modules_with_group_offloading.add(name)
655738

656739

657740
def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:

src/diffusers/models/autoencoders/autoencoder_kl.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ class AutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModel
7272

7373
_supports_gradient_checkpointing = True
7474
_no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D"]
75+
_group_offload_block_modules = ["quant_conv", "post_quant_conv", "encoder", "decoder"]
7576

7677
@register_to_config
7778
def __init__(

src/diffusers/models/autoencoders/autoencoder_kl_wan.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -964,6 +964,7 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
964964
# keys toignore when AlignDeviceHook moves inputs/outputs between devices
965965
# these are shared mutable state modified in-place
966966
_skip_keys = ["feat_cache", "feat_idx"]
967+
_group_offload_block_modules = ["quant_conv", "post_quant_conv", "encoder", "decoder"]
967968

968969
@register_to_config
969970
def __init__(

src/diffusers/models/modeling_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,7 @@ def enable_group_offload(
570570
f"`_supports_group_offloading` to `True` in the class definition. If you believe this is a mistake, please "
571571
f"open an issue at https://github.com/huggingface/diffusers/issues."
572572
)
573+
block_modules = getattr(self, "_group_offload_block_modules", None)
573574
apply_group_offloading(
574575
module=self,
575576
onload_device=onload_device,
@@ -581,6 +582,7 @@ def enable_group_offload(
581582
record_stream=record_stream,
582583
low_cpu_mem_usage=low_cpu_mem_usage,
583584
offload_to_disk_path=offload_to_disk_path,
585+
block_modules=block_modules,
584586
)
585587

586588
def set_attention_backend(self, backend: str) -> None:

0 commit comments

Comments
 (0)