diff --git a/auto_round/auto_scheme/delta_loss.py b/auto_round/auto_scheme/delta_loss.py index 58b2bae3c..8d95828f1 100644 --- a/auto_round/auto_scheme/delta_loss.py +++ b/auto_round/auto_scheme/delta_loss.py @@ -13,12 +13,17 @@ # limitations under the License. import copy +import gc +import os +import shutil +import tempfile from dataclasses import asdict from functools import wraps -from typing import Iterable, Union +from typing import Any, Iterable, Optional, Union import torch from accelerate import dispatch_model +from safetensors.torch import load_file, save_file from tqdm import tqdm from auto_round.auto_scheme.gen_auto_scheme import AutoScheme @@ -55,11 +60,264 @@ set_non_auto_device_map, to_device, ) +from auto_round.utils.device import MemoryMonitor from auto_round.wrapper import WrapperLinear __all__ = ["gen_layer_config"] +def _group_layers_by_block(quant_layer_names, block_names): + """Group quantization layer names by their containing block.""" + groups = {bn: [] for bn in block_names} + non_block = [] + for name in quant_layer_names: + matched = False + for bn in block_names: + if name.startswith(bn + "."): + groups[bn].append(name) + matched = True + break + if not matched: + non_block.append(name) + return groups, non_block + + +# ============================================================================ +# CPU RAM Offload Management for AutoScheme +# ============================================================================ + + +class AutoSchemeOffloadContext: + """Manages disk offload state for AutoScheme to reduce CPU RAM usage. + + Maintains two separate on-disk stores: + - *original* weights (unwrapped, saved once at the start) + - *wrapped* weights (saved per-scheme iteration for forward/backward) + + This allows block weights to stay on disk between scheme iterations, + keeping only one block in RAM at a time. + """ + + def __init__(self, low_cpu_mem_usage: bool = False): + self.low_cpu_mem_usage = low_cpu_mem_usage + # Wrapped-state storage (changes every scheme iteration) + self._offload_tempdir: Optional[str] = None + self._offloaded_blocks: dict[str, dict] = {} + # Original-state storage (saved once, read many) + self._original_dir: Optional[str] = None + self._original_blocks: dict[str, dict] = {} + + # ------------------------------------------------------------------ + # Directory helpers + # ------------------------------------------------------------------ + def init_offload_dir(self) -> Optional[str]: + """Initialize the temporary directory for wrapped-state offload.""" + if not self.low_cpu_mem_usage: + return None + if self._offload_tempdir is None: + self._offload_tempdir = tempfile.mkdtemp(prefix="autoscheme_offload_") + logger.info(f"AutoScheme CPU offload directory: {self._offload_tempdir}") + return self._offload_tempdir + + def _init_original_dir(self) -> str: + """Initialize the temporary directory for original-weight storage.""" + if self._original_dir is None: + self._original_dir = tempfile.mkdtemp(prefix="autoscheme_original_") + return self._original_dir + + # ------------------------------------------------------------------ + # Original (unwrapped) weight management — saved once at start + # ------------------------------------------------------------------ + def save_original_block_weights(self, block_name: str, block: torch.nn.Module) -> None: + """Save original (unwrapped) block weights to disk. Skips if already saved.""" + if not self.low_cpu_mem_usage: + return + if block_name in self._original_blocks: + return + orig_dir = self._init_original_dir() + safe_name = block_name.replace(".", "_") + save_path = os.path.join(orig_dir, f"{safe_name}.safetensors") + try: + state_dict = {k: v.cpu().contiguous() for k, v in block.state_dict().items()} + save_file(state_dict, save_path) + self._original_blocks[block_name] = {"save_path": save_path} + del state_dict + except Exception as e: + logger.warning(f"Failed to save original block {block_name}: {e}") + + def _load_state_into_block(self, save_path: str, block: torch.nn.Module) -> None: + """Low-level helper: load a safetensors file into *block*.""" + state_dict = load_file(save_path, device="cpu") + for name, param in state_dict.items(): + parts = name.split(".") + target = block + try: + for part in parts[:-1]: + target = getattr(target, part) + except AttributeError: + continue # key belongs to a different module tree (e.g. wrapper vs orig) + param_name = parts[-1] + if hasattr(target, param_name): + old_param = getattr(target, param_name) + if isinstance(old_param, torch.nn.Parameter): + setattr(target, param_name, torch.nn.Parameter(param, requires_grad=old_param.requires_grad)) + else: + setattr(target, param_name, param) + del state_dict + + def load_original_block_weights(self, block_name: str, block: torch.nn.Module) -> None: + """Load original (unwrapped) weights from disk into *block*.""" + if not self.low_cpu_mem_usage: + return + metadata = self._original_blocks.get(block_name) + if not metadata: + return + save_path = metadata["save_path"] + if not os.path.exists(save_path): + logger.warning(f"Original weights not found: {save_path}") + return + try: + self._load_state_into_block(save_path, block) + except Exception as e: + logger.warning(f"Failed to load original block {block_name}: {e}") + + def save_and_clear_all_original_blocks(self, model: torch.nn.Module, block_names: list[str]) -> None: + """Save all original block weights to disk and clear them from RAM.""" + if not self.low_cpu_mem_usage: + return + logger.info("AutoScheme: saving original block weights to disk...") + for block_name in block_names: + block = get_module(model, block_name) + if block is not None: + self.save_original_block_weights(block_name, block) + for submodule in block.modules(): + _clear_module_weights(submodule) + gc.collect() + clear_memory() + logger.info("AutoScheme: original weights saved and cleared") + + def load_all_original_blocks(self, model: torch.nn.Module, block_names: list[str]) -> None: + """Load all original block weights back into RAM.""" + if not self.low_cpu_mem_usage: + return + for block_name in block_names: + block = get_module(model, block_name) + if block is not None: + self.load_original_block_weights(block_name, block) + + # ------------------------------------------------------------------ + # Wrapped-state management — re-saved each scheme iteration + # ------------------------------------------------------------------ + def offload_block_weights(self, block_name: str, block: torch.nn.Module) -> None: + """Offload a block's (possibly wrapped) weights to disk.""" + if not self.low_cpu_mem_usage: + return + offload_dir = self.init_offload_dir() + if offload_dir is None: + return + + safe_name = block_name.replace(".", "_") + save_path = os.path.join(offload_dir, f"{safe_name}.safetensors") + + already_saved = block_name in self._offloaded_blocks + + try: + if not already_saved: + state_dict = {k: v.cpu().contiguous() for k, v in block.state_dict().items()} + save_file(state_dict, save_path) + self._offloaded_blocks[block_name] = {"save_path": save_path} + del state_dict + + for submodule in block.modules(): + _clear_module_weights(submodule) + except Exception as e: + logger.warning(f"Failed to offload block {block_name}: {e}") + + def load_block_weights(self, block_name: str, block: torch.nn.Module) -> None: + """Load wrapped block weights from disk back into memory.""" + if not self.low_cpu_mem_usage: + return + metadata = self._offloaded_blocks.get(block_name) + if not metadata: + return + save_path = metadata.get("save_path") + if not save_path or not os.path.exists(save_path): + logger.warning(f"Cannot load block weights: file {save_path} does not exist") + return + try: + self._load_state_into_block(save_path, block) + except Exception as e: + logger.warning(f"Failed to load block weights from {save_path}: {e}") + + def offload_all_blocks(self, model: torch.nn.Module, block_names: list[str]) -> None: + """Offload all block weights to disk.""" + if not self.low_cpu_mem_usage: + return + logger.info("AutoScheme: offloading all block weights to disk for RAM optimization...") + for block_name in block_names: + block = get_module(model, block_name) + if block is not None: + self.offload_block_weights(block_name, block) + gc.collect() + clear_memory() + logger.info("AutoScheme: block weights offload complete") + + def reset_scheme_state(self) -> None: + """Clear wrapped-state tracking so the next scheme iteration re-saves.""" + self._offloaded_blocks = {} + + # ------------------------------------------------------------------ + # Cleanup + # ------------------------------------------------------------------ + def cleanup(self) -> None: + """Clean up all temporary directories.""" + for d in (self._offload_tempdir, self._original_dir): + if d and os.path.isdir(d): + try: + shutil.rmtree(d) + except Exception as e: + logger.warning(f"Failed to cleanup dir {d}: {e}") + self._offload_tempdir = None + self._offloaded_blocks = {} + self._original_dir = None + self._original_blocks = {} + + +def _clear_module_weights(module: torch.nn.Module) -> None: + """Clear module's weight and bias to free CPU RAM. + + Note: Skips WrapperLayer modules since their weight/bias are properties + that delegate to orig_layer. Clearing the actual orig_layer is sufficient. + Caches weight.numel() as ``_cached_weight_numel`` before clearing so that + ``compute_layer_bits`` can still compute correct results with empty tensors. + """ + if module is None: + return + # Skip WrapperLayer - its weight is a property, assigning to it would create + # an instance attribute shadowing the property. We clear orig_layer directly instead. + if hasattr(module, "orig_layer"): + return + with torch.no_grad(): + if hasattr(module, "weight") and module.weight is not None: + # Cache numel / shape before replacing with empty tensor + if module.weight.numel() > 0: + module._cached_weight_numel = module.weight.numel() + module._cached_weight_shape = tuple(module.weight.shape) + if isinstance(module.weight, torch.nn.Parameter): + module.weight = torch.nn.Parameter( + torch.empty(0, dtype=module.weight.dtype, device="cpu"), requires_grad=module.weight.requires_grad + ) + else: + module.weight = torch.empty(0, dtype=module.weight.dtype, device="cpu") + if hasattr(module, "bias") and module.bias is not None: + if isinstance(module.bias, torch.nn.Parameter): + module.bias = torch.nn.Parameter( + torch.empty(0, dtype=module.bias.dtype, device="cpu"), requires_grad=module.bias.requires_grad + ) + else: + module.bias = torch.empty(0, dtype=module.bias.dtype, device="cpu") + + class AutoSchemeWrapperLinear(WrapperLinear): def __init__( self, @@ -322,7 +580,13 @@ def __init__(self, message): last_grad_input = None -def prepare_model_low_gpu(model, block_inputs: dict = None, pbar=None, major_device="cpu"): +def prepare_model_low_gpu( + model, + block_inputs: dict = None, + pbar=None, + major_device="cpu", + offload_context: Optional[AutoSchemeOffloadContext] = None, +): block_inputs.clear() for n, m in model.named_modules(): if hasattr(m, "grad_mode"): @@ -335,6 +599,10 @@ def wrap_forward(module, module_name): @wraps(original_forward) def new_forward(*args, **kwargs): + # Load block weights from disk if using low_cpu_mem_usage + if offload_context is not None: + offload_context.load_block_weights(module_name, module) + move_module_to_tuning_device(module, major_device=major_device) # Call the original forward @@ -353,6 +621,10 @@ def new_forward(*args, **kwargs): module.to("cpu") # torch.cuda.empty_cache() + # Offload block weights back to disk if using low_cpu_mem_usage + if offload_context is not None: + offload_context.offload_block_weights(module_name, module) + # Enable gradients for the output of the last block if module.tmp_name == block_names[-1]: if isinstance(result, torch.Tensor): @@ -375,7 +647,13 @@ def new_forward(*args, **kwargs): module.forward = wrap_forward(module, block_name) -def model_forward_low_gpu(model, dataloader, major_device="cuda", pbar=None): +def model_forward_low_gpu( + model, + dataloader, + major_device="cuda", + pbar=None, + offload_context: Optional[AutoSchemeOffloadContext] = None, +): block_inputs = {} block_names = get_block_names(model)[0] @@ -392,7 +670,9 @@ def backward_pre_hook(module, grad_input): raise MyCustomError("Interrupt backward pass") for data in dataloader: - prepare_model_low_gpu(model, block_inputs, major_device=major_device, pbar=pbar) + prepare_model_low_gpu( + model, block_inputs, major_device=major_device, pbar=pbar, offload_context=offload_context + ) # Register backward hook on the last block last_block = get_module(model, block_names[-1]) @@ -427,6 +707,11 @@ def backward_pre_hook(module, grad_input): # Move the block module to GPU block_module = get_module(model, block_name) + + # Load block weights from disk if offloaded (for backward pass) + if offload_context is not None: + offload_context.load_block_weights(block_name, block_module) + for n, m in block_module.named_modules(): if hasattr(m, "grad_mode"): m.grad_mode = True @@ -468,6 +753,11 @@ def backward_pre_hook(module, grad_input): block_module.to("cpu") # clear_memory() + + # Offload block weights to disk if low_cpu_mem_usage is enabled (after backward) + if offload_context is not None: + offload_context.offload_block_weights(block_name, block_module) + pbar.update(1) @@ -485,9 +775,11 @@ def get_score_for_scheme( need_weight_grad=False, enable_torch_compile=False, low_gpu_mem_usage=True, + low_cpu_mem_usage=False, major_device="cpu", batch_size=1, disable_opt_rtn=True, + offload_context: Optional[AutoSchemeOffloadContext] = None, ): scores_dict = {} # Key=name,Val=[quant_total_bits, loss] for n, m in model.named_modules(): @@ -505,14 +797,14 @@ def get_score_for_scheme( has_imatrix = True break - for name in quant_layer_names: + def wrap_layer(name: str) -> None: if name in fixed_layer_scheme.keys(): - continue + return m = get_module(model, name) if not check_to_quantized(m): layer_bits, _ = compute_layer_bits(m, ignore_scale_zp_bits) scores_dict[name] = [layer_bits, 0.0] - continue + return if m.act_bits > 8 and m.super_bits is not None: m.scale_dtype = torch.float32 # TODO set this via API elif m.act_bits > 8: @@ -547,11 +839,43 @@ def get_score_for_scheme( disable_opt_rtn=disable_opt_rtn, ) set_module(model, name, new_m) + + # ------------------------------------------------------------------ + # Wrapping + forward/backward + # ------------------------------------------------------------------ if low_gpu_mem_usage: dataloader = get_dataloader(tokenizer, seqlen, dataset_name=dataset, seed=42, bs=batch_size, nsamples=nsamples) - model_forward_low_gpu(model, dataloader, major_device=major_device, pbar=pbar) + if offload_context is not None and low_cpu_mem_usage: + # Block-by-block wrapping: load original weights -> wrap -> offload wrapped state + block_names = get_block_names(model)[0] + layer_groups, non_block_layers = _group_layers_by_block(quant_layer_names, block_names) + offload_context.reset_scheme_state() + + for block_name in block_names: + block = get_module(model, block_name) + offload_context.load_original_block_weights(block_name, block) + for name in layer_groups.get(block_name, []): + wrap_layer(name) + offload_context.offload_block_weights(block_name, block) + + # Wrap layers that live outside of blocks (e.g. lm_head) + for name in non_block_layers: + wrap_layer(name) + + gc.collect() + clear_memory() + else: + for name in quant_layer_names: + wrap_layer(name) + + model_forward_low_gpu(model, dataloader, major_device=major_device, pbar=pbar, offload_context=offload_context) + + # NOTE: do NOT load all blocks back — scores are read block-by-block below else: + for name in quant_layer_names: + wrap_layer(name) + dataloader = get_dataloader(tokenizer, seqlen, dataset_name=dataset, seed=42, bs=batch_size, nsamples=nsamples) for data in dataloader: data = to_device(data, model.device) @@ -565,19 +889,72 @@ def get_score_for_scheme( for n, m in model.named_parameters(): m.grad = None - scores_dict = {} - for n, m in model.named_modules(): - if hasattr(m, "mix_score"): - if m.orig_layer.act_bits <= 8: - if m.act_cnt == 0: + # ------------------------------------------------------------------ + # Score reading + unwrapping + # ------------------------------------------------------------------ + if offload_context is not None and low_cpu_mem_usage: + # Block-by-block: load wrapped state -> read scores -> unwrap -> clear + scores_dict = {} + for block_name in block_names: + block = get_module(model, block_name) + offload_context.load_block_weights(block_name, block) + + # Read scores from wrapper attributes in this block + for n, m in block.named_modules(): + full_name = f"{block_name}.{n}" if n else block_name + if hasattr(m, "mix_score"): + if m.orig_layer.act_bits <= 8: + if m.act_cnt == 0: + logger.warning_once( + f"layer {full_name} max abs activation is 0, " + "please use more data to improve the accuracy" + ) + layer_bits, _ = compute_layer_bits(m.orig_layer, ignore_scale_zp_bits=ignore_scale_zp_bits) + scores_dict[full_name] = [layer_bits, m.mix_score] + + # Unwrap layers in this block + unwrap_pairs = [] + for n, m in block.named_modules(): + full_name = f"{block_name}.{n}" if n else block_name + if hasattr(m, "orig_layer"): + unwrap_pairs.append((full_name, m.orig_layer)) + for full_name, orig_layer in unwrap_pairs: + set_module(model, full_name, orig_layer) + + # Clear weights so this block no longer occupies RAM + block = get_module(model, block_name) + for submodule in block.modules(): + _clear_module_weights(submodule) + + # Handle non-block layers + for n, m in model.named_modules(): + if hasattr(m, "mix_score") and n not in scores_dict: + if m.orig_layer.act_bits <= 8 and m.act_cnt == 0: logger.warning_once( - "layer{n} max abs activation is 0, please use more data to improve the accuracy" + f"layer {n} max abs activation is 0, " "please use more data to improve the accuracy" ) - layer_bits, _ = compute_layer_bits(m.orig_layer, ignore_scale_zp_bits=ignore_scale_zp_bits) - scores_dict[n] = [layer_bits, m.mix_score] - for n, m in model.named_modules(): - if hasattr(m, "orig_layer"): - set_module(model, n, m.orig_layer) + layer_bits, _ = compute_layer_bits(m.orig_layer, ignore_scale_zp_bits=ignore_scale_zp_bits) + scores_dict[n] = [layer_bits, m.mix_score] + for n, m in model.named_modules(): + if hasattr(m, "orig_layer"): + set_module(model, n, m.orig_layer) + + gc.collect() + clear_memory() + else: + scores_dict = {} + for n, m in model.named_modules(): + if hasattr(m, "mix_score"): + if m.orig_layer.act_bits <= 8: + if m.act_cnt == 0: + logger.warning_once( + f"layer {n} max abs activation is 0, " "please use more data to improve the accuracy" + ) + layer_bits, _ = compute_layer_bits(m.orig_layer, ignore_scale_zp_bits=ignore_scale_zp_bits) + scores_dict[n] = [layer_bits, m.mix_score] + for n, m in model.named_modules(): + if hasattr(m, "orig_layer"): + set_module(model, n, m.orig_layer) return scores_dict @@ -593,44 +970,72 @@ def choose_bits_per_layer_with_path(layers: dict, P: int): (layer_names, scheme) for each layer, or (None, None) if no feasible solution exists. """ - # dp: total_params -> (accumulated_loss, chosen_path) - # The path explicitly stores the selected options. - dp: dict[int, tuple[float, list]] = {0: (0.0, [])} - - for layer_name, opts in layers.items(): - new_dp: dict[int, tuple[float, list]] = {} - for cur_params, (cur_loss, cur_path) in dp.items(): - for opt in opts: + # dp: total_params -> accumulated_loss + # Use backtracking pointers instead of storing full paths to avoid + # O(layers) list copies per state transition. + dp: dict[int, float] = {0: 0.0} + + layer_list = list(layers.items()) + total_layers = len(layer_list) + logger.info(f"Starting DP for {total_layers} layers, budget P={P}") + + # history[layer_idx][params] = (prev_params, opt_idx) for path reconstruction + history: list[dict[int, tuple[int, int]]] = [] + + pbar = tqdm(range(total_layers), desc="DP bit allocation", leave=True) + for idx in pbar: + layer_name, opts = layer_list[idx] + pbar.set_postfix(dp_states=len(dp)) + + new_dp: dict[int, float] = {} + new_bt: dict[int, tuple[int, int]] = {} + + for cur_params, cur_loss in dp.items(): + for opt_idx, opt in enumerate(opts): scheme, bits_cost, loss_cost, layer_names = opt np_total = cur_params + bits_cost if np_total > P: continue new_loss = cur_loss + loss_cost - new_path = cur_path + [(layer_names, scheme)] - # Keep the path with smaller loss for the same parameter budget - if np_total not in new_dp or new_loss < new_dp[np_total][0]: - new_dp[np_total] = (new_loss, new_path) + # Keep the option with smaller loss for the same parameter budget + if np_total not in new_dp or new_loss < new_dp[np_total]: + new_dp[np_total] = new_loss + new_bt[np_total] = (cur_params, opt_idx) if not new_dp: return None, None # Pareto pruning: remove dominated (params, loss) states - items = sorted(new_dp.items(), key=lambda x: x[0]) # (params, (loss, path)) - pruned: dict[int, tuple[float, list]] = {} + items = sorted(new_dp.items(), key=lambda x: x[0]) # sort by params + pruned_dp: dict[int, float] = {} + pruned_bt: dict[int, tuple[int, int]] = {} best_loss_so_far = float("inf") - for params_val, (loss_val, path_val) in items: + for params_val, loss_val in items: if loss_val < best_loss_so_far: - pruned[params_val] = (loss_val, path_val) + pruned_dp[params_val] = loss_val + pruned_bt[params_val] = new_bt[params_val] best_loss_so_far = loss_val - dp = pruned + dp = pruned_dp + history.append(pruned_bt) # Select the solution with the minimum loss - best_params = min(dp.keys(), key=lambda k: dp[k][0]) - best_loss, best_path = dp[best_params] - return best_loss, best_path + best_params = min(dp.keys(), key=lambda k: dp[k]) + best_loss = dp[best_params] + + # Backtrack to reconstruct the chosen path + path = [] + cur_params = best_params + for layer_idx in range(total_layers - 1, -1, -1): + prev_params, opt_idx = history[layer_idx][cur_params] + scheme, bits_cost, loss_cost, layer_names = layer_list[layer_idx][1][opt_idx] + path.append((layer_names, scheme)) + cur_params = prev_params + path.reverse() + + return best_loss, path def move_module_to_tuning_device(module, major_device="cpu"): @@ -658,6 +1063,18 @@ def _gen_layer_config( major_device="cpu", device_list=None, ): + # Initialize memory tracking for AutoScheme + memory_monitor = MemoryMonitor() + memory_monitor.reset() + memory_monitor.update_cpu() + + # Create offload context for CPU RAM optimization + # Note: low_cpu_mem_usage only works when low_gpu_mem_usage is also enabled, + # because disk offloading requires layer-by-layer processing + offload_context = None + if auto_scheme.low_cpu_mem_usage and auto_scheme.low_gpu_mem_usage: + offload_context = AutoSchemeOffloadContext(low_cpu_mem_usage=True) + target_bits = auto_scheme.avg_bits model.eval() @@ -762,6 +1179,11 @@ def check_bf16_scheme(scheme): cal_imatrix(model, dataloader) logger.info("finish calculating imatrix") + # Offload all original block weights to disk before the scheme loop + # so that only one block needs to be in RAM at a time during scoring. + if offload_context is not None: + offload_context.save_and_clear_all_original_blocks(model, block_name) + pbar = tqdm(total=pbar_cnt, desc="Generating AutoScheme") for index, scheme in enumerate(schemes): apply_quant_scheme( @@ -789,10 +1211,15 @@ def check_bf16_scheme(scheme): need_weight_grad=need_weight_grad, enable_torch_compile=enable_torch_compile, low_gpu_mem_usage=auto_scheme.low_gpu_mem_usage, + low_cpu_mem_usage=auto_scheme.low_cpu_mem_usage, major_device=major_device, batch_size=batch_size, disable_opt_rtn=auto_scheme.disable_opt_rtn, + offload_context=offload_context, ) + # Track peak RAM after each scheme scoring + memory_monitor.update_cpu() + new_scores = {} for share_layer in shared_layers: param_bits = 0 @@ -817,10 +1244,17 @@ def check_bf16_scheme(scheme): options_scores.append(options_total_loss) clear_memory(device_list=device_list) + # Restore original weights from disk for final bit-budget computations + if offload_context is not None: + offload_context.load_all_original_blocks(model, block_name) + total_params = 0 for n, m in model.named_modules(): if n in quant_layer_names + embedding_layers_names: - total_params += m.weight.numel() + n_param = m.weight.numel() + if n_param == 0 and hasattr(m, "_cached_weight_numel"): + n_param = m._cached_weight_numel + total_params += n_param target_params_cnt = int(total_params * target_bits) sorted_indices = sorted(range(len(options_scores)), key=lambda i: options_scores[i]) @@ -901,7 +1335,18 @@ def check_bf16_scheme(scheme): m.grad = None global last_grad_input last_grad_input = None + + # Cleanup offload context + if offload_context is not None: + offload_context.cleanup() + clear_memory(device_list=device_list) + + # Log AutoScheme memory usage + memory_monitor.update_cpu() + low_cpu_str = "enabled" if auto_scheme.low_cpu_mem_usage else "disabled" + memory_monitor.log_summary(f"AutoScheme complete (low_cpu_mem_usage={low_cpu_str})") + pbar.close() return layer_config diff --git a/auto_round/auto_scheme/gen_auto_scheme.py b/auto_round/auto_scheme/gen_auto_scheme.py index 8572a920a..d9d65c645 100644 --- a/auto_round/auto_scheme/gen_auto_scheme.py +++ b/auto_round/auto_scheme/gen_auto_scheme.py @@ -40,6 +40,7 @@ class AutoScheme: enable_torch_compile: Optional[bool] = None disable_opt_rtn: bool = True low_gpu_mem_usage: bool = True + low_cpu_mem_usage: bool = False # Enable disk offload for CPU RAM optimization def __post_init__(self): if isinstance(self.options, str): diff --git a/auto_round/auto_scheme/utils.py b/auto_round/auto_scheme/utils.py index 3c19acc42..57ac3c96a 100644 --- a/auto_round/auto_scheme/utils.py +++ b/auto_round/auto_scheme/utils.py @@ -105,7 +105,10 @@ def compute_avg_bits_for_scheme( # continue if not hasattr(module, "weight"): continue - total_params += module.weight.numel() + n_param = module.weight.numel() + if n_param == 0 and hasattr(module, "_cached_weight_numel"): + n_param = module._cached_weight_numel + total_params += n_param layer_bits, _ = compute_layer_bits(module, ignore_scale_zp_bits) total_quantized_bits += layer_bits avg_bits = float(total_quantized_bits) / total_params @@ -133,7 +136,10 @@ def compute_avg_bits_for_model(model: torch.nn.Module, ignore_scale_zp_bits: boo continue if not hasattr(module, "weight"): continue - total_params += module.weight.numel() + n_param = module.weight.numel() + if n_param == 0 and hasattr(module, "_cached_weight_numel"): + n_param = module._cached_weight_numel + total_params += n_param layer_bits, _ = compute_layer_bits(module, ignore_scale_zp_bits) total_quantized_bits += layer_bits @@ -157,6 +163,9 @@ def compute_layer_bits( """ weight = layer.weight n_param = weight.numel() + # Use cached numel when weight has been cleared to an empty tensor (low_cpu_mem_usage offload) + if n_param == 0 and hasattr(layer, "_cached_weight_numel"): + n_param = layer._cached_weight_numel weight_bits = getattr(layer, "bits", 16) group_size = getattr(layer, "group_size", 128) data_type = getattr(layer, "data_type", "int") diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index 974a055cd..da79de56e 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -14,6 +14,7 @@ import copy import os +import re import sys import time import traceback @@ -94,6 +95,7 @@ is_moe_model, is_quantized_input_module, llm_load_model, + load_module_weights, memory_monitor, mv_module_from_gpu, set_amax_for_all_moe_layers, @@ -104,12 +106,25 @@ ) from auto_round.utils.device import ( clear_memory_if_reached_threshold, + estimate_inputs_size_gb, + estimate_model_size_gb, + estimate_tensor_size_gb, get_major_device, parse_available_devices, set_auto_device_map_for_block_with_tuning, set_non_auto_device_map, ) from auto_round.utils.distributed import setup_ddp_if_needed_ +from auto_round.utils.model import ( + cleanup_cpu_offload_dir, + discard_offloaded_block, + estimate_block_size_gb, + init_cpu_offload_dir, + load_offloaded_block_weights, + offload_block_weights, + restore_offloaded_blocks, + stream_offload_blocks, +) from auto_round.wrapper import WrapperLinear, WrapperMultiblock, unwrapper_block, unwrapper_layer, wrapper_block SERIALIZATION_KEYS = ( @@ -356,6 +371,10 @@ def __init__( self.inner_supported_types = INNER_SUPPORTED_LAYER_TYPES self.scale_dtype = convert_dtype_str2torch(scale_dtype) self.low_cpu_mem_usage = low_cpu_mem_usage + self.cpu_stream_offload_blocks = kwargs.pop("cpu_stream_offload_blocks", False) + self.cpu_stream_loss = kwargs.pop("cpu_stream_loss", False) + self._cpu_offload_tempdir = None + self._offloaded_blocks = {} if kwargs: logger.warning(f"unrecognized keys {list(kwargs.keys())} were passed. Please check them.") @@ -952,7 +971,6 @@ def quantize_and_save( output_dir, format=self.formats, inplace=inplace, return_folders=True, **kwargs ) memory_monitor.log_summary() - return model, folders def _get_save_folder_name(self, format: OutputFormat) -> str: @@ -1534,6 +1552,8 @@ def _quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str]) if not self.is_immediate_saving: # some modules may have been flushed and set to meta, so we could not move to gpu mv_module_from_gpu(block) + if self.low_cpu_mem_usage: + offload_block_weights(self, block_name, block) if block_name == block_names[-1]: clear_memory(input_ids, device_list=self.device_list) else: @@ -1700,6 +1720,9 @@ def quantize(self) -> tuple[torch.nn.Module, dict[str, Any]]: self.configure_layer_config(enable_gguf_official_mixed=enable_gguf_official_mixed) + if self.low_cpu_mem_usage: + self._offloaded_blocks = {} + def _should_disable_inplace_due_to_layers_outside_block() -> bool: return self.has_qlayer_outside_block and (self.iters != 0 or (self.iters == 0 and not self.disable_opt_rtn)) @@ -1737,9 +1760,14 @@ def _should_disable_inplace_due_to_layers_outside_block() -> bool: ) else: logger.info("start to cache block inputs") + all_inputs = self.try_cache_inter_data_gpucpu(all_first_block_names, self.nsamples, layer_names=layer_names) is_quantized_embedding = self._quantize_embedding_layer() clear_memory(device_list=self.device_list) + # Log memory breakdown for calibration inputs + if self.low_cpu_mem_usage: + inputs_size_gb = estimate_inputs_size_gb(all_inputs) + logger.info(f"[Memory] calibration inputs size: {inputs_size_gb:.2f} GB") all_q_inputs = None if is_quantized_embedding: all_inputs = copy.deepcopy(self.inputs) @@ -1752,6 +1780,12 @@ def _should_disable_inplace_due_to_layers_outside_block() -> bool: if hasattr(self.model, "hf_device_map") and len(self.model.hf_device_map) > 1: accelerate.hooks.remove_hook_from_submodules(self.model) # self.model.hf_device_map has not been changed logger.info("caching done") + # Log memory breakdown for model weights + if self.low_cpu_mem_usage: + model_size_gb = estimate_model_size_gb(self.model) + logger.info(f"[Memory] model weights size: {model_size_gb:.2f} GB") + if self.low_cpu_mem_usage and self.cpu_stream_offload_blocks: + stream_offload_blocks(self, all_blocks) if len(all_blocks) > 1: pbar = tqdm(range(0, sum([len(i) for i in all_blocks]), self.nblocks)) else: @@ -1797,6 +1831,10 @@ def _should_disable_inplace_due_to_layers_outside_block() -> bool: if self.is_immediate_saving: shard_writer(self, is_finalize=True) + if self.low_cpu_mem_usage: + restore_offloaded_blocks(self) + cleanup_cpu_offload_dir(self) + end_time = time.time() cost_time = end_time - start_time logger.info(f"quantization tuning time {cost_time}") @@ -1877,6 +1915,7 @@ def _quantize_layers(self, layer_names: list, layer_inputs: dict) -> None: memory_monitor.update() memory_monitor.log_summary() return + q_layer_inputs = None enable_quanted_input = self.enable_quanted_input has_gguf = False @@ -2706,6 +2745,29 @@ def _get_current_output(self, output: list[torch.Tensor], indices: list[int]) -> current_output = torch.cat(current_output, dim=self.batch_dim) return current_output + def _get_current_output_stream( + self, + block: torch.nn.Module, + input_ids: list[torch.Tensor], + input_others: dict, + indices: list[int], + device: str, + cache_device: str = "cpu", + ) -> torch.Tensor: + current_input_ids, current_input_others = self._sampling_inputs( + input_ids, + input_others, + indices, + seqlen=self.seqlen, + batch_dim=self.batch_dim, + share_cache_keys=self.shared_cache_keys, + ) + with torch.no_grad(): + output = self.block_forward( + block, current_input_ids, current_input_others, self.amp, self.amp_dtype, device + ) + return output.to(cache_device) + def _get_current_q_output( self, block: torch.nn.Module, @@ -2841,24 +2903,55 @@ def _quantize_block( hook = AlignDevicesHook(m.tuning_device, io_same_device=True) add_hook_to_module(m, hook, True) + stream_loss = self.cpu_stream_loss and self.nblocks == 1 + if self.cpu_stream_loss and self.nblocks != 1: + logger.warning("cpu_stream_loss only supports nblocks=1; falling back to cached outputs.") + stream_loss = False + if q_input is None: hook_handles = self._register_act_max_hook(block) - output = self._get_block_outputs( - block, input_ids, input_others, self.batch_size * self.infer_bs_coeff, device, self.cache_device - ) + if stream_loss: + output = None + self._get_block_outputs( + block, + input_ids, + input_others, + self.batch_size * self.infer_bs_coeff, + device, + self.cache_device, + save_output=False, + ) + else: + output = self._get_block_outputs( + block, input_ids, input_others, self.batch_size * self.infer_bs_coeff, device, self.cache_device + ) + + # Log output cache size for first block + if ( + self.low_cpu_mem_usage + and self.cpu_stream_offload_blocks + and not hasattr(self, "_logged_output_size") + ): + output_size = estimate_tensor_size_gb(output) + logger.info(f"[Memory] block output cache size: {output_size:.2f} GB") + self._logged_output_size = True for handle in hook_handles: handle.remove() else: - output = self._get_block_outputs( - block, input_ids, input_others, self.batch_size * self.infer_bs_coeff, device, self.cache_device - ) + if stream_loss: + output = None + # Skip pre-computation in stream_loss mode - targets will be computed on-the-fly with frozen_block + else: + output = self._get_block_outputs( + block, input_ids, input_others, self.batch_size * self.infer_bs_coeff, device, self.cache_device + ) hook_handles = self._register_act_max_hook(block) if hook_handles: self._get_block_outputs( block, - q_input, + q_input if q_input is not None else input_ids, input_others, self.batch_size * self.infer_bs_coeff, device, @@ -2876,6 +2969,13 @@ def _quantize_block( clear_memory(device_list=self.device_list) input_ids = q_input + frozen_block = None + if stream_loss: + frozen_block = copy.deepcopy(block).to(device) + frozen_block.eval() + for p in frozen_block.parameters(): + p.requires_grad_(False) + quantized_layer_names, unquantized_layer_names = self.wrapper_block( block, self.enable_minmax_tuning, @@ -2972,7 +3072,12 @@ def _quantize_block( for tmp_step in range(self.gradient_accumulate_steps): indices = global_indices[tmp_step * batch_size : (tmp_step + 1) * batch_size] - current_output = self._get_current_output(output, indices) + if stream_loss: + current_output = self._get_current_output_stream( + frozen_block, input_ids, input_others, indices, loss_device, cache_device=loss_device + ) + else: + current_output = self._get_current_output(output, indices) current_output = to_device(current_output, loss_device) output_q = self._get_current_q_output(block, input_ids, input_others, indices, device, loss_device) loss = self._get_loss(output_q, current_output, indices, mse_loss, device) @@ -3055,6 +3160,16 @@ def _quantize_block( return q_outputs, output else: + # When stream_loss is enabled, output is None, so we need to compute it for the next block + if stream_loss and output is None: + output = self._get_block_outputs( + block, + input_ids, + input_others, + self.batch_size * self.infer_bs_coeff, + device, + cache_device=self.cache_device, + ) if len(self.device_list) > 1 and auto_offload: accelerate.hooks.remove_hook_from_submodules(block) if auto_offload: @@ -3119,6 +3234,14 @@ def _quantize_blocks( input_ids, input_others = self._preprocess_block_inputs(inputs) + # Log detailed memory breakdown for first block + if self.low_cpu_mem_usage and self.cpu_stream_offload_blocks: + input_ids_size = estimate_tensor_size_gb(input_ids) + input_others_size = estimate_tensor_size_gb(input_others) + logger.info( + f"[Memory] input_ids size: {input_ids_size:.2f} GB, input_others size: {input_others_size:.2f} GB" + ) + if pbar is None: pbar = tqdm(range(0, len(block_names), nblocks)) @@ -3135,7 +3258,18 @@ def _quantize_blocks( modules = [get_module(model, n) for n in names] m = WrapperMultiblock(modules) + if self.low_cpu_mem_usage and self.cpu_stream_offload_blocks: + if nblocks == 1: + load_offloaded_block_weights(self, n, get_module(model, n)) + if i == 0: # Log only for first block + block_size = estimate_block_size_gb(get_module(model, n)) + logger.info(f"[Memory] loaded block weights size: {block_size:.2f} GB") + else: + for name in names: + load_offloaded_block_weights(self, name, get_module(model, name)) + m.config = model.config if hasattr(model, "config") else None + q_input, input_ids = self._quantize_block( m, input_ids, @@ -3143,6 +3277,19 @@ def _quantize_blocks( q_input=q_input, device=device, ) + + if self.low_cpu_mem_usage and not self.cpu_stream_offload_blocks: + if nblocks == 1: + offload_block_weights(self, n, get_module(model, n)) + else: + for name in names: + offload_block_weights(self, name, get_module(model, name)) + if self.low_cpu_mem_usage and self.cpu_stream_offload_blocks: + if nblocks == 1: + discard_offloaded_block(self, n) + else: + for name in names: + discard_offloaded_block(self, name) if hasattr(model, "config"): del m.config if self.is_immediate_packing: diff --git a/auto_round/utils/device.py b/auto_round/utils/device.py index a16f441bf..2b186529b 100644 --- a/auto_round/utils/device.py +++ b/auto_round/utils/device.py @@ -19,14 +19,19 @@ from functools import lru_cache from itertools import combinations from threading import Lock -from typing import Callable, Union +from typing import Any, Callable, Optional, Union import cpuinfo import psutil import torch from auto_round.logger import logger -from auto_round.utils.model import check_to_quantized, get_block_names, get_layer_features, get_module +from auto_round.utils.model import ( + check_to_quantized, + get_block_names, + get_layer_features, + get_module, +) # Note on HPU usage: # There are two modes available for enabling auto-round on HPU: @@ -501,6 +506,36 @@ def __call__( clear_memory = torch._dynamo.disable()(ClearMemory(device_list=[0])) +def estimate_tensor_size_gb(tensor: Any) -> float: + """Estimate the size of a tensor (or nested tensors) in GB.""" + if tensor is None: + return 0.0 + if isinstance(tensor, torch.Tensor): + return tensor.numel() * tensor.element_size() / (1024**3) + if isinstance(tensor, list): + return sum(estimate_tensor_size_gb(t) for t in tensor) + if isinstance(tensor, dict): + return sum(estimate_tensor_size_gb(v) for v in tensor.values()) + return 0.0 + + +def estimate_inputs_size_gb(all_inputs: dict) -> float: + """Estimate the total size of calibration inputs in GB.""" + total = 0.0 + for _, inputs in all_inputs.items(): + total += estimate_tensor_size_gb(inputs) + return total + + +def estimate_model_size_gb(model: torch.nn.Module) -> float: + """Estimate the model weights size in GB.""" + total = 0.0 + for param in model.parameters(): + if param.numel() > 0: # Skip empty tensors + total += param.numel() * param.element_size() / (1024**3) + return total + + def clear_memory_if_reached_threshold(threshold=0.85, device_list=None): """Check all available devices and clear memory if any device is using close to the threshold. diff --git a/auto_round/utils/model.py b/auto_round/utils/model.py index 94fb7198b..67335163d 100644 --- a/auto_round/utils/model.py +++ b/auto_round/utils/model.py @@ -15,9 +15,11 @@ import json import os import re +import shutil +import tempfile from collections import UserDict from pathlib import Path -from typing import Union +from typing import Any, Optional, Union import psutil import torch @@ -63,6 +65,197 @@ def clean_module_parameter(submodule: torch.nn.Module, param_name: str) -> None: param.requires_grad = False +def save_module_weights(module: torch.nn.Module, save_path: str) -> dict: + """Save module's weight and bias tensors to disk to reduce CPU RAM usage. + + This function saves the weight and bias tensors of a module to a specified path on disk. + It returns metadata about the saved tensors that can be used later to reload them. + + Args: + module (torch.nn.Module): The module whose weights should be saved. + save_path (str): Path where the weights should be saved. + This should be a unique path for each module. + + Returns: + dict: Metadata containing information about the saved tensors, including: + - 'has_weight': bool indicating if weight was saved + - 'has_bias': bool indicating if bias was saved + - 'weight_shape': shape of the weight tensor + - 'bias_shape': shape of the bias tensor (if exists) + - 'weight_dtype': dtype of the weight tensor + - 'bias_dtype': dtype of the bias tensor (if exists) + - 'weight_device': original device of the weight tensor + - 'bias_device': original device of the bias tensor (if exists) + - 'save_path': the path where tensors were saved + + Example: + >>> module = torch.nn.Linear(10, 5) + >>> metadata = save_module_weights(module, "/tmp/module_weights.pt") + >>> # Now module's weights can be cleared to save RAM + >>> clear_module_weights(module) + >>> # Later, weights can be restored + >>> load_module_weights(module, metadata) + """ + if module is None: + return {} + + metadata = {"save_path": save_path} + tensors_to_save = {} + + # Save weight if it exists + if hasattr(module, "weight") and module.weight is not None: + weight = module.weight + if weight.device.type != "meta" and weight.numel() > 0: + tensors_to_save["weight"] = weight.detach().cpu() + metadata["has_weight"] = True + metadata["weight_shape"] = tuple(weight.shape) + metadata["weight_dtype"] = weight.dtype + metadata["weight_device"] = str(weight.device) + else: + metadata["has_weight"] = False + else: + metadata["has_weight"] = False + + # Save bias if it exists + if hasattr(module, "bias") and module.bias is not None: + bias = module.bias + if bias.device.type != "meta" and bias.numel() > 0: + tensors_to_save["bias"] = bias.detach().cpu() + metadata["has_bias"] = True + metadata["bias_shape"] = tuple(bias.shape) + metadata["bias_dtype"] = bias.dtype + metadata["bias_device"] = str(bias.device) + else: + metadata["has_bias"] = False + else: + metadata["has_bias"] = False + + # Save to disk + if tensors_to_save: + os.makedirs(os.path.dirname(save_path), exist_ok=True) + torch.save(tensors_to_save, save_path) + + return metadata + + +def load_module_weights(module: torch.nn.Module, metadata: dict) -> None: + """Load module's weight and bias tensors from disk. + + This function reloads weights that were previously saved using save_module_weights(). + The weights are loaded back to their original device and dtype. + + Args: + module (torch.nn.Module): The module whose weights should be restored. + metadata (dict): Metadata returned by save_module_weights(), containing + information about the saved tensors and their path. + + Example: + >>> module = torch.nn.Linear(10, 5) + >>> metadata = save_module_weights(module, "/tmp/module_weights.pt") + >>> clear_module_weights(module) + >>> # ... do some work with reduced memory ... + >>> load_module_weights(module, metadata) + >>> # Now module's weights are restored + """ + if module is None or not metadata or "save_path" not in metadata: + return + + save_path = metadata["save_path"] + if not os.path.exists(save_path): + logger.warning(f"Cannot load weights: file {save_path} does not exist") + return + + # Load tensors from disk + try: + tensors = torch.load(save_path, map_location="cpu") + except Exception as e: + logger.warning(f"Failed to load weights from {save_path}: {e}") + return + + # Restore weight + if metadata.get("has_weight", False) and "weight" in tensors: + weight = tensors["weight"] + target_device = metadata.get("weight_device", "cpu") + target_dtype = metadata.get("weight_dtype", weight.dtype) + + # Move to target device and dtype + weight = weight.to(device=target_device, dtype=target_dtype) + + # Set the weight back to the module + if hasattr(module, "weight"): + if isinstance(module.weight, torch.nn.Parameter): + module.weight = torch.nn.Parameter(weight, requires_grad=module.weight.requires_grad) + else: + module.weight = weight + + # Restore bias + if metadata.get("has_bias", False) and "bias" in tensors: + bias = tensors["bias"] + target_device = metadata.get("bias_device", "cpu") + target_dtype = metadata.get("bias_dtype", bias.dtype) + + # Move to target device and dtype + bias = bias.to(device=target_device, dtype=target_dtype) + + # Set the bias back to the module + if hasattr(module, "bias"): + if isinstance(module.bias, torch.nn.Parameter): + module.bias = torch.nn.Parameter(bias, requires_grad=module.bias.requires_grad) + else: + module.bias = bias + + +def clear_module_weights(module: torch.nn.Module, to_meta: bool = False) -> None: + """Clear module's weight and bias to free CPU RAM. + + This function clears the weight and bias of a module to reduce memory usage. + It can either set them to empty tensors (default) or move them to meta device. + + Args: + module (torch.nn.Module): The module whose weights should be cleared. + to_meta (bool): If True, move tensors to meta device. + If False, set them to empty tensors. Default is False. + + Note: + This function should typically be called after save_module_weights() + to preserve the ability to restore weights later. + + Example: + >>> module = torch.nn.Linear(10, 5) + >>> metadata = save_module_weights(module, "/tmp/module_weights.pt") + >>> clear_module_weights(module) # Free memory + >>> # ... later ... + >>> load_module_weights(module, metadata) # Restore weights + """ + if module is None: + return + + with torch.no_grad(): + # Clear weight + if hasattr(module, "weight") and module.weight is not None: + if to_meta: + # Move to meta device + if module.weight.device.type != "meta": + module.weight = torch.nn.Parameter( + torch.empty_like(module.weight, device="meta"), requires_grad=module.weight.requires_grad + ) + else: + # Use clean_module_parameter for safety + clean_module_parameter(module, "weight") + + # Clear bias + if hasattr(module, "bias") and module.bias is not None: + if to_meta: + # Move to meta device + if module.bias.device.type != "meta": + module.bias = torch.nn.Parameter( + torch.empty_like(module.bias, device="meta"), requires_grad=module.bias.requires_grad + ) + else: + # Use clean_module_parameter for safety + clean_module_parameter(module, "bias") + + def convert_dtype_str2torch(str_dtype): """Converts a string dtype to its corresponding PyTorch dtype. @@ -1526,3 +1719,226 @@ def is_separate_tensor(model: torch.nn.Module, tensor_name: str) -> bool: return True else: return False + + +# ===================== CPU Offload Utilities ===================== + + +def init_cpu_offload_dir(compressor: Any) -> Optional[str]: + """Initialize a temporary directory for CPU offloading. + + Args: + compressor: The compressor object containing low_cpu_mem_usage flag + and _cpu_offload_tempdir attribute. + + Returns: + Optional[str]: Path to the temporary directory, or None if not enabled. + """ + if not compressor.low_cpu_mem_usage: + return None + if compressor._cpu_offload_tempdir is None: + compressor._cpu_offload_tempdir = tempfile.mkdtemp(prefix="autoround_cpu_offload_") + return compressor._cpu_offload_tempdir + + +def offload_block_weights(compressor: Any, block_name: str, block: torch.nn.Module) -> None: + """Offload a block's weights to disk to reduce CPU RAM usage. + + Args: + compressor: The compressor object containing offload state. + block_name: Name of the block being offloaded. + block: The block module whose weights should be offloaded. + """ + if not compressor.low_cpu_mem_usage: + return + offload_dir = init_cpu_offload_dir(compressor) + if offload_dir is None: + return + safe_name = block_name.replace(".", "_") + save_path = os.path.join(offload_dir, f"{safe_name}.pt") + metadata = save_module_weights(block, save_path) + compressor._offloaded_blocks[block_name] = metadata + clear_module_weights(block) + + +def stream_offload_blocks(compressor: Any, all_blocks: list[list[str]]) -> None: + """Offload all block weights to disk and clear from memory. + + Args: + compressor: The compressor object containing model and offload state. + all_blocks: List of block name lists to offload. + """ + if not (compressor.low_cpu_mem_usage and compressor.cpu_stream_offload_blocks): + return + offload_dir = init_cpu_offload_dir(compressor) + if offload_dir is None: + return + logger.info("stream offloading block weights to disk...") + total_offloaded_gb = 0.0 + for block_names in all_blocks: + for block_name in block_names: + if block_name in compressor._offloaded_blocks: + continue + block = get_module(compressor.model, block_name) + if block is None: + continue + block_size_gb = estimate_block_size_gb(block) + total_offloaded_gb += block_size_gb + safe_name = block_name.replace(".", "_") + save_path = os.path.join(offload_dir, f"{safe_name}.pt") + # Save entire block state_dict (all submodule weights) + state_dict = {k: v.cpu() for k, v in block.state_dict().items()} + torch.save(state_dict, save_path) + compressor._offloaded_blocks[block_name] = {"save_path": save_path} + # Clear all submodule weights + for submodule in block.modules(): + if hasattr(submodule, "weight") and submodule.weight is not None: + clear_module_weights(submodule) + if hasattr(submodule, "bias") and submodule.bias is not None: + clear_module_weights(submodule) + # Import clear_memory here to avoid circular imports + from auto_round.utils.device import clear_memory + + clear_memory(device_list=compressor.device_list) + logger.info(f"stream offload done, offloaded {total_offloaded_gb:.2f} GB of block weights") + + +def load_offloaded_block_weights(compressor: Any, block_name: str, block: torch.nn.Module) -> None: + """Load block weights from disk back into memory. + + Args: + compressor: The compressor object containing offload state. + block_name: Name of the block to load. + block: The block module to restore weights to. + """ + if not (compressor.low_cpu_mem_usage and compressor.cpu_stream_offload_blocks): + return + metadata = compressor._offloaded_blocks.get(block_name) + if not metadata: + return + save_path = metadata.get("save_path") + if not save_path or not os.path.exists(save_path): + logger.warning(f"Cannot load block weights: file {save_path} does not exist") + return + try: + state_dict = torch.load(save_path, map_location="cpu") + # Manually assign parameters to handle empty tensor replacement + for name, param in state_dict.items(): + parts = name.split(".") + target = block + for part in parts[:-1]: + target = getattr(target, part) + param_name = parts[-1] + if hasattr(target, param_name): + old_param = getattr(target, param_name) + if isinstance(old_param, torch.nn.Parameter): + setattr(target, param_name, torch.nn.Parameter(param, requires_grad=old_param.requires_grad)) + else: + setattr(target, param_name, param) + except Exception as e: + logger.warning(f"Failed to load block weights from {save_path}: {e}") + + +def discard_offloaded_block(compressor: Any, block_name: str) -> None: + """Discard the original offload file and re-offload quantized weights. + + Args: + compressor: The compressor object containing model and offload state. + block_name: Name of the block to discard and re-offload. + """ + if not (compressor.low_cpu_mem_usage and compressor.cpu_stream_offload_blocks): + return + metadata = compressor._offloaded_blocks.pop(block_name, None) + if not metadata: + return + save_path = metadata.get("save_path") + if save_path and os.path.exists(save_path): + try: + os.remove(save_path) + except Exception as e: + logger.warning(f"Failed to remove offloaded block file {save_path}: {e}") + + # Re-offload the quantized block weights to disk + block = get_module(compressor.model, block_name) + if block is None: + return + offload_dir = init_cpu_offload_dir(compressor) + if offload_dir is None: + return + safe_name = block_name.replace(".", "_") + new_save_path = os.path.join(offload_dir, f"{safe_name}_quantized.pt") + try: + state_dict = {k: v.cpu() for k, v in block.state_dict().items()} + torch.save(state_dict, new_save_path) + compressor._offloaded_blocks[block_name] = {"save_path": new_save_path, "quantized": True} + # Clear all submodule weights + for submodule in block.modules(): + if hasattr(submodule, "weight") and submodule.weight is not None: + clear_module_weights(submodule) + if hasattr(submodule, "bias") and submodule.bias is not None: + clear_module_weights(submodule) + except Exception as e: + logger.warning(f"Failed to re-offload quantized block {block_name}: {e}") + + +def restore_offloaded_blocks(compressor: Any) -> None: + """Restore all offloaded block weights back to memory. + + Args: + compressor: The compressor object containing model and offload state. + """ + if not compressor._offloaded_blocks: + return + for block_name, metadata in list(compressor._offloaded_blocks.items()): + try: + block = get_module(compressor.model, block_name) + save_path = metadata.get("save_path") + if not save_path or not os.path.exists(save_path): + logger.warning(f"Cannot restore block {block_name}: file {save_path} does not exist") + continue + state_dict = torch.load(save_path, map_location="cpu") + # Manually assign parameters to handle empty tensor replacement + for name, param in state_dict.items(): + parts = name.split(".") + target = block + for part in parts[:-1]: + target = getattr(target, part) + param_name = parts[-1] + if hasattr(target, param_name): + old_param = getattr(target, param_name) + if isinstance(old_param, torch.nn.Parameter): + setattr(target, param_name, torch.nn.Parameter(param, requires_grad=old_param.requires_grad)) + else: + setattr(target, param_name, param) + except Exception as e: + logger.warning(f"Failed to restore offloaded block {block_name}: {e}") + + +def cleanup_cpu_offload_dir(compressor: Any) -> None: + """Clean up the CPU offload temporary directory. + + Args: + compressor: The compressor object containing the tempdir path. + """ + if compressor._cpu_offload_tempdir and os.path.isdir(compressor._cpu_offload_tempdir): + try: + shutil.rmtree(compressor._cpu_offload_tempdir) + except Exception as e: + logger.warning(f"Failed to cleanup cpu offload dir {compressor._cpu_offload_tempdir}: {e}") + compressor._cpu_offload_tempdir = None + + +def estimate_block_size_gb(block: torch.nn.Module) -> float: + """Estimate a block's weights size in GB. + + Args: + block: The block module to estimate size for. + + Returns: + float: Size of the block's parameters in GB. + """ + total = 0.0 + for param in block.parameters(): + if param.numel() > 0: + total += param.numel() * param.element_size() / (1024**3) + return total diff --git a/test/test_cpu/core/test_low_cpu_mem_options.py b/test/test_cpu/core/test_low_cpu_mem_options.py new file mode 100644 index 000000000..721492618 --- /dev/null +++ b/test/test_cpu/core/test_low_cpu_mem_options.py @@ -0,0 +1,195 @@ +# Copyright (c) 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Unit tests for CPU RAM optimization options: +1. cpu_stream_offload_blocks: Offload block weights to disk, load on demand +2. cpu_stream_loss: Compute loss on-the-fly using frozen block copy +""" + +import torch + +from auto_round import AutoRound +from auto_round.compressors import base as base_module +from auto_round.utils import device as device_module +from auto_round.utils import model as model_module +from auto_round.utils.model import stream_offload_blocks + + +class TestCpuStreamOffloadBlocks: + """Tests for cpu_stream_offload_blocks option.""" + + def test_option_stored_correctly(self, tiny_opt_model_path): + """Test that cpu_stream_offload_blocks option is stored correctly.""" + autoround = AutoRound( + tiny_opt_model_path, + bits=4, + low_cpu_mem_usage=True, + cpu_stream_offload_blocks=True, + iters=0, + disable_opt_rtn=True, + nsamples=1, + seqlen=32, + ) + assert autoround.cpu_stream_offload_blocks is True + + autoround2 = AutoRound( + tiny_opt_model_path, + bits=4, + low_cpu_mem_usage=False, + cpu_stream_offload_blocks=False, + iters=0, + disable_opt_rtn=True, + nsamples=1, + seqlen=32, + ) + assert autoround2.cpu_stream_offload_blocks is False + + def test_offload_requires_low_cpu_mem_usage(self, tiny_opt_model_path): + """Test that offload only activates when low_cpu_mem_usage is True.""" + autoround = AutoRound( + tiny_opt_model_path, + bits=4, + low_cpu_mem_usage=False, + cpu_stream_offload_blocks=True, + iters=0, + disable_opt_rtn=True, + nsamples=1, + seqlen=32, + ) + # Even if cpu_stream_offload_blocks=True, it should not offload + # when low_cpu_mem_usage=False + assert autoround.cpu_stream_offload_blocks is True + assert autoround.low_cpu_mem_usage is False + + def test_stream_offload_blocks_skips_when_disabled(self, tiny_opt_model_path): + """Test that stream_offload_blocks returns early when disabled.""" + autoround = AutoRound( + tiny_opt_model_path, + bits=4, + low_cpu_mem_usage=True, + cpu_stream_offload_blocks=False, + iters=0, + disable_opt_rtn=True, + nsamples=1, + seqlen=32, + ) + stream_offload_blocks(autoround, [["model.layers.0"]]) + assert autoround._offloaded_blocks == {} + + autoround2 = AutoRound( + tiny_opt_model_path, + bits=4, + low_cpu_mem_usage=False, + cpu_stream_offload_blocks=True, + iters=0, + disable_opt_rtn=True, + nsamples=1, + seqlen=32, + ) + stream_offload_blocks(autoround2, [["model.layers.0"]]) + assert autoround2._offloaded_blocks == {} + + def test_stream_offload_blocks_records_blocks(self, tiny_opt_model_path, tmp_path, monkeypatch): + """Test that stream_offload_blocks records offloaded blocks when enabled.""" + autoround = AutoRound( + tiny_opt_model_path, + bits=4, + low_cpu_mem_usage=True, + cpu_stream_offload_blocks=True, + iters=0, + disable_opt_rtn=True, + nsamples=1, + seqlen=32, + ) + + dummy_block = torch.nn.Linear(4, 4) + monkeypatch.setattr(model_module, "get_module", lambda _model, _name: dummy_block) + monkeypatch.setattr(model_module, "init_cpu_offload_dir", lambda _compressor: str(tmp_path)) + monkeypatch.setattr(torch, "save", lambda *args, **kwargs: None) + monkeypatch.setattr(model_module, "clear_module_weights", lambda *_args, **_kwargs: None) + + stream_offload_blocks(autoround, [["model.layers.0"]]) + assert "model.layers.0" in autoround._offloaded_blocks + + +class TestCpuStreamLoss: + """Tests for cpu_stream_loss option.""" + + def test_option_stored_correctly(self, tiny_opt_model_path): + """Test that cpu_stream_loss option is stored correctly.""" + autoround = AutoRound( + tiny_opt_model_path, + bits=4, + low_cpu_mem_usage=True, + cpu_stream_loss=True, + iters=0, + disable_opt_rtn=True, + nsamples=1, + seqlen=32, + ) + assert autoround.cpu_stream_loss is True + + autoround2 = AutoRound( + tiny_opt_model_path, + bits=4, + low_cpu_mem_usage=False, + cpu_stream_loss=False, + iters=0, + disable_opt_rtn=True, + nsamples=1, + seqlen=32, + ) + assert autoround2.cpu_stream_loss is False + + def test_stream_loss_requires_nblocks_1(self, tiny_opt_model_path): + """Test that cpu_stream_loss only works with nblocks=1.""" + # nblocks > 1 should trigger warning and disable stream_loss + autoround = AutoRound( + tiny_opt_model_path, + bits=4, + low_cpu_mem_usage=True, + cpu_stream_loss=True, + nblocks=2, + iters=0, + disable_opt_rtn=True, + nsamples=1, + seqlen=32, + ) + # The option is stored, but internally it will fall back during quantize + assert autoround.cpu_stream_loss is True + assert autoround.nblocks == 2 + stream_loss = autoround.cpu_stream_loss and autoround.nblocks == 1 + assert stream_loss is False + + +class TestCombinedOptions: + """Tests for combined optimization options.""" + + def test_both_options_enabled(self, tiny_opt_model_path): + """Test that both options can be enabled together.""" + autoround = AutoRound( + tiny_opt_model_path, + bits=4, + low_cpu_mem_usage=True, + cpu_stream_offload_blocks=True, + cpu_stream_loss=True, + iters=0, + disable_opt_rtn=True, + nsamples=1, + seqlen=32, + ) + assert autoround.cpu_stream_offload_blocks is True + assert autoround.cpu_stream_loss is True + assert autoround.low_cpu_mem_usage is True diff --git a/test/test_cpu/schemes/test_auto_scheme_low_cpu_mem.py b/test/test_cpu/schemes/test_auto_scheme_low_cpu_mem.py new file mode 100644 index 000000000..6fb72d7f9 --- /dev/null +++ b/test/test_cpu/schemes/test_auto_scheme_low_cpu_mem.py @@ -0,0 +1,448 @@ +# Copyright (c) 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Unit tests for AutoScheme CPU RAM optimization (low_cpu_mem_usage option). + +This tests the disk offload mechanism for AutoScheme that reduces CPU RAM usage +by offloading block weights to disk during gradient computation. +""" + +import os +import shutil + +import pytest +import torch + +from auto_round import AutoRound, AutoScheme +from auto_round.auto_scheme.delta_loss import ( + AutoSchemeOffloadContext, + _clear_module_weights, + _group_layers_by_block, +) +from auto_round.auto_scheme.utils import compute_layer_bits +from auto_round.utils import get_block_names, get_module + + +class TestAutoSchemeOffloadContext: + """Tests for AutoSchemeOffloadContext class.""" + + def test_context_init_disabled(self): + """Test that context is properly initialized when disabled.""" + ctx = AutoSchemeOffloadContext(low_cpu_mem_usage=False) + assert ctx.low_cpu_mem_usage is False + assert ctx._offload_tempdir is None + assert ctx._offloaded_blocks == {} + + def test_context_init_enabled(self): + """Test that context is properly initialized when enabled.""" + ctx = AutoSchemeOffloadContext(low_cpu_mem_usage=True) + assert ctx.low_cpu_mem_usage is True + assert ctx._offload_tempdir is None + assert ctx._offloaded_blocks == {} + + def test_init_offload_dir_disabled(self): + """Test that init_offload_dir returns None when disabled.""" + ctx = AutoSchemeOffloadContext(low_cpu_mem_usage=False) + result = ctx.init_offload_dir() + assert result is None + assert ctx._offload_tempdir is None + + def test_init_offload_dir_enabled(self): + """Test that init_offload_dir creates temp directory when enabled.""" + ctx = AutoSchemeOffloadContext(low_cpu_mem_usage=True) + try: + result = ctx.init_offload_dir() + assert result is not None + assert os.path.isdir(result) + assert "autoscheme_offload_" in result + finally: + ctx.cleanup() + + def test_offload_block_weights_disabled(self): + """Test that offload_block_weights does nothing when disabled.""" + ctx = AutoSchemeOffloadContext(low_cpu_mem_usage=False) + module = torch.nn.Linear(4, 4) + ctx.offload_block_weights("test_block", module) + assert ctx._offloaded_blocks == {} + + def test_offload_block_weights_enabled(self): + """Test that offload_block_weights saves weights to disk.""" + ctx = AutoSchemeOffloadContext(low_cpu_mem_usage=True) + try: + module = torch.nn.Linear(4, 4) + ctx.offload_block_weights("test_block", module) + + assert "test_block" in ctx._offloaded_blocks + metadata = ctx._offloaded_blocks["test_block"] + assert "save_path" in metadata + assert os.path.exists(metadata["save_path"]) + + # Verify weight was cleared + assert module.weight.numel() == 0 + finally: + ctx.cleanup() + + def test_load_block_weights_disabled(self): + """Test that load_block_weights does nothing when disabled.""" + ctx = AutoSchemeOffloadContext(low_cpu_mem_usage=False) + module = torch.nn.Linear(4, 4) + ctx.load_block_weights("test_block", module) + + def test_offload_and_load_block_weights(self): + """Test full cycle of offloading and loading block weights.""" + ctx = AutoSchemeOffloadContext(low_cpu_mem_usage=True) + try: + module = torch.nn.Linear(4, 4) + original_weight = module.weight.clone() + original_bias = module.bias.clone() + + # Offload + ctx.offload_block_weights("test_block", module) + assert module.weight.numel() == 0 + + # Load back + ctx.load_block_weights("test_block", module) + + # Verify weights are restored + assert module.weight.shape == original_weight.shape + assert torch.allclose(module.weight, original_weight) + assert module.bias.shape == original_bias.shape + assert torch.allclose(module.bias, original_bias) + finally: + ctx.cleanup() + + def test_offload_block_weights_idempotent(self): + """Test that offloading same block twice doesn't create duplicate files.""" + ctx = AutoSchemeOffloadContext(low_cpu_mem_usage=True) + try: + module = torch.nn.Linear(4, 4) + + ctx.offload_block_weights("test_block", module) + first_path = ctx._offloaded_blocks["test_block"]["save_path"] + + # Create new module and try to offload again + module2 = torch.nn.Linear(4, 4) + ctx.offload_block_weights("test_block", module2) + + # Should still have same path (second offload skipped) + assert ctx._offloaded_blocks["test_block"]["save_path"] == first_path + finally: + ctx.cleanup() + + def test_cleanup(self): + """Test that cleanup removes temp directory.""" + ctx = AutoSchemeOffloadContext(low_cpu_mem_usage=True) + ctx.init_offload_dir() + tempdir = ctx._offload_tempdir + assert os.path.isdir(tempdir) + + ctx.cleanup() + + assert not os.path.exists(tempdir) + assert ctx._offload_tempdir is None + assert ctx._offloaded_blocks == {} + + def test_save_and_load_original_block_weights(self): + """Test saving and loading original (unwrapped) block weights.""" + ctx = AutoSchemeOffloadContext(low_cpu_mem_usage=True) + try: + module = torch.nn.Linear(8, 4) + original_weight = module.weight.data.clone() + + ctx.save_original_block_weights("block.0", module) + assert "block.0" in ctx._original_blocks + + # Clear and load back + _clear_module_weights(module) + assert module.weight.numel() == 0 + + ctx.load_original_block_weights("block.0", module) + assert module.weight.numel() == original_weight.numel() + assert torch.allclose(module.weight.data, original_weight) + finally: + ctx.cleanup() + + def test_save_original_idempotent(self): + """Test that saving original block twice doesn't overwrite.""" + ctx = AutoSchemeOffloadContext(low_cpu_mem_usage=True) + try: + module = torch.nn.Linear(4, 4) + ctx.save_original_block_weights("b", module) + path1 = ctx._original_blocks["b"]["save_path"] + + # Modify module and save again — should be a no-op + module.weight.data.fill_(999.0) + ctx.save_original_block_weights("b", module) + path2 = ctx._original_blocks["b"]["save_path"] + assert path1 == path2 + finally: + ctx.cleanup() + + def test_reset_scheme_state(self): + """Test that reset_scheme_state clears wrapped state but keeps originals.""" + ctx = AutoSchemeOffloadContext(low_cpu_mem_usage=True) + try: + module = torch.nn.Linear(4, 4) + ctx.save_original_block_weights("b", module) + ctx.offload_block_weights("b", module) + + assert len(ctx._offloaded_blocks) == 1 + assert len(ctx._original_blocks) == 1 + + ctx.reset_scheme_state() + assert len(ctx._offloaded_blocks) == 0 + assert len(ctx._original_blocks) == 1 + finally: + ctx.cleanup() + + def test_cleanup_removes_both_dirs(self): + """Test that cleanup removes both original and offload directories.""" + ctx = AutoSchemeOffloadContext(low_cpu_mem_usage=True) + module = torch.nn.Linear(4, 4) + ctx.save_original_block_weights("b", module) + ctx.offload_block_weights("b", module) + + orig_dir = ctx._original_dir + offload_dir = ctx._offload_tempdir + assert os.path.isdir(orig_dir) + assert os.path.isdir(offload_dir) + + ctx.cleanup() + assert not os.path.exists(orig_dir) + assert not os.path.exists(offload_dir) + + +class TestClearModuleWeights: + """Tests for _clear_module_weights helper function.""" + + def test_clear_linear_weights(self): + """Test clearing weights from a Linear module.""" + module = torch.nn.Linear(16, 8) + _clear_module_weights(module) + + assert module.weight.numel() == 0 + assert module.bias.numel() == 0 + + def test_clear_caches_numel(self): + """Test that clearing weights caches the original numel.""" + module = torch.nn.Linear(32, 16) + expected = 32 * 16 + _clear_module_weights(module) + assert hasattr(module, "_cached_weight_numel") + assert module._cached_weight_numel == expected + + def test_compute_layer_bits_with_empty_weight(self): + """Test compute_layer_bits works after weight is cleared.""" + layer = torch.nn.Linear(64, 128) + layer.bits = 4 + layer.group_size = 32 + layer.data_type = "int" + layer.sym = True + layer.super_group_size = None + layer.super_bits = None + + bits_before, _ = compute_layer_bits(layer, False) + _clear_module_weights(layer) + bits_after, _ = compute_layer_bits(layer, False) + assert bits_before == bits_after + + def test_clear_module_none(self): + """Test clearing weights with None module doesn't crash.""" + _clear_module_weights(None) # Should not raise + + def test_clear_module_no_weights(self): + """Test clearing module without weights doesn't crash.""" + module = torch.nn.ReLU() + _clear_module_weights(module) # Should not raise + + +class TestAutoSchemeDataclassLowCpuMem: + """Tests for low_cpu_mem_usage parameter in AutoScheme dataclass.""" + + def test_auto_scheme_default_low_cpu_mem_usage(self): + """Test that low_cpu_mem_usage defaults to False.""" + scheme = AutoScheme(avg_bits=4, options="W4A16") + assert scheme.low_cpu_mem_usage is False + + def test_auto_scheme_low_cpu_mem_usage_enabled(self): + """Test that low_cpu_mem_usage can be enabled.""" + scheme = AutoScheme(avg_bits=4, options="W4A16", low_cpu_mem_usage=True) + assert scheme.low_cpu_mem_usage is True + + def test_auto_scheme_low_cpu_mem_usage_with_low_gpu_mem_usage(self): + """Test that both low_cpu_mem_usage and low_gpu_mem_usage can be set.""" + scheme = AutoScheme( + avg_bits=4, + options="W4A16", + low_cpu_mem_usage=True, + low_gpu_mem_usage=True, + ) + assert scheme.low_cpu_mem_usage is True + assert scheme.low_gpu_mem_usage is True + + +class TestGroupLayersByBlock: + """Tests for _group_layers_by_block helper.""" + + def test_basic_grouping(self): + layers = ["model.layers.0.attn.q", "model.layers.0.attn.k", "model.layers.1.mlp.fc1", "lm_head"] + blocks = ["model.layers.0", "model.layers.1"] + groups, non_block = _group_layers_by_block(layers, blocks) + assert groups["model.layers.0"] == ["model.layers.0.attn.q", "model.layers.0.attn.k"] + assert groups["model.layers.1"] == ["model.layers.1.mlp.fc1"] + assert non_block == ["lm_head"] + + def test_empty_layers(self): + groups, non_block = _group_layers_by_block([], ["b0"]) + assert groups == {"b0": []} + assert non_block == [] + + +class TestAutoSchemeIntegration: + """Integration tests for AutoScheme with low_cpu_mem_usage.""" + + save_dir = "./saved" + + @pytest.fixture(autouse=True, scope="class") + def setup_and_teardown_class(self): + yield + shutil.rmtree("./saved", ignore_errors=True) + shutil.rmtree("runs", ignore_errors=True) + + def test_auto_scheme_with_low_cpu_mem_disabled(self, tiny_opt_model_path): + """Test AutoScheme works normally with low_cpu_mem_usage disabled.""" + model_name = tiny_opt_model_path + scheme = AutoScheme( + avg_bits=4, + options="W2A16,W4A16", + nsamples=1, + ignore_scale_zp_bits=True, + low_cpu_mem_usage=False, + low_gpu_mem_usage=True, + ) + ar = AutoRound(model=model_name, scheme=scheme, iters=0, nsamples=1) + _, layer_config = ar.quantize() + assert layer_config is not None + assert len(layer_config) > 0 + + def test_auto_scheme_with_low_cpu_mem_enabled(self, tiny_opt_model_path): + """Test AutoScheme works with low_cpu_mem_usage enabled.""" + model_name = tiny_opt_model_path + scheme = AutoScheme( + avg_bits=4, + options="W2A16,W4A16", + nsamples=1, + ignore_scale_zp_bits=True, + low_cpu_mem_usage=True, + low_gpu_mem_usage=True, + ) + ar = AutoRound(model=model_name, scheme=scheme, iters=0, nsamples=1) + _, layer_config = ar.quantize() + assert layer_config is not None + assert len(layer_config) > 0 + + def test_auto_scheme_low_cpu_mem_results_consistent(self, tiny_opt_model_path): + """Test that results are consistent with and without low_cpu_mem_usage.""" + model_name = tiny_opt_model_path + + # Without low_cpu_mem_usage + scheme1 = AutoScheme( + avg_bits=4, + options="W4A16", + nsamples=1, + ignore_scale_zp_bits=True, + low_cpu_mem_usage=False, + low_gpu_mem_usage=True, + ) + ar1 = AutoRound(model=model_name, scheme=scheme1, iters=0, nsamples=1, seed=42) + _, layer_config1 = ar1.quantize() + + # With low_cpu_mem_usage + scheme2 = AutoScheme( + avg_bits=4, + options="W4A16", + nsamples=1, + ignore_scale_zp_bits=True, + low_cpu_mem_usage=True, + low_gpu_mem_usage=True, + ) + ar2 = AutoRound(model=model_name, scheme=scheme2, iters=0, nsamples=1, seed=42) + _, layer_config2 = ar2.quantize() + + # Layer configs should have same keys + assert set(layer_config1.keys()) == set(layer_config2.keys()) + + +class TestAutoSchemeOffloadContextWithModel: + """Tests for AutoSchemeOffloadContext with actual model blocks.""" + + def test_offload_model_block(self, tiny_opt_model_path): + """Test offloading an actual model block.""" + from transformers import AutoModelForCausalLM + + model = AutoModelForCausalLM.from_pretrained(tiny_opt_model_path, torch_dtype=torch.float32) + block_names = get_block_names(model)[0] + + if len(block_names) == 0: + pytest.skip("Model has no blocks") + + block_name = block_names[0] + block = get_module(model, block_name) + + ctx = AutoSchemeOffloadContext(low_cpu_mem_usage=True) + try: + # Get original param count + original_params = sum(p.numel() for p in block.parameters()) + + # Offload + ctx.offload_block_weights(block_name, block) + + # Check that weights were cleared + current_params = sum(p.numel() for p in block.parameters()) + assert current_params < original_params + + # Load back + ctx.load_block_weights(block_name, block) + + # Check params are restored + restored_params = sum(p.numel() for p in block.parameters()) + assert restored_params == original_params + finally: + ctx.cleanup() + + def test_offload_all_blocks(self, tiny_opt_model_path): + """Test offloading all model blocks.""" + from transformers import AutoModelForCausalLM + + model = AutoModelForCausalLM.from_pretrained(tiny_opt_model_path, torch_dtype=torch.float32) + block_names = get_block_names(model)[0] + + if len(block_names) == 0: + pytest.skip("Model has no blocks") + + ctx = AutoSchemeOffloadContext(low_cpu_mem_usage=True) + try: + # Offload all blocks + ctx.offload_all_blocks(model, block_names) + + # Verify all blocks were offloaded + for block_name in block_names: + assert block_name in ctx._offloaded_blocks + block = get_module(model, block_name) + # Check weights are cleared (should have very few params) + block_params = sum(p.numel() for p in block.parameters()) + assert block_params == 0 or block_params < 100 # Allow for some edge cases + finally: + ctx.cleanup()