diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 38f291f5203c..36b09cb692dc 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -17,7 +17,7 @@ from contextlib import contextmanager, nullcontext from dataclasses import dataclass from enum import Enum -from typing import Dict, List, Optional, Set, Tuple, Union +from typing import Callable, Dict, List, Optional, Set, Tuple, Union import safetensors.torch import torch @@ -59,6 +59,8 @@ class GroupOffloadingConfig: num_blocks_per_group: Optional[int] = None offload_to_disk_path: Optional[str] = None stream: Optional[Union[torch.cuda.Stream, torch.Stream]] = None + block_modules: Optional[List[str]] = None + pin_groups: Optional[Union[str, Callable]] = None class ModuleGroup: @@ -77,7 +79,7 @@ def __init__( low_cpu_mem_usage: bool = False, onload_self: bool = True, offload_to_disk_path: Optional[str] = None, - group_id: Optional[int] = None, + group_id: Optional[Union[int, str]] = None, ) -> None: self.modules = modules self.offload_device = offload_device @@ -91,6 +93,7 @@ def __init__( self.record_stream = record_stream self.onload_self = onload_self self.low_cpu_mem_usage = low_cpu_mem_usage + self.pinned = False self.offload_to_disk_path = offload_to_disk_path self._is_offloaded_to_disk = False @@ -296,6 +299,24 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs): if self.group.onload_leader is None: self.group.onload_leader = module + if self.group.pinned: + if self.group.onload_leader == module and not self._is_group_on_device(): + self.group.onload_() + + should_onload_next_group = self.next_group is not None and not self.next_group.onload_self + if should_onload_next_group: + self.next_group.onload_() + + should_synchronize = ( + not self.group.onload_self and self.group.stream is not None and not should_onload_next_group + ) + if should_synchronize: + self.group.stream.synchronize() + + args = send_to_device(args, self.group.onload_device, non_blocking=self.group.non_blocking) + kwargs = send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking) + return args, kwargs + # If the current module is the onload_leader of the group, we onload the group if it is supposed # to onload itself. In the case of using prefetching with streams, we onload the next group if # it is not supposed to onload itself. @@ -324,10 +345,26 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs): return args, kwargs def post_forward(self, module: torch.nn.Module, output): + if self.group.pinned: + return output + if self.group.offload_leader == module: self.group.offload_() return output + def _is_group_on_device(self) -> bool: + tensors = [] + for group_module in self.group.modules: + tensors.extend(list(group_module.parameters())) + tensors.extend(list(group_module.buffers())) + tensors.extend(self.group.parameters) + tensors.extend(self.group.buffers) + + if len(tensors) == 0: + return True + + return all(t.device == self.group.onload_device for t in tensors) + class LazyPrefetchGroupOffloadingHook(ModelHook): r""" @@ -339,9 +376,10 @@ class LazyPrefetchGroupOffloadingHook(ModelHook): _is_stateful = False - def __init__(self): + def __init__(self, pin_groups: Optional[Union[str, Callable]] = None): self.execution_order: List[Tuple[str, torch.nn.Module]] = [] self._layer_execution_tracker_module_names = set() + self.pin_groups = pin_groups def initialize_hook(self, module): def make_execution_order_update_callback(current_name, current_submodule): @@ -423,6 +461,50 @@ def post_forward(self, module, output): group_offloading_hooks[i].next_group = group_offloading_hooks[i + 1].group group_offloading_hooks[i].next_group.onload_self = False + if self.pin_groups is not None and num_executed > 0: + param_exec_info = [] + for idx, ((name, submodule), hook) in enumerate(zip(self.execution_order, group_offloading_hooks)): + if hook is None: + continue + if next(submodule.parameters(), None) is None and next(submodule.buffers(), None) is None: + continue + param_exec_info.append((name, submodule, hook)) + + num_param_modules = len(param_exec_info) + if num_param_modules > 0: + pinned_indices = set() + if isinstance(self.pin_groups, str): + if self.pin_groups == "all": + pinned_indices = set(range(num_param_modules)) + elif self.pin_groups == "first_last": + pinned_indices.add(0) + pinned_indices.add(num_param_modules - 1) + elif callable(self.pin_groups): + for idx, (name, submodule, _) in enumerate(param_exec_info): + should_pin = False + try: + should_pin = bool(self.pin_groups(submodule)) + except TypeError: + try: + should_pin = bool(self.pin_groups(name, submodule)) + except TypeError: + should_pin = bool(self.pin_groups(name, submodule, idx)) + if should_pin: + pinned_indices.add(idx) + + pinned_groups = set() + for idx in pinned_indices: + if idx >= num_param_modules: + continue + group = param_exec_info[idx][2].group + if group not in pinned_groups: + group.pinned = True + pinned_groups.add(group) + + for group in pinned_groups: + if group.offload_device != group.onload_device: + group.onload_() + return output @@ -453,6 +535,8 @@ def apply_group_offloading( record_stream: bool = False, low_cpu_mem_usage: bool = False, offload_to_disk_path: Optional[str] = None, + block_modules: Optional[List[str]] = None, + pin_groups: Optional[Union[str, Callable]] = None, ) -> None: r""" Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and @@ -510,6 +594,13 @@ def apply_group_offloading( If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when the CPU memory is a bottleneck but may counteract the benefits of using streams. + block_modules (`List[str]`, *optional*): + List of module names that should be treated as blocks for offloading. If provided, only these modules + will be considered for block-level offloading. If not provided, the default block detection logic will be used. + pin_groups (`"first_last"` or `"all"` or `Callable`, *optional*, defaults to `None`): + Optionally keeps selected groups on the onload device permanently. Use `"first_last"` to pin the first + and last parameter-bearing groups, `"all"` to pin every parameter-bearing group, or pass a callable that + receives a module (and optionally the module name and index) and returns `True` to pin that group. Example: ```python @@ -549,6 +640,16 @@ def apply_group_offloading( if offload_type == GroupOffloadingType.BLOCK_LEVEL and num_blocks_per_group is None: raise ValueError("`num_blocks_per_group` must be provided when using `offload_type='block_level'.") + normalized_pin_groups = pin_groups + if isinstance(pin_groups, str): + normalized_pin_groups = pin_groups.lower() + if normalized_pin_groups not in {"first_last", "all"}: + raise ValueError("`pin_groups` must be one of `None`, 'first_last', 'all', or a callable.") + elif pin_groups is not None and not callable(pin_groups): + raise ValueError("`pin_groups` must be one of `None`, 'first_last', 'all', or a callable.") + + pin_groups = normalized_pin_groups + _raise_error_if_accelerate_model_or_sequential_hook_present(module) config = GroupOffloadingConfig( @@ -561,6 +662,8 @@ def apply_group_offloading( record_stream=record_stream, low_cpu_mem_usage=low_cpu_mem_usage, offload_to_disk_path=offload_to_disk_path, + block_modules=block_modules, + pin_groups=pin_groups, ) _apply_group_offloading(module, config) @@ -576,28 +679,124 @@ def _apply_group_offloading(module: torch.nn.Module, config: GroupOffloadingConf def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None: r""" - This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to - the "leaf_level" offloading, which is more fine-grained, this offloading is done at the top-level blocks. - """ + This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks, and explicitly + defined block modules. In comparison to the "leaf_level" offloading, which is more fine-grained, this offloading + is done at the top-level blocks and modules specified in block_modules. + When block_modules is provided, only those modules will be treated as blocks for offloading. For each specified + module, we either offload the entire submodule or recursively apply block offloading to it. + """ if config.stream is not None and config.num_blocks_per_group != 1: logger.warning( f"Using streams is only supported for num_blocks_per_group=1. Got {config.num_blocks_per_group=}. Setting it to 1." ) config.num_blocks_per_group = 1 - # Create module groups for ModuleList and Sequential blocks + block_modules = set(config.block_modules) if config.block_modules is not None else set() + + # Create module groups for ModuleList and Sequential blocks, and explicitly defined block modules modules_with_group_offloading = set() unmatched_modules = [] matched_module_groups = [] + for name, submodule in module.named_children(): - if not isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)): + # Check if this is an explicitly defined block module + if name in block_modules: + # Apply block offloading to the specified submodule + _apply_block_offloading_to_submodule( + submodule, name, config, modules_with_group_offloading, matched_module_groups + ) + elif isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)): + # Handle ModuleList and Sequential blocks as before + for i in range(0, len(submodule), config.num_blocks_per_group): + current_modules = list(submodule[i : i + config.num_blocks_per_group]) + if len(current_modules) == 0: + continue + + group_id = f"{name}_{i}_{i + len(current_modules) - 1}" + group = ModuleGroup( + modules=current_modules, + offload_device=config.offload_device, + onload_device=config.onload_device, + offload_to_disk_path=config.offload_to_disk_path, + offload_leader=current_modules[-1], + onload_leader=current_modules[0], + non_blocking=config.non_blocking, + stream=config.stream, + record_stream=config.record_stream, + low_cpu_mem_usage=config.low_cpu_mem_usage, + onload_self=True, + group_id=group_id, + ) + matched_module_groups.append(group) + for j in range(i, i + len(current_modules)): + modules_with_group_offloading.add(f"{name}.{j}") + else: + # This is an unmatched module unmatched_modules.append((name, submodule)) modules_with_group_offloading.add(name) - continue + # Apply group offloading hooks to the module groups + for i, group in enumerate(matched_module_groups): + for group_module in group.modules: + _apply_group_offloading_hook(group_module, group, config=config) + + # Parameters and Buffers of the top-level module need to be offloaded/onloaded separately + # when the forward pass of this module is called. This is because the top-level module is not + # part of any group (as doing so would lead to no VRAM savings). + parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading) + buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading) + parameters = [param for _, param in parameters] + buffers = [buffer for _, buffer in buffers] + + # Create a group for the remaining unmatched submodules of the top-level + # module so that they are on the correct device when the forward pass is called. + unmatched_modules = [unmatched_module for _, unmatched_module in unmatched_modules] + if len(unmatched_modules) > 0 or len(parameters) > 0 or len(buffers) > 0: + unmatched_group = ModuleGroup( + modules=unmatched_modules, + offload_device=config.offload_device, + onload_device=config.onload_device, + offload_to_disk_path=config.offload_to_disk_path, + offload_leader=module, + onload_leader=module, + parameters=parameters, + buffers=buffers, + non_blocking=False, + stream=None, + record_stream=False, + onload_self=True, + group_id=f"{module.__class__.__name__}_unmatched_group", + ) + if config.stream is None: + _apply_group_offloading_hook(module, unmatched_group, config=config) + else: + _apply_lazy_group_offloading_hook(module, unmatched_group, config=config) + + +def _apply_block_offloading_to_submodule( + submodule: torch.nn.Module, + name: str, + config: GroupOffloadingConfig, + modules_with_group_offloading: Set[str], + matched_module_groups: List[ModuleGroup], +) -> None: + r""" + Apply block offloading to a explicitly defined submodule. This function either: + 1. Offloads the entire submodule as a single group ( SIMPLE APPROACH) + 2. Recursively applies block offloading to the submodule + + For now, we use the simple approach - offload the entire submodule as a single group. + """ + # Simple approach: offload the entire submodule as a single group + # Since AEs are typically small, this is usually okay + if isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)): + # If it's a ModuleList or Sequential, apply the normal block-level logic for i in range(0, len(submodule), config.num_blocks_per_group): - current_modules = submodule[i : i + config.num_blocks_per_group] + current_modules = list(submodule[i : i + config.num_blocks_per_group]) + if len(current_modules) == 0: + continue + group_id = f"{name}_{i}_{i + len(current_modules) - 1}" group = ModuleGroup( modules=current_modules, @@ -616,42 +815,24 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf matched_module_groups.append(group) for j in range(i, i + len(current_modules)): modules_with_group_offloading.add(f"{name}.{j}") - - # Apply group offloading hooks to the module groups - for i, group in enumerate(matched_module_groups): - for group_module in group.modules: - _apply_group_offloading_hook(group_module, group, config=config) - - # Parameters and Buffers of the top-level module need to be offloaded/onloaded separately - # when the forward pass of this module is called. This is because the top-level module is not - # part of any group (as doing so would lead to no VRAM savings). - parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading) - buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading) - parameters = [param for _, param in parameters] - buffers = [buffer for _, buffer in buffers] - - # Create a group for the unmatched submodules of the top-level module so that they are on the correct - # device when the forward pass is called. - unmatched_modules = [unmatched_module for _, unmatched_module in unmatched_modules] - unmatched_group = ModuleGroup( - modules=unmatched_modules, - offload_device=config.offload_device, - onload_device=config.onload_device, - offload_to_disk_path=config.offload_to_disk_path, - offload_leader=module, - onload_leader=module, - parameters=parameters, - buffers=buffers, - non_blocking=False, - stream=None, - record_stream=False, - onload_self=True, - group_id=f"{module.__class__.__name__}_unmatched_group", - ) - if config.stream is None: - _apply_group_offloading_hook(module, unmatched_group, config=config) else: - _apply_lazy_group_offloading_hook(module, unmatched_group, config=config) + # For other modules, treat the entire submodule as a single group + group = ModuleGroup( + modules=[submodule], + offload_device=config.offload_device, + onload_device=config.onload_device, + offload_to_disk_path=config.offload_to_disk_path, + offload_leader=submodule, + onload_leader=submodule, + non_blocking=config.non_blocking, + stream=config.stream, + record_stream=config.record_stream, + low_cpu_mem_usage=config.low_cpu_mem_usage, + onload_self=True, + group_id=name, + ) + matched_module_groups.append(group) + modules_with_group_offloading.add(name) def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None: @@ -780,8 +961,8 @@ def _apply_lazy_group_offloading_hook( if registry.get_hook(_GROUP_OFFLOADING) is None: hook = GroupOffloadingHook(group, config=config) registry.register_hook(hook, _GROUP_OFFLOADING) - - lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook() + + lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook(pin_groups = config.pin_groups) registry.register_hook(lazy_prefetch_hook, _LAZY_PREFETCH_GROUP_OFFLOADING) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl.py b/src/diffusers/models/autoencoders/autoencoder_kl.py index ffc8778e7aca..4096b7c07609 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl.py @@ -72,6 +72,7 @@ class AutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModel _supports_gradient_checkpointing = True _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D"] + _group_offload_block_modules = ["quant_conv", "post_quant_conv", "encoder", "decoder"] @register_to_config def __init__( diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py index b0b2960aaf18..6b29a6273cd9 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py @@ -964,6 +964,7 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo # keys toignore when AlignDeviceHook moves inputs/outputs between devices # these are shared mutable state modified in-place _skip_keys = ["feat_cache", "feat_idx"] + _group_offload_block_modules = ["quant_conv", "post_quant_conv", "encoder", "decoder"] @register_to_config def __init__( diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index f06822c741ca..3263be4e046e 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -531,6 +531,7 @@ def enable_group_offload( record_stream: bool = False, low_cpu_mem_usage=False, offload_to_disk_path: Optional[str] = None, + pin_groups: Optional[Union[str, Callable]] = None ) -> None: r""" Activates group offloading for the current model. @@ -570,6 +571,7 @@ def enable_group_offload( f"`_supports_group_offloading` to `True` in the class definition. If you believe this is a mistake, please " f"open an issue at https://github.com/huggingface/diffusers/issues." ) + block_modules = getattr(self, "_group_offload_block_modules", None) apply_group_offloading( module=self, onload_device=onload_device, @@ -581,6 +583,8 @@ def enable_group_offload( record_stream=record_stream, low_cpu_mem_usage=low_cpu_mem_usage, offload_to_disk_path=offload_to_disk_path, + block_modules=block_modules, + pin_groups=pin_groups ) def set_attention_backend(self, backend: str) -> None: diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 392d5fb3feb4..a8084688d498 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -1342,6 +1342,7 @@ def enable_group_offload( low_cpu_mem_usage=False, offload_to_disk_path: Optional[str] = None, exclude_modules: Optional[Union[str, List[str]]] = None, + pin_groups: Optional[Union[str, Callable]] = None, ) -> None: r""" Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, @@ -1402,6 +1403,9 @@ def enable_group_offload( This option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when the CPU memory is a bottleneck but may counteract the benefits of using streams. exclude_modules (`Union[str, List[str]]`, defaults to `None`): List of modules to exclude from offloading. + pin_groups (`\"first_last\"` | `\"all\"` | `Callable`, *optional*): + Optionally keep selected groups on the onload device permanently. See `ModelMixin.enable_group_offload` + for details. Example: ```python @@ -1442,6 +1446,7 @@ def enable_group_offload( "record_stream": record_stream, "low_cpu_mem_usage": low_cpu_mem_usage, "offload_to_disk_path": offload_to_disk_path, + "pin_groups": pin_groups, } for name, component in self.components.items(): if name not in exclude_modules and isinstance(component, torch.nn.Module): diff --git a/tests/hooks/test_group_offloading.py b/tests/hooks/test_group_offloading.py index 96cbecfbf530..58520bef9aa5 100644 --- a/tests/hooks/test_group_offloading.py +++ b/tests/hooks/test_group_offloading.py @@ -25,6 +25,8 @@ from diffusers.utils import get_logger from diffusers.utils.import_utils import compare_versions +from typing import Any, Iterable, List, Optional, Sequence, Union + from ..testing_utils import ( backend_empty_cache, backend_max_memory_allocated, @@ -147,6 +149,66 @@ def __init__(self): def post_forward(self, module, output): self.outputs.append(output) return output + + +# Test for https://github.com/huggingface/diffusers/pull/12747 +class DummyCallableBySubmodule: + """ + Callable group offloading pinner that pins first and last DummyBlock + called in the program by callable(submodule) + """ + def __init__(self, pin_targets: Iterable[torch.nn.Module]) -> None: + self.pin_targets = set(pin_targets) + self.calls_track = [] # testing only + + def __call__(self, submodule: torch.nn.Module) -> bool: + self.calls_track.append(submodule) + return self._normalize_module_type(submodule) in self.pin_targets + + def _normalize_module_type(self, obj: Any) -> Optional[torch.nn.Module]: + # group might be a single module, or a container of modules + # The group-offloading code may pass either: + # - a single `torch.nn.Module`, or + # - a container (list/tuple) of modules. + + # Only return a module when the mapping is unambiguous: + # - if `obj` is a module -> return it + # - if `obj` is a list/tuple containing exactly one module -> return that module + # - otherwise -> return None (won't be considered as a target candidate) + if isinstance(obj, torch.nn.Module): + return obj + if isinstance(obj, (list, tuple)): + mods = [m for m in obj if isinstance(m, torch.nn.Module)] + return mods[0] if len(mods) == 1 else None + return None + +class DummyCallableByNameSubmodule(DummyCallableBySubmodule): + """ + Callable group offloading pinner that pins first and last DummyBlock + Same behaviour with DummyCallableBySubmodule, only with different call signature + called in the program by callable(name, submodule) + """ + def __call__(self, name: str, submodule: torch.nn.Module) -> bool: + self.calls_track.append((name, submodule)) + return self._normalize_module_type(submodule) in self.pin_targets + +class DummyCallableByNameSubmoduleIdx(DummyCallableBySubmodule): + """ + Callable group offloading pinner that pins first and last DummyBlock. + Same behaviour with DummyCallableBySubmodule, only with different call signature + Called in the program by callable(name, submodule, idx) + """ + def __call__(self, name: str, submodule: torch.nn.Module, idx: int) -> bool: + self.calls_track.append((name, submodule, idx)) + return self._normalize_module_type(submodule) in self.pin_targets + +class DummyInvalidCallable(DummyCallableBySubmodule): + """ + Callable group offloading pinner that uses invalid call signature + """ + def __call__(self, name: str, submodule: torch.nn.Module, idx: int, extra: Any) -> bool: + self.calls_track.append((name, submodule, idx, extra)) + return self._normalize_module_type(submodule) in self.pin_targets @require_torch_accelerator @@ -362,3 +424,160 @@ def apply_layer_output_tracker_hook(model: DummyModelWithLayerNorm): self.assertLess( cumulated_absmax, 1e-5, f"Output differences for {name} exceeded threshold: {cumulated_absmax:.5f}" ) + + def test_block_level_offloading_with_pin_groups_stay_on_device(self): + if torch.device(torch_device).type not in ["cuda", "xpu"]: + return + + def assert_all_modules_on_expected_device(modules: Sequence[torch.nn.Module], + expected_device: Union[torch.device, str], + header_error_msg: str = "") -> None: + def first_param_device(modules: torch.nn.Module) -> torch.device: + p = next(modules.parameters(), None) + self.assertIsNotNone(p, f"No parameters found for module {modules}") + return p.device + + if isinstance(expected_device, torch.device): + expected_device = expected_device.type + + bad = [] + for i, m in enumerate(modules): + dev_type = first_param_device(m).type + if dev_type != expected_device: + bad.append((i, m.__class__.__name__, dev_type)) + self.assertTrue( + len(bad) == 0, + (header_error_msg + "\n" if header_error_msg else "") + + f"Expected all modules on {expected_device}, but found mismatches: {bad}", + ) + + def get_param_modules_from_execution_order(model: DummyModel) -> List[torch.nn.Module]: + model.eval() + root_registry = HookRegistry.check_if_exists_or_initialize(model) + + lazy_hook = root_registry.get_hook("lazy_prefetch_group_offloading") + self.assertIsNotNone(lazy_hook, "lazy_prefetch_group_offloading hook was not registered") + + #record execution order with first forward + with torch.no_grad(): + model(self.input) + + mods = [m for _, m in lazy_hook.execution_order] + param_modules = [m for m in mods if next(m.parameters(), None) is not None] + return param_modules + + def assert_callables_offloading_tests( + param_modules: Sequence[torch.nn.Module], + callable: Any, + header_error_msg: str = "", + ) -> None: + pinned_modules = [m for m in param_modules if m in callable.pin_targets] + unpinned_modules = [m for m in param_modules if m not in callable.pin_targets] + self.assertTrue(len(callable.calls_track) > 0, f"{header_error_msg}: callable should have been called at least once") + assert_all_modules_on_expected_device(pinned_modules, torch_device, f"{header_error_msg}: pinned blocks should stay on device") + assert_all_modules_on_expected_device(unpinned_modules, "cpu", f"{header_error_msg}: unpinned blocks should be offloaded") + + + default_parameters = { + "onload_device": torch_device, + "offload_type": "block_level", + "num_blocks_per_group": 1, + "use_stream": True, + } + model_default_no_pin = self.get_model() + model_default_no_pin.enable_group_offload( + **default_parameters + ) + param_modules = get_param_modules_from_execution_order(model_default_no_pin) + assert_all_modules_on_expected_device(param_modules, + expected_device="cpu", + header_error_msg="default pin_groups: expected ALL modules offloaded to CPU") + + model_pin_all = self.get_model() + model_pin_all.enable_group_offload( + **default_parameters, + pin_groups="all", + ) + param_modules = get_param_modules_from_execution_order(model_pin_all) + assert_all_modules_on_expected_device(param_modules, + expected_device=torch_device, + header_error_msg="pin_groups = all: expected ALL layers on accelerator device") + + + model_pin_first_last = self.get_model() + model_pin_first_last.enable_group_offload( + **default_parameters, + pin_groups="first_last", + ) + param_modules = get_param_modules_from_execution_order(model_pin_first_last) + assert_all_modules_on_expected_device([param_modules[0], param_modules[-1]], + expected_device=torch_device, + header_error_msg="pin_groups = first_last: expected first and last layers on accelerator device") + assert_all_modules_on_expected_device(param_modules[1:-1], + expected_device="cpu", + header_error_msg="pin_groups = first_last: expected ALL middle layers offloaded to CPU") + + + model = self.get_model() + callable_by_submodule = DummyCallableBySubmodule(pin_targets=[model.blocks[0], model.blocks[-1]]) + model.enable_group_offload(**default_parameters, + pin_groups=callable_by_submodule) + param_modules = get_param_modules_from_execution_order(model) + assert_callables_offloading_tests(param_modules, + callable_by_submodule, + header_error_msg="pin_groups with callable(submodule)") + + model = self.get_model() + callable_by_name_submodule = DummyCallableByNameSubmodule(pin_targets=[model.blocks[0], model.blocks[-1]]) + model.enable_group_offload(**default_parameters, + pin_groups=callable_by_name_submodule) + param_modules = get_param_modules_from_execution_order(model) + assert_callables_offloading_tests(param_modules, + callable_by_name_submodule, + header_error_msg="pin_groups with callable(name, submodule)") + + model = self.get_model() + callable_by_name_submodule_idx = DummyCallableByNameSubmoduleIdx(pin_targets=[model.blocks[0], model.blocks[-1]]) + model.enable_group_offload(**default_parameters, + pin_groups=callable_by_name_submodule_idx) + param_modules = get_param_modules_from_execution_order(model) + assert_callables_offloading_tests(param_modules, + callable_by_name_submodule_idx, + header_error_msg="pin_groups with callable(name, submodule, idx)") + + def test_error_raised_if_pin_groups_received_invalid_value(self): + default_parameters = { + "onload_device": torch_device, + "offload_type": "block_level", + "num_blocks_per_group": 1, + "use_stream": True, + } + model = self.get_model() + with self.assertRaisesRegex(ValueError, + "`pin_groups` must be one of `None`, 'first_last', 'all', or a callable."): + model.enable_group_offload( + **default_parameters, + pin_groups="invalid value", + ) + + def test_error_raised_if_pin_groups_received_invalid_callables(self): + default_parameters = { + "onload_device": torch_device, + "offload_type": "block_level", + "num_blocks_per_group": 1, + "use_stream": True, + } + model = self.get_model() + invalid_callable = DummyInvalidCallable(pin_targets=[model.blocks[0], model.blocks[-1]]) + model.enable_group_offload( + **default_parameters, + pin_groups=invalid_callable, + ) + with self.assertRaisesRegex(TypeError, + r"missing\s+\d+\s+required\s+positional\s+argument(s)?:"): + with torch.no_grad(): + model(self.input) + + + + \ No newline at end of file