From dee1db7689729a4e3bb7402985c629bd6b991f92 Mon Sep 17 00:00:00 2001 From: lvliang-intel Date: Tue, 3 Feb 2026 02:24:31 +0000 Subject: [PATCH 01/13] Optimize CPU RAM peak memeory during quantization Signed-off-by: lvliang-intel --- auto_round/compressors/base.py | 338 +++++++++++++++++- auto_round/utils/model.py | 193 ++++++++++ .../advanced/test_cpu_ram_optimization.py | 174 +++++++++ 3 files changed, 696 insertions(+), 9 deletions(-) create mode 100644 test/test_cuda/advanced/test_cpu_ram_optimization.py diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index b9537ea08..ee080fd01 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -14,7 +14,10 @@ import copy import os +import re +import shutil import sys +import tempfile import time import traceback from collections import defaultdict @@ -77,6 +80,7 @@ check_seqlen_compatible, check_to_quantized, clear_memory, + clear_module_weights, compile_func, convert_dtype_str2torch, convert_module_to_hp_if_necessary, @@ -95,8 +99,10 @@ is_moe_model, is_quantized_input_module, llm_load_model, + load_module_weights, memory_monitor, mv_module_from_gpu, + save_module_weights, set_amax_for_all_moe_layers, set_module, to_device, @@ -345,6 +351,10 @@ def __init__( self.inner_supported_types = INNER_SUPPORTED_LAYER_TYPES self.scale_dtype = convert_dtype_str2torch(scale_dtype) self.low_cpu_mem_usage = low_cpu_mem_usage + self.cpu_stream_offload_blocks = kwargs.pop("cpu_stream_offload_blocks", False) + self.cpu_stream_loss = kwargs.pop("cpu_stream_loss", False) + self._cpu_offload_tempdir = None + self._offloaded_blocks = {} if kwargs: logger.warning(f"unrecognized keys {list(kwargs.keys())} were passed. Please check them.") @@ -917,7 +927,6 @@ def quantize_and_save( output_dir, format=self.formats, inplace=inplace, return_folders=True, **kwargs ) memory_monitor.log_summary() - return model, folders def _get_save_folder_name(self, format: OutputFormat) -> str: @@ -1499,6 +1508,8 @@ def _quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str]) if not self.is_immediate_saving: # some modules may have been flushed and set to meta, so we could not move to gpu mv_module_from_gpu(block) + if self.low_cpu_mem_usage: + self._offload_block_weights(block_name, block) if block_name == block_names[-1]: clear_memory(input_ids, device_list=self.device_list) else: @@ -1517,6 +1528,194 @@ def _quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str]) # if self.is_immediate_saving: # shard_writer(self, is_finalize=True) + def _estimate_tensor_size_gb(self, tensor) -> float: + """Estimate the size of a tensor in GB.""" + if tensor is None: + return 0.0 + if isinstance(tensor, torch.Tensor): + return tensor.numel() * tensor.element_size() / (1024 ** 3) + elif isinstance(tensor, list): + return sum(self._estimate_tensor_size_gb(t) for t in tensor) + elif isinstance(tensor, dict): + return sum(self._estimate_tensor_size_gb(v) for v in tensor.values()) + return 0.0 + + def _estimate_inputs_size_gb(self, all_inputs: dict) -> float: + """Estimate the total size of calibration inputs in GB.""" + total = 0.0 + for name, inputs in all_inputs.items(): + total += self._estimate_tensor_size_gb(inputs) + return total + + def _estimate_model_size_gb(self) -> float: + """Estimate the model weights size in GB.""" + total = 0.0 + for param in self.model.parameters(): + if param.numel() > 0: # Skip empty tensors + total += param.numel() * param.element_size() / (1024 ** 3) + return total + + def _estimate_block_size_gb(self, block: torch.nn.Module) -> float: + """Estimate a block's weights size in GB.""" + total = 0.0 + for param in block.parameters(): + if param.numel() > 0: + total += param.numel() * param.element_size() / (1024 ** 3) + return total + + def _init_cpu_offload_dir(self) -> Optional[str]: + if not self.low_cpu_mem_usage: + return None + if self._cpu_offload_tempdir is None: + self._cpu_offload_tempdir = tempfile.mkdtemp(prefix="autoround_cpu_offload_") + return self._cpu_offload_tempdir + + def _offload_block_weights(self, block_name: str, block: torch.nn.Module) -> None: + if not self.low_cpu_mem_usage: + return + offload_dir = self._init_cpu_offload_dir() + if offload_dir is None: + return + safe_name = block_name.replace(".", "_") + save_path = os.path.join(offload_dir, f"{safe_name}.pt") + metadata = save_module_weights(block, save_path) + self._offloaded_blocks[block_name] = metadata + clear_module_weights(block) + + def _stream_offload_blocks(self, all_blocks: list[list[str]]) -> None: + """Offload all block weights to disk and clear from memory.""" + if not (self.low_cpu_mem_usage and self.cpu_stream_offload_blocks): + return + offload_dir = self._init_cpu_offload_dir() + if offload_dir is None: + return + logger.info("stream offloading block weights to disk...") + total_offloaded_gb = 0.0 + for block_names in all_blocks: + for block_name in block_names: + if block_name in self._offloaded_blocks: + continue + block = get_module(self.model, block_name) + if block is None: + continue + block_size_gb = self._estimate_block_size_gb(block) + total_offloaded_gb += block_size_gb + safe_name = block_name.replace(".", "_") + save_path = os.path.join(offload_dir, f"{safe_name}.pt") + # Save entire block state_dict (all submodule weights) + state_dict = {k: v.cpu() for k, v in block.state_dict().items()} + torch.save(state_dict, save_path) + self._offloaded_blocks[block_name] = {"save_path": save_path} + # Clear all submodule weights + for submodule in block.modules(): + if hasattr(submodule, "weight") and submodule.weight is not None: + clear_module_weights(submodule) + if hasattr(submodule, "bias") and submodule.bias is not None: + clear_module_weights(submodule) + clear_memory(device_list=self.device_list) + logger.info(f"stream offload done, offloaded {total_offloaded_gb:.2f} GB of block weights") + + def _load_offloaded_block_weights(self, block_name: str, block: torch.nn.Module) -> None: + """Load block weights from disk back into memory.""" + if not (self.low_cpu_mem_usage and self.cpu_stream_offload_blocks): + return + metadata = self._offloaded_blocks.get(block_name) + if not metadata: + return + save_path = metadata.get("save_path") + if not save_path or not os.path.exists(save_path): + logger.warning(f"Cannot load block weights: file {save_path} does not exist") + return + try: + state_dict = torch.load(save_path, map_location="cpu") + # Manually assign parameters to handle empty tensor replacement + for name, param in state_dict.items(): + parts = name.split(".") + target = block + for part in parts[:-1]: + target = getattr(target, part) + param_name = parts[-1] + if hasattr(target, param_name): + old_param = getattr(target, param_name) + if isinstance(old_param, torch.nn.Parameter): + setattr(target, param_name, torch.nn.Parameter(param, requires_grad=old_param.requires_grad)) + else: + setattr(target, param_name, param) + except Exception as e: + logger.warning(f"Failed to load block weights from {save_path}: {e}") + + def _discard_offloaded_block(self, block_name: str) -> None: + """Discard the original offload file and re-offload quantized weights.""" + if not (self.low_cpu_mem_usage and self.cpu_stream_offload_blocks): + return + metadata = self._offloaded_blocks.pop(block_name, None) + if not metadata: + return + save_path = metadata.get("save_path") + if save_path and os.path.exists(save_path): + try: + os.remove(save_path) + except Exception as e: + logger.warning(f"Failed to remove offloaded block file {save_path}: {e}") + + # Re-offload the quantized block weights to disk + block = get_module(self.model, block_name) + if block is None: + return + offload_dir = self._init_cpu_offload_dir() + if offload_dir is None: + return + safe_name = block_name.replace(".", "_") + new_save_path = os.path.join(offload_dir, f"{safe_name}_quantized.pt") + try: + state_dict = {k: v.cpu() for k, v in block.state_dict().items()} + torch.save(state_dict, new_save_path) + self._offloaded_blocks[block_name] = {"save_path": new_save_path, "quantized": True} + # Clear all submodule weights + for submodule in block.modules(): + if hasattr(submodule, "weight") and submodule.weight is not None: + clear_module_weights(submodule) + if hasattr(submodule, "bias") and submodule.bias is not None: + clear_module_weights(submodule) + except Exception as e: + logger.warning(f"Failed to re-offload quantized block {block_name}: {e}") + + def _restore_offloaded_blocks(self) -> None: + """Restore all offloaded block weights back to memory.""" + if not self._offloaded_blocks: + return + for block_name, metadata in list(self._offloaded_blocks.items()): + try: + block = get_module(self.model, block_name) + save_path = metadata.get("save_path") + if not save_path or not os.path.exists(save_path): + logger.warning(f"Cannot restore block {block_name}: file {save_path} does not exist") + continue + state_dict = torch.load(save_path, map_location="cpu") + # Manually assign parameters to handle empty tensor replacement + for name, param in state_dict.items(): + parts = name.split(".") + target = block + for part in parts[:-1]: + target = getattr(target, part) + param_name = parts[-1] + if hasattr(target, param_name): + old_param = getattr(target, param_name) + if isinstance(old_param, torch.nn.Parameter): + setattr(target, param_name, torch.nn.Parameter(param, requires_grad=old_param.requires_grad)) + else: + setattr(target, param_name, param) + except Exception as e: + logger.warning(f"Failed to restore offloaded block {block_name}: {e}") + + def _cleanup_cpu_offload_dir(self) -> None: + if self._cpu_offload_tempdir and os.path.isdir(self._cpu_offload_tempdir): + try: + shutil.rmtree(self._cpu_offload_tempdir) + except Exception as e: + logger.warning(f"Failed to cleanup cpu offload dir {self._cpu_offload_tempdir}: {e}") + self._cpu_offload_tempdir = None + def _update_inputs(self, inputs: dict, q_inputs: dict) -> tuple[dict, torch.Tensor]: keys = inputs.keys() input_id_str = [key for key in keys if key.startswith("hidden_state")] @@ -1652,6 +1851,9 @@ def quantize(self) -> tuple[torch.nn.Module, dict[str, Any]]: self.configure_layer_config(enable_gguf_official_mixed=enable_gguf_official_mixed) + if self.low_cpu_mem_usage: + self._offloaded_blocks = {} + def _should_disable_inplace_due_to_layers_outside_block() -> bool: return self.has_qlayer_outside_block and (self.iters != 0 or (self.iters == 0 and not self.disable_opt_rtn)) @@ -1689,9 +1891,14 @@ def _should_disable_inplace_due_to_layers_outside_block() -> bool: ) else: logger.info("start to cache block inputs") + all_inputs = self.try_cache_inter_data_gpucpu(all_first_block_names, self.nsamples, layer_names=layer_names) is_quantized_embedding = self._quantize_embedding_layer() clear_memory(device_list=self.device_list) + # Log memory breakdown for calibration inputs + if self.low_cpu_mem_usage: + inputs_size_gb = self._estimate_inputs_size_gb(all_inputs) + logger.info(f"[Memory] calibration inputs size: {inputs_size_gb:.2f} GB") all_q_inputs = None if is_quantized_embedding: all_inputs = copy.deepcopy(self.inputs) @@ -1704,6 +1911,12 @@ def _should_disable_inplace_due_to_layers_outside_block() -> bool: if hasattr(self.model, "hf_device_map") and len(self.model.hf_device_map) > 1: accelerate.hooks.remove_hook_from_submodules(self.model) # self.model.hf_device_map has not been changed logger.info("caching done") + # Log memory breakdown for model weights + if self.low_cpu_mem_usage: + model_size_gb = self._estimate_model_size_gb() + logger.info(f"[Memory] model weights size: {model_size_gb:.2f} GB") + if self.low_cpu_mem_usage and self.cpu_stream_offload_blocks: + self._stream_offload_blocks(all_blocks) if len(all_blocks) > 1: pbar = tqdm(range(0, sum([len(i) for i in all_blocks]), self.nblocks)) else: @@ -1749,6 +1962,10 @@ def _should_disable_inplace_due_to_layers_outside_block() -> bool: if self.is_immediate_saving: shard_writer(self, is_finalize=True) + if self.low_cpu_mem_usage: + self._restore_offloaded_blocks() + self._cleanup_cpu_offload_dir() + end_time = time.time() cost_time = end_time - start_time logger.info(f"quantization tuning time {cost_time}") @@ -1829,6 +2046,7 @@ def _quantize_layers(self, layer_names: list, layer_inputs: dict) -> None: memory_monitor.update() memory_monitor.log_summary() return + q_layer_inputs = None enable_quanted_input = self.enable_quanted_input has_gguf = False @@ -2657,6 +2875,29 @@ def _get_current_output(self, output: list[torch.Tensor], indices: list[int]) -> current_output = torch.cat(current_output, dim=self.batch_dim) return current_output + def _get_current_output_stream( + self, + block: torch.nn.Module, + input_ids: list[torch.Tensor], + input_others: dict, + indices: list[int], + device: str, + cache_device: str = "cpu", + ) -> torch.Tensor: + current_input_ids, current_input_others = self._sampling_inputs( + input_ids, + input_others, + indices, + seqlen=self.seqlen, + batch_dim=self.batch_dim, + share_cache_keys=self.shared_cache_keys, + ) + with torch.no_grad(): + output = self.block_forward( + block, current_input_ids, current_input_others, self.amp, self.amp_dtype, device + ) + return output.to(cache_device) + def _get_current_q_output( self, block: torch.nn.Module, @@ -2792,24 +3033,51 @@ def _quantize_block( hook = AlignDevicesHook(m.tuning_device, io_same_device=True) add_hook_to_module(m, hook, True) + stream_loss = self.cpu_stream_loss and self.nblocks == 1 + if self.cpu_stream_loss and self.nblocks != 1: + logger.warning("cpu_stream_loss only supports nblocks=1; falling back to cached outputs.") + stream_loss = False + if q_input is None: hook_handles = self._register_act_max_hook(block) - output = self._get_block_outputs( - block, input_ids, input_others, self.batch_size * self.infer_bs_coeff, device, self.cache_device - ) + if stream_loss: + output = None + self._get_block_outputs( + block, + input_ids, + input_others, + self.batch_size * self.infer_bs_coeff, + device, + self.cache_device, + save_output=False, + ) + else: + output = self._get_block_outputs( + block, input_ids, input_others, self.batch_size * self.infer_bs_coeff, device, self.cache_device + ) + + # Log output cache size for first block + if self.low_cpu_mem_usage and self.cpu_stream_offload_blocks and not hasattr(self, "_logged_output_size"): + output_size = self._estimate_tensor_size_gb(output) + logger.info(f"[Memory] block output cache size: {output_size:.2f} GB") + self._logged_output_size = True for handle in hook_handles: handle.remove() else: - output = self._get_block_outputs( - block, input_ids, input_others, self.batch_size * self.infer_bs_coeff, device, self.cache_device - ) + if stream_loss: + output = None + # Skip pre-computation in stream_loss mode - targets will be computed on-the-fly with frozen_block + else: + output = self._get_block_outputs( + block, input_ids, input_others, self.batch_size * self.infer_bs_coeff, device, self.cache_device + ) hook_handles = self._register_act_max_hook(block) if hook_handles: self._get_block_outputs( block, - q_input, + q_input if q_input is not None else input_ids, input_others, self.batch_size * self.infer_bs_coeff, device, @@ -2827,6 +3095,13 @@ def _quantize_block( clear_memory(device_list=self.device_list) input_ids = q_input + frozen_block = None + if stream_loss: + frozen_block = copy.deepcopy(block).to(device) + frozen_block.eval() + for p in frozen_block.parameters(): + p.requires_grad_(False) + quantized_layer_names, unquantized_layer_names = self.wrapper_block( block, self.enable_minmax_tuning, @@ -2923,7 +3198,12 @@ def _quantize_block( for tmp_step in range(self.gradient_accumulate_steps): indices = global_indices[tmp_step * batch_size : (tmp_step + 1) * batch_size] - current_output = self._get_current_output(output, indices) + if stream_loss: + current_output = self._get_current_output_stream( + frozen_block, input_ids, input_others, indices, loss_device, cache_device=loss_device + ) + else: + current_output = self._get_current_output(output, indices) current_output = to_device(current_output, loss_device) output_q = self._get_current_q_output(block, input_ids, input_others, indices, device, loss_device) loss = self._get_loss(output_q, current_output, indices, mse_loss, device) @@ -3006,6 +3286,16 @@ def _quantize_block( return q_outputs, output else: + # When stream_loss is enabled, output is None, so we need to compute it for the next block + if stream_loss and output is None: + output = self._get_block_outputs( + block, + input_ids, + input_others, + self.batch_size * self.infer_bs_coeff, + device, + cache_device=self.cache_device, + ) if len(self.device_list) > 1 and auto_offload: accelerate.hooks.remove_hook_from_submodules(block) if auto_offload: @@ -3070,6 +3360,12 @@ def _quantize_blocks( input_ids, input_others = self._preprocess_block_inputs(inputs) + # Log detailed memory breakdown for first block + if self.low_cpu_mem_usage and self.cpu_stream_offload_blocks: + input_ids_size = self._estimate_tensor_size_gb(input_ids) + input_others_size = self._estimate_tensor_size_gb(input_others) + logger.info(f"[Memory] input_ids size: {input_ids_size:.2f} GB, input_others size: {input_others_size:.2f} GB") + if pbar is None: pbar = tqdm(range(0, len(block_names), nblocks)) @@ -3086,7 +3382,18 @@ def _quantize_blocks( modules = [get_module(model, n) for n in names] m = WrapperMultiblock(modules) + if self.low_cpu_mem_usage and self.cpu_stream_offload_blocks: + if nblocks == 1: + self._load_offloaded_block_weights(n, get_module(model, n)) + if i == 0: # Log only for first block + block_size = self._estimate_block_size_gb(get_module(model, n)) + logger.info(f"[Memory] loaded block weights size: {block_size:.2f} GB") + else: + for name in names: + self._load_offloaded_block_weights(name, get_module(model, name)) + m.config = model.config if hasattr(model, "config") else None + q_input, input_ids = self._quantize_block( m, input_ids, @@ -3094,6 +3401,19 @@ def _quantize_blocks( q_input=q_input, device=device, ) + + if self.low_cpu_mem_usage and not self.cpu_stream_offload_blocks: + if nblocks == 1: + self._offload_block_weights(n, get_module(model, n)) + else: + for name in names: + self._offload_block_weights(name, get_module(model, name)) + if self.low_cpu_mem_usage and self.cpu_stream_offload_blocks: + if nblocks == 1: + self._discard_offloaded_block(n) + else: + for name in names: + self._discard_offloaded_block(name) if hasattr(model, "config"): del m.config if self.is_immediate_packing: diff --git a/auto_round/utils/model.py b/auto_round/utils/model.py index 0371caf2f..9adb650f2 100644 --- a/auto_round/utils/model.py +++ b/auto_round/utils/model.py @@ -63,6 +63,199 @@ def clean_module_parameter(submodule: torch.nn.Module, param_name: str) -> None: param.requires_grad = False +def save_module_weights(module: torch.nn.Module, save_path: str) -> dict: + """Save module's weight and bias tensors to disk to reduce CPU RAM usage. + + This function saves the weight and bias tensors of a module to a specified path on disk. + It returns metadata about the saved tensors that can be used later to reload them. + + Args: + module (torch.nn.Module): The module whose weights should be saved. + save_path (str): Path where the weights should be saved. + This should be a unique path for each module. + + Returns: + dict: Metadata containing information about the saved tensors, including: + - 'has_weight': bool indicating if weight was saved + - 'has_bias': bool indicating if bias was saved + - 'weight_shape': shape of the weight tensor + - 'bias_shape': shape of the bias tensor (if exists) + - 'weight_dtype': dtype of the weight tensor + - 'bias_dtype': dtype of the bias tensor (if exists) + - 'weight_device': original device of the weight tensor + - 'bias_device': original device of the bias tensor (if exists) + - 'save_path': the path where tensors were saved + + Example: + >>> module = torch.nn.Linear(10, 5) + >>> metadata = save_module_weights(module, "/tmp/module_weights.pt") + >>> # Now module's weights can be cleared to save RAM + >>> clear_module_weights(module) + >>> # Later, weights can be restored + >>> load_module_weights(module, metadata) + """ + if module is None: + return {} + + metadata = {'save_path': save_path} + tensors_to_save = {} + + # Save weight if it exists + if hasattr(module, 'weight') and module.weight is not None: + weight = module.weight + if weight.device.type != 'meta' and weight.numel() > 0: + tensors_to_save['weight'] = weight.detach().cpu() + metadata['has_weight'] = True + metadata['weight_shape'] = tuple(weight.shape) + metadata['weight_dtype'] = weight.dtype + metadata['weight_device'] = str(weight.device) + else: + metadata['has_weight'] = False + else: + metadata['has_weight'] = False + + # Save bias if it exists + if hasattr(module, 'bias') and module.bias is not None: + bias = module.bias + if bias.device.type != 'meta' and bias.numel() > 0: + tensors_to_save['bias'] = bias.detach().cpu() + metadata['has_bias'] = True + metadata['bias_shape'] = tuple(bias.shape) + metadata['bias_dtype'] = bias.dtype + metadata['bias_device'] = str(bias.device) + else: + metadata['has_bias'] = False + else: + metadata['has_bias'] = False + + # Save to disk + if tensors_to_save: + os.makedirs(os.path.dirname(save_path), exist_ok=True) + torch.save(tensors_to_save, save_path) + + return metadata + + +def load_module_weights(module: torch.nn.Module, metadata: dict) -> None: + """Load module's weight and bias tensors from disk. + + This function reloads weights that were previously saved using save_module_weights(). + The weights are loaded back to their original device and dtype. + + Args: + module (torch.nn.Module): The module whose weights should be restored. + metadata (dict): Metadata returned by save_module_weights(), containing + information about the saved tensors and their path. + + Example: + >>> module = torch.nn.Linear(10, 5) + >>> metadata = save_module_weights(module, "/tmp/module_weights.pt") + >>> clear_module_weights(module) + >>> # ... do some work with reduced memory ... + >>> load_module_weights(module, metadata) + >>> # Now module's weights are restored + """ + if module is None or not metadata or 'save_path' not in metadata: + return + + save_path = metadata['save_path'] + if not os.path.exists(save_path): + logger.warning(f"Cannot load weights: file {save_path} does not exist") + return + + # Load tensors from disk + try: + tensors = torch.load(save_path, map_location='cpu') + except Exception as e: + logger.warning(f"Failed to load weights from {save_path}: {e}") + return + + # Restore weight + if metadata.get('has_weight', False) and 'weight' in tensors: + weight = tensors['weight'] + target_device = metadata.get('weight_device', 'cpu') + target_dtype = metadata.get('weight_dtype', weight.dtype) + + # Move to target device and dtype + weight = weight.to(device=target_device, dtype=target_dtype) + + # Set the weight back to the module + if hasattr(module, 'weight'): + if isinstance(module.weight, torch.nn.Parameter): + module.weight = torch.nn.Parameter(weight, requires_grad=module.weight.requires_grad) + else: + module.weight = weight + + # Restore bias + if metadata.get('has_bias', False) and 'bias' in tensors: + bias = tensors['bias'] + target_device = metadata.get('bias_device', 'cpu') + target_dtype = metadata.get('bias_dtype', bias.dtype) + + # Move to target device and dtype + bias = bias.to(device=target_device, dtype=target_dtype) + + # Set the bias back to the module + if hasattr(module, 'bias'): + if isinstance(module.bias, torch.nn.Parameter): + module.bias = torch.nn.Parameter(bias, requires_grad=module.bias.requires_grad) + else: + module.bias = bias + + +def clear_module_weights(module: torch.nn.Module, to_meta: bool = False) -> None: + """Clear module's weight and bias to free CPU RAM. + + This function clears the weight and bias of a module to reduce memory usage. + It can either set them to empty tensors (default) or move them to meta device. + + Args: + module (torch.nn.Module): The module whose weights should be cleared. + to_meta (bool): If True, move tensors to meta device. + If False, set them to empty tensors. Default is False. + + Note: + This function should typically be called after save_module_weights() + to preserve the ability to restore weights later. + + Example: + >>> module = torch.nn.Linear(10, 5) + >>> metadata = save_module_weights(module, "/tmp/module_weights.pt") + >>> clear_module_weights(module) # Free memory + >>> # ... later ... + >>> load_module_weights(module, metadata) # Restore weights + """ + if module is None: + return + + with torch.no_grad(): + # Clear weight + if hasattr(module, 'weight') and module.weight is not None: + if to_meta: + # Move to meta device + if module.weight.device.type != 'meta': + module.weight = torch.nn.Parameter( + torch.empty_like(module.weight, device='meta'), + requires_grad=module.weight.requires_grad + ) + else: + # Use clean_module_parameter for safety + clean_module_parameter(module, 'weight') + + # Clear bias + if hasattr(module, 'bias') and module.bias is not None: + if to_meta: + # Move to meta device + if module.bias.device.type != 'meta': + module.bias = torch.nn.Parameter( + torch.empty_like(module.bias, device='meta'), + requires_grad=module.bias.requires_grad + ) + else: + # Use clean_module_parameter for safety + clean_module_parameter(module, 'bias') + + def convert_dtype_str2torch(str_dtype): """Converts a string dtype to its corresponding PyTorch dtype. diff --git a/test/test_cuda/advanced/test_cpu_ram_optimization.py b/test/test_cuda/advanced/test_cpu_ram_optimization.py new file mode 100644 index 000000000..317dac991 --- /dev/null +++ b/test/test_cuda/advanced/test_cpu_ram_optimization.py @@ -0,0 +1,174 @@ +# Copyright (c) 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Quantize Qwen/Qwen3-4B-Instruct-2507 with AutoRound (4-bit) +and compare CPU RAM peak usage with different optimization options. + +Optimization options: +1. cpu_stream_offload_blocks: Offload block weights to disk, load on demand +2. cpu_stream_loss: Compute loss on-the-fly using frozen block copy +""" + +import gc +import time +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +from auto_round import AutoRound +from auto_round.utils.device import memory_monitor + + +def get_rss_gb() -> float: + """Return process RSS in GB (Linux).""" + try: + with open("/proc/self/status", "r", encoding="utf-8") as f: + for line in f: + if line.startswith("VmRSS:"): + parts = line.split() + kb = int(parts[1]) + return kb / 1024 / 1024 + except Exception: + return -1.0 + return -1.0 + + +def log_rss(tag: str) -> None: + rss_gb = get_rss_gb() + if rss_gb >= 0: + print(f"[RAM] {tag}: {rss_gb:.2f} GB") + else: + print(f"[RAM] {tag}: N/A") + + +def cleanup(): + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + +def run_quantization( + label: str, + cpu_stream_offload_blocks: bool = False, + cpu_stream_loss: bool = False, +) -> tuple[float, float]: + model_name = "Qwen/Qwen3-4B-Instruct-2507" + print("\n" + "=" * 60) + print(label) + print("=" * 60) + print(f" cpu_stream_offload_blocks={cpu_stream_offload_blocks}") + print(f" cpu_stream_loss={cpu_stream_loss}") + + memory_monitor.reset() + log_rss("before model load") + + model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype=torch.float16, + device_map="cpu", + ) + tokenizer = AutoTokenizer.from_pretrained(model_name) + + log_rss("after model load") + + # Determine if any optimization is enabled + any_optimization = cpu_stream_offload_blocks or cpu_stream_loss + + autoround = AutoRound( + model, + tokenizer, + bits=4, + group_size=128, + low_gpu_mem_usage=True, + low_cpu_mem_usage=any_optimization, + cpu_stream_offload_blocks=cpu_stream_offload_blocks, + cpu_stream_loss=cpu_stream_loss, + iters=200, + nsamples=512, + seqlen=2048, + ) + + print("Start 4-bit quantization...") + t0 = time.time() + quantized_model, _ = autoround.quantize() + t1 = time.time() + elapsed = t1 - t0 + print(f"Quantization finished in {elapsed:.1f}s") + + print(f"[PEAK] {memory_monitor.get_summary()}") + log_rss("after quantization") + + del quantized_model + del autoround + del model + del tokenizer + cleanup() + + return memory_monitor.peak_ram, elapsed + + +def main(): + print("=" * 60) + print("AutoRound 4-bit Quantization - CPU RAM Optimization Test") + print("=" * 60) + + results = [] + + # Test 1: Baseline (no optimization) + peak, elapsed = run_quantization( + "Test 1: Baseline (no optimization)", + cpu_stream_offload_blocks=False, + cpu_stream_loss=False, + ) + results.append(("Baseline", peak, elapsed)) + + # Test 2: cpu_stream_offload_blocks only + peak, elapsed = run_quantization( + "Test 2: cpu_stream_offload_blocks only", + cpu_stream_offload_blocks=True, + cpu_stream_loss=False, + ) + results.append(("+ offload_blocks", peak, elapsed)) + + # Test 3: cpu_stream_loss only + peak, elapsed = run_quantization( + "Test 3: cpu_stream_loss only", + cpu_stream_offload_blocks=False, + cpu_stream_loss=True, + ) + results.append(("+ stream_loss", peak, elapsed)) + + # Test 4: offload_blocks + stream_loss (All optimizations) + peak, elapsed = run_quantization( + "Test 4: All optimizations (offload_blocks + stream_loss)", + cpu_stream_offload_blocks=True, + cpu_stream_loss=True, + ) + results.append(("All optimizations", peak, elapsed)) + + # Summary + print("\n" + "=" * 60) + print("Summary: Peak RAM Comparison") + print("=" * 60) + print(f"{'Configuration':<25} {'Peak RAM (GB)':<15} {'Time (s)':<10} {'RAM Saved':<12}") + print("-" * 62) + baseline_ram = results[0][1] + for name, peak, elapsed in results: + saved = baseline_ram - peak + saved_str = f"-{saved:.2f} GB" if saved > 0 else "baseline" + print(f"{name:<25} {peak:<15.2f} {elapsed:<10.1f} {saved_str:<12}") + + +if __name__ == "__main__": + main() \ No newline at end of file From 2a78a1890e523aa62c32ea8598679a9878456dc8 Mon Sep 17 00:00:00 2001 From: Weiwei Date: Tue, 3 Feb 2026 11:08:37 +0800 Subject: [PATCH 02/13] rm duplicate args of the quantization extra config (#1334) Signed-off-by: WeiweiZhang1 Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../export/export_to_autoround/export.py | 9 ++++-- .../export_to_autoround/export_to_fp8.py | 2 +- .../export_to_nvfp_mxfp.py | 11 +++++-- .../quantization/test_act_quantization.py | 12 ------- test/test_cpu/quantization/test_mix_bits.py | 31 ++++++++---------- test/test_cuda/integrations/test_sglang.py | 16 ++++++++-- test/test_cuda/quantization/test_mix_bits.py | 32 +++++++++---------- 7 files changed, 59 insertions(+), 54 deletions(-) diff --git a/auto_round/export/export_to_autoround/export.py b/auto_round/export/export_to_autoround/export.py index 52b35f9a4..8d78d0bee 100644 --- a/auto_round/export/export_to_autoround/export.py +++ b/auto_round/export/export_to_autoround/export.py @@ -294,9 +294,14 @@ def save_quantized_as_autoround( regex_config = quantization_config.pop("regex_config") if regex_config is not None: - for name in regex_config.keys(): + for name, cfg in regex_config.items(): regex_name = to_standard_regex(name) - extra_config[regex_name] = {**{k: regex_config[name][k] for k in scheme_keys}} + neq_keys = check_neq_config(cfg, **{k: quantization_config[k] for k in scheme_keys}) + if len(neq_keys) > 0: + extra_config[regex_name] = {} + for key in neq_keys: + if cfg.get(key) is not None: + extra_config[regex_name][key] = cfg[key] if len(extra_config) > 0: quantization_config["extra_config"] = extra_config diff --git a/auto_round/export/export_to_autoround/export_to_fp8.py b/auto_round/export/export_to_autoround/export_to_fp8.py index b97a5db5a..38bea79f3 100644 --- a/auto_round/export/export_to_autoround/export_to_fp8.py +++ b/auto_round/export/export_to_autoround/export_to_fp8.py @@ -191,7 +191,7 @@ def save_quantized_as_autoround( neq_keys = check_neq_config(cfg, **{k: quantization_config[k] for k in scheme_keys}) if len(neq_keys) > 0: extra_config[layer_name] = {} - for key in scheme_keys: + for key in neq_keys: if cfg[key] is not None: extra_config[layer_name][key] = cfg[key] diff --git a/auto_round/export/export_to_autoround/export_to_nvfp_mxfp.py b/auto_round/export/export_to_autoround/export_to_nvfp_mxfp.py index 7265941d3..1320846dc 100644 --- a/auto_round/export/export_to_autoround/export_to_nvfp_mxfp.py +++ b/auto_round/export/export_to_autoround/export_to_nvfp_mxfp.py @@ -210,15 +210,20 @@ def save_quantized_as_fp( neq_keys = check_neq_config(cfg, **{k: quantization_config[k] for k in scheme_keys}) if len(neq_keys) > 0: extra_config[layer_name] = {} - for key in scheme_keys: + for key in neq_keys: if cfg.get(key, None) is not None: extra_config[layer_name][key] = cfg.get(key, None) regex_config = quantization_config.pop("regex_config") if regex_config is not None: - for name in regex_config.keys(): + for name, cfg in regex_config.items(): regex_name = to_standard_regex(name) - extra_config[regex_name] = {**{k: regex_config[name][k] for k in scheme_keys}} + neq_keys = check_neq_config(cfg, **{k: quantization_config[k] for k in scheme_keys}) + if len(neq_keys) > 0: + extra_config[regex_name] = {} + for key in neq_keys: + if cfg.get(key) is not None: + extra_config[regex_name][key] = cfg[key] if len(extra_config) > 0: quantization_config["extra_config"] = extra_config diff --git a/test/test_cpu/quantization/test_act_quantization.py b/test/test_cpu/quantization/test_act_quantization.py index 47cda3599..014d80f6f 100644 --- a/test/test_cpu/quantization/test_act_quantization.py +++ b/test/test_cpu/quantization/test_act_quantization.py @@ -119,14 +119,8 @@ def test_act_config_MXFP4_saving(self, tiny_opt_model_path, dataloader): # check inblock layer config values kproj_config = model.config.quantization_config.extra_config["model.decoder.layers.1.self_attn.k_proj"] - assert "act_data_type" in kproj_config.keys() and kproj_config["act_data_type"] == "mx_fp" assert "act_bits" in kproj_config.keys() and kproj_config["act_bits"] == 8 - assert "act_group_size" in kproj_config.keys() and kproj_config["act_group_size"] == 32 - assert "act_sym" in kproj_config.keys() and kproj_config["act_sym"] - assert "data_type" in kproj_config.keys() and kproj_config["data_type"] == "mx_fp" assert "bits" in kproj_config.keys() and kproj_config["bits"] == 8 - assert "group_size" in kproj_config.keys() and kproj_config["group_size"] == 32 - assert "sym" in kproj_config.keys() and kproj_config["sym"] shutil.rmtree(quantized_model_path, ignore_errors=True) def test_act_config_NVFP4_saving(self, tiny_opt_model_path, dataloader): @@ -144,14 +138,8 @@ def test_act_config_NVFP4_saving(self, tiny_opt_model_path, dataloader): autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round") model = AutoModelForCausalLM.from_pretrained(quantized_model_path, device_map="cpu") kproj_config = model.config.quantization_config.extra_config["model.decoder.layers.1.self_attn.k_proj"] - assert "act_data_type" in kproj_config.keys() and kproj_config["act_data_type"] == "nv_fp4_with_static_gs" assert "act_bits" in kproj_config.keys() and kproj_config["act_bits"] == 16 - assert "act_group_size" in kproj_config.keys() and kproj_config["act_group_size"] == 16 - assert "act_sym" in kproj_config.keys() and kproj_config["act_sym"] - assert "data_type" in kproj_config.keys() and kproj_config["data_type"] == "nv_fp" assert "bits" in kproj_config.keys() and kproj_config["bits"] == 16 - assert "group_size" in kproj_config.keys() and kproj_config["group_size"] == 16 - assert "sym" in kproj_config.keys() and kproj_config["sym"] shutil.rmtree(quantized_model_path, ignore_errors=True) def test_WOQ_config_INT_saving(self, tiny_opt_model_path, dataloader): diff --git a/test/test_cpu/quantization/test_mix_bits.py b/test/test_cpu/quantization/test_mix_bits.py index 5db3053cb..e5a8a7c42 100644 --- a/test/test_cpu/quantization/test_mix_bits.py +++ b/test/test_cpu/quantization/test_mix_bits.py @@ -112,7 +112,17 @@ def test_mixed_autoround_format(self, dataloader): layer_config=layer_config, ) quantized_model_path = "./saved" - compressed_model = autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round") + autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round") + config_file = Path(quantized_model_path) / "config.json" + with open(config_file, "r", encoding="utf-8") as f: + config = json.load(f) + quant_config = config.get("quantization_config", {}) + extra_config = quant_config.get("extra_config", {}) + # check extra_config only saved attributes differing from Scheme values + assert "act_bits" not in extra_config[".*fc1.*"].keys() ## TODO refine this assert + assert "group_size" not in extra_config[".*fc1.*"].keys() + assert "act_bits" not in extra_config["model.decoder.layers.0.self_attn.k_proj"].keys() + assert "group_size" not in extra_config["model.decoder.layers.0.self_attn.k_proj"].keys() model = AutoModelForCausalLM.from_pretrained(quantized_model_path, device_map="cpu") assert model.model.decoder.layers[0].self_attn.k_proj.bits == 8 assert model.model.decoder.layers[0].self_attn.q_proj.bits == 3 @@ -167,26 +177,10 @@ def test_mixed_ar_format_part_name_hf_loading(self, dataloader): # remove old extra_config(which contains full name layer configs), only test regex config loading new_extra_config = { ".*fc1.*": { # standard regex - "act_bits": 16, - "act_data_type": "float", - "act_dynamic": True, - "act_group_size": 128, - "act_sym": True, "bits": 16, - "data_type": "int", - "group_size": 128, - "sym": True, }, "k_proj": { # part name - "act_bits": 16, - "act_data_type": "float", - "act_dynamic": True, - "act_group_size": 128, - "act_sym": True, "bits": 8, - "data_type": "int", - "group_size": 128, - "sym": True, }, } config_file = Path(quantized_model_path) / "config.json" @@ -194,6 +188,9 @@ def test_mixed_ar_format_part_name_hf_loading(self, dataloader): config = json.load(f) quant_config = config.get("quantization_config", {}) old_extra_config = quant_config.get("extra_config", {}) + # check extra_config only saved attributes differing from Scheme values + assert "act_bits" not in old_extra_config[".*fc1.*"].keys() + assert "group_size" not in old_extra_config[".*fc1.*"].keys() quant_config["extra_config"] = new_extra_config config["quantization_config"] = quant_config with open(config_file, "w", encoding="utf-8") as f: diff --git a/test/test_cuda/integrations/test_sglang.py b/test/test_cuda/integrations/test_sglang.py index ac96bed74..ddbb3f0f4 100644 --- a/test/test_cuda/integrations/test_sglang.py +++ b/test/test_cuda/integrations/test_sglang.py @@ -1,3 +1,4 @@ +import json import shutil import sys from pathlib import Path @@ -59,8 +60,8 @@ def test_ar_format_sglang(self, dataloader): def test_mixed_ar_format_sglang(self, dataloader): layer_config = { - "self_attn": {"bits": 16, "act_bits": 16}, - "lm_head": {"bits": 16, "act_bits": 16}, + "self_attn": {"bits": 8}, + "lm_head": {"bits": 16}, "fc1": {"bits": 16, "act_bits": 16}, } @@ -78,7 +79,16 @@ def test_mixed_ar_format_sglang(self, dataloader): inplace=True, format="auto_round", ) - + config_file = Path(self.save_dir) / "config.json" + with open(config_file, "r", encoding="utf-8") as f: + config = json.load(f) + quant_config = config.get("quantization_config", {}) + extra_config = quant_config.get("extra_config", {}) + # check extra_config only saved attributes differing from Scheme values + assert "act_bits" not in extra_config[".*fc1.*"].keys() + assert "group_size" not in extra_config[".*fc1.*"].keys() + assert "bits" in extra_config[".*fc1.*"].keys() and extra_config[".*fc1.*"]["bits"] == 16 + assert "bits" in extra_config[".*self_attn.*"].keys() and extra_config[".*self_attn.*"]["bits"] == 8 generated_text = self._run_sglang_inference(self.save_dir) print(generated_text) diff --git a/test/test_cuda/quantization/test_mix_bits.py b/test/test_cuda/quantization/test_mix_bits.py index 9daa9b727..eadaf8f53 100644 --- a/test/test_cuda/quantization/test_mix_bits.py +++ b/test/test_cuda/quantization/test_mix_bits.py @@ -108,6 +108,18 @@ def test_mixed_autoround_format(self, tiny_opt_model_path, dataloader): ) quantized_model_path = "self.save_dir" autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round") + config_file = Path(quantized_model_path) / "config.json" + with open(config_file, "r", encoding="utf-8") as f: + config = json.load(f) + quant_config = config.get("quantization_config", {}) + extra_config = quant_config.get("extra_config", {}) + # check extra_config only saved attributes differing from Scheme values + assert "act_bits" not in extra_config[".*fc1.*"].keys() ## TODO refine this assert + assert "group_size" not in extra_config[".*fc1.*"].keys() + assert "act_bits" not in extra_config["model.decoder.layers.0.self_attn.k_proj"].keys() + assert "group_size" not in extra_config["model.decoder.layers.0.self_attn.k_proj"].keys() + assert "group_size" not in extra_config["model.decoder.layers.1.self_attn.q_proj"].keys() + assert "bits" in extra_config["model.decoder.layers.1.self_attn.q_proj"].keys() model = AutoModelForCausalLM.from_pretrained(quantized_model_path, device_map="auto") assert model.model.decoder.layers[0].self_attn.k_proj.bits == 8 assert model.model.decoder.layers[0].self_attn.q_proj.bits == 3 @@ -164,26 +176,10 @@ def test_mixed_ar_format_part_name_hf_loading(self, tiny_opt_model_path, dataloa # remove old extra_config(which contains full name layer configs), only test regex config loading new_extra_config = { ".*fc1.*": { # standard regex - "act_bits": 16, - "act_data_type": "float", - "act_dynamic": True, - "act_group_size": 128, - "act_sym": True, "bits": 16, - "data_type": "int", - "group_size": 128, - "sym": True, }, "k_proj": { # part name - "act_bits": 16, - "act_data_type": "float", - "act_dynamic": True, - "act_group_size": 128, - "act_sym": True, "bits": 8, - "data_type": "int", - "group_size": 128, - "sym": True, }, } config_file = Path(quantized_model_path) / "config.json" @@ -191,6 +187,10 @@ def test_mixed_ar_format_part_name_hf_loading(self, tiny_opt_model_path, dataloa config = json.load(f) quant_config = config.get("quantization_config", {}) old_extra_config = quant_config.get("extra_config", {}) + # check extra_config only saved attributes differing from Scheme values + assert "sym" not in old_extra_config[".*fc1.*"].keys() + assert "act_dynamic" not in old_extra_config[".*fc1.*"].keys() + assert "group_size" not in old_extra_config[".*fc1.*"].keys() quant_config["extra_config"] = new_extra_config config["quantization_config"] = quant_config with open(config_file, "w", encoding="utf-8") as f: From e00c176c5916a0161ebc23ab6ceb04b453d72b35 Mon Sep 17 00:00:00 2001 From: Wenhua Cheng Date: Tue, 3 Feb 2026 13:08:41 +0800 Subject: [PATCH 03/13] fix --device_map cuda and xpu issue (#1383) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- auto_round/auto_scheme/utils.py | 2 ++ auto_round/compressors/base.py | 2 ++ test/test_cuda/export/test_gguf.py | 2 +- 3 files changed, 5 insertions(+), 1 deletion(-) diff --git a/auto_round/auto_scheme/utils.py b/auto_round/auto_scheme/utils.py index c64b17e88..3c19acc42 100644 --- a/auto_round/auto_scheme/utils.py +++ b/auto_round/auto_scheme/utils.py @@ -241,6 +241,8 @@ def dispatch_model_by_all_available_devices( device = int(device.split(":")[-1]) elif device == "cpu": device = "cpu" + elif isinstance(device, str): + device = 0 else: raise ValueError(f"Unsupported device {device} in device_map: {device_map}") new_max_memory[device] = max_memory[device] diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index ee080fd01..68aefd954 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -2368,6 +2368,8 @@ def try_cache_inter_data_gpucpu(self, block_names, nsamples, layer_names=None, l device = int(device.split(":")[-1]) elif device == "cpu": device = "cpu" + elif isinstance(device, str): + device = 0 else: raise ValueError(f"Unsupported device {device} in device_map: {self.device_map}") # Use 90% of the reported max memory to leave headroom for activations, diff --git a/test/test_cuda/export/test_gguf.py b/test/test_cuda/export/test_gguf.py index e71af74c4..60182aa48 100644 --- a/test/test_cuda/export/test_gguf.py +++ b/test/test_cuda/export/test_gguf.py @@ -66,7 +66,7 @@ def test_gguf_format(self, tiny_qwen_model_path, dataloader): save_dir = os.path.join(os.path.dirname(__file__), "saved") res = os.system( f"PYTHONPATH='{AUTO_ROUND_PATH}:$PYTHONPATH' {sys.executable} -m auto_round --model {tiny_qwen_model_path} --iter 2 " - f"--output_dir {save_dir} --nsample 2 --format gguf:q4_0 --device 0" + f"--output_dir {save_dir} --nsample 2 --format gguf:q4_0 --device cuda" ) print(save_dir) assert not (res > 0 or res == -1), "qwen2 tuning fail" From d6d9f775065638be50fbe80868a20fd151ca9257 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 3 Feb 2026 07:06:33 +0000 Subject: [PATCH 04/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- auto_round/compressors/base.py | 20 +++-- auto_round/utils/model.py | 88 +++++++++---------- .../advanced/test_cpu_ram_optimization.py | 3 +- 3 files changed, 59 insertions(+), 52 deletions(-) diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index 68aefd954..9a620b778 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -1533,7 +1533,7 @@ def _estimate_tensor_size_gb(self, tensor) -> float: if tensor is None: return 0.0 if isinstance(tensor, torch.Tensor): - return tensor.numel() * tensor.element_size() / (1024 ** 3) + return tensor.numel() * tensor.element_size() / (1024**3) elif isinstance(tensor, list): return sum(self._estimate_tensor_size_gb(t) for t in tensor) elif isinstance(tensor, dict): @@ -1552,7 +1552,7 @@ def _estimate_model_size_gb(self) -> float: total = 0.0 for param in self.model.parameters(): if param.numel() > 0: # Skip empty tensors - total += param.numel() * param.element_size() / (1024 ** 3) + total += param.numel() * param.element_size() / (1024**3) return total def _estimate_block_size_gb(self, block: torch.nn.Module) -> float: @@ -1560,7 +1560,7 @@ def _estimate_block_size_gb(self, block: torch.nn.Module) -> float: total = 0.0 for param in block.parameters(): if param.numel() > 0: - total += param.numel() * param.element_size() / (1024 ** 3) + total += param.numel() * param.element_size() / (1024**3) return total def _init_cpu_offload_dir(self) -> Optional[str]: @@ -1702,7 +1702,9 @@ def _restore_offloaded_blocks(self) -> None: if hasattr(target, param_name): old_param = getattr(target, param_name) if isinstance(old_param, torch.nn.Parameter): - setattr(target, param_name, torch.nn.Parameter(param, requires_grad=old_param.requires_grad)) + setattr( + target, param_name, torch.nn.Parameter(param, requires_grad=old_param.requires_grad) + ) else: setattr(target, param_name, param) except Exception as e: @@ -3060,7 +3062,11 @@ def _quantize_block( ) # Log output cache size for first block - if self.low_cpu_mem_usage and self.cpu_stream_offload_blocks and not hasattr(self, "_logged_output_size"): + if ( + self.low_cpu_mem_usage + and self.cpu_stream_offload_blocks + and not hasattr(self, "_logged_output_size") + ): output_size = self._estimate_tensor_size_gb(output) logger.info(f"[Memory] block output cache size: {output_size:.2f} GB") self._logged_output_size = True @@ -3366,7 +3372,9 @@ def _quantize_blocks( if self.low_cpu_mem_usage and self.cpu_stream_offload_blocks: input_ids_size = self._estimate_tensor_size_gb(input_ids) input_others_size = self._estimate_tensor_size_gb(input_others) - logger.info(f"[Memory] input_ids size: {input_ids_size:.2f} GB, input_others size: {input_others_size:.2f} GB") + logger.info( + f"[Memory] input_ids size: {input_ids_size:.2f} GB, input_others size: {input_others_size:.2f} GB" + ) if pbar is None: pbar = tqdm(range(0, len(block_names), nblocks)) diff --git a/auto_round/utils/model.py b/auto_round/utils/model.py index 9adb650f2..7377404b4 100644 --- a/auto_round/utils/model.py +++ b/auto_round/utils/model.py @@ -71,7 +71,7 @@ def save_module_weights(module: torch.nn.Module, save_path: str) -> dict: Args: module (torch.nn.Module): The module whose weights should be saved. - save_path (str): Path where the weights should be saved. + save_path (str): Path where the weights should be saved. This should be a unique path for each module. Returns: @@ -97,36 +97,36 @@ def save_module_weights(module: torch.nn.Module, save_path: str) -> dict: if module is None: return {} - metadata = {'save_path': save_path} + metadata = {"save_path": save_path} tensors_to_save = {} # Save weight if it exists - if hasattr(module, 'weight') and module.weight is not None: + if hasattr(module, "weight") and module.weight is not None: weight = module.weight - if weight.device.type != 'meta' and weight.numel() > 0: - tensors_to_save['weight'] = weight.detach().cpu() - metadata['has_weight'] = True - metadata['weight_shape'] = tuple(weight.shape) - metadata['weight_dtype'] = weight.dtype - metadata['weight_device'] = str(weight.device) + if weight.device.type != "meta" and weight.numel() > 0: + tensors_to_save["weight"] = weight.detach().cpu() + metadata["has_weight"] = True + metadata["weight_shape"] = tuple(weight.shape) + metadata["weight_dtype"] = weight.dtype + metadata["weight_device"] = str(weight.device) else: - metadata['has_weight'] = False + metadata["has_weight"] = False else: - metadata['has_weight'] = False + metadata["has_weight"] = False # Save bias if it exists - if hasattr(module, 'bias') and module.bias is not None: + if hasattr(module, "bias") and module.bias is not None: bias = module.bias - if bias.device.type != 'meta' and bias.numel() > 0: - tensors_to_save['bias'] = bias.detach().cpu() - metadata['has_bias'] = True - metadata['bias_shape'] = tuple(bias.shape) - metadata['bias_dtype'] = bias.dtype - metadata['bias_device'] = str(bias.device) + if bias.device.type != "meta" and bias.numel() > 0: + tensors_to_save["bias"] = bias.detach().cpu() + metadata["has_bias"] = True + metadata["bias_shape"] = tuple(bias.shape) + metadata["bias_dtype"] = bias.dtype + metadata["bias_device"] = str(bias.device) else: - metadata['has_bias'] = False + metadata["has_bias"] = False else: - metadata['has_bias'] = False + metadata["has_bias"] = False # Save to disk if tensors_to_save: @@ -155,48 +155,48 @@ def load_module_weights(module: torch.nn.Module, metadata: dict) -> None: >>> load_module_weights(module, metadata) >>> # Now module's weights are restored """ - if module is None or not metadata or 'save_path' not in metadata: + if module is None or not metadata or "save_path" not in metadata: return - save_path = metadata['save_path'] + save_path = metadata["save_path"] if not os.path.exists(save_path): logger.warning(f"Cannot load weights: file {save_path} does not exist") return # Load tensors from disk try: - tensors = torch.load(save_path, map_location='cpu') + tensors = torch.load(save_path, map_location="cpu") except Exception as e: logger.warning(f"Failed to load weights from {save_path}: {e}") return # Restore weight - if metadata.get('has_weight', False) and 'weight' in tensors: - weight = tensors['weight'] - target_device = metadata.get('weight_device', 'cpu') - target_dtype = metadata.get('weight_dtype', weight.dtype) + if metadata.get("has_weight", False) and "weight" in tensors: + weight = tensors["weight"] + target_device = metadata.get("weight_device", "cpu") + target_dtype = metadata.get("weight_dtype", weight.dtype) # Move to target device and dtype weight = weight.to(device=target_device, dtype=target_dtype) # Set the weight back to the module - if hasattr(module, 'weight'): + if hasattr(module, "weight"): if isinstance(module.weight, torch.nn.Parameter): module.weight = torch.nn.Parameter(weight, requires_grad=module.weight.requires_grad) else: module.weight = weight # Restore bias - if metadata.get('has_bias', False) and 'bias' in tensors: - bias = tensors['bias'] - target_device = metadata.get('bias_device', 'cpu') - target_dtype = metadata.get('bias_dtype', bias.dtype) + if metadata.get("has_bias", False) and "bias" in tensors: + bias = tensors["bias"] + target_device = metadata.get("bias_device", "cpu") + target_dtype = metadata.get("bias_dtype", bias.dtype) # Move to target device and dtype bias = bias.to(device=target_device, dtype=target_dtype) # Set the bias back to the module - if hasattr(module, 'bias'): + if hasattr(module, "bias"): if isinstance(module.bias, torch.nn.Parameter): module.bias = torch.nn.Parameter(bias, requires_grad=module.bias.requires_grad) else: @@ -211,11 +211,11 @@ def clear_module_weights(module: torch.nn.Module, to_meta: bool = False) -> None Args: module (torch.nn.Module): The module whose weights should be cleared. - to_meta (bool): If True, move tensors to meta device. + to_meta (bool): If True, move tensors to meta device. If False, set them to empty tensors. Default is False. Note: - This function should typically be called after save_module_weights() + This function should typically be called after save_module_weights() to preserve the ability to restore weights later. Example: @@ -230,30 +230,28 @@ def clear_module_weights(module: torch.nn.Module, to_meta: bool = False) -> None with torch.no_grad(): # Clear weight - if hasattr(module, 'weight') and module.weight is not None: + if hasattr(module, "weight") and module.weight is not None: if to_meta: # Move to meta device - if module.weight.device.type != 'meta': + if module.weight.device.type != "meta": module.weight = torch.nn.Parameter( - torch.empty_like(module.weight, device='meta'), - requires_grad=module.weight.requires_grad + torch.empty_like(module.weight, device="meta"), requires_grad=module.weight.requires_grad ) else: # Use clean_module_parameter for safety - clean_module_parameter(module, 'weight') + clean_module_parameter(module, "weight") # Clear bias - if hasattr(module, 'bias') and module.bias is not None: + if hasattr(module, "bias") and module.bias is not None: if to_meta: # Move to meta device - if module.bias.device.type != 'meta': + if module.bias.device.type != "meta": module.bias = torch.nn.Parameter( - torch.empty_like(module.bias, device='meta'), - requires_grad=module.bias.requires_grad + torch.empty_like(module.bias, device="meta"), requires_grad=module.bias.requires_grad ) else: # Use clean_module_parameter for safety - clean_module_parameter(module, 'bias') + clean_module_parameter(module, "bias") def convert_dtype_str2torch(str_dtype): diff --git a/test/test_cuda/advanced/test_cpu_ram_optimization.py b/test/test_cuda/advanced/test_cpu_ram_optimization.py index 317dac991..9399b08e5 100644 --- a/test/test_cuda/advanced/test_cpu_ram_optimization.py +++ b/test/test_cuda/advanced/test_cpu_ram_optimization.py @@ -23,6 +23,7 @@ import gc import time + import torch from transformers import AutoModelForCausalLM, AutoTokenizer @@ -171,4 +172,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() From ca55ae843d1d4a9928fba0e519103b1734a4b44f Mon Sep 17 00:00:00 2001 From: lvliang-intel Date: Tue, 3 Feb 2026 08:19:37 +0000 Subject: [PATCH 05/13] refine test case Signed-off-by: lvliang-intel --- .../test_cpu/core/test_low_cpu_mem_options.py | 192 ++++++++++++++++++ .../advanced/test_cpu_ram_optimization.py | 175 ---------------- 2 files changed, 192 insertions(+), 175 deletions(-) create mode 100644 test/test_cpu/core/test_low_cpu_mem_options.py delete mode 100644 test/test_cuda/advanced/test_cpu_ram_optimization.py diff --git a/test/test_cpu/core/test_low_cpu_mem_options.py b/test/test_cpu/core/test_low_cpu_mem_options.py new file mode 100644 index 000000000..26eacfdc6 --- /dev/null +++ b/test/test_cpu/core/test_low_cpu_mem_options.py @@ -0,0 +1,192 @@ +# Copyright (c) 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Unit tests for CPU RAM optimization options: +1. cpu_stream_offload_blocks: Offload block weights to disk, load on demand +2. cpu_stream_loss: Compute loss on-the-fly using frozen block copy +""" + +import torch + +from auto_round import AutoRound +from auto_round.compressors import base as base_module + + +class TestCpuStreamOffloadBlocks: + """Tests for cpu_stream_offload_blocks option.""" + + def test_option_stored_correctly(self, tiny_opt_model_path): + """Test that cpu_stream_offload_blocks option is stored correctly.""" + autoround = AutoRound( + tiny_opt_model_path, + bits=4, + low_cpu_mem_usage=True, + cpu_stream_offload_blocks=True, + iters=0, + disable_opt_rtn=True, + nsamples=1, + seqlen=32, + ) + assert autoround.cpu_stream_offload_blocks is True + + autoround2 = AutoRound( + tiny_opt_model_path, + bits=4, + low_cpu_mem_usage=False, + cpu_stream_offload_blocks=False, + iters=0, + disable_opt_rtn=True, + nsamples=1, + seqlen=32, + ) + assert autoround2.cpu_stream_offload_blocks is False + + def test_offload_requires_low_cpu_mem_usage(self, tiny_opt_model_path): + """Test that offload only activates when low_cpu_mem_usage is True.""" + autoround = AutoRound( + tiny_opt_model_path, + bits=4, + low_cpu_mem_usage=False, + cpu_stream_offload_blocks=True, + iters=0, + disable_opt_rtn=True, + nsamples=1, + seqlen=32, + ) + # Even if cpu_stream_offload_blocks=True, it should not offload + # when low_cpu_mem_usage=False + assert autoround.cpu_stream_offload_blocks is True + assert autoround.low_cpu_mem_usage is False + + def test_stream_offload_blocks_skips_when_disabled(self, tiny_opt_model_path): + """Test that _stream_offload_blocks returns early when disabled.""" + autoround = AutoRound( + tiny_opt_model_path, + bits=4, + low_cpu_mem_usage=True, + cpu_stream_offload_blocks=False, + iters=0, + disable_opt_rtn=True, + nsamples=1, + seqlen=32, + ) + autoround._stream_offload_blocks([["model.layers.0"]]) + assert autoround._offloaded_blocks == {} + + autoround2 = AutoRound( + tiny_opt_model_path, + bits=4, + low_cpu_mem_usage=False, + cpu_stream_offload_blocks=True, + iters=0, + disable_opt_rtn=True, + nsamples=1, + seqlen=32, + ) + autoround2._stream_offload_blocks([["model.layers.0"]]) + assert autoround2._offloaded_blocks == {} + + def test_stream_offload_blocks_records_blocks(self, tiny_opt_model_path, tmp_path, monkeypatch): + """Test that _stream_offload_blocks records offloaded blocks when enabled.""" + autoround = AutoRound( + tiny_opt_model_path, + bits=4, + low_cpu_mem_usage=True, + cpu_stream_offload_blocks=True, + iters=0, + disable_opt_rtn=True, + nsamples=1, + seqlen=32, + ) + + dummy_block = torch.nn.Linear(4, 4) + monkeypatch.setattr(base_module, "get_module", lambda _model, _name: dummy_block) + monkeypatch.setattr(autoround, "_init_cpu_offload_dir", lambda: str(tmp_path)) + monkeypatch.setattr(torch, "save", lambda *args, **kwargs: None) + monkeypatch.setattr(base_module, "clear_module_weights", lambda *_args, **_kwargs: None) + + autoround._stream_offload_blocks([["model.layers.0"]]) + assert "model.layers.0" in autoround._offloaded_blocks + + +class TestCpuStreamLoss: + """Tests for cpu_stream_loss option.""" + + def test_option_stored_correctly(self, tiny_opt_model_path): + """Test that cpu_stream_loss option is stored correctly.""" + autoround = AutoRound( + tiny_opt_model_path, + bits=4, + low_cpu_mem_usage=True, + cpu_stream_loss=True, + iters=0, + disable_opt_rtn=True, + nsamples=1, + seqlen=32, + ) + assert autoround.cpu_stream_loss is True + + autoround2 = AutoRound( + tiny_opt_model_path, + bits=4, + low_cpu_mem_usage=False, + cpu_stream_loss=False, + iters=0, + disable_opt_rtn=True, + nsamples=1, + seqlen=32, + ) + assert autoround2.cpu_stream_loss is False + + def test_stream_loss_requires_nblocks_1(self, tiny_opt_model_path): + """Test that cpu_stream_loss only works with nblocks=1.""" + # nblocks > 1 should trigger warning and disable stream_loss + autoround = AutoRound( + tiny_opt_model_path, + bits=4, + low_cpu_mem_usage=True, + cpu_stream_loss=True, + nblocks=2, + iters=0, + disable_opt_rtn=True, + nsamples=1, + seqlen=32, + ) + # The option is stored, but internally it will fall back during quantize + assert autoround.cpu_stream_loss is True + assert autoround.nblocks == 2 + stream_loss = autoround.cpu_stream_loss and autoround.nblocks == 1 + assert stream_loss is False + + +class TestCombinedOptions: + """Tests for combined optimization options.""" + + def test_both_options_enabled(self, tiny_opt_model_path): + """Test that both options can be enabled together.""" + autoround = AutoRound( + tiny_opt_model_path, + bits=4, + low_cpu_mem_usage=True, + cpu_stream_offload_blocks=True, + cpu_stream_loss=True, + iters=0, + disable_opt_rtn=True, + nsamples=1, + seqlen=32, + ) + assert autoround.cpu_stream_offload_blocks is True + assert autoround.cpu_stream_loss is True + assert autoround.low_cpu_mem_usage is True diff --git a/test/test_cuda/advanced/test_cpu_ram_optimization.py b/test/test_cuda/advanced/test_cpu_ram_optimization.py deleted file mode 100644 index 9399b08e5..000000000 --- a/test/test_cuda/advanced/test_cpu_ram_optimization.py +++ /dev/null @@ -1,175 +0,0 @@ -# Copyright (c) 2026 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Quantize Qwen/Qwen3-4B-Instruct-2507 with AutoRound (4-bit) -and compare CPU RAM peak usage with different optimization options. - -Optimization options: -1. cpu_stream_offload_blocks: Offload block weights to disk, load on demand -2. cpu_stream_loss: Compute loss on-the-fly using frozen block copy -""" - -import gc -import time - -import torch -from transformers import AutoModelForCausalLM, AutoTokenizer - -from auto_round import AutoRound -from auto_round.utils.device import memory_monitor - - -def get_rss_gb() -> float: - """Return process RSS in GB (Linux).""" - try: - with open("/proc/self/status", "r", encoding="utf-8") as f: - for line in f: - if line.startswith("VmRSS:"): - parts = line.split() - kb = int(parts[1]) - return kb / 1024 / 1024 - except Exception: - return -1.0 - return -1.0 - - -def log_rss(tag: str) -> None: - rss_gb = get_rss_gb() - if rss_gb >= 0: - print(f"[RAM] {tag}: {rss_gb:.2f} GB") - else: - print(f"[RAM] {tag}: N/A") - - -def cleanup(): - gc.collect() - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - -def run_quantization( - label: str, - cpu_stream_offload_blocks: bool = False, - cpu_stream_loss: bool = False, -) -> tuple[float, float]: - model_name = "Qwen/Qwen3-4B-Instruct-2507" - print("\n" + "=" * 60) - print(label) - print("=" * 60) - print(f" cpu_stream_offload_blocks={cpu_stream_offload_blocks}") - print(f" cpu_stream_loss={cpu_stream_loss}") - - memory_monitor.reset() - log_rss("before model load") - - model = AutoModelForCausalLM.from_pretrained( - model_name, - torch_dtype=torch.float16, - device_map="cpu", - ) - tokenizer = AutoTokenizer.from_pretrained(model_name) - - log_rss("after model load") - - # Determine if any optimization is enabled - any_optimization = cpu_stream_offload_blocks or cpu_stream_loss - - autoround = AutoRound( - model, - tokenizer, - bits=4, - group_size=128, - low_gpu_mem_usage=True, - low_cpu_mem_usage=any_optimization, - cpu_stream_offload_blocks=cpu_stream_offload_blocks, - cpu_stream_loss=cpu_stream_loss, - iters=200, - nsamples=512, - seqlen=2048, - ) - - print("Start 4-bit quantization...") - t0 = time.time() - quantized_model, _ = autoround.quantize() - t1 = time.time() - elapsed = t1 - t0 - print(f"Quantization finished in {elapsed:.1f}s") - - print(f"[PEAK] {memory_monitor.get_summary()}") - log_rss("after quantization") - - del quantized_model - del autoround - del model - del tokenizer - cleanup() - - return memory_monitor.peak_ram, elapsed - - -def main(): - print("=" * 60) - print("AutoRound 4-bit Quantization - CPU RAM Optimization Test") - print("=" * 60) - - results = [] - - # Test 1: Baseline (no optimization) - peak, elapsed = run_quantization( - "Test 1: Baseline (no optimization)", - cpu_stream_offload_blocks=False, - cpu_stream_loss=False, - ) - results.append(("Baseline", peak, elapsed)) - - # Test 2: cpu_stream_offload_blocks only - peak, elapsed = run_quantization( - "Test 2: cpu_stream_offload_blocks only", - cpu_stream_offload_blocks=True, - cpu_stream_loss=False, - ) - results.append(("+ offload_blocks", peak, elapsed)) - - # Test 3: cpu_stream_loss only - peak, elapsed = run_quantization( - "Test 3: cpu_stream_loss only", - cpu_stream_offload_blocks=False, - cpu_stream_loss=True, - ) - results.append(("+ stream_loss", peak, elapsed)) - - # Test 4: offload_blocks + stream_loss (All optimizations) - peak, elapsed = run_quantization( - "Test 4: All optimizations (offload_blocks + stream_loss)", - cpu_stream_offload_blocks=True, - cpu_stream_loss=True, - ) - results.append(("All optimizations", peak, elapsed)) - - # Summary - print("\n" + "=" * 60) - print("Summary: Peak RAM Comparison") - print("=" * 60) - print(f"{'Configuration':<25} {'Peak RAM (GB)':<15} {'Time (s)':<10} {'RAM Saved':<12}") - print("-" * 62) - baseline_ram = results[0][1] - for name, peak, elapsed in results: - saved = baseline_ram - peak - saved_str = f"-{saved:.2f} GB" if saved > 0 else "baseline" - print(f"{name:<25} {peak:<15.2f} {elapsed:<10.1f} {saved_str:<12}") - - -if __name__ == "__main__": - main() From 7a3dcacc475ed3018a279451fd3b4c980599fde4 Mon Sep 17 00:00:00 2001 From: Yi Liu Date: Tue, 3 Feb 2026 16:08:33 +0800 Subject: [PATCH 06/13] Disable replace `FP8Expert` (#1379) Signed-off-by: yiliu30 --- auto_round/experimental/qmodules/fp4_utils.py | 2 +- auto_round/modeling/__init__.py | 2 + auto_round/modeling/fp8_quant.py | 114 ++++++++++++++++++ auto_round/utils/common.py | 8 ++ auto_round/utils/device.py | 31 +++-- auto_round/utils/model.py | 36 +++--- test/test_cuda/models/test_fp8_model.py | 55 +++++++++ 7 files changed, 218 insertions(+), 30 deletions(-) create mode 100644 auto_round/modeling/fp8_quant.py create mode 100644 test/test_cuda/models/test_fp8_model.py diff --git a/auto_round/experimental/qmodules/fp4_utils.py b/auto_round/experimental/qmodules/fp4_utils.py index e755a2b38..aa9cae421 100644 --- a/auto_round/experimental/qmodules/fp4_utils.py +++ b/auto_round/experimental/qmodules/fp4_utils.py @@ -65,7 +65,7 @@ def _unpack_fp4_from_uint8_cpu( return _unpack_fp4_from_uint8(a, m, n, dtype) -@torch.compile(fullgraph=True, dynamic=True) +# @torch.compile(fullgraph=True, dynamic=True) def _unpack_fp4_from_uint8_cuda( a: torch.Tensor, m: int, n: int, dtype: Optional[torch.dtype] = torch.bfloat16 ) -> torch.Tensor: diff --git a/auto_round/modeling/__init__.py b/auto_round/modeling/__init__.py index 14a492441..d1bc25269 100644 --- a/auto_round/modeling/__init__.py +++ b/auto_round/modeling/__init__.py @@ -11,3 +11,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from .fp8_quant import * diff --git a/auto_round/modeling/fp8_quant.py b/auto_round/modeling/fp8_quant.py new file mode 100644 index 000000000..da4cae175 --- /dev/null +++ b/auto_round/modeling/fp8_quant.py @@ -0,0 +1,114 @@ +# Copyright (c) 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from transformers.quantizers.quantizer_finegrained_fp8 import ( + FineGrainedFP8HfQuantizer as OriginalFineGrainedFP8HfQuantizer, +) + +from auto_round.utils import is_transformers_version_greater_or_equal_5 +from auto_round.utils import logger as auto_round_logger +from auto_round.utils.device import override_cuda_device_capability + + +# Patching replace_with_fp8_linear to disable expert replacement +# https://github.com/huggingface/transformers/blob/78bb85146c59258a0710c8d08311d98d52303c38/src/transformers/integrations/finegrained_fp8.py#L720 +def oot_replace_with_fp8_linear( + model, modules_to_not_convert: list[str] | None = None, quantization_config=None, pre_quantized=False +): + """ + A helper function to replace all `torch.nn.Linear` modules by `FP8Linear` modules. + + Parameters: + model (`torch.nn.Module`): + Input model or `torch.nn.Module` as the function is run recursively. + modules_to_not_convert (`list[`str`]`, *optional*, defaults to `None`): + Names of the modules to not convert. In practice we keep the `lm_head` + in full precision for numerical stability reasons. + quantization_config (`FbgemmFp8Config`): + The quantization config object that contains the quantization parameters. + pre_quantized (`bool`, defaults to `False`): + Whether the model is pre-quantized or not + """ + from transformers.integrations.finegrained_fp8 import ( + FP8Linear, + logger, + should_convert_module, + ) + + if quantization_config.dequantize: + return model + + has_been_replaced = False + for module_name, module in model.named_modules(): + if not should_convert_module(module_name, modules_to_not_convert): + continue + # we need this to correctly materialize the weights during quantization + module_kwargs = {} if pre_quantized else {"dtype": None} + new_module = None + with torch.device("meta"): + # Note: Disable replacing experts, as we do not want concatenated experts + # if module_name.endswith(".experts"): + # new_module = FP8Expert( + # config=model.config, block_size=quantization_config.weight_block_size, **module_kwargs + # ) + # elif isinstance(module, nn.Linear): + if isinstance(module, torch.nn.Linear): + new_module = FP8Linear( + in_features=module.in_features, + out_features=module.out_features, + bias=module.bias is not None, + activation_scheme=quantization_config.activation_scheme, + block_size=quantization_config.weight_block_size, + **module_kwargs, + ) + if new_module is not None: + model.set_submodule(module_name, new_module) + has_been_replaced = True + + if not has_been_replaced: + logger.warning( + "You are loading your model using fp8 but no linear modules were found in your model." + " Please double check your model architecture." + ) + return model + + +_orig_validate_environment = OriginalFineGrainedFP8HfQuantizer.validate_environment + + +@override_cuda_device_capability() +def oot_validate_environment(self, *args, **kwargs): + return _orig_validate_environment(self, *args, **kwargs) + + +def apply_fp8_expert_replacement_patch(): + if is_transformers_version_greater_or_equal_5() and torch.cuda.is_available(): + try: + import transformers.integrations.finegrained_fp8 as transformers_fp8 + + transformers_fp8.replace_with_fp8_linear = oot_replace_with_fp8_linear + auto_round_logger.debug("Applied FP8 expert replacement patch to transformers.") + OriginalFineGrainedFP8HfQuantizer.validate_environment = oot_validate_environment + auto_round_logger.debug( + ( + "Patched FineGrainedFP8HfQuantizer.validate_environment to bypass device " + "capability check for loading FP8 models on unsupported GPUs." + ) + ) + except ImportError as e: + auto_round_logger.warning(f"Could not apply FP8 expert replacement patch as {e}.") + + +apply_fp8_expert_replacement_patch() diff --git a/auto_round/utils/common.py b/auto_round/utils/common.py index b8d7ab9fb..2ae2a78e4 100644 --- a/auto_round/utils/common.py +++ b/auto_round/utils/common.py @@ -390,3 +390,11 @@ class GlobalState: global_state = GlobalState() + + +@lru_cache(None) +def is_transformers_version_greater_or_equal_5(): + import transformers + from packaging import version + + return version.parse(transformers.__version__) >= version.parse("5.0.0") diff --git a/auto_round/utils/device.py b/auto_round/utils/device.py index b25a6888e..a16f441bf 100644 --- a/auto_round/utils/device.py +++ b/auto_round/utils/device.py @@ -15,7 +15,7 @@ import gc import os import re -from contextlib import contextmanager +from contextlib import ContextDecorator, contextmanager from functools import lru_cache from itertools import combinations from threading import Lock @@ -315,17 +315,28 @@ def set_cuda_visible_devices(device): os.environ["CUDA_VISIBLE_DEVICES"] = device -def set_fake_cuda_device_capability(func=None): - if func is not None: - torch.cuda.get_device_capability = func - return func +class override_cuda_device_capability(ContextDecorator): + """Context manager/decorator to temporarily override CUDA capability checks.""" - def fake_cuda(): - return 100, 1 + def __init__(self, major: int = 100, minor: int = 1) -> None: + self.major = major + self.minor = minor + self._orig_func = None - orig_func = torch.cuda.get_device_capability - torch.cuda.get_device_capability = fake_cuda - return orig_func + def __enter__(self): + self._orig_func = torch.cuda.get_device_capability + + def _override_capability(*_args, **_kwargs): + return self.major, self.minor + + torch.cuda.get_device_capability = _override_capability + return self + + def __exit__(self, exc_type, exc, exc_tb): + if self._orig_func is not None: + torch.cuda.get_device_capability = self._orig_func + self._orig_func = None + return False def get_packing_device(device: str | torch.device | None = "auto") -> torch.device: diff --git a/auto_round/utils/model.py b/auto_round/utils/model.py index 7377404b4..1fe13a458 100644 --- a/auto_round/utils/model.py +++ b/auto_round/utils/model.py @@ -447,6 +447,7 @@ def llm_load_model( device: str = "cpu", **kwargs, ): + assert platform.lower() in [ "hf", "model_scope", @@ -460,11 +461,10 @@ def llm_load_model( from modelscope import AutoModel, AutoModelForCausalLM, AutoTokenizer # pylint: disable=E0401 else: from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer - from auto_round.utils.device import ( _use_hpu_compile_mode, get_device_and_parallelism, - set_fake_cuda_device_capability, + override_cuda_device_capability, ) device_str, use_auto_mapping = get_device_and_parallelism(device) @@ -497,14 +497,13 @@ def llm_load_model( ) except ValueError as e: if "FP8 quantized" in str(e): - orig_func = set_fake_cuda_device_capability() - model = model_cls.from_pretrained( - pretrained_model_name_or_path, - torch_dtype=torch_dtype, - trust_remote_code=trust_remote_code, - device_map="auto" if use_auto_mapping else None, - ) - torch.cuda.get_device_capability = orig_func + with override_cuda_device_capability(): + model = model_cls.from_pretrained( + pretrained_model_name_or_path, + torch_dtype=torch_dtype, + trust_remote_code=trust_remote_code, + device_map="auto" if use_auto_mapping else None, + ) logger.warning("the support for fp8 model as input is experimental, please use with caution.") else: raise @@ -557,7 +556,7 @@ def mllm_load_model( base_lib = transformers - from auto_round.utils.device import get_device_and_parallelism, set_fake_cuda_device_capability + from auto_round.utils.device import get_device_and_parallelism, override_cuda_device_capability device_str, use_auto_mapping = get_device_and_parallelism(device) torch_dtype = "auto" @@ -630,14 +629,13 @@ def mllm_load_model( ) except ValueError as e: if "FP8 quantized" in str(e): - orig_func = set_fake_cuda_device_capability() - model = cls.from_pretrained( - pretrained_model_name_or_path, - trust_remote_code=trust_remote_code, - torch_dtype=torch_dtype, - device_map="auto" if use_auto_mapping else None, - ) - torch.cuda.get_device_capability = orig_func + with override_cuda_device_capability(): + model = cls.from_pretrained( + pretrained_model_name_or_path, + trust_remote_code=trust_remote_code, + torch_dtype=torch_dtype, + device_map="auto" if use_auto_mapping else None, + ) logger.warning("the support for fp8 model as input is experimental, please use with caution.") else: raise diff --git a/test/test_cuda/models/test_fp8_model.py b/test/test_cuda/models/test_fp8_model.py new file mode 100644 index 000000000..5b7962f7b --- /dev/null +++ b/test/test_cuda/models/test_fp8_model.py @@ -0,0 +1,55 @@ +import shutil + +import pytest +import torch +import torch.nn as nn +from transformers import AutoConfig, AutoModelForCausalLM, AutoProcessor, AutoTokenizer + +from auto_round import AutoRound + + +@pytest.fixture +def setup_qwen_fp8(): + """Fixture to set up the GPT-OSS model and tokenizer.""" + model_name = "INC4AI/Qwen3-30B-A3B-Instruct-2507-FP8-2Layers" + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) + # model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True) + output_dir = "test_quantized_qwen3_fp8_moe_mxfp" + return tokenizer, output_dir, config, model_name + + +def test_qwen3_fp8_moe_mxfp(setup_qwen_fp8): + tokenizer, output_dir, config, model_name = setup_qwen_fp8 + autoround = AutoRound( + model_name, + scheme="MXFP4", + nsamples=2, + seqlen=32, + iters=0, + ) + quantized_model, _ = autoround.quantize_and_save(format="auto_round", output_dir=output_dir) + assert quantized_model is not None, "Quantized model should not be None." + loaded_model = AutoModelForCausalLM.from_pretrained(output_dir) + loaded_model.to("cuda") + quantized_model.to("cuda") + for n, m in quantized_model.named_modules(): + if m.__class__.__name__ == "QuantLinear": + loaded_m = loaded_model.get_submodule(n) + assert (loaded_m.weight_packed == m.weight_packed).all() + # Expect all linear in experts are quantized + for n, m in quantized_model.named_modules(): + if "experts" in m.__class__.__name__.lower(): + for sub_n, sub_m in m.named_modules(): + assert sub_m.__class__.__name__ == "QuantLinear", f"Module {n}.{sub_n} is not quantized." + inp = torch.randint(0, 100, (1, 64)).to("cuda") + with torch.inference_mode(): + loaded_out = loaded_model(inp) + + # test generation + tokenizer = AutoTokenizer.from_pretrained(output_dir) + text = "There is a girl who likes adventure," + inputs = tokenizer(text, return_tensors="pt").to(device=loaded_model.device) + print(tokenizer.decode(loaded_model.generate(**inputs, max_new_tokens=50)[0])) + # clean the output directory after test + shutil.rmtree(output_dir, ignore_errors=True) From 082bf4cafced03111f4a63486d36ffeed6494b39 Mon Sep 17 00:00:00 2001 From: Liang Lv Date: Tue, 3 Feb 2026 16:41:53 +0800 Subject: [PATCH 07/13] Support general MOE replacement for MOE models (Transformers 5.0 compatible) (#1374) Signed-off-by: lvliang-intel --- auto_round/experimental/qmodules/nvfp4.py | 7 + .../export_to_nvfp_mxfp.py | 6 + auto_round/export/utils.py | 64 +++ auto_round/modeling/fused_moe/__init__.py | 25 + .../fused_moe/moe_experts_interface.py | 447 ++++++++++++++++++ .../modeling/fused_moe/replace_modules.py | 100 ++++ auto_round/utils/model.py | 240 +++++++++- auto_round/wrapper.py | 18 + test/test_cpu/quantization/test_mxfp_nvfp.py | 49 +- .../models/test_moe_experts_interface.py | 239 ++++++++++ 10 files changed, 1157 insertions(+), 38 deletions(-) create mode 100644 auto_round/modeling/fused_moe/moe_experts_interface.py create mode 100644 test/test_cuda/models/test_moe_experts_interface.py diff --git a/auto_round/experimental/qmodules/nvfp4.py b/auto_round/experimental/qmodules/nvfp4.py index 88ea577e2..81aea8b54 100644 --- a/auto_round/experimental/qmodules/nvfp4.py +++ b/auto_round/experimental/qmodules/nvfp4.py @@ -71,6 +71,7 @@ def __init__( self.config = config self.dtype = dtype self.pre_dequantized = False + self._cached_weight = None # Validate dtype assert ( @@ -165,6 +166,12 @@ def dequant_weight_online(self) -> torch.Tensor: dq_weight = self._dequant_nvfp4_tensor(self.weight_packed, self.weight_scale) return dq_weight + @property + def weight(self) -> torch.Tensor: + if not hasattr(self, "_cached_weight") or self._cached_weight is None: + self._cached_weight = self.dequant_weight_online() + return self._cached_weight + def qdq_input(self, activation: torch.Tensor): original_dtype = activation.dtype temp_qdq_act = _nvfp4_qdq(activation.to(torch.float32), self.config, self.input_global_scale) diff --git a/auto_round/export/export_to_autoround/export_to_nvfp_mxfp.py b/auto_round/export/export_to_autoround/export_to_nvfp_mxfp.py index 1320846dc..186eea053 100644 --- a/auto_round/export/export_to_autoround/export_to_nvfp_mxfp.py +++ b/auto_round/export/export_to_autoround/export_to_nvfp_mxfp.py @@ -175,6 +175,12 @@ def save_quantized_as_fp( set_module(model, n, orig_layer) if is_nv_fp(act_data_type) and "static_gs" in str(act_data_type).lower(): + # Ensure all MOE layers have act_max set (needed after deep copy or for uncalibrated layers) + from auto_round.utils.model import is_moe_model, set_amax_for_all_moe_layers + + if is_moe_model(model): + set_amax_for_all_moe_layers(model) + # generate static input_global_scale for n, m in model.named_modules(): if type(m) in SUPPORTED_LAYER_TYPES: diff --git a/auto_round/export/utils.py b/auto_round/export/utils.py index 6f9490dad..42e6b4f9e 100644 --- a/auto_round/export/utils.py +++ b/auto_round/export/utils.py @@ -19,6 +19,59 @@ from auto_round.utils import copy_python_files_from_model_cache, logger, unsupported_meta_device +def _has_unfused_moe_experts(model: nn.Module) -> bool: + """Check if model has unfused MOE experts (nn.ModuleList instead of 3D Parameter). + + This is used to detect if we need to bypass transformers' weight conversion + during save_pretrained. + """ + for module in model.modules(): + if hasattr(module, "gate_up_proj") and isinstance(module.gate_up_proj, nn.ModuleList): + if hasattr(module, "down_proj") and isinstance(module.down_proj, nn.ModuleList): + return True + return False + + +def _save_model_state_dict( + model: nn.Module, + save_dir: str, + max_shard_size: str = "5GB", + safe_serialization: bool = True, +): + """Save model using state_dict directly, bypassing transformers' weight conversion. + + This is needed for models with unfused MOE experts where transformers' + revert_weight_conversion expects 3D tensor format but we have ModuleList. + """ + import torch + from safetensors.torch import save_file + + os.makedirs(save_dir, exist_ok=True) + + # Save config + if hasattr(model, "config") and model.config is not None: + model.config.save_pretrained(save_dir) + + if hasattr(model, "generation_config") and model.generation_config is not None: + try: + model.generation_config.save_pretrained(save_dir) + except Exception: + pass # generation_config save can fail for some models + + # Get state dict + state_dict = model.state_dict() + + # Save weights + if safe_serialization: + # Save as safetensors + weights_file = os.path.join(save_dir, "model.safetensors") + save_file(state_dict, weights_file) + else: + # Save as pytorch bin + weights_file = os.path.join(save_dir, "pytorch_model.bin") + torch.save(state_dict, weights_file) + + def save_model( model: nn.Module, save_dir: str, @@ -47,12 +100,22 @@ def save_model( Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`). """ os.makedirs(save_dir, exist_ok=True) + + # Check if model has unfused MOE experts - if so, we need to bypass transformers' + # weight conversion which expects original 3D tensor format + has_unfused_experts = _has_unfused_moe_experts(model) + if unsupported_meta_device(model): if hasattr(model, "config") and model.config is not None: model.config.save_pretrained(save_dir) if hasattr(model, "generation_config") and model.generation_config is not None: model.generation_config.save_pretrained(save_dir) + elif has_unfused_experts: + # For models with unfused MOE experts, save state_dict directly to avoid + # transformers' revert_weight_conversion which expects 3D tensor format + logger.info("Saving model with unfused MOE experts using state_dict (bypassing weight conversion)") + _save_model_state_dict(model, save_dir, max_shard_size, safe_serialization) else: try: model.save_pretrained(save_dir, max_shard_size=max_shard_size, safe_serialization=safe_serialization) @@ -68,6 +131,7 @@ def save_model( data["torch_dtype"] = str(dtype).split(".")[-1] with open(config_path, "w") as file: json.dump(data, file, indent=2) + config_file = "quantization_config.json" if hasattr(model, "config") and hasattr(model.config, "quantization_config"): with open(os.path.join(save_dir, config_file), "w", encoding="utf-8") as f: diff --git a/auto_round/modeling/fused_moe/__init__.py b/auto_round/modeling/fused_moe/__init__.py index 14a492441..deacabb31 100644 --- a/auto_round/modeling/fused_moe/__init__.py +++ b/auto_round/modeling/fused_moe/__init__.py @@ -11,3 +11,28 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from auto_round.modeling.fused_moe.replace_modules import ( + ReplacementModuleBase, + apply_replacements, + materialize_model_, + release_original_module_, +) +from auto_round.modeling.fused_moe.moe_experts_interface import ( + is_linear_loop_available, + linear_loop_experts_forward, + prepare_model_for_moe_quantization, + register_linear_loop_experts, +) + +__all__ = [ + "ReplacementModuleBase", + "apply_replacements", + "materialize_model_", + "release_original_module_", + # Transformers-native MOE integration (transformers 5.0+) + "linear_loop_experts_forward", + "register_linear_loop_experts", + "prepare_model_for_moe_quantization", + "is_linear_loop_available", +] diff --git a/auto_round/modeling/fused_moe/moe_experts_interface.py b/auto_round/modeling/fused_moe/moe_experts_interface.py new file mode 100644 index 000000000..aaef9ab8b --- /dev/null +++ b/auto_round/modeling/fused_moe/moe_experts_interface.py @@ -0,0 +1,447 @@ +# Copyright (c) 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Custom experts implementation for transformers' MOE integration. + +This module provides a `linear_loop` experts implementation that uses +individual nn.Linear layers per expert instead of fused 3D Parameters. +This enables proper quantization of MOE expert weights. + +The implementation integrates with transformers' `use_experts_implementation` +decorator and `ALL_EXPERTS_FUNCTIONS` registry. + +Usage: + + from auto_round.modeling.fused_moe.moe_experts_interface import prepare_model_for_moe_quantization + # Before quantization + prepare_model_for_moe_quantization(model) + + # Now the model uses linear_loop forward which supports quantized nn.Linear layers +""" + +import torch +from torch import nn + +from auto_round.utils import clear_memory, logger + +try: + from transformers.integrations.moe import ALL_EXPERTS_FUNCTIONS + + HAS_EXPERTS_INTERFACE = True +except ImportError: + HAS_EXPERTS_INTERFACE = False + ALL_EXPERTS_FUNCTIONS = None + +# Expert implementation name - change this if transformers want to use a different name +LINEAR_LOOP_IMPL = "linear_loop" + +# Known expert projection patterns for reference +# These are used as hints when auto-detection needs to infer projection properties +# Format: proj_name -> {"is_input_proj": bool, "output_multiplier": int} +# is_input_proj: True if takes hidden_dim as input, False if takes intermediate_dim +# output_multiplier: output dimension multiplier (e.g., 2 for fused gate+up projection) +KNOWN_PROJECTION_PATTERNS = { + # Transformers 5.0+ standard (Qwen3-MoE, etc.) + "gate_up_proj": {"is_input_proj": True, "output_multiplier": 2}, # hidden -> 2*intermediate + "down_proj": {"is_input_proj": False, "output_multiplier": 1}, # intermediate -> hidden + # Mixtral-style + "w1": {"is_input_proj": True, "output_multiplier": 1}, # gate: hidden -> intermediate + "w2": {"is_input_proj": False, "output_multiplier": 1}, # down: intermediate -> hidden + "w3": {"is_input_proj": True, "output_multiplier": 1}, # up: hidden -> intermediate + # DBRX-style + "v1": {"is_input_proj": True, "output_multiplier": 1}, + "w1_proj": {"is_input_proj": True, "output_multiplier": 1}, + "w2_proj": {"is_input_proj": False, "output_multiplier": 1}, +} + + +def linear_loop_experts_forward( + self: nn.Module, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, +) -> torch.Tensor: + """Forward using individual nn.Linear layers per expert. + + This implementation loops over experts and uses self.gate_up_proj[i] and + self.down_proj[i] as nn.Linear layers (or quantized equivalents), enabling + proper quantization support. + + Expected module attributes: + - gate_up_proj: nn.ModuleList of nn.Linear (in_features=hidden_dim, out_features=2*intermediate_dim) + - down_proj: nn.ModuleList of nn.Linear (in_features=intermediate_dim, out_features=hidden_dim) + - act_fn: activation function + - num_experts: number of experts + - _apply_gate: optional custom gating function + + Args: + self: The experts module + hidden_states: Input tensor of shape (num_tokens, hidden_dim) + top_k_index: Selected expert indices of shape (num_tokens, top_k) + top_k_weights: Expert weights of shape (num_tokens, top_k) + + Returns: + final_hidden_states: Output tensor of shape (num_tokens, hidden_dim) + """ + logger.debug(f"Using {LINEAR_LOOP_IMPL} experts forward for {self.__class__.__name__}") + + # Handle [batch_size, seq_len, hidden_dim] input format + if hidden_states.dim() == 3: + batch_size, seq_len, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) # [bs * seq_len, hidden_dim] + top_k_index = top_k_index.view(-1, top_k_index.size(-1)) # [bs * seq_len, top_k] + top_k_weights = top_k_weights.view(-1, top_k_weights.size(-1)) # [bs * seq_len, top_k] + else: + batch_size, seq_len = None, None + hidden_dim = hidden_states.size(-1) + + device = hidden_states.device + num_top_k = top_k_index.size(-1) + num_tokens = hidden_states.size(0) + num_experts = self.num_experts + + # Reshape for easier indexing + # S is the number of selected token-expert pairs (S = num_tokens * num_top_k) + token_idx = torch.arange(num_tokens, device=device).unsqueeze(1).expand(-1, num_top_k).reshape(-1) # (S,) + sample_weights = top_k_weights.reshape(-1).to(hidden_states.dtype) # (S,) + expert_ids = top_k_index.reshape(-1) # (S,) + + # Get current hidden states for selected samples + selected_hidden_states = hidden_states[token_idx] # (S, hidden_dim) + + # Allocate output tensor + out_per_sample = torch.zeros(token_idx.size(0), hidden_dim, device=device, dtype=hidden_states.dtype) + + # Process each expert + for expert_idx in range(num_experts): + # Find samples routed to this expert + mask = expert_ids == expert_idx + if not mask.any(): + continue + + expert_input = selected_hidden_states[mask] # (num_samples_for_expert, hidden_dim) + + # Use nn.Linear layers for this expert + gate_up_out = self.gate_up_proj[expert_idx](expert_input) # (num_samples, 2*intermediate_dim) + + # Apply gating + if hasattr(self, "_apply_gate"): + gated_out = self._apply_gate(gate_up_out) # (num_samples, intermediate_dim) + else: + gate, up = gate_up_out.chunk(2, dim=-1) + gated_out = self.act_fn(gate) * up + + # Down projection + expert_out = self.down_proj[expert_idx](gated_out) # (num_samples, hidden_dim) + + # Store results + out_per_sample[mask] = expert_out.to(out_per_sample.dtype) + + # Apply routing weights + out_per_sample = out_per_sample * sample_weights.unsqueeze(-1) # (S, hidden_dim) + + # Accumulate results using deterministic reshape+sum instead of index_add_ + # (index_add_ with duplicate indices is non-deterministic on CUDA due to atomicAdd) + final_hidden_states = out_per_sample.view(num_tokens, num_top_k, hidden_dim).sum(dim=1) + + # Reshape back to original format if input was [batch_size, seq_len, hidden_dim] + if batch_size is not None: + final_hidden_states = final_hidden_states.view(batch_size, seq_len, hidden_dim) + + return final_hidden_states + + +def register_linear_loop_experts() -> bool: + """Register the linear_loop experts implementation with transformers. + + Returns: + True if registration was successful, False otherwise. + """ + if not HAS_EXPERTS_INTERFACE: + logger.warning( + "transformers.integrations.moe.ALL_EXPERTS_FUNCTIONS not available. " + "linear_loop experts implementation not registered. " + "Requires transformers >= 5.0.0" + ) + return False + + if LINEAR_LOOP_IMPL not in ALL_EXPERTS_FUNCTIONS._global_mapping: + ALL_EXPERTS_FUNCTIONS._global_mapping[LINEAR_LOOP_IMPL] = linear_loop_experts_forward + logger.debug(f"Registered '{LINEAR_LOOP_IMPL}' experts implementation") + + return True + + +def _experts_supports_decorator(module: nn.Module) -> bool: + """Check if experts module supports @use_experts_implementation decorator. + + Only experts classes decorated with @use_experts_implementation will use + our linear_loop forward. Others need full module replacement. + """ + forward_method = getattr(module.__class__, "forward", None) + if forward_method is None: + return False + # @use_experts_implementation sets __wrapped__ on the decorated method + return hasattr(forward_method, "__wrapped__") + + +def _detect_expert_projections(module: nn.Module) -> dict[str, dict]: + """Detect which expert projections exist in the module. + + This function scans the module for any 3D nn.Parameter attributes. + It first checks known projection names, then discovers any unknown 3D parameters. + + Returns: + Dict mapping projection names to their config, only for projections that exist + as 3D nn.Parameter in the module. + """ + detected = {} + + # First, check known projection patterns + for proj_name, config in KNOWN_PROJECTION_PATTERNS.items(): + param = getattr(module, proj_name, None) + if param is not None and isinstance(param, nn.Parameter) and param.dim() == 3: + detected[proj_name] = config + + # If no known patterns found, scan for any 3D Parameter (future-proofing) + if not detected: + for attr_name in dir(module): + if attr_name.startswith("_"): + continue + param = getattr(module, attr_name, None) + if param is not None and isinstance(param, nn.Parameter) and param.dim() == 3: + # Use default config for unknown projections + logger.debug(f"Discovered unknown 3D projection: {attr_name}") + detected[attr_name] = {"is_input_proj": True, "output_multiplier": 1} + + return detected + + +def _infer_dimensions(param: nn.Parameter, config: dict, is_transposed: bool) -> tuple[int, int]: + """Infer input and output dimensions for a projection. + + Args: + param: The 3D parameter (num_experts, dim1, dim2) + config: Projection config with is_input_proj and output_multiplier + is_transposed: Whether weights are stored transposed + + Returns: + (in_features, out_features) for the Linear layer + """ + dim1, dim2 = param.shape[1], param.shape[2] + multiplier = config.get("output_multiplier", 1) + + if is_transposed: + # transposed: (num_experts, in_features, out_features) + in_features, out_features = dim1, dim2 + else: + # not transposed: (num_experts, out_features, in_features) + out_features, in_features = dim1, dim2 + + # Adjust for multiplier (e.g., gate_up has 2x intermediate) + if multiplier > 1: + out_features = out_features // multiplier * multiplier # ensure divisible + + return in_features, out_features + + +def _unfuse_single_projection( + module: nn.Module, + proj_name: str, + num_experts: int, + is_transposed: bool, + dtype: torch.dtype, + target_device: torch.device, +) -> nn.ModuleList | None: + """Unfuse a single projection from 3D Parameter to ModuleList of Linear layers. + + Args: + module: The experts module + proj_name: Name of the projection attribute + num_experts: Number of experts + is_transposed: Whether weights are stored transposed + dtype: Data type for the Linear layers + target_device: Device for the Linear layers + + Returns: + ModuleList of Linear layers, or None if projection doesn't exist + """ + param = getattr(module, proj_name, None) + if param is None or not isinstance(param, nn.Parameter) or param.dim() != 3: + return None + + # Get projection config + config = KNOWN_PROJECTION_PATTERNS.get(proj_name, {"is_input_proj": True, "output_multiplier": 1}) + + # Infer dimensions + in_features, out_features = _infer_dimensions(param, config, is_transposed) + + # Check for bias + bias_name = f"{proj_name}_bias" + bias_param = getattr(module, bias_name, None) + has_bias = bias_param is not None + + # Create ModuleList + linears = nn.ModuleList() + source_device = param.device + + for i in range(num_experts): + linear = nn.Linear(in_features, out_features, bias=has_bias, dtype=dtype, device=target_device) + + # Copy weights if not on meta device + if source_device.type != "meta": + if is_transposed: + linear.weight.data.copy_(param[i].t()) + else: + linear.weight.data.copy_(param[i]) + + if has_bias: + linear.bias.data.copy_(bias_param[i]) + + linears.append(linear) + + # Release original parameter memory using to_empty + if source_device.type != "meta": + try: + param.data = param.data.to_empty(device="meta") + logger.debug(f"Released memory for {proj_name} using to_empty(device='meta')") + except Exception: + # Fallback: just delete + pass + + return linears + + +def _unfuse_experts_weights_inplace( + module: nn.Module, + check_decorator: bool = True, + projection_names: list[str] | None = None, +) -> bool: + """Convert fused 3D expert weights to nn.ModuleList of nn.Linear layers. + + This function modifies the module in-place, replacing fused 3D Parameters + with nn.ModuleList[nn.Linear] for each detected projection. + + Args: + module: The experts module to unfuse + check_decorator: If True, only unfuse if the module supports + @use_experts_implementation decorator. Default is True. + projection_names: Optional list of projection names to unfuse. + If None, auto-detects from KNOWN_PROJECTION_PATTERNS. + + Returns: + True if unfusing was successful, False if module doesn't match pattern + """ + # Detect available projections + if projection_names: + detected_projections = { + name: config for name, config in KNOWN_PROJECTION_PATTERNS.items() if name in projection_names + } + else: + detected_projections = _detect_expert_projections(module) + + if not detected_projections: + return False + + # Only unfuse if the module supports the decorator (unless check_decorator is False) + if check_decorator and not _experts_supports_decorator(module): + logger.debug(f"Skipping unfuse for {module.__class__.__name__}: does not support @use_experts_implementation") + return False + + # Get first projection to determine num_experts and layout + first_proj_name = next(iter(detected_projections)) + first_param = getattr(module, first_proj_name) + num_experts = first_param.shape[0] + + # Detect if transposed + is_transposed = getattr(module, "is_transposed", None) + if is_transposed is None: + # Infer from shape: typically hidden_dim < intermediate_dim + dim1, dim2 = first_param.shape[1], first_param.shape[2] + is_transposed = dim1 < dim2 + + dtype = first_param.dtype + target_device = first_param.device if first_param.device.type != "meta" else "cpu" + + # Unfuse each projection + unfused_count = 0 + for proj_name in detected_projections: + linears = _unfuse_single_projection(module, proj_name, num_experts, is_transposed, dtype, target_device) + if linears is not None: + # Delete original parameter and set new ModuleList + delattr(module, proj_name) + setattr(module, proj_name, linears) + unfused_count += 1 + logger.debug(f"Unfused {proj_name}: {num_experts} experts") + + # Ensure num_experts is set + if not hasattr(module, "num_experts"): + module.num_experts = num_experts + + return unfused_count > 0 + + +def prepare_model_for_moe_quantization(model: nn.Module, implementation: str = LINEAR_LOOP_IMPL) -> list[str]: + """Prepare a model for MOE quantization using transformers' experts interface. + + This function: + 1. Registers the linear_loop experts implementation with transformers + 2. Sets model.config._experts_implementation = implementation + 3. Unfuses all fused MOE expert weights into nn.ModuleList[nn.Linear] + + After calling this function, the model's forward pass will use individual + nn.Linear layers per expert, which can be quantized normally. + + Args: + model: The model to prepare + implementation: The experts implementation to use (default: "linear_loop") + + Returns: + List of module names that were unfused + """ + # Register our custom implementation + if not register_linear_loop_experts(): + raise RuntimeError( + "Failed to register linear_loop experts implementation. " + "This requires transformers >= 5.0.0 with MOE integration support." + ) + + # Unfuse all fused experts modules (only those supporting @use_experts_implementation) + unfused_modules = [] + for name, module in model.named_modules(): + if _unfuse_experts_weights_inplace(module): + unfused_modules.append(name) + logger.debug(f"Unfused expert weights in: {name}") + + # Only set config if we actually unfused something + # Models that don't support the decorator (like Llama4) won't have anything unfused + # and should use full module replacement instead + if unfused_modules: + logger.info(f"Unfused {len(unfused_modules)} MOE experts modules for quantization") + clear_memory() + + # Set config for linear_loop forward + if hasattr(model, "config"): + saved_impl = getattr(model.config, "experts_implementation", None) + impl_to_set = saved_impl if saved_impl else implementation + model.config._experts_implementation = impl_to_set + logger.debug(f"Set model.config._experts_implementation = '{impl_to_set}'") + + return unfused_modules + + +def is_linear_loop_available() -> bool: + """Check if linear_loop experts implementation is available.""" + return HAS_EXPERTS_INTERFACE diff --git a/auto_round/modeling/fused_moe/replace_modules.py b/auto_round/modeling/fused_moe/replace_modules.py index 66222f414..06845e93c 100644 --- a/auto_round/modeling/fused_moe/replace_modules.py +++ b/auto_round/modeling/fused_moe/replace_modules.py @@ -20,6 +20,8 @@ from tqdm import tqdm from transformers import PreTrainedModel +# Import constant for expert implementation name +from auto_round.modeling.fused_moe.moe_experts_interface import LINEAR_LOOP_IMPL from auto_round.utils import LazyImport, dump_mem_usage, dump_memory_usage_ctx, global_state, logger BUILTIN_MODULES = { @@ -30,6 +32,55 @@ } +def _has_custom_replacement_only(model: torch.nn.Module) -> bool: + """Check if all MOE-like modules in model have custom replacements registered. + + If all MOE modules are covered by BUILTIN_MODULES, we can skip the linear_loop path entirely. + """ + for _, module in model.named_modules(): + class_name = module.__class__.__name__ + # Check if this looks like an MOE experts module (has fused 3D weights) + if ( + hasattr(module, "gate_up_proj") + and isinstance(module.gate_up_proj, torch.nn.Parameter) + and module.gate_up_proj.dim() == 3 + ): + # This is a fused experts module - check if it has a custom replacement + if class_name not in BUILTIN_MODULES and not ReplacementModuleBase.is_registered(class_name): + # Found a fused experts module without custom replacement - need linear_loop + return False + return True + + +def _handle_moe_modules(model: torch.nn.Module) -> list[str]: + """Handle fused MOE modules using transformers' linear_loop backend. + + Args: + model: The model to process + + Returns: + List of module names that were processed + """ + from auto_round.modeling.fused_moe.moe_experts_interface import ( + is_linear_loop_available, + prepare_model_for_moe_quantization, + ) + + if not is_linear_loop_available(): + logger.warning( + "transformers' linear_loop experts interface not available (requires transformers 5.0+). " + "MOE modules with @use_experts_implementation decorator will fall back to custom replacements " + "if registered." + ) + return [] + + # Use transformers' experts interface + unfused = prepare_model_for_moe_quantization(model) + if unfused: + logger.info(f"Prepared {len(unfused)} MOE modules for quantization") + return unfused + + def _import_required_replacements(model: torch.nn.Module) -> None: """Scan model and trigger lazy imports for registered replacement modules.""" imported = set() @@ -44,6 +95,33 @@ def _import_required_replacements(model: torch.nn.Module) -> None: logger.debug(f"Loaded replacement module for {class_name}") +def _should_skip_moe_replacement(module: torch.nn.Module, model: torch.nn.Module) -> bool: + """Skip MOE replacement if linear_loop experts are already unfused. + + This is only true when: + 1. model.config._experts_implementation == "linear_loop" (set by prepare_model_for_moe_quantization) + 2. The experts' gate_up_proj and down_proj are already nn.ModuleList + + Note: _experts_implementation is only set when the experts class supports + @use_experts_implementation decorator, so we don't need to check that again here. + """ + if not hasattr(model, "config"): + return False + if getattr(model.config, "_experts_implementation", None) != LINEAR_LOOP_IMPL: + return False + experts = getattr(module, "experts", None) + if experts is None: + return False + gate_up = getattr(experts, "gate_up_proj", None) + down = getattr(experts, "down_proj", None) + result = isinstance(gate_up, torch.nn.ModuleList) and isinstance(down, torch.nn.ModuleList) + logger.debug( + f"_should_skip_moe_replacement for {module.__class__.__name__}: " + f"gate_up type={type(gate_up).__name__}, down type={type(down).__name__}, skip={result}" + ) + return result + + @dump_mem_usage("Materializing model", log_level="debug") def materialize_model_(model: torch.nn.Module) -> None: def _materialize_module(module: torch.nn.Module) -> None: @@ -219,6 +297,7 @@ def post_process_materialization(self) -> None: def apply_replacements( model: torch.nn.Module, + auto_detect_moe: bool = True, ) -> torch.nn.Module: """ Function to apply module replacements to a model. @@ -230,11 +309,19 @@ def apply_replacements( Args: model: The model to apply module replacement to (modified in-place). + auto_detect_moe: If True, automatically detect and handle fused MOE modules + (transformers 5.0+ pattern). Default is True. Returns: The model with modules replaced. """ _import_required_replacements(model) + + # Auto-detect and handle fused MOE modules if enabled + # Skip if all MOE modules already have custom replacements registered + if auto_detect_moe and not _has_custom_replacement_only(model): + _handle_moe_modules(model) + replaced = [] # Step 1: Collect all modules that need replacement @@ -245,6 +332,9 @@ def apply_replacements( if isinstance(module, ReplacementModuleBase): continue class_name = module.__class__.__name__ + if class_name in BUILTIN_MODULES and _should_skip_moe_replacement(module, model): + logger.debug(f"Skipping replacement for {name}: linear_loop experts already unfused") + continue if ReplacementModuleBase.is_registered(class_name) and ReplacementModuleBase.get_replacement_class( class_name ).is_to_be_replaced(module): @@ -255,8 +345,18 @@ def apply_replacements( logger.info(f"Found {len(modules_to_replace)} modules to replace") for name, module, class_name in tqdm(modules_to_replace, desc="Replacing modules"): module = model.get_submodule(name) + # The module might have been replaced earlier in the loop (parent-first replacement). + # Skip if the class has changed or it no longer matches replacement criteria. + if module.__class__.__name__ != class_name: + logger.debug( + f"Skipping replacement for {name}: class changed from {class_name} to {module.__class__.__name__}" + ) + continue with dump_memory_usage_ctx(f"Replacing module {name}", log_level="debug"): replacement_cls = ReplacementModuleBase.get_replacement_class(class_name) + if not replacement_cls.is_to_be_replaced(module): + logger.debug(f"Skipping replacement for {name}: no longer matches replacement criteria") + continue replacement = replacement_cls.from_original( module, model.config, diff --git a/auto_round/utils/model.py b/auto_round/utils/model.py index 1fe13a458..34190593e 100644 --- a/auto_round/utils/model.py +++ b/auto_round/utils/model.py @@ -1313,34 +1313,65 @@ def set_amax_for_uncalibrated_experts( uncalibrated_experts: a list of uncalibrated experts """ uncalibrated_experts = [] + + def _get_attr(module, name): + """Get attribute from module or its orig_layer.""" + if hasattr(module, name): + return getattr(module, name) + if hasattr(module, "orig_layer") and hasattr(module.orig_layer, name): + return getattr(module.orig_layer, name) + return None + + def _get_amax_value(module): + value = get_nested_attr(module, attr_name) + if value is None and hasattr(module, "orig_layer"): + value = get_nested_attr(module.orig_layer, attr_name) + return value + # get the max amax value from all experts if set_amax_value is None: - amax_values = [ - get_nested_attr(module, attr_name) for module in experts if get_nested_attr(module, attr_name) is not None - ] + amax_values = [_get_amax_value(m) for m in experts if _get_amax_value(m) is not None] if len(amax_values) == 0: + # Check if any expert actually needs act_max (act_bits < 8, not dynamic, not already quantized) + sample = next((m for m in experts if m is not None), None) + if sample is not None: + act_bits = _get_attr(sample, "act_bits") + act_dynamic = _get_attr(sample, "act_dynamic") + is_quantized = "Quant" in sample.__class__.__name__ or hasattr(sample, "is_mx") + needs_warning = ( + not is_quantized and isinstance(act_bits, (int, float)) and act_bits < 8 and not act_dynamic + ) + if needs_warning: + logger.warning_once( + f"All {len(experts)} expert layers are missing '{attr_name}' values. " + f"This may indicate calibration hooks were not attached to expert layers." + ) return uncalibrated_experts - # Flatten all tensors to 1D before concatenation - flat_values = [t.reshape(-1) for t in amax_values] - all_values = torch.cat(flat_values) - set_amax_value = torch.max(all_values) + else: + # Flatten all tensors to 1D before concatenation + flat_values = [t.reshape(-1) for t in amax_values] + all_values = torch.cat(flat_values) + set_amax_value = torch.max(all_values) for module in experts: - current_amax = get_nested_attr(module, attr_name) + current_amax = _get_amax_value(module) # Set amax if it's None (uncalibrated) OR if unify_all is True if current_amax is None or unify_all: - if current_amax is None: - logger.warning_once( - "Missing amax value of expert layers." - "This typically occurs in MoE models when certain experts are not activated during calibration. " - "Consider increasing your calibration dataset size to ensure all experts are exercised." - ) # Use float32 dtype explicitly to ensure we create a floating point tensor if not isinstance(set_amax_value, torch.Tensor): set_amax_value = torch.tensor(set_amax_value, dtype=torch.float32) - set_nested_attr(module, attr_name, set_amax_value) - # uncalibrated_experts.append(module) + set_nested_attr(module, attr_name, set_amax_value.clone()) + if current_amax is None: + uncalibrated_experts.append(module) + + if uncalibrated_experts: + logger.info_once( + f"Found {len(uncalibrated_experts)} uncalibrated expert layers. " + "Using max amax from calibrated experts to fill missing values. " + ) + + return uncalibrated_experts # Please refer to: https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/ @@ -1357,13 +1388,31 @@ def set_amax_for_all_moe_layers(model: torch.nn.Module, layer_name=None, attr_na for name, sub_module in model.named_modules(): if not (is_moe_layer(sub_module) and hasattr(sub_module, "experts")): continue + + # Handle router (gate) layer - it's a Linear layer used for token routing + # It needs act_max for quantization but may not be calibrated if it wasn't exercised + _set_amax_for_moe_auxiliary_layers(sub_module, attr_name=attr_name) + expert_linear_names = get_expert_linear_names(sub_module) # Get input projection names for FP8 dispatch unification expert_input_proj_names = get_expert_input_proj_names(sub_module) - for linear_name in expert_linear_names: - if isinstance(sub_module.experts, collections.abc.Iterable): - # For other MoE models (like Mixtral) with iterable experts + # Check experts structure and handle accordingly + if _is_unfused_experts_module(sub_module.experts): + # Unfused experts: gate_up_proj/down_proj are nn.ModuleList + _set_amax_for_unfused_experts(sub_module.experts, attr_name=attr_name) + elif _is_fused_experts_module(sub_module.experts): + # Fused experts: 3D Parameters (e.g., DeepseekV2Experts) + # For fused experts, act_max is set on the parent MOE module, not individual experts + # Skip processing here as they don't have individual Linear layers to calibrate + logger.debug( + f"Skipping act_max setting for fused experts module '{name}': " + f"fused experts use parent module's act_max" + ) + continue + elif isinstance(sub_module.experts, collections.abc.Iterable): + # Iterable experts: list of expert modules (e.g., Mixtral) + for linear_name in expert_linear_names: try: # Determine if this is an input projection that needs scale unification unify_scale = linear_name in expert_input_proj_names and envs.AR_ENABLE_UNIFY_MOE_INPUT_SCALE @@ -1385,12 +1434,153 @@ def set_amax_for_all_moe_layers(model: torch.nn.Module, layer_name=None, attr_na f"to be updated for this model architecture. " f"Original error: {e}" ) from e - else: - # Unsupported MoE model structure - raise NotImplementedError( - f"MoE model with experts type '{type(sub_module.experts).__name__}' is not supported in export." - f"Please file an issue or add support for this model architecture." - ) + else: + # Unknown experts structure + logger.warning( + f"Unknown experts structure in '{name}': type={type(sub_module.experts).__name__}. " + f"Skipping act_max setting. This may cause issues during export." + ) + + +def _is_unfused_experts_module(module: torch.nn.Module) -> bool: + """Check if the module is an unfused experts module (has ModuleList gate_up_proj/down_proj).""" + if not hasattr(module, "gate_up_proj") or not hasattr(module, "down_proj"): + return False + return isinstance(module.gate_up_proj, torch.nn.ModuleList) and isinstance(module.down_proj, torch.nn.ModuleList) + + +def _is_fused_experts_module(module: torch.nn.Module) -> bool: + """Check if the module is a fused experts module (has 3D Parameter gate_up_proj/down_proj).""" + if not hasattr(module, "gate_up_proj") or not hasattr(module, "down_proj"): + return False + return ( + isinstance(module.gate_up_proj, torch.nn.Parameter) + and isinstance(module.down_proj, torch.nn.Parameter) + and module.gate_up_proj.dim() == 3 + and module.down_proj.dim() == 3 + ) + + +def _set_amax_for_unfused_experts(experts_module: torch.nn.Module, attr_name: str = "act_max"): + """Set amax for unfused experts module with ModuleList attributes. + + This handles experts modules that have been unfused to have: + - gate_up_proj: nn.ModuleList of nn.Linear (input projections, unified scale) + - down_proj: nn.ModuleList of nn.Linear (output projections) + """ + if hasattr(experts_module, "gate_up_proj") and isinstance(experts_module.gate_up_proj, torch.nn.ModuleList): + unify_scale = envs.AR_ENABLE_UNIFY_MOE_INPUT_SCALE + set_amax_for_uncalibrated_experts( + list(experts_module.gate_up_proj), + attr_name=attr_name, + unify_all=unify_scale, + ) + + if hasattr(experts_module, "down_proj") and isinstance(experts_module.down_proj, torch.nn.ModuleList): + set_amax_for_uncalibrated_experts( + list(experts_module.down_proj), + attr_name=attr_name, + unify_all=False, + ) + + +def _set_amax_for_moe_auxiliary_layers(moe_module: torch.nn.Module, attr_name: str = "act_max"): + """Set amax for auxiliary layers in MOE modules (gate/router, shared_experts). + + These layers are not part of the experts structure but are siblings in the MOE module. + They need act_max for quantization but may be missing if not all paths were exercised + during calibration. + + Args: + moe_module: The MOE module (e.g., DeepseekV2MoE) + attr_name: The attribute name for amax (default: "act_max") + """ + # Collect all Linear layers that have act_bits set but missing act_max + layers_needing_amax = [] + + # Check gate (router) layer - it's typically a Linear layer for token routing + if hasattr(moe_module, "gate") and isinstance(moe_module.gate, torch.nn.Linear): + gate = moe_module.gate + if hasattr(gate, "act_bits") and gate.act_bits < 8: + if get_nested_attr(gate, attr_name) is None: + layers_needing_amax.append(gate) + + # Check shared_experts - may have Linear layers that need act_max + if hasattr(moe_module, "shared_experts"): + shared_experts = moe_module.shared_experts + if shared_experts is not None: + for child_name, child in shared_experts.named_modules(): + if isinstance(child, torch.nn.Linear): + if hasattr(child, "act_bits") and child.act_bits < 8: + if get_nested_attr(child, attr_name) is None: + layers_needing_amax.append(child) + + if not layers_needing_amax: + return + + # Try to get reference amax from calibrated experts + reference_amax = _get_reference_amax_from_experts(moe_module, attr_name) + + if reference_amax is not None: + for layer in layers_needing_amax: + if not isinstance(reference_amax, torch.Tensor): + reference_amax = torch.tensor(reference_amax, dtype=torch.float32) + set_nested_attr(layer, attr_name, reference_amax.clone()) + logger.info_once( + f"Set act_max for {len(layers_needing_amax)} MOE auxiliary layers (gate/shared_experts) " + f"using reference value from calibrated experts." + ) + else: + logger.warning_once( + f"Cannot set act_max for {len(layers_needing_amax)} MOE auxiliary layers: " + f"no calibrated experts found to use as reference." + ) + + +def _get_reference_amax_from_experts(moe_module: torch.nn.Module, attr_name: str = "act_max"): + """Get a reference amax value from calibrated expert layers. + + Args: + moe_module: The MOE module containing experts + attr_name: The attribute name for amax + + Returns: + A reference amax tensor, or None if no calibrated experts found + """ + amax_values = [] + + if not hasattr(moe_module, "experts"): + return None + + experts = moe_module.experts + + # Handle unfused experts (ModuleList) + if _is_unfused_experts_module(experts): + for proj_list in [getattr(experts, "gate_up_proj", None), getattr(experts, "down_proj", None)]: + if proj_list is not None and isinstance(proj_list, torch.nn.ModuleList): + for layer in proj_list: + amax = get_nested_attr(layer, attr_name) + if amax is not None: + amax_values.append(amax) + + # Handle iterable experts (list of modules) + elif isinstance(experts, collections.abc.Iterable): + expert_linear_names = get_expert_linear_names(moe_module) + for expert in experts: + for linear_name in expert_linear_names: + layer = getattr(expert, linear_name, None) + if layer is not None: + amax = get_nested_attr(layer, attr_name) + if amax is not None: + amax_values.append(amax) + + if not amax_values: + return None + + # Return max of all amax values + flat_values = [t.reshape(-1) for t in amax_values] + all_values = torch.cat(flat_values) + return torch.max(all_values) # Adapted from https://github.com/vllm-project/llm-compressor/blob/ diff --git a/auto_round/wrapper.py b/auto_round/wrapper.py index 24836d85b..89f903827 100644 --- a/auto_round/wrapper.py +++ b/auto_round/wrapper.py @@ -141,6 +141,14 @@ def __init__( else: self.orig_forward = self.linear_forward if type(self.orig_layer) == torch.nn.Linear else self.conv1d_forward + @property + def weight(self): + return self.orig_layer.weight + + @property + def bias(self): + return self.orig_layer.bias + def _init_tuning_params_and_quant_func(self): """Initializes tuning parameters and quantization functions. @@ -502,6 +510,16 @@ def __init__(self, orig_layer, enable_torch_compile=False, device="cpu"): self.act_quant_func = compile_func(self.act_quant_func, self.device) self.extra_repr_org = orig_layer.extra_repr + @property + def weight(self): + """Exposes the weight of the wrapped layer for external access.""" + return self.orig_layer.weight + + @property + def bias(self): + """Exposes the bias of the wrapped layer for external access.""" + return self.orig_layer.bias + def forward(self, x): act_max = self.orig_layer.act_max if hasattr(self.orig_layer, "act_max") else None x, _, _ = self.orig_layer.act_quant_func( diff --git a/test/test_cpu/quantization/test_mxfp_nvfp.py b/test/test_cpu/quantization/test_mxfp_nvfp.py index 3c2e9bcce..37b8a8a4d 100644 --- a/test/test_cpu/quantization/test_mxfp_nvfp.py +++ b/test/test_cpu/quantization/test_mxfp_nvfp.py @@ -1,3 +1,4 @@ +import collections import os import shutil @@ -32,9 +33,6 @@ def teardown_class(self): shutil.rmtree("./saved", ignore_errors=True) shutil.rmtree("runs", ignore_errors=True) - @pytest.mark.skipif( - transformers_version >= version.parse("5.0.0"), reason="transformers v5 MOE model has breaking changes" - ) def test_nvfp4_moe_actmax_rtn(self, tiny_deepseek_v2_model_path, dataloader): model_name = tiny_deepseek_v2_model_path layer_config = { @@ -55,16 +53,47 @@ def test_nvfp4_moe_actmax_rtn(self, tiny_deepseek_v2_model_path, dataloader): layer_config=layer_config, ) compressed_model, _ = autoround.quantize() - assert hasattr(compressed_model.model.layers[1].mlp.experts[0].gate_proj.orig_layer, "act_max") + moe = compressed_model.model.layers[1].mlp + experts = moe.experts + + def _has_act_max(layer): + if layer is None: + return False + if hasattr(layer, "orig_layer"): + layer = layer.orig_layer + return hasattr(layer, "act_max") + + found_act_max = False + if hasattr(experts, "gate_up_proj") and isinstance(experts.gate_up_proj, torch.nn.ModuleList): + if len(experts.gate_up_proj) > 0: + found_act_max = _has_act_max(experts.gate_up_proj[0]) + elif isinstance(experts, collections.abc.Iterable): + first_expert = next(iter(experts), None) + if first_expert is not None: + for linear_name in [ + "gate_proj", + "up_proj", + "down_proj", + "linear_fc1", + "linear_fc2", + "w1", + "w2", + "w3", + ]: + if hasattr(first_expert, linear_name): + found_act_max = _has_act_max(getattr(first_expert, linear_name)) + if found_act_max: + break + elif hasattr(moe, "act_max"): + found_act_max = True + + assert found_act_max, "Missing act_max on MOE expert layers" lm_head = compressed_model.lm_head assert hasattr(lm_head, "orig_layer") and hasattr( lm_head.orig_layer, "act_max" ), "Illegal NVFP4 quantization for lm_head layer" shutil.rmtree(self.save_dir, ignore_errors=True) - @pytest.mark.skipif( - transformers_version >= version.parse("5.0.0"), reason="transformers v5 MOE model has breaking changes" - ) def test_nvfp4_moe_actmax_ar(self, tiny_deepseek_v2_model_path, dataloader): model_name = tiny_deepseek_v2_model_path layer_config = { @@ -97,9 +126,6 @@ def test_nvfp4_moe_actmax_ar(self, tiny_deepseek_v2_model_path, dataloader): assert is_model_outputs_similar(model_name, quantized_model_path) shutil.rmtree(self.save_dir, ignore_errors=True) - @pytest.mark.skipif( - transformers_version >= version.parse("5.0.0"), reason="transformers v5 MOE model has breaking changes" - ) def test_mxfp4_moe_ar(self, tiny_deepseek_v2_model_path, dataloader): model_name = tiny_deepseek_v2_model_path layer_config = { @@ -333,9 +359,6 @@ def test_nvfp4_autoround_save_quantized(self, tiny_opt_model_path, dataloader): ), "Illegal NVFP4 packing name or data_type or shape" shutil.rmtree(quantized_model_path, ignore_errors=True) - @pytest.mark.skipif( - transformers_version >= version.parse("5.0.0"), reason="transformers v5 MOE model has breaking changes" - ) def test_qwen_moe_quant_infer(self, tiny_qwen_moe_model_path, dataloader): model_name = tiny_qwen_moe_model_path layer_config = { diff --git a/test/test_cuda/models/test_moe_experts_interface.py b/test/test_cuda/models/test_moe_experts_interface.py new file mode 100644 index 000000000..a292d4294 --- /dev/null +++ b/test/test_cuda/models/test_moe_experts_interface.py @@ -0,0 +1,239 @@ +# Copyright (c) 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Test script for moe_experts_interface.py - linear_loop experts implementation. + +This verifies that: +1. linear_loop is registered with transformers 'ALL_EXPERTS_FUNCTIONS' +2. Fused expert weights are correctly unfused to nn.Linear layers +3. The forward pass produces correct results +""" + +import pytest +import torch +from torch import nn + + +def _skip_if_no_linear_loop(): + from auto_round.modeling.fused_moe.moe_experts_interface import is_linear_loop_available + + if not is_linear_loop_available(): + pytest.skip("transformers MOE integration not available") + + +def test_linear_loop_registration(): + """Test that linear_loop is registered with transformers.""" + from auto_round.modeling.fused_moe.moe_experts_interface import ( + register_linear_loop_experts, + ) + + _skip_if_no_linear_loop() + success = register_linear_loop_experts() + assert success, "Failed to register linear_loop" + + from transformers.integrations.moe import ALL_EXPERTS_FUNCTIONS + + assert "linear_loop" in ALL_EXPERTS_FUNCTIONS._global_mapping + + +def test_unfuse_experts_weights(): + """Test unfusing fused expert weights to nn.Linear layers.""" + from auto_round.modeling.fused_moe.moe_experts_interface import _unfuse_experts_weights_inplace + + # Create a mock fused experts module (Mixtral style - not transposed) + num_experts = 4 + hidden_dim = 64 + intermediate_dim = 128 + + class MockFusedExperts(nn.Module): + def __init__(self): + super().__init__() + # Not transposed: (num_experts, 2*intermediate, hidden) + self.gate_up_proj = nn.Parameter(torch.randn(num_experts, 2 * intermediate_dim, hidden_dim)) + # (num_experts, hidden, intermediate) + self.down_proj = nn.Parameter(torch.randn(num_experts, hidden_dim, intermediate_dim)) + self.act_fn = nn.SiLU() + self.num_experts = num_experts + self.has_bias = False + self.is_transposed = False + + module = MockFusedExperts() + + # Store original weights for comparison + original_gate_up = module.gate_up_proj.data.clone() + original_down = module.down_proj.data.clone() + + # Unfuse (check_decorator=False since mock module doesn't have the decorator) + success = _unfuse_experts_weights_inplace(module, check_decorator=False) + assert success, "Failed to unfuse weights" + + # Verify structure + assert isinstance(module.gate_up_proj, nn.ModuleList) + assert isinstance(module.down_proj, nn.ModuleList) + assert len(module.gate_up_proj) == num_experts + assert len(module.down_proj) == num_experts + + # Verify weights are preserved + for i in range(num_experts): + # gate_up: original (2*intermediate, hidden), linear.weight should be same + assert torch.allclose( + module.gate_up_proj[i].weight.data, original_gate_up[i], atol=1e-6 + ), f"gate_up weight mismatch for expert {i}" + + # down: original (hidden, intermediate), linear.weight should be same + assert torch.allclose( + module.down_proj[i].weight.data, original_down[i], atol=1e-6 + ), f"down weight mismatch for expert {i}" + + +def test_unfuse_experts_weights_transposed(): + """Test unfusing transposed expert weights (Llama4/GptOss style).""" + from auto_round.modeling.fused_moe.moe_experts_interface import _unfuse_experts_weights_inplace + + num_experts = 4 + hidden_dim = 64 + intermediate_dim = 128 + + class MockFusedExpertsTransposed(nn.Module): + def __init__(self): + super().__init__() + # Transposed: (num_experts, hidden, 2*intermediate) + self.gate_up_proj = nn.Parameter(torch.randn(num_experts, hidden_dim, 2 * intermediate_dim)) + # Transposed: (num_experts, intermediate, hidden) + self.down_proj = nn.Parameter(torch.randn(num_experts, intermediate_dim, hidden_dim)) + self.act_fn = nn.SiLU() + self.num_experts = num_experts + self.is_transposed = True + self.has_bias = False + + module = MockFusedExpertsTransposed() + + # Store original weights for comparison + original_gate_up = module.gate_up_proj.data.clone() + original_down = module.down_proj.data.clone() + + # Unfuse (check_decorator=False since mock module doesn't have the decorator) + success = _unfuse_experts_weights_inplace(module, check_decorator=False) + assert success, "Failed to unfuse transposed weights" + + # Verify structure + assert isinstance(module.gate_up_proj, nn.ModuleList) + assert isinstance(module.down_proj, nn.ModuleList) + + # Verify weights are correctly transposed + for i in range(num_experts): + # gate_up: original (hidden, 2*intermediate), should be transposed to (2*intermediate, hidden) + assert torch.allclose( + module.gate_up_proj[i].weight.data, original_gate_up[i].t(), atol=1e-6 + ), f"gate_up weight mismatch for expert {i}" + + # down: original (intermediate, hidden), should be transposed to (hidden, intermediate) + assert torch.allclose( + module.down_proj[i].weight.data, original_down[i].t(), atol=1e-6 + ), f"down weight mismatch for expert {i}" + + +def test_linear_loop_forward(): + """Test that linear_loop forward produces correct results.""" + from auto_round.modeling.fused_moe.moe_experts_interface import ( + _unfuse_experts_weights_inplace, + linear_loop_experts_forward, + ) + + num_experts = 4 + hidden_dim = 64 + intermediate_dim = 128 + num_tokens = 10 + top_k = 2 + + # Create module with unfused weights + class MockExperts(nn.Module): + def __init__(self): + super().__init__() + self.gate_up_proj = nn.Parameter(torch.randn(num_experts, 2 * intermediate_dim, hidden_dim)) + self.down_proj = nn.Parameter(torch.randn(num_experts, hidden_dim, intermediate_dim)) + self.act_fn = nn.SiLU() + self.num_experts = num_experts + self.has_bias = False + self.is_transposed = False + + module = MockExperts() + + # Unfuse weights (check_decorator=False since mock module doesn't have the decorator) + _unfuse_experts_weights_inplace(module, check_decorator=False) + + # Create inputs + hidden_states = torch.randn(num_tokens, hidden_dim) + top_k_index = torch.randint(0, num_experts, (num_tokens, top_k)) + top_k_weights = torch.softmax(torch.randn(num_tokens, top_k), dim=-1) + + # Run forward + output = linear_loop_experts_forward(module, hidden_states, top_k_index, top_k_weights) + + # Verify output shape + assert output.shape == hidden_states.shape, f"Output shape mismatch: {output.shape} vs {hidden_states.shape}" + + # Verify output is not all zeros (sanity check) + assert not torch.allclose(output, torch.zeros_like(output)), "Output is all zeros" + + +def test_prepare_model_for_moe_quantization(): + """Test the full prepare_model_for_moe_quantization flow.""" + from auto_round.modeling.fused_moe.moe_experts_interface import ( + prepare_model_for_moe_quantization, + ) + + _skip_if_no_linear_loop() + + num_experts = 4 + hidden_dim = 64 + intermediate_dim = 128 + + # Create a mock model with fused experts + class MockConfig: + def __init__(self): + self._experts_implementation = "eager" + + class MockExpertsModule(nn.Module): + def __init__(self): + super().__init__() + self.gate_up_proj = nn.Parameter(torch.randn(num_experts, 2 * intermediate_dim, hidden_dim)) + self.down_proj = nn.Parameter(torch.randn(num_experts, hidden_dim, intermediate_dim)) + self.act_fn = nn.SiLU() + self.num_experts = num_experts + self.has_bias = False + self.is_transposed = False + + def forward(self, x): + pass + + # Add __wrapped__ to simulate @use_experts_implementation decorator + MockExpertsModule.forward.__wrapped__ = True + + class MockModel(nn.Module): + def __init__(self): + super().__init__() + self.config = MockConfig() + self.layer = nn.ModuleDict({"experts": MockExpertsModule()}) + + model = MockModel() + + # Prepare for quantization + unfused = prepare_model_for_moe_quantization(model) + + # Verify + assert model.config._experts_implementation == "linear_loop" + assert len(unfused) == 1 + assert isinstance(model.layer["experts"].gate_up_proj, nn.ModuleList) From dd45c31038490fd8330b00355816a5167e0555d9 Mon Sep 17 00:00:00 2001 From: Heng Guo Date: Wed, 4 Feb 2026 09:21:59 +0800 Subject: [PATCH 08/13] fix cuda ut fail (#1370) Signed-off-by: n1ck-guo --- auto_round/compressors/base.py | 18 ++++++++---------- test/test_cuda/models/test_get_block_name.py | 2 +- test/test_cuda/models/test_support_vlms.py | 8 +++++++- test/test_cuda/models/test_vlms.py | 2 +- 4 files changed, 17 insertions(+), 13 deletions(-) diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index 9a620b778..f0fe6f160 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -514,20 +514,18 @@ def __init__( except (ImportError, ModuleNotFoundError): logger.error("algorithm extension import error, fallback to default mode") - def _gen_auto_scheme( - self, model: torch.nn.Module, scheme: AutoScheme, dataset: str, device_map: Union[str, int, dict, torch.device] - ) -> dict[str, dict]: + def _gen_auto_scheme(self) -> dict[str, dict]: if self.mllm: logger.info("AutoScheme is not yet supported for multimodal LLMs.") sys.exit(-1) - if is_quantized_input_module(model): + if is_quantized_input_module(self.model): logger.info("AutoScheme does not currently support quantized input models (e.g., FP8).") sys.exit(-1) all_dtypes = [] all_gguf = True - for option in scheme.options: + for option in self.orig_scheme.options: # Resolve the quantization scheme or data type dtype = "int" if isinstance(option, str): @@ -579,15 +577,15 @@ def _gen_auto_scheme( # mainly using quant_layers and fixed by users from auto_round.auto_scheme.gen_auto_scheme import GenScheme - if not self.enable_torch_compile and self.super_bits is None and not scheme.low_gpu_mem_usage: + if not self.enable_torch_compile and self.super_bits is None and not self.orig_scheme.low_gpu_mem_usage: logger.warning("we strongly recommend to set `enable_torch_compile` to True for AutoScheme to save VRAM") self.scheme_generator = GenScheme( - scheme, + self.orig_scheme, self.model, quant_layer_names, fixed_layer_scheme_new, - dataset, - device_map=device_map, + self.dataset, + device_map=self.device_map, tokenizer=self.tokenizer, enable_torch_compile=self.enable_torch_compile, ) @@ -1744,7 +1742,7 @@ def configure_layer_config(self, enable_gguf_official_mixed: None | bool = True) self.ignore_layers += "," + tmp_str if self.is_auto_scheme: - self.layer_config = self._gen_auto_scheme(self.model, self.orig_scheme, self.dataset, self.device_map) + self.layer_config = self._gen_auto_scheme() fill_default_value = True if self.is_auto_scheme: diff --git a/test/test_cuda/models/test_get_block_name.py b/test/test_cuda/models/test_get_block_name.py index 1850ec5c7..e8928d1f5 100644 --- a/test/test_cuda/models/test_get_block_name.py +++ b/test/test_cuda/models/test_get_block_name.py @@ -165,7 +165,7 @@ def test_gemma3(self): assert not is_pure_text_model(model) def test_Mistral3(self): - model_name = "/models/Mistral-Small-3.1-24B-Instruct-2503" + model_name = "/models/Mistral-Small-3.2-24B-Instruct-2506" model = Mistral3ForConditionalGeneration.from_pretrained(model_name, torch_dtype="auto", trust_remote_code=True) block_names = get_block_names(model) self.check_block_names(block_names, ["model.language_model.layers"], [40]) diff --git a/test/test_cuda/models/test_support_vlms.py b/test/test_cuda/models/test_support_vlms.py index 798674c83..67296925a 100644 --- a/test/test_cuda/models/test_support_vlms.py +++ b/test/test_cuda/models/test_support_vlms.py @@ -4,11 +4,14 @@ import pytest import requests +from packaging import version from PIL import Image from transformers import AutoRoundConfig # # must import for auto-round format from auto_round.testing_utils import require_gptqmodel, require_package_version_ut, require_vlm_env +from ...helpers import transformers_version + AUTO_ROUND_PATH = __file__.split("/") AUTO_ROUND_PATH = "/".join(AUTO_ROUND_PATH[: AUTO_ROUND_PATH.index("test")]) @@ -177,7 +180,10 @@ def test_phi3_vision_awq(self): print(response) shutil.rmtree(quantized_model_path, ignore_errors=True) - @require_package_version_ut("transformers", "<4.54.0") + @pytest.mark.skipif( + transformers_version >= version.parse("4.57.0"), + reason="transformers api changed", + ) def test_glm(self): model_path = "/models/glm-4v-9b/" ## test tune diff --git a/test/test_cuda/models/test_vlms.py b/test/test_cuda/models/test_vlms.py index b0aa4a327..ad7fcaa11 100644 --- a/test/test_cuda/models/test_vlms.py +++ b/test/test_cuda/models/test_vlms.py @@ -136,7 +136,7 @@ def test_mllm_detect(self): "/models/Phi-3.5-vision-instruct", "/models/Qwen2-VL-2B-Instruct", "/models/SmolVLM-256M-Instruct", - "/models/Mistral-Small-3.1-24B-Instruct-2503", + "/models/Mistral-Small-3.2-24B-Instruct-2506", "/models/InternVL3-1B", "/models/pixtral-12b", ]: From 10028e83aedec044eedc2b4ef60d07235770b007 Mon Sep 17 00:00:00 2001 From: Xin He Date: Wed, 4 Feb 2026 09:23:32 +0800 Subject: [PATCH 09/13] [Regression] Detach scale tensor to prevent holding computation graph references (#1389) --- auto_round/experimental/attention.py | 2 +- auto_round/experimental/kv_cache.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/auto_round/experimental/attention.py b/auto_round/experimental/attention.py index 348d5524d..0a2361423 100644 --- a/auto_round/experimental/attention.py +++ b/auto_round/experimental/attention.py @@ -96,7 +96,7 @@ def forward( ) update_parameter_data(module, query_max, QUERY_MAX_NAME) _, query_scale = per_tensor_fp8_qdq(query, tensor_max=query_max) - update_parameter_data(module, query_scale.squeeze(0), QUERY_SCALE_NAME) + update_parameter_data(module, query_scale.squeeze(0).detach(), QUERY_SCALE_NAME) # original attention return ALL_ATTENTION_FUNCTIONS[self._original_impl]( module, diff --git a/auto_round/experimental/kv_cache.py b/auto_round/experimental/kv_cache.py index dbbe7e7eb..1774315d7 100644 --- a/auto_round/experimental/kv_cache.py +++ b/auto_round/experimental/kv_cache.py @@ -173,7 +173,8 @@ def _quant_dequant(self, tensor: torch.Tensor, kv_type: KVCacheScaleType, layer_ scales = self.v_scales qdq_tensor, scale = per_tensor_fp8_qdq(tensor) - _pad_and_append_at_idx_(scales, layer_idx, scale.squeeze(0)) + # Detach scale to prevent holding computation graph references + _pad_and_append_at_idx_(scales, layer_idx, scale.squeeze(0).detach()) return qdq_tensor From b2dff81074e33caa5f794f1bdf5c272d16b0ec0c Mon Sep 17 00:00:00 2001 From: Wenhua Cheng Date: Wed, 4 Feb 2026 10:21:28 +0800 Subject: [PATCH 10/13] fix layer config (#1373) Signed-off-by: n1ck-guo Signed-off-by: WeiweiZhang1 Co-authored-by: n1ck-guo Co-authored-by: WeiweiZhang1 Signed-off-by: lvliang-intel --- auto_round/compressors/base.py | 461 +++++++----------- .../export_to_nvfp_mxfp.py | 3 +- .../export_to_llmcompressor/export_to_fp.py | 3 +- auto_round/schemes.py | 13 +- auto_round/utils/device.py | 212 +++++++- auto_round/utils/model.py | 10 +- .../test_cpu/core/test_low_cpu_mem_options.py | 7 +- test/test_cpu/schemes/test_scheme.py | 27 + 8 files changed, 441 insertions(+), 295 deletions(-) diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index f0fe6f160..93b892163 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -15,9 +15,7 @@ import copy import os import re -import shutil import sys -import tempfile import time import traceback from collections import defaultdict @@ -28,13 +26,11 @@ import accelerate import torch -import transformers from accelerate.big_modeling import dispatch_model, infer_auto_device_map from accelerate.utils import get_balanced_memory, get_max_memory -from packaging import version from torch import autocast from tqdm import tqdm -from transformers import AutoConfig, set_seed +from transformers import set_seed from auto_round import envs from auto_round.auto_scheme.gen_auto_scheme import AutoScheme @@ -80,7 +76,6 @@ check_seqlen_compatible, check_to_quantized, clear_memory, - clear_module_weights, compile_func, convert_dtype_str2torch, convert_module_to_hp_if_necessary, @@ -102,7 +97,6 @@ load_module_weights, memory_monitor, mv_module_from_gpu, - save_module_weights, set_amax_for_all_moe_layers, set_module, to_device, @@ -110,11 +104,22 @@ unsupported_meta_device, ) from auto_round.utils.device import ( + cleanup_cpu_offload_dir, clear_memory_if_reached_threshold, + discard_offloaded_block, + estimate_block_size_gb, + estimate_inputs_size_gb, + estimate_model_size_gb, + estimate_tensor_size_gb, get_major_device, + init_cpu_offload_dir, + load_offloaded_block_weights, + offload_block_weights, parse_available_devices, + restore_offloaded_blocks, set_auto_device_map_for_block_with_tuning, set_non_auto_device_map, + stream_offload_blocks, ) from auto_round.wrapper import WrapperLinear, WrapperMultiblock, unwrapper_block, unwrapper_layer, wrapper_block @@ -259,6 +264,16 @@ def __init__( ... # ... ... } """ + scheme_fields = [f.name for f in fields(QuantizationScheme)] + + # 1. Pre-extract user-specified overrides from kwargs + # This ensures we know exactly what the user wants to "force" + self.user_scheme_overrides = {} + for k in scheme_fields: + if k in kwargs: + value = kwargs.pop(k) + if value is not None: + self.user_scheme_overrides[k] = value # Model related model_dtype = kwargs.pop("model_dtype", None) @@ -313,8 +328,9 @@ def __init__( # should be set after loading model and set layer_config, cause some special scheme need these. # Preserve the original, unparsed scheme for later use in auto scheme generation # within `configure_layer_config` (which may need the raw value instead of `self.scheme`). + default_scheme, self.is_auto_scheme = self._parse_and_set_scheme(scheme, self.user_scheme_overrides) self.orig_scheme = copy.deepcopy(scheme) - self.scheme, self.is_auto_scheme = self._parse_and_set_scheme(scheme, kwargs) + self.scheme = default_scheme gguf_scheme_name = get_gguf_scheme(self.scheme) # GGUF uses fp32 scale dtype as default @@ -616,127 +632,152 @@ def _set_device(self, device_map: Union[str, torch.device, int, dict]) -> None: else: raise TypeError(f"device_map should be [str, torch.device, int, dict], but got {type(device_map)}") - def _parse_and_set_scheme( - self, scheme: Union[str, dict, QuantizationScheme], kwargs - ) -> tuple[QuantizationScheme, bool]: - """Parse and set the quantization scheme.""" - - def _parse_and_set(scheme, kwargs): - if kwargs.get("data_type", None) and kwargs["data_type"].endswith("_dq") and not scheme.startswith("gguf"): - if "bits" not in kwargs: - data_type = kwargs["data_type"] - raise KeyError( - f"please set bits when setting data_type={data_type}, or using scheme as an alternative." - ) - bits = kwargs["bits"] - scheme = f"gguf:q{bits}_k" if bits == 6 else f"gguf:q{bits}_k_s" - res = None - if isinstance(scheme, QuantizationScheme): - scheme = asdict(scheme) - elif isinstance(scheme, dict): - scheme = scheme - elif isinstance(scheme, str): - # We’d better keep the string scheme instead of the dict config, - # since GGUF uses different mixed-bit strategies for q4_k_s and q4_k_m - # even though they share the same scheme dict. - scheme = scheme.strip("'\" ") - res = scheme - scheme = scheme.upper() - self.layer_config = _handle_special_schemes( - scheme, - self.layer_config, - self.model, - supported_types=self.supported_types, - inner_supported_types=self.inner_supported_types, - quant_lm_head=self.quant_lm_head, - mllm=getattr(self, "mllm", False), - ) - scheme = asdict(preset_name_to_scheme(scheme)) - scheme_keys = [f.name for f in fields(QuantizationScheme)] - for key in scheme_keys: - if key in kwargs and kwargs[key] is not None: - setattr(self, key, kwargs[key]) - else: - setattr(self, key, scheme.get(key, None)) - # kwargs.pop(key, None) - if self.act_dynamic is None: - self.act_dynamic = True + def _reconcile_bits_and_dtype(self, config: dict, prefix: str = ""): + """ + Harmonizes 'bits' and 'data_type' for weights or activations. + Ensures internal consistency by prioritizing data_type inference. + """ + dt_key = f"{prefix}data_type" + bits_key = f"{prefix}bits" + + if config.get(dt_key) is None: + return + + # Infer the correct bit-width based on the data_type string + inferred_bits = infer_bits_by_data_type(config[dt_key]) - tmp_bits = infer_bits_by_data_type(self.data_type) - if tmp_bits is not None and tmp_bits < 16 and tmp_bits != self.bits: + if inferred_bits is not None and inferred_bits < 16: + # Check for conflict between user-specified bits and inferred bits + if inferred_bits != config.get(bits_key): logger.warning( - f"'data_type' do not match the specified 'bits' setting. Resetting 'bits' to {tmp_bits}." + f"'{dt_key}' does not match '{bits_key}'. " f"Resetting '{bits_key}' to {inferred_bits}." ) - self.bits = tmp_bits - if tmp_bits is not None and tmp_bits < 16: - for ( - supported_dtype - ) in SUPPORTED_DTYPES: # to easily handle dtype mx_fp4 and layer_config={xxx:{bits:8}} - if self.data_type.startswith(supported_dtype): - if supported_dtype + str(tmp_bits) == self.data_type: # could not replace FP8_e4m3 - self.data_type = supported_dtype - break + config[bits_key] = inferred_bits - self.act_group_size = self.act_group_size if self.act_group_size is not None else self.group_size - self.act_bits = self.act_bits if self.act_bits is not None else 16 - self.act_sym = self.act_sym if self.act_sym is not None else self.sym + # Normalize data_type (e.g., 'mx_fp4' -> 'mx') + for supported in SUPPORTED_DTYPES: + if config[dt_key] == f"{supported}{inferred_bits}": + config[dt_key] = supported + break - if self.act_data_type is None: - if self.data_type in SUPPORTED_DTYPES and self.act_bits < 16: - self.act_data_type = self.data_type - logger.info(f"activation adopts {self.data_type}") - else: - self.act_data_type = "float" - tmp_act_bits = infer_bits_by_data_type(self.act_data_type) - if tmp_act_bits is not None and tmp_act_bits < 16 and tmp_act_bits != self.act_bits: - self.act_bits = tmp_act_bits - logger.warning( - f"`act_data_type` do not" - f" match the specified 'act_bits' setting. Resetting 'act_bits' to {tmp_act_bits}." + def _override_scheme_with_user_specify( + self, scheme: Union[str, dict, QuantizationScheme], user_scheme_overrides: dict[str, Any], return_str=True + ) -> Union[str, QuantizationScheme]: + """ + Updates a base quantization scheme with user-provided overrides. + Handles GGUF formatting and synchronizes weight/activation parameters. + """ + # 1. GGUF special handling: map data_type suffix to GGUF scheme names + dt_override = user_scheme_overrides.get("data_type", "") + if ( + isinstance(scheme, QuantizationScheme) or (isinstance(scheme, str) and not scheme.startswith("gguf")) + ) and dt_override.endswith("_dq"): + if "bits" not in user_scheme_overrides: + raise KeyError(f"Must specify 'bits' when using data_type={dt_override}") + + bits = user_scheme_overrides["bits"] + suffix = "k" if bits == 6 else "k_s" + scheme = f"gguf:q{bits}_{suffix}" + + # 2. Convert input scheme to a dictionary for processing + if isinstance(scheme, QuantizationScheme): + scheme_dict = asdict(scheme) + elif isinstance(scheme, str): + normalized_name = scheme.strip("'\" ").upper() + if normalized_name.startswith("GGUF") and len(user_scheme_overrides) > 0: + logger.warning_once( + "When using GGUF scheme, user-specified overrides will be ignored to ensure format compatibility." ) - if tmp_act_bits is not None and tmp_act_bits < 16: - for ( - supported_dtype - ) in SUPPORTED_DTYPES: # To easily handle dtype mx_fp4 and layer_config={xxx:{bits:8}} - if self.act_data_type.startswith(supported_dtype): - if supported_dtype + str(tmp_act_bits) == self.act_data_type: # Could not replace FP8_e4m3 - self.act_data_type = supported_dtype - break - for key in scheme_keys: - scheme[key] = getattr(self, key) - if res and QuantizationScheme.from_dict(scheme) == preset_name_to_scheme(res): - return res + user_scheme_overrides = {} + # If no overrides exist, return the normalized string immediately + if not user_scheme_overrides and return_str: + return normalized_name + scheme_dict = asdict(preset_name_to_scheme(normalized_name)) + else: + scheme_dict = scheme.copy() + + # 3. Apply overrides and define default behaviors + scheme_dict.update(user_scheme_overrides) + + if scheme_dict.get("act_dynamic") is None: + scheme_dict["act_dynamic"] = True + + # 4. Reconcile weight settings (bits vs data_type) + self._reconcile_bits_and_dtype(scheme_dict) + + # 5. Fallback logic: Inherit activation settings from weight settings + scheme_dict["act_group_size"] = ( + scheme_dict.get("act_group_size") + if scheme_dict.get("act_group_size") is not None + else scheme_dict.get("group_size") + ) + scheme_dict["act_bits"] = scheme_dict.get("act_bits") or 16 + scheme_dict["act_sym"] = ( + scheme_dict.get("act_sym") if scheme_dict.get("act_sym") is not None else scheme_dict.get("sym") + ) + + # 6. Activation data_type logic + if scheme_dict.get("act_data_type") is None: + is_supported = scheme_dict["data_type"] in SUPPORTED_DTYPES + if is_supported and scheme_dict["act_bits"] < 16: + scheme_dict["act_data_type"] = scheme_dict["data_type"] + logger.info(f"Activation adopting weight data_type: {scheme_dict['data_type']}") + else: + scheme_dict["act_data_type"] = "float" + + # 7. Reconcile activation settings + self._reconcile_bits_and_dtype(scheme_dict, prefix="act_") + + return QuantizationScheme.from_dict(scheme_dict) + + def _parse_and_set_scheme( + self, scheme: Union[str, dict, QuantizationScheme, AutoScheme], user_scheme_overrides: dict[str, Any] + ) -> tuple[Union[str, QuantizationScheme], bool]: + """ + Parses the final scheme and binds all resulting attributes to 'self'. + """ + + is_auto_scheme = isinstance(scheme, AutoScheme) + + if is_auto_scheme: + if not scheme.options: + raise ValueError("AutoScheme options cannot be empty") else: - return QuantizationScheme.from_dict(scheme) - - if isinstance(scheme, AutoScheme): - if len(scheme.options) <= 0: - raise ValueError("options of AutoScheme must not be empty") - options = [] - for option in scheme.options: - new_option = _parse_and_set(option, kwargs) - options.append(new_option) - scheme.options = options - for opt in options: - if isinstance(opt, str) and opt == "BF16": + for option in scheme.options: + if isinstance(option, str): + if "mixed" in option: + raise ValueError(f"Mixed option {option} is not supported") + + # Map user overrides across all auto-scheme options + scheme.options = [ + self._override_scheme_with_user_specify(opt, user_scheme_overrides) for opt in scheme.options + ] + + # Select the primary scheme for attribute binding (skipping BF16) + default_scheme = scheme.options[0] + for opt in scheme.options: + if opt == "BF16": continue if isinstance(opt, QuantizationScheme): - if opt.bits >= 16 and (opt.act_bits is None or opt.act_bits >= 16): - continue - self.scheme = opt # Choose the first one that not 16 bits - break - # apply scheme to set default bits - scheme = _parse_and_set(self.scheme, kwargs) - is_auto_scheme = True + if opt.bits < 16 or (opt.act_bits and opt.act_bits < 16): + default_scheme = opt + break else: - scheme = _parse_and_set(scheme, kwargs) - is_auto_scheme = False + default_scheme = self._override_scheme_with_user_specify(scheme, user_scheme_overrides) - scheme_keys = [f.name for f in fields(QuantizationScheme)] - for key in scheme_keys: - kwargs.pop(key, None) + # Extract attributes from the chosen default_scheme + if isinstance(default_scheme, str): + final_attrs = self._override_scheme_with_user_specify( + default_scheme, user_scheme_overrides, return_str=False + ) + else: + final_attrs = asdict(default_scheme) + + # Bind attributes to self for easy instance-level access + for key, value in final_attrs.items(): + setattr(self, key, value) - return scheme, is_auto_scheme + return default_scheme, is_auto_scheme def _adjust_torch_compile(self, enable_torch_compile: bool) -> None: """Sets the torch compile configuration for the tuning.""" @@ -1528,193 +1569,44 @@ def _quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str]) def _estimate_tensor_size_gb(self, tensor) -> float: """Estimate the size of a tensor in GB.""" - if tensor is None: - return 0.0 - if isinstance(tensor, torch.Tensor): - return tensor.numel() * tensor.element_size() / (1024**3) - elif isinstance(tensor, list): - return sum(self._estimate_tensor_size_gb(t) for t in tensor) - elif isinstance(tensor, dict): - return sum(self._estimate_tensor_size_gb(v) for v in tensor.values()) - return 0.0 + return estimate_tensor_size_gb(tensor) def _estimate_inputs_size_gb(self, all_inputs: dict) -> float: """Estimate the total size of calibration inputs in GB.""" - total = 0.0 - for name, inputs in all_inputs.items(): - total += self._estimate_tensor_size_gb(inputs) - return total + return estimate_inputs_size_gb(all_inputs) def _estimate_model_size_gb(self) -> float: """Estimate the model weights size in GB.""" - total = 0.0 - for param in self.model.parameters(): - if param.numel() > 0: # Skip empty tensors - total += param.numel() * param.element_size() / (1024**3) - return total + return estimate_model_size_gb(self.model) def _estimate_block_size_gb(self, block: torch.nn.Module) -> float: """Estimate a block's weights size in GB.""" - total = 0.0 - for param in block.parameters(): - if param.numel() > 0: - total += param.numel() * param.element_size() / (1024**3) - return total + return estimate_block_size_gb(block) def _init_cpu_offload_dir(self) -> Optional[str]: - if not self.low_cpu_mem_usage: - return None - if self._cpu_offload_tempdir is None: - self._cpu_offload_tempdir = tempfile.mkdtemp(prefix="autoround_cpu_offload_") - return self._cpu_offload_tempdir + return init_cpu_offload_dir(self) def _offload_block_weights(self, block_name: str, block: torch.nn.Module) -> None: - if not self.low_cpu_mem_usage: - return - offload_dir = self._init_cpu_offload_dir() - if offload_dir is None: - return - safe_name = block_name.replace(".", "_") - save_path = os.path.join(offload_dir, f"{safe_name}.pt") - metadata = save_module_weights(block, save_path) - self._offloaded_blocks[block_name] = metadata - clear_module_weights(block) + offload_block_weights(self, block_name, block) def _stream_offload_blocks(self, all_blocks: list[list[str]]) -> None: """Offload all block weights to disk and clear from memory.""" - if not (self.low_cpu_mem_usage and self.cpu_stream_offload_blocks): - return - offload_dir = self._init_cpu_offload_dir() - if offload_dir is None: - return - logger.info("stream offloading block weights to disk...") - total_offloaded_gb = 0.0 - for block_names in all_blocks: - for block_name in block_names: - if block_name in self._offloaded_blocks: - continue - block = get_module(self.model, block_name) - if block is None: - continue - block_size_gb = self._estimate_block_size_gb(block) - total_offloaded_gb += block_size_gb - safe_name = block_name.replace(".", "_") - save_path = os.path.join(offload_dir, f"{safe_name}.pt") - # Save entire block state_dict (all submodule weights) - state_dict = {k: v.cpu() for k, v in block.state_dict().items()} - torch.save(state_dict, save_path) - self._offloaded_blocks[block_name] = {"save_path": save_path} - # Clear all submodule weights - for submodule in block.modules(): - if hasattr(submodule, "weight") and submodule.weight is not None: - clear_module_weights(submodule) - if hasattr(submodule, "bias") and submodule.bias is not None: - clear_module_weights(submodule) - clear_memory(device_list=self.device_list) - logger.info(f"stream offload done, offloaded {total_offloaded_gb:.2f} GB of block weights") + stream_offload_blocks(self, all_blocks) def _load_offloaded_block_weights(self, block_name: str, block: torch.nn.Module) -> None: """Load block weights from disk back into memory.""" - if not (self.low_cpu_mem_usage and self.cpu_stream_offload_blocks): - return - metadata = self._offloaded_blocks.get(block_name) - if not metadata: - return - save_path = metadata.get("save_path") - if not save_path or not os.path.exists(save_path): - logger.warning(f"Cannot load block weights: file {save_path} does not exist") - return - try: - state_dict = torch.load(save_path, map_location="cpu") - # Manually assign parameters to handle empty tensor replacement - for name, param in state_dict.items(): - parts = name.split(".") - target = block - for part in parts[:-1]: - target = getattr(target, part) - param_name = parts[-1] - if hasattr(target, param_name): - old_param = getattr(target, param_name) - if isinstance(old_param, torch.nn.Parameter): - setattr(target, param_name, torch.nn.Parameter(param, requires_grad=old_param.requires_grad)) - else: - setattr(target, param_name, param) - except Exception as e: - logger.warning(f"Failed to load block weights from {save_path}: {e}") + load_offloaded_block_weights(self, block_name, block) def _discard_offloaded_block(self, block_name: str) -> None: """Discard the original offload file and re-offload quantized weights.""" - if not (self.low_cpu_mem_usage and self.cpu_stream_offload_blocks): - return - metadata = self._offloaded_blocks.pop(block_name, None) - if not metadata: - return - save_path = metadata.get("save_path") - if save_path and os.path.exists(save_path): - try: - os.remove(save_path) - except Exception as e: - logger.warning(f"Failed to remove offloaded block file {save_path}: {e}") - - # Re-offload the quantized block weights to disk - block = get_module(self.model, block_name) - if block is None: - return - offload_dir = self._init_cpu_offload_dir() - if offload_dir is None: - return - safe_name = block_name.replace(".", "_") - new_save_path = os.path.join(offload_dir, f"{safe_name}_quantized.pt") - try: - state_dict = {k: v.cpu() for k, v in block.state_dict().items()} - torch.save(state_dict, new_save_path) - self._offloaded_blocks[block_name] = {"save_path": new_save_path, "quantized": True} - # Clear all submodule weights - for submodule in block.modules(): - if hasattr(submodule, "weight") and submodule.weight is not None: - clear_module_weights(submodule) - if hasattr(submodule, "bias") and submodule.bias is not None: - clear_module_weights(submodule) - except Exception as e: - logger.warning(f"Failed to re-offload quantized block {block_name}: {e}") + discard_offloaded_block(self, block_name) def _restore_offloaded_blocks(self) -> None: """Restore all offloaded block weights back to memory.""" - if not self._offloaded_blocks: - return - for block_name, metadata in list(self._offloaded_blocks.items()): - try: - block = get_module(self.model, block_name) - save_path = metadata.get("save_path") - if not save_path or not os.path.exists(save_path): - logger.warning(f"Cannot restore block {block_name}: file {save_path} does not exist") - continue - state_dict = torch.load(save_path, map_location="cpu") - # Manually assign parameters to handle empty tensor replacement - for name, param in state_dict.items(): - parts = name.split(".") - target = block - for part in parts[:-1]: - target = getattr(target, part) - param_name = parts[-1] - if hasattr(target, param_name): - old_param = getattr(target, param_name) - if isinstance(old_param, torch.nn.Parameter): - setattr( - target, param_name, torch.nn.Parameter(param, requires_grad=old_param.requires_grad) - ) - else: - setattr(target, param_name, param) - except Exception as e: - logger.warning(f"Failed to restore offloaded block {block_name}: {e}") + restore_offloaded_blocks(self) def _cleanup_cpu_offload_dir(self) -> None: - if self._cpu_offload_tempdir and os.path.isdir(self._cpu_offload_tempdir): - try: - shutil.rmtree(self._cpu_offload_tempdir) - except Exception as e: - logger.warning(f"Failed to cleanup cpu offload dir {self._cpu_offload_tempdir}: {e}") - self._cpu_offload_tempdir = None + cleanup_cpu_offload_dir(self) def _update_inputs(self, inputs: dict, q_inputs: dict) -> tuple[dict, torch.Tensor]: keys = inputs.keys() @@ -1743,6 +1635,16 @@ def configure_layer_config(self, enable_gguf_official_mixed: None | bool = True) if self.is_auto_scheme: self.layer_config = self._gen_auto_scheme() + else: + self.layer_config = _handle_special_schemes( + self.orig_scheme, + self.layer_config, + self.model, + supported_types=self.supported_types, + inner_supported_types=self.inner_supported_types, + quant_lm_head=self.quant_lm_head, + mllm=getattr(self, "mllm", False), + ) fill_default_value = True if self.is_auto_scheme: @@ -2652,9 +2554,8 @@ def get_act_max_hook(module, input, output): else: act_max = act_max.to(module.act_max.device) if is_nv_fp(self.act_data_type): ## for nvfp per-tensor input_global_scale calculation usage - module.act_max = torch.max( - torch.tensor([act_max.max(), module.act_max.max()], device=act_max.device) - ) + max_val = torch.max(act_max.max(), module.act_max.max()) + module.act_max = max_val.unsqueeze(0) if max_val.dim() == 0 else max_val else: module.act_max = torch.max(act_max, module.act_max) diff --git a/auto_round/export/export_to_autoround/export_to_nvfp_mxfp.py b/auto_round/export/export_to_autoround/export_to_nvfp_mxfp.py index 186eea053..b7382a317 100644 --- a/auto_round/export/export_to_autoround/export_to_nvfp_mxfp.py +++ b/auto_round/export/export_to_autoround/export_to_nvfp_mxfp.py @@ -73,7 +73,8 @@ def pack_layer(name, model, backend, device=None): sym = layer.sym if is_nv_fp(act_data_type) and act_bits <= 8: - if not getattr(layer, "input_global_scale", None): + input_global_scale = getattr(layer, "input_global_scale", None) + if input_global_scale is None: assert hasattr(layer, "act_max") from auto_round.data_type.nvfp import calculate_gparam diff --git a/auto_round/export/export_to_llmcompressor/export_to_fp.py b/auto_round/export/export_to_llmcompressor/export_to_fp.py index e7835309f..05c01ee3a 100644 --- a/auto_round/export/export_to_llmcompressor/export_to_fp.py +++ b/auto_round/export/export_to_llmcompressor/export_to_fp.py @@ -72,7 +72,8 @@ def pack_layer(name, model, device=None): sym = layer.sym if is_nv_fp(act_data_type) and act_bits <= 8: - if not getattr(layer, "input_global_scale", None): + input_global_scale = getattr(layer, "input_global_scale", None) + if input_global_scale is None: assert hasattr(layer, "act_max") from auto_round.data_type.nvfp import calculate_gparam diff --git a/auto_round/schemes.py b/auto_round/schemes.py index 54824270e..e9026450d 100644 --- a/auto_round/schemes.py +++ b/auto_round/schemes.py @@ -176,6 +176,7 @@ def is_preset_scheme(name: str) -> bool: "act_data_type": "mx_fp", "act_group_size": 32, "act_sym": True, + "act_dynamic": True, } ) @@ -188,6 +189,7 @@ def is_preset_scheme(name: str) -> bool: "act_data_type": "mx_fp_rceil", "act_group_size": 32, "act_sym": True, + "act_dynamic": True, } ) @@ -201,6 +203,7 @@ def is_preset_scheme(name: str) -> bool: "act_data_type": "mx_fp", "act_group_size": 32, "act_sym": True, + "act_dynamic": True, } ) @@ -213,6 +216,7 @@ def is_preset_scheme(name: str) -> bool: "act_data_type": "mx_fp_rceil", "act_group_size": 32, "act_sym": True, + "act_dynamic": True, } ) @@ -226,6 +230,7 @@ def is_preset_scheme(name: str) -> bool: "act_data_type": "nv_fp4_with_static_gs", "act_group_size": 16, "act_sym": True, + "act_dynamic": True, } ) @@ -312,6 +317,8 @@ def _handle_special_schemes( Provide some special auto_round recipes. """ + if not isinstance(scheme_name, str): + return layer_config if layer_config is None: layer_config = {} if scheme_name.lower() == "gguf:q2_k_mixed": @@ -340,13 +347,13 @@ def _handle_special_schemes( continue if type(m) in supported_types or type(m) in inner_supported_types: if "expert" in n and "shared" not in n: - layer_config[n] = {"bits": 4} + layer_config[n] = {"bits": 4, "data_type": "int"} elif n != lm_head_name and mllm: layer_config[n] = {"bits": 16} elif n != lm_head_name: - layer_config[n] = {"bits": 8} + layer_config[n] = {"bits": 8, "data_type": "int"} elif n == lm_head_name and quant_lm_head: - layer_config[n] = {"bits": 8} + layer_config[n] = {"bits": 8, "data_type": "int"} return layer_config diff --git a/auto_round/utils/device.py b/auto_round/utils/device.py index a16f441bf..713023c38 100644 --- a/auto_round/utils/device.py +++ b/auto_round/utils/device.py @@ -15,18 +15,27 @@ import gc import os import re +import shutil +import tempfile from contextlib import ContextDecorator, contextmanager from functools import lru_cache from itertools import combinations from threading import Lock -from typing import Callable, Union +from typing import Any, Callable, Optional, Union import cpuinfo import psutil import torch from auto_round.logger import logger -from auto_round.utils.model import check_to_quantized, get_block_names, get_layer_features, get_module +from auto_round.utils.model import ( + check_to_quantized, + clear_module_weights, + get_block_names, + get_layer_features, + get_module, + save_module_weights, +) # Note on HPU usage: # There are two modes available for enabling auto-round on HPU: @@ -501,6 +510,205 @@ def __call__( clear_memory = torch._dynamo.disable()(ClearMemory(device_list=[0])) +def estimate_tensor_size_gb(tensor: Any) -> float: + """Estimate the size of a tensor (or nested tensors) in GB.""" + if tensor is None: + return 0.0 + if isinstance(tensor, torch.Tensor): + return tensor.numel() * tensor.element_size() / (1024**3) + if isinstance(tensor, list): + return sum(estimate_tensor_size_gb(t) for t in tensor) + if isinstance(tensor, dict): + return sum(estimate_tensor_size_gb(v) for v in tensor.values()) + return 0.0 + + +def estimate_inputs_size_gb(all_inputs: dict) -> float: + """Estimate the total size of calibration inputs in GB.""" + total = 0.0 + for _, inputs in all_inputs.items(): + total += estimate_tensor_size_gb(inputs) + return total + + +def estimate_model_size_gb(model: torch.nn.Module) -> float: + """Estimate the model weights size in GB.""" + total = 0.0 + for param in model.parameters(): + if param.numel() > 0: # Skip empty tensors + total += param.numel() * param.element_size() / (1024**3) + return total + + +def estimate_block_size_gb(block: torch.nn.Module) -> float: + """Estimate a block's weights size in GB.""" + total = 0.0 + for param in block.parameters(): + if param.numel() > 0: + total += param.numel() * param.element_size() / (1024**3) + return total + + +def init_cpu_offload_dir(compressor: Any) -> Optional[str]: + if not compressor.low_cpu_mem_usage: + return None + if compressor._cpu_offload_tempdir is None: + compressor._cpu_offload_tempdir = tempfile.mkdtemp(prefix="autoround_cpu_offload_") + return compressor._cpu_offload_tempdir + + +def offload_block_weights(compressor: Any, block_name: str, block: torch.nn.Module) -> None: + if not compressor.low_cpu_mem_usage: + return + offload_dir = init_cpu_offload_dir(compressor) + if offload_dir is None: + return + safe_name = block_name.replace(".", "_") + save_path = os.path.join(offload_dir, f"{safe_name}.pt") + metadata = save_module_weights(block, save_path) + compressor._offloaded_blocks[block_name] = metadata + clear_module_weights(block) + + +def stream_offload_blocks(compressor: Any, all_blocks: list[list[str]]) -> None: + """Offload all block weights to disk and clear from memory.""" + if not (compressor.low_cpu_mem_usage and compressor.cpu_stream_offload_blocks): + return + offload_dir = init_cpu_offload_dir(compressor) + if offload_dir is None: + return + logger.info("stream offloading block weights to disk...") + total_offloaded_gb = 0.0 + for block_names in all_blocks: + for block_name in block_names: + if block_name in compressor._offloaded_blocks: + continue + block = get_module(compressor.model, block_name) + if block is None: + continue + block_size_gb = estimate_block_size_gb(block) + total_offloaded_gb += block_size_gb + safe_name = block_name.replace(".", "_") + save_path = os.path.join(offload_dir, f"{safe_name}.pt") + # Save entire block state_dict (all submodule weights) + state_dict = {k: v.cpu() for k, v in block.state_dict().items()} + torch.save(state_dict, save_path) + compressor._offloaded_blocks[block_name] = {"save_path": save_path} + # Clear all submodule weights + for submodule in block.modules(): + if hasattr(submodule, "weight") and submodule.weight is not None: + clear_module_weights(submodule) + if hasattr(submodule, "bias") and submodule.bias is not None: + clear_module_weights(submodule) + clear_memory(device_list=compressor.device_list) + logger.info(f"stream offload done, offloaded {total_offloaded_gb:.2f} GB of block weights") + + +def load_offloaded_block_weights(compressor: Any, block_name: str, block: torch.nn.Module) -> None: + """Load block weights from disk back into memory.""" + if not (compressor.low_cpu_mem_usage and compressor.cpu_stream_offload_blocks): + return + metadata = compressor._offloaded_blocks.get(block_name) + if not metadata: + return + save_path = metadata.get("save_path") + if not save_path or not os.path.exists(save_path): + logger.warning(f"Cannot load block weights: file {save_path} does not exist") + return + try: + state_dict = torch.load(save_path, map_location="cpu") + # Manually assign parameters to handle empty tensor replacement + for name, param in state_dict.items(): + parts = name.split(".") + target = block + for part in parts[:-1]: + target = getattr(target, part) + param_name = parts[-1] + if hasattr(target, param_name): + old_param = getattr(target, param_name) + if isinstance(old_param, torch.nn.Parameter): + setattr(target, param_name, torch.nn.Parameter(param, requires_grad=old_param.requires_grad)) + else: + setattr(target, param_name, param) + except Exception as e: + logger.warning(f"Failed to load block weights from {save_path}: {e}") + + +def discard_offloaded_block(compressor: Any, block_name: str) -> None: + """Discard the original offload file and re-offload quantized weights.""" + if not (compressor.low_cpu_mem_usage and compressor.cpu_stream_offload_blocks): + return + metadata = compressor._offloaded_blocks.pop(block_name, None) + if not metadata: + return + save_path = metadata.get("save_path") + if save_path and os.path.exists(save_path): + try: + os.remove(save_path) + except Exception as e: + logger.warning(f"Failed to remove offloaded block file {save_path}: {e}") + + # Re-offload the quantized block weights to disk + block = get_module(compressor.model, block_name) + if block is None: + return + offload_dir = init_cpu_offload_dir(compressor) + if offload_dir is None: + return + safe_name = block_name.replace(".", "_") + new_save_path = os.path.join(offload_dir, f"{safe_name}_quantized.pt") + try: + state_dict = {k: v.cpu() for k, v in block.state_dict().items()} + torch.save(state_dict, new_save_path) + compressor._offloaded_blocks[block_name] = {"save_path": new_save_path, "quantized": True} + # Clear all submodule weights + for submodule in block.modules(): + if hasattr(submodule, "weight") and submodule.weight is not None: + clear_module_weights(submodule) + if hasattr(submodule, "bias") and submodule.bias is not None: + clear_module_weights(submodule) + except Exception as e: + logger.warning(f"Failed to re-offload quantized block {block_name}: {e}") + + +def restore_offloaded_blocks(compressor: Any) -> None: + """Restore all offloaded block weights back to memory.""" + if not compressor._offloaded_blocks: + return + for block_name, metadata in list(compressor._offloaded_blocks.items()): + try: + block = get_module(compressor.model, block_name) + save_path = metadata.get("save_path") + if not save_path or not os.path.exists(save_path): + logger.warning(f"Cannot restore block {block_name}: file {save_path} does not exist") + continue + state_dict = torch.load(save_path, map_location="cpu") + # Manually assign parameters to handle empty tensor replacement + for name, param in state_dict.items(): + parts = name.split(".") + target = block + for part in parts[:-1]: + target = getattr(target, part) + param_name = parts[-1] + if hasattr(target, param_name): + old_param = getattr(target, param_name) + if isinstance(old_param, torch.nn.Parameter): + setattr(target, param_name, torch.nn.Parameter(param, requires_grad=old_param.requires_grad)) + else: + setattr(target, param_name, param) + except Exception as e: + logger.warning(f"Failed to restore offloaded block {block_name}: {e}") + + +def cleanup_cpu_offload_dir(compressor: Any) -> None: + if compressor._cpu_offload_tempdir and os.path.isdir(compressor._cpu_offload_tempdir): + try: + shutil.rmtree(compressor._cpu_offload_tempdir) + except Exception as e: + logger.warning(f"Failed to cleanup cpu offload dir {compressor._cpu_offload_tempdir}: {e}") + compressor._cpu_offload_tempdir = None + + def clear_memory_if_reached_threshold(threshold=0.85, device_list=None): """Check all available devices and clear memory if any device is using close to the threshold. diff --git a/auto_round/utils/model.py b/auto_round/utils/model.py index 34190593e..487f621f8 100644 --- a/auto_round/utils/model.py +++ b/auto_round/utils/model.py @@ -1347,11 +1347,11 @@ def _get_amax_value(module): f"This may indicate calibration hooks were not attached to expert layers." ) return uncalibrated_experts - else: - # Flatten all tensors to 1D before concatenation - flat_values = [t.reshape(-1) for t in amax_values] - all_values = torch.cat(flat_values) - set_amax_value = torch.max(all_values) + # Flatten all tensors to 1D before concatenation + flat_values = [t.reshape(-1) for t in amax_values] + all_values = torch.cat(flat_values) + set_amax_value = torch.max(all_values) + set_amax_value = set_amax_value.unsqueeze(0) if set_amax_value.dim() == 0 else set_amax_value for module in experts: current_amax = _get_amax_value(module) diff --git a/test/test_cpu/core/test_low_cpu_mem_options.py b/test/test_cpu/core/test_low_cpu_mem_options.py index 26eacfdc6..64cae63a1 100644 --- a/test/test_cpu/core/test_low_cpu_mem_options.py +++ b/test/test_cpu/core/test_low_cpu_mem_options.py @@ -22,6 +22,7 @@ from auto_round import AutoRound from auto_round.compressors import base as base_module +from auto_round.utils import device as device_module class TestCpuStreamOffloadBlocks: @@ -112,10 +113,10 @@ def test_stream_offload_blocks_records_blocks(self, tiny_opt_model_path, tmp_pat ) dummy_block = torch.nn.Linear(4, 4) - monkeypatch.setattr(base_module, "get_module", lambda _model, _name: dummy_block) - monkeypatch.setattr(autoround, "_init_cpu_offload_dir", lambda: str(tmp_path)) + monkeypatch.setattr(device_module, "get_module", lambda _model, _name: dummy_block) + monkeypatch.setattr(device_module, "init_cpu_offload_dir", lambda _compressor: str(tmp_path)) monkeypatch.setattr(torch, "save", lambda *args, **kwargs: None) - monkeypatch.setattr(base_module, "clear_module_weights", lambda *_args, **_kwargs: None) + monkeypatch.setattr(device_module, "clear_module_weights", lambda *_args, **_kwargs: None) autoround._stream_offload_blocks([["model.layers.0"]]) assert "model.layers.0" in autoround._offloaded_blocks diff --git a/test/test_cpu/schemes/test_scheme.py b/test/test_cpu/schemes/test_scheme.py index e2b0c15c3..dcafb3ca8 100644 --- a/test/test_cpu/schemes/test_scheme.py +++ b/test/test_cpu/schemes/test_scheme.py @@ -159,3 +159,30 @@ def test_parse_available_devices(self): assert device_list == ["cuda:0", "cuda:1", "cpu"] device_list = parse_available_devices("0,1") assert len(device_list) == 1 and "cpu" in device_list + + def test_set_scheme(self, tiny_qwen_model_path): + ar = AutoRound( + tiny_qwen_model_path, + scheme="gguf:q2_k_s", + data_type="fp", + nsamples=1, + disable_opt_rtn=True, + iters=0, + seqlen=2, + ) + ar.quantize() + + from auto_round.schemes import QuantizationScheme + + qs = QuantizationScheme.from_dict({"bits": 4, "group_size": 64}) + ar = AutoRound( + tiny_qwen_model_path, + scheme=qs, + bits=2, + data_type="int_asym_dq", + nsamples=1, + iters=0, + disable_opt_rtn=True, + seqlen=2, + ) + ar.quantize() From a041da868f1d61854e35b20ba16bd545a8b5a9b2 Mon Sep 17 00:00:00 2001 From: lvliang-intel Date: Wed, 4 Feb 2026 03:27:58 +0000 Subject: [PATCH 11/13] update code for comments Signed-off-by: lvliang-intel --- auto_round/compressors/base.py | 87 ++----- auto_round/utils/device.py | 173 ------------- auto_round/utils/model.py | 227 +++++++++++++++++- .../test_cpu/core/test_low_cpu_mem_options.py | 18 +- 4 files changed, 260 insertions(+), 245 deletions(-) diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index 93b892163..5506f28d9 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -104,21 +104,23 @@ unsupported_meta_device, ) from auto_round.utils.device import ( - cleanup_cpu_offload_dir, clear_memory_if_reached_threshold, - discard_offloaded_block, - estimate_block_size_gb, estimate_inputs_size_gb, estimate_model_size_gb, estimate_tensor_size_gb, get_major_device, + parse_available_devices, + set_auto_device_map_for_block_with_tuning, + set_non_auto_device_map, +) +from auto_round.utils.model import ( + cleanup_cpu_offload_dir, + discard_offloaded_block, + estimate_block_size_gb, init_cpu_offload_dir, load_offloaded_block_weights, offload_block_weights, - parse_available_devices, restore_offloaded_blocks, - set_auto_device_map_for_block_with_tuning, - set_non_auto_device_map, stream_offload_blocks, ) from auto_round.wrapper import WrapperLinear, WrapperMultiblock, unwrapper_block, unwrapper_layer, wrapper_block @@ -1548,7 +1550,7 @@ def _quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str]) # some modules may have been flushed and set to meta, so we could not move to gpu mv_module_from_gpu(block) if self.low_cpu_mem_usage: - self._offload_block_weights(block_name, block) + offload_block_weights(self, block_name, block) if block_name == block_names[-1]: clear_memory(input_ids, device_list=self.device_list) else: @@ -1567,47 +1569,6 @@ def _quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str]) # if self.is_immediate_saving: # shard_writer(self, is_finalize=True) - def _estimate_tensor_size_gb(self, tensor) -> float: - """Estimate the size of a tensor in GB.""" - return estimate_tensor_size_gb(tensor) - - def _estimate_inputs_size_gb(self, all_inputs: dict) -> float: - """Estimate the total size of calibration inputs in GB.""" - return estimate_inputs_size_gb(all_inputs) - - def _estimate_model_size_gb(self) -> float: - """Estimate the model weights size in GB.""" - return estimate_model_size_gb(self.model) - - def _estimate_block_size_gb(self, block: torch.nn.Module) -> float: - """Estimate a block's weights size in GB.""" - return estimate_block_size_gb(block) - - def _init_cpu_offload_dir(self) -> Optional[str]: - return init_cpu_offload_dir(self) - - def _offload_block_weights(self, block_name: str, block: torch.nn.Module) -> None: - offload_block_weights(self, block_name, block) - - def _stream_offload_blocks(self, all_blocks: list[list[str]]) -> None: - """Offload all block weights to disk and clear from memory.""" - stream_offload_blocks(self, all_blocks) - - def _load_offloaded_block_weights(self, block_name: str, block: torch.nn.Module) -> None: - """Load block weights from disk back into memory.""" - load_offloaded_block_weights(self, block_name, block) - - def _discard_offloaded_block(self, block_name: str) -> None: - """Discard the original offload file and re-offload quantized weights.""" - discard_offloaded_block(self, block_name) - - def _restore_offloaded_blocks(self) -> None: - """Restore all offloaded block weights back to memory.""" - restore_offloaded_blocks(self) - - def _cleanup_cpu_offload_dir(self) -> None: - cleanup_cpu_offload_dir(self) - def _update_inputs(self, inputs: dict, q_inputs: dict) -> tuple[dict, torch.Tensor]: keys = inputs.keys() input_id_str = [key for key in keys if key.startswith("hidden_state")] @@ -1799,7 +1760,7 @@ def _should_disable_inplace_due_to_layers_outside_block() -> bool: clear_memory(device_list=self.device_list) # Log memory breakdown for calibration inputs if self.low_cpu_mem_usage: - inputs_size_gb = self._estimate_inputs_size_gb(all_inputs) + inputs_size_gb = estimate_inputs_size_gb(all_inputs) logger.info(f"[Memory] calibration inputs size: {inputs_size_gb:.2f} GB") all_q_inputs = None if is_quantized_embedding: @@ -1815,10 +1776,10 @@ def _should_disable_inplace_due_to_layers_outside_block() -> bool: logger.info("caching done") # Log memory breakdown for model weights if self.low_cpu_mem_usage: - model_size_gb = self._estimate_model_size_gb() + model_size_gb = estimate_model_size_gb(self.model) logger.info(f"[Memory] model weights size: {model_size_gb:.2f} GB") if self.low_cpu_mem_usage and self.cpu_stream_offload_blocks: - self._stream_offload_blocks(all_blocks) + stream_offload_blocks(self, all_blocks) if len(all_blocks) > 1: pbar = tqdm(range(0, sum([len(i) for i in all_blocks]), self.nblocks)) else: @@ -1865,8 +1826,8 @@ def _should_disable_inplace_due_to_layers_outside_block() -> bool: shard_writer(self, is_finalize=True) if self.low_cpu_mem_usage: - self._restore_offloaded_blocks() - self._cleanup_cpu_offload_dir() + restore_offloaded_blocks(self) + cleanup_cpu_offload_dir(self) end_time = time.time() cost_time = end_time - start_time @@ -2966,7 +2927,7 @@ def _quantize_block( and self.cpu_stream_offload_blocks and not hasattr(self, "_logged_output_size") ): - output_size = self._estimate_tensor_size_gb(output) + output_size = estimate_tensor_size_gb(output) logger.info(f"[Memory] block output cache size: {output_size:.2f} GB") self._logged_output_size = True @@ -3269,8 +3230,8 @@ def _quantize_blocks( # Log detailed memory breakdown for first block if self.low_cpu_mem_usage and self.cpu_stream_offload_blocks: - input_ids_size = self._estimate_tensor_size_gb(input_ids) - input_others_size = self._estimate_tensor_size_gb(input_others) + input_ids_size = estimate_tensor_size_gb(input_ids) + input_others_size = estimate_tensor_size_gb(input_others) logger.info( f"[Memory] input_ids size: {input_ids_size:.2f} GB, input_others size: {input_others_size:.2f} GB" ) @@ -3293,13 +3254,13 @@ def _quantize_blocks( if self.low_cpu_mem_usage and self.cpu_stream_offload_blocks: if nblocks == 1: - self._load_offloaded_block_weights(n, get_module(model, n)) + load_offloaded_block_weights(self, n, get_module(model, n)) if i == 0: # Log only for first block - block_size = self._estimate_block_size_gb(get_module(model, n)) + block_size = estimate_block_size_gb(get_module(model, n)) logger.info(f"[Memory] loaded block weights size: {block_size:.2f} GB") else: for name in names: - self._load_offloaded_block_weights(name, get_module(model, name)) + load_offloaded_block_weights(self, name, get_module(model, name)) m.config = model.config if hasattr(model, "config") else None @@ -3313,16 +3274,16 @@ def _quantize_blocks( if self.low_cpu_mem_usage and not self.cpu_stream_offload_blocks: if nblocks == 1: - self._offload_block_weights(n, get_module(model, n)) + offload_block_weights(self, n, get_module(model, n)) else: for name in names: - self._offload_block_weights(name, get_module(model, name)) + offload_block_weights(self, name, get_module(model, name)) if self.low_cpu_mem_usage and self.cpu_stream_offload_blocks: if nblocks == 1: - self._discard_offloaded_block(n) + discard_offloaded_block(self, n) else: for name in names: - self._discard_offloaded_block(name) + discard_offloaded_block(self, name) if hasattr(model, "config"): del m.config if self.is_immediate_packing: diff --git a/auto_round/utils/device.py b/auto_round/utils/device.py index 713023c38..2b186529b 100644 --- a/auto_round/utils/device.py +++ b/auto_round/utils/device.py @@ -15,8 +15,6 @@ import gc import os import re -import shutil -import tempfile from contextlib import ContextDecorator, contextmanager from functools import lru_cache from itertools import combinations @@ -30,11 +28,9 @@ from auto_round.logger import logger from auto_round.utils.model import ( check_to_quantized, - clear_module_weights, get_block_names, get_layer_features, get_module, - save_module_weights, ) # Note on HPU usage: @@ -540,175 +536,6 @@ def estimate_model_size_gb(model: torch.nn.Module) -> float: return total -def estimate_block_size_gb(block: torch.nn.Module) -> float: - """Estimate a block's weights size in GB.""" - total = 0.0 - for param in block.parameters(): - if param.numel() > 0: - total += param.numel() * param.element_size() / (1024**3) - return total - - -def init_cpu_offload_dir(compressor: Any) -> Optional[str]: - if not compressor.low_cpu_mem_usage: - return None - if compressor._cpu_offload_tempdir is None: - compressor._cpu_offload_tempdir = tempfile.mkdtemp(prefix="autoround_cpu_offload_") - return compressor._cpu_offload_tempdir - - -def offload_block_weights(compressor: Any, block_name: str, block: torch.nn.Module) -> None: - if not compressor.low_cpu_mem_usage: - return - offload_dir = init_cpu_offload_dir(compressor) - if offload_dir is None: - return - safe_name = block_name.replace(".", "_") - save_path = os.path.join(offload_dir, f"{safe_name}.pt") - metadata = save_module_weights(block, save_path) - compressor._offloaded_blocks[block_name] = metadata - clear_module_weights(block) - - -def stream_offload_blocks(compressor: Any, all_blocks: list[list[str]]) -> None: - """Offload all block weights to disk and clear from memory.""" - if not (compressor.low_cpu_mem_usage and compressor.cpu_stream_offload_blocks): - return - offload_dir = init_cpu_offload_dir(compressor) - if offload_dir is None: - return - logger.info("stream offloading block weights to disk...") - total_offloaded_gb = 0.0 - for block_names in all_blocks: - for block_name in block_names: - if block_name in compressor._offloaded_blocks: - continue - block = get_module(compressor.model, block_name) - if block is None: - continue - block_size_gb = estimate_block_size_gb(block) - total_offloaded_gb += block_size_gb - safe_name = block_name.replace(".", "_") - save_path = os.path.join(offload_dir, f"{safe_name}.pt") - # Save entire block state_dict (all submodule weights) - state_dict = {k: v.cpu() for k, v in block.state_dict().items()} - torch.save(state_dict, save_path) - compressor._offloaded_blocks[block_name] = {"save_path": save_path} - # Clear all submodule weights - for submodule in block.modules(): - if hasattr(submodule, "weight") and submodule.weight is not None: - clear_module_weights(submodule) - if hasattr(submodule, "bias") and submodule.bias is not None: - clear_module_weights(submodule) - clear_memory(device_list=compressor.device_list) - logger.info(f"stream offload done, offloaded {total_offloaded_gb:.2f} GB of block weights") - - -def load_offloaded_block_weights(compressor: Any, block_name: str, block: torch.nn.Module) -> None: - """Load block weights from disk back into memory.""" - if not (compressor.low_cpu_mem_usage and compressor.cpu_stream_offload_blocks): - return - metadata = compressor._offloaded_blocks.get(block_name) - if not metadata: - return - save_path = metadata.get("save_path") - if not save_path or not os.path.exists(save_path): - logger.warning(f"Cannot load block weights: file {save_path} does not exist") - return - try: - state_dict = torch.load(save_path, map_location="cpu") - # Manually assign parameters to handle empty tensor replacement - for name, param in state_dict.items(): - parts = name.split(".") - target = block - for part in parts[:-1]: - target = getattr(target, part) - param_name = parts[-1] - if hasattr(target, param_name): - old_param = getattr(target, param_name) - if isinstance(old_param, torch.nn.Parameter): - setattr(target, param_name, torch.nn.Parameter(param, requires_grad=old_param.requires_grad)) - else: - setattr(target, param_name, param) - except Exception as e: - logger.warning(f"Failed to load block weights from {save_path}: {e}") - - -def discard_offloaded_block(compressor: Any, block_name: str) -> None: - """Discard the original offload file and re-offload quantized weights.""" - if not (compressor.low_cpu_mem_usage and compressor.cpu_stream_offload_blocks): - return - metadata = compressor._offloaded_blocks.pop(block_name, None) - if not metadata: - return - save_path = metadata.get("save_path") - if save_path and os.path.exists(save_path): - try: - os.remove(save_path) - except Exception as e: - logger.warning(f"Failed to remove offloaded block file {save_path}: {e}") - - # Re-offload the quantized block weights to disk - block = get_module(compressor.model, block_name) - if block is None: - return - offload_dir = init_cpu_offload_dir(compressor) - if offload_dir is None: - return - safe_name = block_name.replace(".", "_") - new_save_path = os.path.join(offload_dir, f"{safe_name}_quantized.pt") - try: - state_dict = {k: v.cpu() for k, v in block.state_dict().items()} - torch.save(state_dict, new_save_path) - compressor._offloaded_blocks[block_name] = {"save_path": new_save_path, "quantized": True} - # Clear all submodule weights - for submodule in block.modules(): - if hasattr(submodule, "weight") and submodule.weight is not None: - clear_module_weights(submodule) - if hasattr(submodule, "bias") and submodule.bias is not None: - clear_module_weights(submodule) - except Exception as e: - logger.warning(f"Failed to re-offload quantized block {block_name}: {e}") - - -def restore_offloaded_blocks(compressor: Any) -> None: - """Restore all offloaded block weights back to memory.""" - if not compressor._offloaded_blocks: - return - for block_name, metadata in list(compressor._offloaded_blocks.items()): - try: - block = get_module(compressor.model, block_name) - save_path = metadata.get("save_path") - if not save_path or not os.path.exists(save_path): - logger.warning(f"Cannot restore block {block_name}: file {save_path} does not exist") - continue - state_dict = torch.load(save_path, map_location="cpu") - # Manually assign parameters to handle empty tensor replacement - for name, param in state_dict.items(): - parts = name.split(".") - target = block - for part in parts[:-1]: - target = getattr(target, part) - param_name = parts[-1] - if hasattr(target, param_name): - old_param = getattr(target, param_name) - if isinstance(old_param, torch.nn.Parameter): - setattr(target, param_name, torch.nn.Parameter(param, requires_grad=old_param.requires_grad)) - else: - setattr(target, param_name, param) - except Exception as e: - logger.warning(f"Failed to restore offloaded block {block_name}: {e}") - - -def cleanup_cpu_offload_dir(compressor: Any) -> None: - if compressor._cpu_offload_tempdir and os.path.isdir(compressor._cpu_offload_tempdir): - try: - shutil.rmtree(compressor._cpu_offload_tempdir) - except Exception as e: - logger.warning(f"Failed to cleanup cpu offload dir {compressor._cpu_offload_tempdir}: {e}") - compressor._cpu_offload_tempdir = None - - def clear_memory_if_reached_threshold(threshold=0.85, device_list=None): """Check all available devices and clear memory if any device is using close to the threshold. diff --git a/auto_round/utils/model.py b/auto_round/utils/model.py index 487f621f8..e3d638b1d 100644 --- a/auto_round/utils/model.py +++ b/auto_round/utils/model.py @@ -15,9 +15,11 @@ import json import os import re +import shutil +import tempfile from collections import UserDict from pathlib import Path -from typing import Union +from typing import Any, Optional, Union import psutil import torch @@ -1715,3 +1717,226 @@ def is_separate_tensor(model: torch.nn.Module, tensor_name: str) -> bool: return True else: return False + + +# ===================== CPU Offload Utilities ===================== + + +def init_cpu_offload_dir(compressor: Any) -> Optional[str]: + """Initialize a temporary directory for CPU offloading. + + Args: + compressor: The compressor object containing low_cpu_mem_usage flag + and _cpu_offload_tempdir attribute. + + Returns: + Optional[str]: Path to the temporary directory, or None if not enabled. + """ + if not compressor.low_cpu_mem_usage: + return None + if compressor._cpu_offload_tempdir is None: + compressor._cpu_offload_tempdir = tempfile.mkdtemp(prefix="autoround_cpu_offload_") + return compressor._cpu_offload_tempdir + + +def offload_block_weights(compressor: Any, block_name: str, block: torch.nn.Module) -> None: + """Offload a block's weights to disk to reduce CPU RAM usage. + + Args: + compressor: The compressor object containing offload state. + block_name: Name of the block being offloaded. + block: The block module whose weights should be offloaded. + """ + if not compressor.low_cpu_mem_usage: + return + offload_dir = init_cpu_offload_dir(compressor) + if offload_dir is None: + return + safe_name = block_name.replace(".", "_") + save_path = os.path.join(offload_dir, f"{safe_name}.pt") + metadata = save_module_weights(block, save_path) + compressor._offloaded_blocks[block_name] = metadata + clear_module_weights(block) + + +def stream_offload_blocks(compressor: Any, all_blocks: list[list[str]]) -> None: + """Offload all block weights to disk and clear from memory. + + Args: + compressor: The compressor object containing model and offload state. + all_blocks: List of block name lists to offload. + """ + if not (compressor.low_cpu_mem_usage and compressor.cpu_stream_offload_blocks): + return + offload_dir = init_cpu_offload_dir(compressor) + if offload_dir is None: + return + logger.info("stream offloading block weights to disk...") + total_offloaded_gb = 0.0 + for block_names in all_blocks: + for block_name in block_names: + if block_name in compressor._offloaded_blocks: + continue + block = get_module(compressor.model, block_name) + if block is None: + continue + block_size_gb = estimate_block_size_gb(block) + total_offloaded_gb += block_size_gb + safe_name = block_name.replace(".", "_") + save_path = os.path.join(offload_dir, f"{safe_name}.pt") + # Save entire block state_dict (all submodule weights) + state_dict = {k: v.cpu() for k, v in block.state_dict().items()} + torch.save(state_dict, save_path) + compressor._offloaded_blocks[block_name] = {"save_path": save_path} + # Clear all submodule weights + for submodule in block.modules(): + if hasattr(submodule, "weight") and submodule.weight is not None: + clear_module_weights(submodule) + if hasattr(submodule, "bias") and submodule.bias is not None: + clear_module_weights(submodule) + # Import clear_memory here to avoid circular imports + from auto_round.utils.device import clear_memory + + clear_memory(device_list=compressor.device_list) + logger.info(f"stream offload done, offloaded {total_offloaded_gb:.2f} GB of block weights") + + +def load_offloaded_block_weights(compressor: Any, block_name: str, block: torch.nn.Module) -> None: + """Load block weights from disk back into memory. + + Args: + compressor: The compressor object containing offload state. + block_name: Name of the block to load. + block: The block module to restore weights to. + """ + if not (compressor.low_cpu_mem_usage and compressor.cpu_stream_offload_blocks): + return + metadata = compressor._offloaded_blocks.get(block_name) + if not metadata: + return + save_path = metadata.get("save_path") + if not save_path or not os.path.exists(save_path): + logger.warning(f"Cannot load block weights: file {save_path} does not exist") + return + try: + state_dict = torch.load(save_path, map_location="cpu") + # Manually assign parameters to handle empty tensor replacement + for name, param in state_dict.items(): + parts = name.split(".") + target = block + for part in parts[:-1]: + target = getattr(target, part) + param_name = parts[-1] + if hasattr(target, param_name): + old_param = getattr(target, param_name) + if isinstance(old_param, torch.nn.Parameter): + setattr(target, param_name, torch.nn.Parameter(param, requires_grad=old_param.requires_grad)) + else: + setattr(target, param_name, param) + except Exception as e: + logger.warning(f"Failed to load block weights from {save_path}: {e}") + + +def discard_offloaded_block(compressor: Any, block_name: str) -> None: + """Discard the original offload file and re-offload quantized weights. + + Args: + compressor: The compressor object containing model and offload state. + block_name: Name of the block to discard and re-offload. + """ + if not (compressor.low_cpu_mem_usage and compressor.cpu_stream_offload_blocks): + return + metadata = compressor._offloaded_blocks.pop(block_name, None) + if not metadata: + return + save_path = metadata.get("save_path") + if save_path and os.path.exists(save_path): + try: + os.remove(save_path) + except Exception as e: + logger.warning(f"Failed to remove offloaded block file {save_path}: {e}") + + # Re-offload the quantized block weights to disk + block = get_module(compressor.model, block_name) + if block is None: + return + offload_dir = init_cpu_offload_dir(compressor) + if offload_dir is None: + return + safe_name = block_name.replace(".", "_") + new_save_path = os.path.join(offload_dir, f"{safe_name}_quantized.pt") + try: + state_dict = {k: v.cpu() for k, v in block.state_dict().items()} + torch.save(state_dict, new_save_path) + compressor._offloaded_blocks[block_name] = {"save_path": new_save_path, "quantized": True} + # Clear all submodule weights + for submodule in block.modules(): + if hasattr(submodule, "weight") and submodule.weight is not None: + clear_module_weights(submodule) + if hasattr(submodule, "bias") and submodule.bias is not None: + clear_module_weights(submodule) + except Exception as e: + logger.warning(f"Failed to re-offload quantized block {block_name}: {e}") + + +def restore_offloaded_blocks(compressor: Any) -> None: + """Restore all offloaded block weights back to memory. + + Args: + compressor: The compressor object containing model and offload state. + """ + if not compressor._offloaded_blocks: + return + for block_name, metadata in list(compressor._offloaded_blocks.items()): + try: + block = get_module(compressor.model, block_name) + save_path = metadata.get("save_path") + if not save_path or not os.path.exists(save_path): + logger.warning(f"Cannot restore block {block_name}: file {save_path} does not exist") + continue + state_dict = torch.load(save_path, map_location="cpu") + # Manually assign parameters to handle empty tensor replacement + for name, param in state_dict.items(): + parts = name.split(".") + target = block + for part in parts[:-1]: + target = getattr(target, part) + param_name = parts[-1] + if hasattr(target, param_name): + old_param = getattr(target, param_name) + if isinstance(old_param, torch.nn.Parameter): + setattr(target, param_name, torch.nn.Parameter(param, requires_grad=old_param.requires_grad)) + else: + setattr(target, param_name, param) + except Exception as e: + logger.warning(f"Failed to restore offloaded block {block_name}: {e}") + + +def cleanup_cpu_offload_dir(compressor: Any) -> None: + """Clean up the CPU offload temporary directory. + + Args: + compressor: The compressor object containing the tempdir path. + """ + if compressor._cpu_offload_tempdir and os.path.isdir(compressor._cpu_offload_tempdir): + try: + shutil.rmtree(compressor._cpu_offload_tempdir) + except Exception as e: + logger.warning(f"Failed to cleanup cpu offload dir {compressor._cpu_offload_tempdir}: {e}") + compressor._cpu_offload_tempdir = None + + +def estimate_block_size_gb(block: torch.nn.Module) -> float: + """Estimate a block's weights size in GB. + + Args: + block: The block module to estimate size for. + + Returns: + float: Size of the block's parameters in GB. + """ + total = 0.0 + for param in block.parameters(): + if param.numel() > 0: + total += param.numel() * param.element_size() / (1024**3) + return total diff --git a/test/test_cpu/core/test_low_cpu_mem_options.py b/test/test_cpu/core/test_low_cpu_mem_options.py index 64cae63a1..721492618 100644 --- a/test/test_cpu/core/test_low_cpu_mem_options.py +++ b/test/test_cpu/core/test_low_cpu_mem_options.py @@ -23,6 +23,8 @@ from auto_round import AutoRound from auto_round.compressors import base as base_module from auto_round.utils import device as device_module +from auto_round.utils import model as model_module +from auto_round.utils.model import stream_offload_blocks class TestCpuStreamOffloadBlocks: @@ -72,7 +74,7 @@ def test_offload_requires_low_cpu_mem_usage(self, tiny_opt_model_path): assert autoround.low_cpu_mem_usage is False def test_stream_offload_blocks_skips_when_disabled(self, tiny_opt_model_path): - """Test that _stream_offload_blocks returns early when disabled.""" + """Test that stream_offload_blocks returns early when disabled.""" autoround = AutoRound( tiny_opt_model_path, bits=4, @@ -83,7 +85,7 @@ def test_stream_offload_blocks_skips_when_disabled(self, tiny_opt_model_path): nsamples=1, seqlen=32, ) - autoround._stream_offload_blocks([["model.layers.0"]]) + stream_offload_blocks(autoround, [["model.layers.0"]]) assert autoround._offloaded_blocks == {} autoround2 = AutoRound( @@ -96,11 +98,11 @@ def test_stream_offload_blocks_skips_when_disabled(self, tiny_opt_model_path): nsamples=1, seqlen=32, ) - autoround2._stream_offload_blocks([["model.layers.0"]]) + stream_offload_blocks(autoround2, [["model.layers.0"]]) assert autoround2._offloaded_blocks == {} def test_stream_offload_blocks_records_blocks(self, tiny_opt_model_path, tmp_path, monkeypatch): - """Test that _stream_offload_blocks records offloaded blocks when enabled.""" + """Test that stream_offload_blocks records offloaded blocks when enabled.""" autoround = AutoRound( tiny_opt_model_path, bits=4, @@ -113,12 +115,12 @@ def test_stream_offload_blocks_records_blocks(self, tiny_opt_model_path, tmp_pat ) dummy_block = torch.nn.Linear(4, 4) - monkeypatch.setattr(device_module, "get_module", lambda _model, _name: dummy_block) - monkeypatch.setattr(device_module, "init_cpu_offload_dir", lambda _compressor: str(tmp_path)) + monkeypatch.setattr(model_module, "get_module", lambda _model, _name: dummy_block) + monkeypatch.setattr(model_module, "init_cpu_offload_dir", lambda _compressor: str(tmp_path)) monkeypatch.setattr(torch, "save", lambda *args, **kwargs: None) - monkeypatch.setattr(device_module, "clear_module_weights", lambda *_args, **_kwargs: None) + monkeypatch.setattr(model_module, "clear_module_weights", lambda *_args, **_kwargs: None) - autoround._stream_offload_blocks([["model.layers.0"]]) + stream_offload_blocks(autoround, [["model.layers.0"]]) assert "model.layers.0" in autoround._offloaded_blocks From 5dcd064c6e97bfff3dca21a169513c34cef10ef4 Mon Sep 17 00:00:00 2001 From: lvliang-intel Date: Mon, 9 Feb 2026 06:55:15 +0000 Subject: [PATCH 12/13] support AutoScheme cpu ram optimization Signed-off-by: lvliang-intel --- auto_round/auto_scheme/delta_loss.py | 541 ++++++++++++++++-- auto_round/auto_scheme/gen_auto_scheme.py | 1 + auto_round/auto_scheme/utils.py | 13 +- .../schemes/test_auto_scheme_low_cpu_mem.py | 450 +++++++++++++++ 4 files changed, 963 insertions(+), 42 deletions(-) create mode 100644 test/test_cpu/schemes/test_auto_scheme_low_cpu_mem.py diff --git a/auto_round/auto_scheme/delta_loss.py b/auto_round/auto_scheme/delta_loss.py index 58b2bae3c..269ad0bf5 100644 --- a/auto_round/auto_scheme/delta_loss.py +++ b/auto_round/auto_scheme/delta_loss.py @@ -13,9 +13,15 @@ # limitations under the License. import copy +import gc +import os +import shutil +import tempfile + +from safetensors.torch import load_file, save_file from dataclasses import asdict from functools import wraps -from typing import Iterable, Union +from typing import Any, Iterable, Optional, Union import torch from accelerate import dispatch_model @@ -55,11 +61,271 @@ set_non_auto_device_map, to_device, ) +from auto_round.utils.device import MemoryMonitor from auto_round.wrapper import WrapperLinear __all__ = ["gen_layer_config"] +def _group_layers_by_block(quant_layer_names, block_names): + """Group quantization layer names by their containing block.""" + groups = {bn: [] for bn in block_names} + non_block = [] + for name in quant_layer_names: + matched = False + for bn in block_names: + if name.startswith(bn + "."): + groups[bn].append(name) + matched = True + break + if not matched: + non_block.append(name) + return groups, non_block + + +# ============================================================================ +# CPU RAM Offload Management for AutoScheme +# ============================================================================ + + +class AutoSchemeOffloadContext: + """Manages disk offload state for AutoScheme to reduce CPU RAM usage. + + Maintains two separate on-disk stores: + - *original* weights (unwrapped, saved once at the start) + - *wrapped* weights (saved per-scheme iteration for forward/backward) + + This allows block weights to stay on disk between scheme iterations, + keeping only one block in RAM at a time. + """ + + def __init__(self, low_cpu_mem_usage: bool = False): + self.low_cpu_mem_usage = low_cpu_mem_usage + # Wrapped-state storage (changes every scheme iteration) + self._offload_tempdir: Optional[str] = None + self._offloaded_blocks: dict[str, dict] = {} + # Original-state storage (saved once, read many) + self._original_dir: Optional[str] = None + self._original_blocks: dict[str, dict] = {} + + # ------------------------------------------------------------------ + # Directory helpers + # ------------------------------------------------------------------ + def init_offload_dir(self) -> Optional[str]: + """Initialize the temporary directory for wrapped-state offload.""" + if not self.low_cpu_mem_usage: + return None + if self._offload_tempdir is None: + self._offload_tempdir = tempfile.mkdtemp(prefix="autoscheme_offload_") + logger.info(f"AutoScheme CPU offload directory: {self._offload_tempdir}") + return self._offload_tempdir + + def _init_original_dir(self) -> str: + """Initialize the temporary directory for original-weight storage.""" + if self._original_dir is None: + self._original_dir = tempfile.mkdtemp(prefix="autoscheme_original_") + return self._original_dir + + # ------------------------------------------------------------------ + # Original (unwrapped) weight management — saved once at start + # ------------------------------------------------------------------ + def save_original_block_weights(self, block_name: str, block: torch.nn.Module) -> None: + """Save original (unwrapped) block weights to disk. Skips if already saved.""" + if not self.low_cpu_mem_usage: + return + if block_name in self._original_blocks: + return + orig_dir = self._init_original_dir() + safe_name = block_name.replace(".", "_") + save_path = os.path.join(orig_dir, f"{safe_name}.safetensors") + try: + state_dict = {k: v.cpu().contiguous() for k, v in block.state_dict().items()} + save_file(state_dict, save_path) + self._original_blocks[block_name] = {"save_path": save_path} + del state_dict + except Exception as e: + logger.warning(f"Failed to save original block {block_name}: {e}") + + def _load_state_into_block(self, save_path: str, block: torch.nn.Module) -> None: + """Low-level helper: load a safetensors file into *block*.""" + state_dict = load_file(save_path, device="cpu") + for name, param in state_dict.items(): + parts = name.split(".") + target = block + try: + for part in parts[:-1]: + target = getattr(target, part) + except AttributeError: + continue # key belongs to a different module tree (e.g. wrapper vs orig) + param_name = parts[-1] + if hasattr(target, param_name): + old_param = getattr(target, param_name) + if isinstance(old_param, torch.nn.Parameter): + setattr(target, param_name, + torch.nn.Parameter(param, requires_grad=old_param.requires_grad)) + else: + setattr(target, param_name, param) + del state_dict + + def load_original_block_weights(self, block_name: str, block: torch.nn.Module) -> None: + """Load original (unwrapped) weights from disk into *block*.""" + if not self.low_cpu_mem_usage: + return + metadata = self._original_blocks.get(block_name) + if not metadata: + return + save_path = metadata["save_path"] + if not os.path.exists(save_path): + logger.warning(f"Original weights not found: {save_path}") + return + try: + self._load_state_into_block(save_path, block) + except Exception as e: + logger.warning(f"Failed to load original block {block_name}: {e}") + + def save_and_clear_all_original_blocks( + self, model: torch.nn.Module, block_names: list[str] + ) -> None: + """Save all original block weights to disk and clear them from RAM.""" + if not self.low_cpu_mem_usage: + return + logger.info("AutoScheme: saving original block weights to disk...") + for block_name in block_names: + block = get_module(model, block_name) + if block is not None: + self.save_original_block_weights(block_name, block) + for submodule in block.modules(): + _clear_module_weights(submodule) + gc.collect() + clear_memory() + logger.info("AutoScheme: original weights saved and cleared") + + def load_all_original_blocks( + self, model: torch.nn.Module, block_names: list[str] + ) -> None: + """Load all original block weights back into RAM.""" + if not self.low_cpu_mem_usage: + return + for block_name in block_names: + block = get_module(model, block_name) + if block is not None: + self.load_original_block_weights(block_name, block) + + # ------------------------------------------------------------------ + # Wrapped-state management — re-saved each scheme iteration + # ------------------------------------------------------------------ + def offload_block_weights(self, block_name: str, block: torch.nn.Module) -> None: + """Offload a block's (possibly wrapped) weights to disk.""" + if not self.low_cpu_mem_usage: + return + offload_dir = self.init_offload_dir() + if offload_dir is None: + return + + safe_name = block_name.replace(".", "_") + save_path = os.path.join(offload_dir, f"{safe_name}.safetensors") + + already_saved = block_name in self._offloaded_blocks + + try: + if not already_saved: + state_dict = {k: v.cpu().contiguous() for k, v in block.state_dict().items()} + save_file(state_dict, save_path) + self._offloaded_blocks[block_name] = {"save_path": save_path} + del state_dict + + for submodule in block.modules(): + _clear_module_weights(submodule) + except Exception as e: + logger.warning(f"Failed to offload block {block_name}: {e}") + + def load_block_weights(self, block_name: str, block: torch.nn.Module) -> None: + """Load wrapped block weights from disk back into memory.""" + if not self.low_cpu_mem_usage: + return + metadata = self._offloaded_blocks.get(block_name) + if not metadata: + return + save_path = metadata.get("save_path") + if not save_path or not os.path.exists(save_path): + logger.warning(f"Cannot load block weights: file {save_path} does not exist") + return + try: + self._load_state_into_block(save_path, block) + except Exception as e: + logger.warning(f"Failed to load block weights from {save_path}: {e}") + + def offload_all_blocks(self, model: torch.nn.Module, block_names: list[str]) -> None: + """Offload all block weights to disk.""" + if not self.low_cpu_mem_usage: + return + logger.info("AutoScheme: offloading all block weights to disk for RAM optimization...") + for block_name in block_names: + block = get_module(model, block_name) + if block is not None: + self.offload_block_weights(block_name, block) + gc.collect() + clear_memory() + logger.info("AutoScheme: block weights offload complete") + + def reset_scheme_state(self) -> None: + """Clear wrapped-state tracking so the next scheme iteration re-saves.""" + self._offloaded_blocks = {} + + # ------------------------------------------------------------------ + # Cleanup + # ------------------------------------------------------------------ + def cleanup(self) -> None: + """Clean up all temporary directories.""" + for d in (self._offload_tempdir, self._original_dir): + if d and os.path.isdir(d): + try: + shutil.rmtree(d) + except Exception as e: + logger.warning(f"Failed to cleanup dir {d}: {e}") + self._offload_tempdir = None + self._offloaded_blocks = {} + self._original_dir = None + self._original_blocks = {} + + +def _clear_module_weights(module: torch.nn.Module) -> None: + """Clear module's weight and bias to free CPU RAM. + + Note: Skips WrapperLayer modules since their weight/bias are properties + that delegate to orig_layer. Clearing the actual orig_layer is sufficient. + Caches weight.numel() as ``_cached_weight_numel`` before clearing so that + ``compute_layer_bits`` can still compute correct results with empty tensors. + """ + if module is None: + return + # Skip WrapperLayer - its weight is a property, assigning to it would create + # an instance attribute shadowing the property. We clear orig_layer directly instead. + if hasattr(module, "orig_layer"): + return + with torch.no_grad(): + if hasattr(module, "weight") and module.weight is not None: + # Cache numel / shape before replacing with empty tensor + if module.weight.numel() > 0: + module._cached_weight_numel = module.weight.numel() + module._cached_weight_shape = tuple(module.weight.shape) + if isinstance(module.weight, torch.nn.Parameter): + module.weight = torch.nn.Parameter( + torch.empty(0, dtype=module.weight.dtype, device="cpu"), + requires_grad=module.weight.requires_grad + ) + else: + module.weight = torch.empty(0, dtype=module.weight.dtype, device="cpu") + if hasattr(module, "bias") and module.bias is not None: + if isinstance(module.bias, torch.nn.Parameter): + module.bias = torch.nn.Parameter( + torch.empty(0, dtype=module.bias.dtype, device="cpu"), + requires_grad=module.bias.requires_grad + ) + else: + module.bias = torch.empty(0, dtype=module.bias.dtype, device="cpu") + + class AutoSchemeWrapperLinear(WrapperLinear): def __init__( self, @@ -322,7 +588,13 @@ def __init__(self, message): last_grad_input = None -def prepare_model_low_gpu(model, block_inputs: dict = None, pbar=None, major_device="cpu"): +def prepare_model_low_gpu( + model, + block_inputs: dict = None, + pbar=None, + major_device="cpu", + offload_context: Optional[AutoSchemeOffloadContext] = None, +): block_inputs.clear() for n, m in model.named_modules(): if hasattr(m, "grad_mode"): @@ -335,6 +607,10 @@ def wrap_forward(module, module_name): @wraps(original_forward) def new_forward(*args, **kwargs): + # Load block weights from disk if using low_cpu_mem_usage + if offload_context is not None: + offload_context.load_block_weights(module_name, module) + move_module_to_tuning_device(module, major_device=major_device) # Call the original forward @@ -353,6 +629,10 @@ def new_forward(*args, **kwargs): module.to("cpu") # torch.cuda.empty_cache() + # Offload block weights back to disk if using low_cpu_mem_usage + if offload_context is not None: + offload_context.offload_block_weights(module_name, module) + # Enable gradients for the output of the last block if module.tmp_name == block_names[-1]: if isinstance(result, torch.Tensor): @@ -375,7 +655,13 @@ def new_forward(*args, **kwargs): module.forward = wrap_forward(module, block_name) -def model_forward_low_gpu(model, dataloader, major_device="cuda", pbar=None): +def model_forward_low_gpu( + model, + dataloader, + major_device="cuda", + pbar=None, + offload_context: Optional[AutoSchemeOffloadContext] = None, +): block_inputs = {} block_names = get_block_names(model)[0] @@ -392,7 +678,9 @@ def backward_pre_hook(module, grad_input): raise MyCustomError("Interrupt backward pass") for data in dataloader: - prepare_model_low_gpu(model, block_inputs, major_device=major_device, pbar=pbar) + prepare_model_low_gpu( + model, block_inputs, major_device=major_device, pbar=pbar, offload_context=offload_context + ) # Register backward hook on the last block last_block = get_module(model, block_names[-1]) @@ -427,6 +715,11 @@ def backward_pre_hook(module, grad_input): # Move the block module to GPU block_module = get_module(model, block_name) + + # Load block weights from disk if offloaded (for backward pass) + if offload_context is not None: + offload_context.load_block_weights(block_name, block_module) + for n, m in block_module.named_modules(): if hasattr(m, "grad_mode"): m.grad_mode = True @@ -468,6 +761,11 @@ def backward_pre_hook(module, grad_input): block_module.to("cpu") # clear_memory() + + # Offload block weights to disk if low_cpu_mem_usage is enabled (after backward) + if offload_context is not None: + offload_context.offload_block_weights(block_name, block_module) + pbar.update(1) @@ -485,9 +783,11 @@ def get_score_for_scheme( need_weight_grad=False, enable_torch_compile=False, low_gpu_mem_usage=True, + low_cpu_mem_usage=False, major_device="cpu", batch_size=1, disable_opt_rtn=True, + offload_context: Optional[AutoSchemeOffloadContext] = None, ): scores_dict = {} # Key=name,Val=[quant_total_bits, loss] for n, m in model.named_modules(): @@ -505,14 +805,14 @@ def get_score_for_scheme( has_imatrix = True break - for name in quant_layer_names: + def wrap_layer(name: str) -> None: if name in fixed_layer_scheme.keys(): - continue + return m = get_module(model, name) if not check_to_quantized(m): layer_bits, _ = compute_layer_bits(m, ignore_scale_zp_bits) scores_dict[name] = [layer_bits, 0.0] - continue + return if m.act_bits > 8 and m.super_bits is not None: m.scale_dtype = torch.float32 # TODO set this via API elif m.act_bits > 8: @@ -547,11 +847,45 @@ def get_score_for_scheme( disable_opt_rtn=disable_opt_rtn, ) set_module(model, name, new_m) + + # ------------------------------------------------------------------ + # Wrapping + forward/backward + # ------------------------------------------------------------------ if low_gpu_mem_usage: dataloader = get_dataloader(tokenizer, seqlen, dataset_name=dataset, seed=42, bs=batch_size, nsamples=nsamples) - model_forward_low_gpu(model, dataloader, major_device=major_device, pbar=pbar) + if offload_context is not None and low_cpu_mem_usage: + # Block-by-block wrapping: load original weights -> wrap -> offload wrapped state + block_names = get_block_names(model)[0] + layer_groups, non_block_layers = _group_layers_by_block(quant_layer_names, block_names) + offload_context.reset_scheme_state() + + for block_name in block_names: + block = get_module(model, block_name) + offload_context.load_original_block_weights(block_name, block) + for name in layer_groups.get(block_name, []): + wrap_layer(name) + offload_context.offload_block_weights(block_name, block) + + # Wrap layers that live outside of blocks (e.g. lm_head) + for name in non_block_layers: + wrap_layer(name) + + gc.collect() + clear_memory() + else: + for name in quant_layer_names: + wrap_layer(name) + + model_forward_low_gpu( + model, dataloader, major_device=major_device, pbar=pbar, offload_context=offload_context + ) + + # NOTE: do NOT load all blocks back — scores are read block-by-block below else: + for name in quant_layer_names: + wrap_layer(name) + dataloader = get_dataloader(tokenizer, seqlen, dataset_name=dataset, seed=42, bs=batch_size, nsamples=nsamples) for data in dataloader: data = to_device(data, model.device) @@ -565,19 +899,78 @@ def get_score_for_scheme( for n, m in model.named_parameters(): m.grad = None - scores_dict = {} - for n, m in model.named_modules(): - if hasattr(m, "mix_score"): - if m.orig_layer.act_bits <= 8: - if m.act_cnt == 0: + # ------------------------------------------------------------------ + # Score reading + unwrapping + # ------------------------------------------------------------------ + if offload_context is not None and low_cpu_mem_usage: + # Block-by-block: load wrapped state -> read scores -> unwrap -> clear + scores_dict = {} + for block_name in block_names: + block = get_module(model, block_name) + offload_context.load_block_weights(block_name, block) + + # Read scores from wrapper attributes in this block + for n, m in block.named_modules(): + full_name = f"{block_name}.{n}" if n else block_name + if hasattr(m, "mix_score"): + if m.orig_layer.act_bits <= 8: + if m.act_cnt == 0: + logger.warning_once( + f"layer {full_name} max abs activation is 0, " + "please use more data to improve the accuracy" + ) + layer_bits, _ = compute_layer_bits( + m.orig_layer, ignore_scale_zp_bits=ignore_scale_zp_bits + ) + scores_dict[full_name] = [layer_bits, m.mix_score] + + # Unwrap layers in this block + unwrap_pairs = [] + for n, m in block.named_modules(): + full_name = f"{block_name}.{n}" if n else block_name + if hasattr(m, "orig_layer"): + unwrap_pairs.append((full_name, m.orig_layer)) + for full_name, orig_layer in unwrap_pairs: + set_module(model, full_name, orig_layer) + + # Clear weights so this block no longer occupies RAM + block = get_module(model, block_name) + for submodule in block.modules(): + _clear_module_weights(submodule) + + # Handle non-block layers + for n, m in model.named_modules(): + if hasattr(m, "mix_score") and n not in scores_dict: + if m.orig_layer.act_bits <= 8 and m.act_cnt == 0: logger.warning_once( - "layer{n} max abs activation is 0, please use more data to improve the accuracy" + f"layer {n} max abs activation is 0, " + "please use more data to improve the accuracy" ) - layer_bits, _ = compute_layer_bits(m.orig_layer, ignore_scale_zp_bits=ignore_scale_zp_bits) - scores_dict[n] = [layer_bits, m.mix_score] - for n, m in model.named_modules(): - if hasattr(m, "orig_layer"): - set_module(model, n, m.orig_layer) + layer_bits, _ = compute_layer_bits( + m.orig_layer, ignore_scale_zp_bits=ignore_scale_zp_bits + ) + scores_dict[n] = [layer_bits, m.mix_score] + for n, m in model.named_modules(): + if hasattr(m, "orig_layer"): + set_module(model, n, m.orig_layer) + + gc.collect() + clear_memory() + else: + scores_dict = {} + for n, m in model.named_modules(): + if hasattr(m, "mix_score"): + if m.orig_layer.act_bits <= 8: + if m.act_cnt == 0: + logger.warning_once( + f"layer {n} max abs activation is 0, " + "please use more data to improve the accuracy" + ) + layer_bits, _ = compute_layer_bits(m.orig_layer, ignore_scale_zp_bits=ignore_scale_zp_bits) + scores_dict[n] = [layer_bits, m.mix_score] + for n, m in model.named_modules(): + if hasattr(m, "orig_layer"): + set_module(model, n, m.orig_layer) return scores_dict @@ -593,44 +986,72 @@ def choose_bits_per_layer_with_path(layers: dict, P: int): (layer_names, scheme) for each layer, or (None, None) if no feasible solution exists. """ - # dp: total_params -> (accumulated_loss, chosen_path) - # The path explicitly stores the selected options. - dp: dict[int, tuple[float, list]] = {0: (0.0, [])} - - for layer_name, opts in layers.items(): - new_dp: dict[int, tuple[float, list]] = {} - for cur_params, (cur_loss, cur_path) in dp.items(): - for opt in opts: + # dp: total_params -> accumulated_loss + # Use backtracking pointers instead of storing full paths to avoid + # O(layers) list copies per state transition. + dp: dict[int, float] = {0: 0.0} + + layer_list = list(layers.items()) + total_layers = len(layer_list) + logger.info(f"Starting DP for {total_layers} layers, budget P={P}") + + # history[layer_idx][params] = (prev_params, opt_idx) for path reconstruction + history: list[dict[int, tuple[int, int]]] = [] + + pbar = tqdm(range(total_layers), desc="DP bit allocation", leave=True) + for idx in pbar: + layer_name, opts = layer_list[idx] + pbar.set_postfix(dp_states=len(dp)) + + new_dp: dict[int, float] = {} + new_bt: dict[int, tuple[int, int]] = {} + + for cur_params, cur_loss in dp.items(): + for opt_idx, opt in enumerate(opts): scheme, bits_cost, loss_cost, layer_names = opt np_total = cur_params + bits_cost if np_total > P: continue new_loss = cur_loss + loss_cost - new_path = cur_path + [(layer_names, scheme)] - # Keep the path with smaller loss for the same parameter budget - if np_total not in new_dp or new_loss < new_dp[np_total][0]: - new_dp[np_total] = (new_loss, new_path) + # Keep the option with smaller loss for the same parameter budget + if np_total not in new_dp or new_loss < new_dp[np_total]: + new_dp[np_total] = new_loss + new_bt[np_total] = (cur_params, opt_idx) if not new_dp: return None, None # Pareto pruning: remove dominated (params, loss) states - items = sorted(new_dp.items(), key=lambda x: x[0]) # (params, (loss, path)) - pruned: dict[int, tuple[float, list]] = {} + items = sorted(new_dp.items(), key=lambda x: x[0]) # sort by params + pruned_dp: dict[int, float] = {} + pruned_bt: dict[int, tuple[int, int]] = {} best_loss_so_far = float("inf") - for params_val, (loss_val, path_val) in items: + for params_val, loss_val in items: if loss_val < best_loss_so_far: - pruned[params_val] = (loss_val, path_val) + pruned_dp[params_val] = loss_val + pruned_bt[params_val] = new_bt[params_val] best_loss_so_far = loss_val - dp = pruned + dp = pruned_dp + history.append(pruned_bt) # Select the solution with the minimum loss - best_params = min(dp.keys(), key=lambda k: dp[k][0]) - best_loss, best_path = dp[best_params] - return best_loss, best_path + best_params = min(dp.keys(), key=lambda k: dp[k]) + best_loss = dp[best_params] + + # Backtrack to reconstruct the chosen path + path = [] + cur_params = best_params + for layer_idx in range(total_layers - 1, -1, -1): + prev_params, opt_idx = history[layer_idx][cur_params] + scheme, bits_cost, loss_cost, layer_names = layer_list[layer_idx][1][opt_idx] + path.append((layer_names, scheme)) + cur_params = prev_params + path.reverse() + + return best_loss, path def move_module_to_tuning_device(module, major_device="cpu"): @@ -658,6 +1079,18 @@ def _gen_layer_config( major_device="cpu", device_list=None, ): + # Initialize memory tracking for AutoScheme + memory_monitor = MemoryMonitor() + memory_monitor.reset() + memory_monitor.update_cpu() + + # Create offload context for CPU RAM optimization + # Note: low_cpu_mem_usage only works when low_gpu_mem_usage is also enabled, + # because disk offloading requires layer-by-layer processing + offload_context = None + if auto_scheme.low_cpu_mem_usage and auto_scheme.low_gpu_mem_usage: + offload_context = AutoSchemeOffloadContext(low_cpu_mem_usage=True) + target_bits = auto_scheme.avg_bits model.eval() @@ -762,6 +1195,11 @@ def check_bf16_scheme(scheme): cal_imatrix(model, dataloader) logger.info("finish calculating imatrix") + # Offload all original block weights to disk before the scheme loop + # so that only one block needs to be in RAM at a time during scoring. + if offload_context is not None: + offload_context.save_and_clear_all_original_blocks(model, block_name) + pbar = tqdm(total=pbar_cnt, desc="Generating AutoScheme") for index, scheme in enumerate(schemes): apply_quant_scheme( @@ -789,10 +1227,15 @@ def check_bf16_scheme(scheme): need_weight_grad=need_weight_grad, enable_torch_compile=enable_torch_compile, low_gpu_mem_usage=auto_scheme.low_gpu_mem_usage, + low_cpu_mem_usage=auto_scheme.low_cpu_mem_usage, major_device=major_device, batch_size=batch_size, disable_opt_rtn=auto_scheme.disable_opt_rtn, + offload_context=offload_context, ) + # Track peak RAM after each scheme scoring + memory_monitor.update_cpu() + new_scores = {} for share_layer in shared_layers: param_bits = 0 @@ -817,10 +1260,17 @@ def check_bf16_scheme(scheme): options_scores.append(options_total_loss) clear_memory(device_list=device_list) + # Restore original weights from disk for final bit-budget computations + if offload_context is not None: + offload_context.load_all_original_blocks(model, block_name) + total_params = 0 for n, m in model.named_modules(): if n in quant_layer_names + embedding_layers_names: - total_params += m.weight.numel() + n_param = m.weight.numel() + if n_param == 0 and hasattr(m, "_cached_weight_numel"): + n_param = m._cached_weight_numel + total_params += n_param target_params_cnt = int(total_params * target_bits) sorted_indices = sorted(range(len(options_scores)), key=lambda i: options_scores[i]) @@ -901,7 +1351,18 @@ def check_bf16_scheme(scheme): m.grad = None global last_grad_input last_grad_input = None + + # Cleanup offload context + if offload_context is not None: + offload_context.cleanup() + clear_memory(device_list=device_list) + + # Log AutoScheme memory usage + memory_monitor.update_cpu() + low_cpu_str = "enabled" if auto_scheme.low_cpu_mem_usage else "disabled" + memory_monitor.log_summary(f"AutoScheme complete (low_cpu_mem_usage={low_cpu_str})") + pbar.close() return layer_config diff --git a/auto_round/auto_scheme/gen_auto_scheme.py b/auto_round/auto_scheme/gen_auto_scheme.py index 8572a920a..d9d65c645 100644 --- a/auto_round/auto_scheme/gen_auto_scheme.py +++ b/auto_round/auto_scheme/gen_auto_scheme.py @@ -40,6 +40,7 @@ class AutoScheme: enable_torch_compile: Optional[bool] = None disable_opt_rtn: bool = True low_gpu_mem_usage: bool = True + low_cpu_mem_usage: bool = False # Enable disk offload for CPU RAM optimization def __post_init__(self): if isinstance(self.options, str): diff --git a/auto_round/auto_scheme/utils.py b/auto_round/auto_scheme/utils.py index 3c19acc42..57ac3c96a 100644 --- a/auto_round/auto_scheme/utils.py +++ b/auto_round/auto_scheme/utils.py @@ -105,7 +105,10 @@ def compute_avg_bits_for_scheme( # continue if not hasattr(module, "weight"): continue - total_params += module.weight.numel() + n_param = module.weight.numel() + if n_param == 0 and hasattr(module, "_cached_weight_numel"): + n_param = module._cached_weight_numel + total_params += n_param layer_bits, _ = compute_layer_bits(module, ignore_scale_zp_bits) total_quantized_bits += layer_bits avg_bits = float(total_quantized_bits) / total_params @@ -133,7 +136,10 @@ def compute_avg_bits_for_model(model: torch.nn.Module, ignore_scale_zp_bits: boo continue if not hasattr(module, "weight"): continue - total_params += module.weight.numel() + n_param = module.weight.numel() + if n_param == 0 and hasattr(module, "_cached_weight_numel"): + n_param = module._cached_weight_numel + total_params += n_param layer_bits, _ = compute_layer_bits(module, ignore_scale_zp_bits) total_quantized_bits += layer_bits @@ -157,6 +163,9 @@ def compute_layer_bits( """ weight = layer.weight n_param = weight.numel() + # Use cached numel when weight has been cleared to an empty tensor (low_cpu_mem_usage offload) + if n_param == 0 and hasattr(layer, "_cached_weight_numel"): + n_param = layer._cached_weight_numel weight_bits = getattr(layer, "bits", 16) group_size = getattr(layer, "group_size", 128) data_type = getattr(layer, "data_type", "int") diff --git a/test/test_cpu/schemes/test_auto_scheme_low_cpu_mem.py b/test/test_cpu/schemes/test_auto_scheme_low_cpu_mem.py new file mode 100644 index 000000000..dcfcb411c --- /dev/null +++ b/test/test_cpu/schemes/test_auto_scheme_low_cpu_mem.py @@ -0,0 +1,450 @@ +# Copyright (c) 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Unit tests for AutoScheme CPU RAM optimization (low_cpu_mem_usage option). + +This tests the disk offload mechanism for AutoScheme that reduces CPU RAM usage +by offloading block weights to disk during gradient computation. +""" + +import os +import shutil + +import pytest +import torch + +from auto_round import AutoRound, AutoScheme +from auto_round.auto_scheme.delta_loss import ( + AutoSchemeOffloadContext, + _clear_module_weights, + _group_layers_by_block, +) +from auto_round.auto_scheme.utils import compute_layer_bits +from auto_round.utils import get_block_names, get_module + + +class TestAutoSchemeOffloadContext: + """Tests for AutoSchemeOffloadContext class.""" + + def test_context_init_disabled(self): + """Test that context is properly initialized when disabled.""" + ctx = AutoSchemeOffloadContext(low_cpu_mem_usage=False) + assert ctx.low_cpu_mem_usage is False + assert ctx._offload_tempdir is None + assert ctx._offloaded_blocks == {} + + def test_context_init_enabled(self): + """Test that context is properly initialized when enabled.""" + ctx = AutoSchemeOffloadContext(low_cpu_mem_usage=True) + assert ctx.low_cpu_mem_usage is True + assert ctx._offload_tempdir is None + assert ctx._offloaded_blocks == {} + + def test_init_offload_dir_disabled(self): + """Test that init_offload_dir returns None when disabled.""" + ctx = AutoSchemeOffloadContext(low_cpu_mem_usage=False) + result = ctx.init_offload_dir() + assert result is None + assert ctx._offload_tempdir is None + + def test_init_offload_dir_enabled(self): + """Test that init_offload_dir creates temp directory when enabled.""" + ctx = AutoSchemeOffloadContext(low_cpu_mem_usage=True) + try: + result = ctx.init_offload_dir() + assert result is not None + assert os.path.isdir(result) + assert "autoscheme_offload_" in result + finally: + ctx.cleanup() + + def test_offload_block_weights_disabled(self): + """Test that offload_block_weights does nothing when disabled.""" + ctx = AutoSchemeOffloadContext(low_cpu_mem_usage=False) + module = torch.nn.Linear(4, 4) + ctx.offload_block_weights("test_block", module) + assert ctx._offloaded_blocks == {} + + def test_offload_block_weights_enabled(self): + """Test that offload_block_weights saves weights to disk.""" + ctx = AutoSchemeOffloadContext(low_cpu_mem_usage=True) + try: + module = torch.nn.Linear(4, 4) + ctx.offload_block_weights("test_block", module) + + assert "test_block" in ctx._offloaded_blocks + metadata = ctx._offloaded_blocks["test_block"] + assert "save_path" in metadata + assert os.path.exists(metadata["save_path"]) + + # Verify weight was cleared + assert module.weight.numel() == 0 + finally: + ctx.cleanup() + + def test_load_block_weights_disabled(self): + """Test that load_block_weights does nothing when disabled.""" + ctx = AutoSchemeOffloadContext(low_cpu_mem_usage=False) + module = torch.nn.Linear(4, 4) + ctx.load_block_weights("test_block", module) + + def test_offload_and_load_block_weights(self): + """Test full cycle of offloading and loading block weights.""" + ctx = AutoSchemeOffloadContext(low_cpu_mem_usage=True) + try: + module = torch.nn.Linear(4, 4) + original_weight = module.weight.clone() + original_bias = module.bias.clone() + + # Offload + ctx.offload_block_weights("test_block", module) + assert module.weight.numel() == 0 + + # Load back + ctx.load_block_weights("test_block", module) + + # Verify weights are restored + assert module.weight.shape == original_weight.shape + assert torch.allclose(module.weight, original_weight) + assert module.bias.shape == original_bias.shape + assert torch.allclose(module.bias, original_bias) + finally: + ctx.cleanup() + + def test_offload_block_weights_idempotent(self): + """Test that offloading same block twice doesn't create duplicate files.""" + ctx = AutoSchemeOffloadContext(low_cpu_mem_usage=True) + try: + module = torch.nn.Linear(4, 4) + + ctx.offload_block_weights("test_block", module) + first_path = ctx._offloaded_blocks["test_block"]["save_path"] + + # Create new module and try to offload again + module2 = torch.nn.Linear(4, 4) + ctx.offload_block_weights("test_block", module2) + + # Should still have same path (second offload skipped) + assert ctx._offloaded_blocks["test_block"]["save_path"] == first_path + finally: + ctx.cleanup() + + def test_cleanup(self): + """Test that cleanup removes temp directory.""" + ctx = AutoSchemeOffloadContext(low_cpu_mem_usage=True) + ctx.init_offload_dir() + tempdir = ctx._offload_tempdir + assert os.path.isdir(tempdir) + + ctx.cleanup() + + assert not os.path.exists(tempdir) + assert ctx._offload_tempdir is None + assert ctx._offloaded_blocks == {} + + def test_save_and_load_original_block_weights(self): + """Test saving and loading original (unwrapped) block weights.""" + ctx = AutoSchemeOffloadContext(low_cpu_mem_usage=True) + try: + module = torch.nn.Linear(8, 4) + original_weight = module.weight.data.clone() + + ctx.save_original_block_weights("block.0", module) + assert "block.0" in ctx._original_blocks + + # Clear and load back + _clear_module_weights(module) + assert module.weight.numel() == 0 + + ctx.load_original_block_weights("block.0", module) + assert module.weight.numel() == original_weight.numel() + assert torch.allclose(module.weight.data, original_weight) + finally: + ctx.cleanup() + + def test_save_original_idempotent(self): + """Test that saving original block twice doesn't overwrite.""" + ctx = AutoSchemeOffloadContext(low_cpu_mem_usage=True) + try: + module = torch.nn.Linear(4, 4) + ctx.save_original_block_weights("b", module) + path1 = ctx._original_blocks["b"]["save_path"] + + # Modify module and save again — should be a no-op + module.weight.data.fill_(999.0) + ctx.save_original_block_weights("b", module) + path2 = ctx._original_blocks["b"]["save_path"] + assert path1 == path2 + finally: + ctx.cleanup() + + def test_reset_scheme_state(self): + """Test that reset_scheme_state clears wrapped state but keeps originals.""" + ctx = AutoSchemeOffloadContext(low_cpu_mem_usage=True) + try: + module = torch.nn.Linear(4, 4) + ctx.save_original_block_weights("b", module) + ctx.offload_block_weights("b", module) + + assert len(ctx._offloaded_blocks) == 1 + assert len(ctx._original_blocks) == 1 + + ctx.reset_scheme_state() + assert len(ctx._offloaded_blocks) == 0 + assert len(ctx._original_blocks) == 1 + finally: + ctx.cleanup() + + def test_cleanup_removes_both_dirs(self): + """Test that cleanup removes both original and offload directories.""" + ctx = AutoSchemeOffloadContext(low_cpu_mem_usage=True) + module = torch.nn.Linear(4, 4) + ctx.save_original_block_weights("b", module) + ctx.offload_block_weights("b", module) + + orig_dir = ctx._original_dir + offload_dir = ctx._offload_tempdir + assert os.path.isdir(orig_dir) + assert os.path.isdir(offload_dir) + + ctx.cleanup() + assert not os.path.exists(orig_dir) + assert not os.path.exists(offload_dir) + + +class TestClearModuleWeights: + """Tests for _clear_module_weights helper function.""" + + def test_clear_linear_weights(self): + """Test clearing weights from a Linear module.""" + module = torch.nn.Linear(16, 8) + _clear_module_weights(module) + + assert module.weight.numel() == 0 + assert module.bias.numel() == 0 + + def test_clear_caches_numel(self): + """Test that clearing weights caches the original numel.""" + module = torch.nn.Linear(32, 16) + expected = 32 * 16 + _clear_module_weights(module) + assert hasattr(module, "_cached_weight_numel") + assert module._cached_weight_numel == expected + + def test_compute_layer_bits_with_empty_weight(self): + """Test compute_layer_bits works after weight is cleared.""" + layer = torch.nn.Linear(64, 128) + layer.bits = 4 + layer.group_size = 32 + layer.data_type = "int" + layer.sym = True + layer.super_group_size = None + layer.super_bits = None + + bits_before, _ = compute_layer_bits(layer, False) + _clear_module_weights(layer) + bits_after, _ = compute_layer_bits(layer, False) + assert bits_before == bits_after + + def test_clear_module_none(self): + """Test clearing weights with None module doesn't crash.""" + _clear_module_weights(None) # Should not raise + + def test_clear_module_no_weights(self): + """Test clearing module without weights doesn't crash.""" + module = torch.nn.ReLU() + _clear_module_weights(module) # Should not raise + + +class TestAutoSchemeDataclassLowCpuMem: + """Tests for low_cpu_mem_usage parameter in AutoScheme dataclass.""" + + def test_auto_scheme_default_low_cpu_mem_usage(self): + """Test that low_cpu_mem_usage defaults to False.""" + scheme = AutoScheme(avg_bits=4, options="W4A16") + assert scheme.low_cpu_mem_usage is False + + def test_auto_scheme_low_cpu_mem_usage_enabled(self): + """Test that low_cpu_mem_usage can be enabled.""" + scheme = AutoScheme(avg_bits=4, options="W4A16", low_cpu_mem_usage=True) + assert scheme.low_cpu_mem_usage is True + + def test_auto_scheme_low_cpu_mem_usage_with_low_gpu_mem_usage(self): + """Test that both low_cpu_mem_usage and low_gpu_mem_usage can be set.""" + scheme = AutoScheme( + avg_bits=4, + options="W4A16", + low_cpu_mem_usage=True, + low_gpu_mem_usage=True, + ) + assert scheme.low_cpu_mem_usage is True + assert scheme.low_gpu_mem_usage is True + + +class TestGroupLayersByBlock: + """Tests for _group_layers_by_block helper.""" + + def test_basic_grouping(self): + layers = ["model.layers.0.attn.q", "model.layers.0.attn.k", + "model.layers.1.mlp.fc1", "lm_head"] + blocks = ["model.layers.0", "model.layers.1"] + groups, non_block = _group_layers_by_block(layers, blocks) + assert groups["model.layers.0"] == ["model.layers.0.attn.q", "model.layers.0.attn.k"] + assert groups["model.layers.1"] == ["model.layers.1.mlp.fc1"] + assert non_block == ["lm_head"] + + def test_empty_layers(self): + groups, non_block = _group_layers_by_block([], ["b0"]) + assert groups == {"b0": []} + assert non_block == [] + + +class TestAutoSchemeIntegration: + """Integration tests for AutoScheme with low_cpu_mem_usage.""" + + save_dir = "./saved" + + @pytest.fixture(autouse=True, scope="class") + def setup_and_teardown_class(self): + yield + shutil.rmtree("./saved", ignore_errors=True) + shutil.rmtree("runs", ignore_errors=True) + + def test_auto_scheme_with_low_cpu_mem_disabled(self, tiny_opt_model_path): + """Test AutoScheme works normally with low_cpu_mem_usage disabled.""" + model_name = tiny_opt_model_path + scheme = AutoScheme( + avg_bits=4, + options="W2A16,W4A16", + nsamples=1, + ignore_scale_zp_bits=True, + low_cpu_mem_usage=False, + low_gpu_mem_usage=True, + ) + ar = AutoRound(model=model_name, scheme=scheme, iters=0, nsamples=1) + _, layer_config = ar.quantize() + assert layer_config is not None + assert len(layer_config) > 0 + + def test_auto_scheme_with_low_cpu_mem_enabled(self, tiny_opt_model_path): + """Test AutoScheme works with low_cpu_mem_usage enabled.""" + model_name = tiny_opt_model_path + scheme = AutoScheme( + avg_bits=4, + options="W2A16,W4A16", + nsamples=1, + ignore_scale_zp_bits=True, + low_cpu_mem_usage=True, + low_gpu_mem_usage=True, + ) + ar = AutoRound(model=model_name, scheme=scheme, iters=0, nsamples=1) + _, layer_config = ar.quantize() + assert layer_config is not None + assert len(layer_config) > 0 + + def test_auto_scheme_low_cpu_mem_results_consistent(self, tiny_opt_model_path): + """Test that results are consistent with and without low_cpu_mem_usage.""" + model_name = tiny_opt_model_path + + # Without low_cpu_mem_usage + scheme1 = AutoScheme( + avg_bits=4, + options="W4A16", + nsamples=1, + ignore_scale_zp_bits=True, + low_cpu_mem_usage=False, + low_gpu_mem_usage=True, + ) + ar1 = AutoRound(model=model_name, scheme=scheme1, iters=0, nsamples=1, seed=42) + _, layer_config1 = ar1.quantize() + + # With low_cpu_mem_usage + scheme2 = AutoScheme( + avg_bits=4, + options="W4A16", + nsamples=1, + ignore_scale_zp_bits=True, + low_cpu_mem_usage=True, + low_gpu_mem_usage=True, + ) + ar2 = AutoRound(model=model_name, scheme=scheme2, iters=0, nsamples=1, seed=42) + _, layer_config2 = ar2.quantize() + + # Layer configs should have same keys + assert set(layer_config1.keys()) == set(layer_config2.keys()) + + +class TestAutoSchemeOffloadContextWithModel: + """Tests for AutoSchemeOffloadContext with actual model blocks.""" + + def test_offload_model_block(self, tiny_opt_model_path): + """Test offloading an actual model block.""" + from transformers import AutoModelForCausalLM + + model = AutoModelForCausalLM.from_pretrained(tiny_opt_model_path, torch_dtype=torch.float32) + block_names = get_block_names(model)[0] + + if len(block_names) == 0: + pytest.skip("Model has no blocks") + + block_name = block_names[0] + block = get_module(model, block_name) + + ctx = AutoSchemeOffloadContext(low_cpu_mem_usage=True) + try: + # Get original param count + original_params = sum(p.numel() for p in block.parameters()) + + # Offload + ctx.offload_block_weights(block_name, block) + + # Check that weights were cleared + current_params = sum(p.numel() for p in block.parameters()) + assert current_params < original_params + + # Load back + ctx.load_block_weights(block_name, block) + + # Check params are restored + restored_params = sum(p.numel() for p in block.parameters()) + assert restored_params == original_params + finally: + ctx.cleanup() + + def test_offload_all_blocks(self, tiny_opt_model_path): + """Test offloading all model blocks.""" + from transformers import AutoModelForCausalLM + + model = AutoModelForCausalLM.from_pretrained(tiny_opt_model_path, torch_dtype=torch.float32) + block_names = get_block_names(model)[0] + + if len(block_names) == 0: + pytest.skip("Model has no blocks") + + ctx = AutoSchemeOffloadContext(low_cpu_mem_usage=True) + try: + # Offload all blocks + ctx.offload_all_blocks(model, block_names) + + # Verify all blocks were offloaded + for block_name in block_names: + assert block_name in ctx._offloaded_blocks + block = get_module(model, block_name) + # Check weights are cleared (should have very few params) + block_params = sum(p.numel() for p in block.parameters()) + assert block_params == 0 or block_params < 100 # Allow for some edge cases + finally: + ctx.cleanup() + From e62a7088eded67bf1516709987091eb1776183d0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 9 Feb 2026 06:41:39 +0000 Subject: [PATCH 13/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- auto_round/auto_scheme/delta_loss.py | 38 ++++++------------- auto_round/compressors/base.py | 2 +- .../schemes/test_auto_scheme_low_cpu_mem.py | 6 +-- 3 files changed, 14 insertions(+), 32 deletions(-) diff --git a/auto_round/auto_scheme/delta_loss.py b/auto_round/auto_scheme/delta_loss.py index 269ad0bf5..8d95828f1 100644 --- a/auto_round/auto_scheme/delta_loss.py +++ b/auto_round/auto_scheme/delta_loss.py @@ -17,14 +17,13 @@ import os import shutil import tempfile - -from safetensors.torch import load_file, save_file from dataclasses import asdict from functools import wraps from typing import Any, Iterable, Optional, Union import torch from accelerate import dispatch_model +from safetensors.torch import load_file, save_file from tqdm import tqdm from auto_round.auto_scheme.gen_auto_scheme import AutoScheme @@ -161,8 +160,7 @@ def _load_state_into_block(self, save_path: str, block: torch.nn.Module) -> None if hasattr(target, param_name): old_param = getattr(target, param_name) if isinstance(old_param, torch.nn.Parameter): - setattr(target, param_name, - torch.nn.Parameter(param, requires_grad=old_param.requires_grad)) + setattr(target, param_name, torch.nn.Parameter(param, requires_grad=old_param.requires_grad)) else: setattr(target, param_name, param) del state_dict @@ -183,9 +181,7 @@ def load_original_block_weights(self, block_name: str, block: torch.nn.Module) - except Exception as e: logger.warning(f"Failed to load original block {block_name}: {e}") - def save_and_clear_all_original_blocks( - self, model: torch.nn.Module, block_names: list[str] - ) -> None: + def save_and_clear_all_original_blocks(self, model: torch.nn.Module, block_names: list[str]) -> None: """Save all original block weights to disk and clear them from RAM.""" if not self.low_cpu_mem_usage: return @@ -200,9 +196,7 @@ def save_and_clear_all_original_blocks( clear_memory() logger.info("AutoScheme: original weights saved and cleared") - def load_all_original_blocks( - self, model: torch.nn.Module, block_names: list[str] - ) -> None: + def load_all_original_blocks(self, model: torch.nn.Module, block_names: list[str]) -> None: """Load all original block weights back into RAM.""" if not self.low_cpu_mem_usage: return @@ -311,16 +305,14 @@ def _clear_module_weights(module: torch.nn.Module) -> None: module._cached_weight_shape = tuple(module.weight.shape) if isinstance(module.weight, torch.nn.Parameter): module.weight = torch.nn.Parameter( - torch.empty(0, dtype=module.weight.dtype, device="cpu"), - requires_grad=module.weight.requires_grad + torch.empty(0, dtype=module.weight.dtype, device="cpu"), requires_grad=module.weight.requires_grad ) else: module.weight = torch.empty(0, dtype=module.weight.dtype, device="cpu") if hasattr(module, "bias") and module.bias is not None: if isinstance(module.bias, torch.nn.Parameter): module.bias = torch.nn.Parameter( - torch.empty(0, dtype=module.bias.dtype, device="cpu"), - requires_grad=module.bias.requires_grad + torch.empty(0, dtype=module.bias.dtype, device="cpu"), requires_grad=module.bias.requires_grad ) else: module.bias = torch.empty(0, dtype=module.bias.dtype, device="cpu") @@ -877,9 +869,7 @@ def wrap_layer(name: str) -> None: for name in quant_layer_names: wrap_layer(name) - model_forward_low_gpu( - model, dataloader, major_device=major_device, pbar=pbar, offload_context=offload_context - ) + model_forward_low_gpu(model, dataloader, major_device=major_device, pbar=pbar, offload_context=offload_context) # NOTE: do NOT load all blocks back — scores are read block-by-block below else: @@ -919,9 +909,7 @@ def wrap_layer(name: str) -> None: f"layer {full_name} max abs activation is 0, " "please use more data to improve the accuracy" ) - layer_bits, _ = compute_layer_bits( - m.orig_layer, ignore_scale_zp_bits=ignore_scale_zp_bits - ) + layer_bits, _ = compute_layer_bits(m.orig_layer, ignore_scale_zp_bits=ignore_scale_zp_bits) scores_dict[full_name] = [layer_bits, m.mix_score] # Unwrap layers in this block @@ -943,12 +931,9 @@ def wrap_layer(name: str) -> None: if hasattr(m, "mix_score") and n not in scores_dict: if m.orig_layer.act_bits <= 8 and m.act_cnt == 0: logger.warning_once( - f"layer {n} max abs activation is 0, " - "please use more data to improve the accuracy" + f"layer {n} max abs activation is 0, " "please use more data to improve the accuracy" ) - layer_bits, _ = compute_layer_bits( - m.orig_layer, ignore_scale_zp_bits=ignore_scale_zp_bits - ) + layer_bits, _ = compute_layer_bits(m.orig_layer, ignore_scale_zp_bits=ignore_scale_zp_bits) scores_dict[n] = [layer_bits, m.mix_score] for n, m in model.named_modules(): if hasattr(m, "orig_layer"): @@ -963,8 +948,7 @@ def wrap_layer(name: str) -> None: if m.orig_layer.act_bits <= 8: if m.act_cnt == 0: logger.warning_once( - f"layer {n} max abs activation is 0, " - "please use more data to improve the accuracy" + f"layer {n} max abs activation is 0, " "please use more data to improve the accuracy" ) layer_bits, _ = compute_layer_bits(m.orig_layer, ignore_scale_zp_bits=ignore_scale_zp_bits) scores_dict[n] = [layer_bits, m.mix_score] diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index 66c079542..da79de56e 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -114,6 +114,7 @@ set_auto_device_map_for_block_with_tuning, set_non_auto_device_map, ) +from auto_round.utils.distributed import setup_ddp_if_needed_ from auto_round.utils.model import ( cleanup_cpu_offload_dir, discard_offloaded_block, @@ -124,7 +125,6 @@ restore_offloaded_blocks, stream_offload_blocks, ) -from auto_round.utils.distributed import setup_ddp_if_needed_ from auto_round.wrapper import WrapperLinear, WrapperMultiblock, unwrapper_block, unwrapper_layer, wrapper_block SERIALIZATION_KEYS = ( diff --git a/test/test_cpu/schemes/test_auto_scheme_low_cpu_mem.py b/test/test_cpu/schemes/test_auto_scheme_low_cpu_mem.py index dcfcb411c..6fb72d7f9 100644 --- a/test/test_cpu/schemes/test_auto_scheme_low_cpu_mem.py +++ b/test/test_cpu/schemes/test_auto_scheme_low_cpu_mem.py @@ -88,7 +88,7 @@ def test_offload_block_weights_enabled(self): metadata = ctx._offloaded_blocks["test_block"] assert "save_path" in metadata assert os.path.exists(metadata["save_path"]) - + # Verify weight was cleared assert module.weight.numel() == 0 finally: @@ -297,8 +297,7 @@ class TestGroupLayersByBlock: """Tests for _group_layers_by_block helper.""" def test_basic_grouping(self): - layers = ["model.layers.0.attn.q", "model.layers.0.attn.k", - "model.layers.1.mlp.fc1", "lm_head"] + layers = ["model.layers.0.attn.q", "model.layers.0.attn.k", "model.layers.1.mlp.fc1", "lm_head"] blocks = ["model.layers.0", "model.layers.1"] groups, non_block = _group_layers_by_block(layers, blocks) assert groups["model.layers.0"] == ["model.layers.0.attn.q", "model.layers.0.attn.k"] @@ -447,4 +446,3 @@ def test_offload_all_blocks(self, tiny_opt_model_path): assert block_params == 0 or block_params < 100 # Allow for some edge cases finally: ctx.cleanup() -