Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
525 changes: 485 additions & 40 deletions auto_round/auto_scheme/delta_loss.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions auto_round/auto_scheme/gen_auto_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class AutoScheme:
enable_torch_compile: Optional[bool] = None
disable_opt_rtn: bool = True
low_gpu_mem_usage: bool = True
low_cpu_mem_usage: bool = False # Enable disk offload for CPU RAM optimization

def __post_init__(self):
if isinstance(self.options, str):
Expand Down
13 changes: 11 additions & 2 deletions auto_round/auto_scheme/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,10 @@ def compute_avg_bits_for_scheme(
# continue
if not hasattr(module, "weight"):
continue
total_params += module.weight.numel()
n_param = module.weight.numel()
if n_param == 0 and hasattr(module, "_cached_weight_numel"):
n_param = module._cached_weight_numel
total_params += n_param
layer_bits, _ = compute_layer_bits(module, ignore_scale_zp_bits)
total_quantized_bits += layer_bits
avg_bits = float(total_quantized_bits) / total_params
Expand Down Expand Up @@ -133,7 +136,10 @@ def compute_avg_bits_for_model(model: torch.nn.Module, ignore_scale_zp_bits: boo
continue
if not hasattr(module, "weight"):
continue
total_params += module.weight.numel()
n_param = module.weight.numel()
if n_param == 0 and hasattr(module, "_cached_weight_numel"):
n_param = module._cached_weight_numel
total_params += n_param
layer_bits, _ = compute_layer_bits(module, ignore_scale_zp_bits)
total_quantized_bits += layer_bits

Expand All @@ -157,6 +163,9 @@ def compute_layer_bits(
"""
weight = layer.weight
n_param = weight.numel()
# Use cached numel when weight has been cleared to an empty tensor (low_cpu_mem_usage offload)
if n_param == 0 and hasattr(layer, "_cached_weight_numel"):
n_param = layer._cached_weight_numel
weight_bits = getattr(layer, "bits", 16)
group_size = getattr(layer, "group_size", 128)
data_type = getattr(layer, "data_type", "int")
Expand Down
165 changes: 156 additions & 9 deletions auto_round/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import copy
import os
import re
import sys
import time
import traceback
Expand Down Expand Up @@ -94,6 +95,7 @@
is_moe_model,
is_quantized_input_module,
llm_load_model,
load_module_weights,
memory_monitor,
mv_module_from_gpu,
set_amax_for_all_moe_layers,
Expand All @@ -104,12 +106,25 @@
)
from auto_round.utils.device import (
clear_memory_if_reached_threshold,
estimate_inputs_size_gb,
estimate_model_size_gb,
estimate_tensor_size_gb,
get_major_device,
parse_available_devices,
set_auto_device_map_for_block_with_tuning,
set_non_auto_device_map,
)
from auto_round.utils.distributed import setup_ddp_if_needed_
from auto_round.utils.model import (
cleanup_cpu_offload_dir,
discard_offloaded_block,
estimate_block_size_gb,
init_cpu_offload_dir,
load_offloaded_block_weights,
offload_block_weights,
restore_offloaded_blocks,
stream_offload_blocks,
)
from auto_round.wrapper import WrapperLinear, WrapperMultiblock, unwrapper_block, unwrapper_layer, wrapper_block

SERIALIZATION_KEYS = (
Expand Down Expand Up @@ -356,6 +371,10 @@ def __init__(
self.inner_supported_types = INNER_SUPPORTED_LAYER_TYPES
self.scale_dtype = convert_dtype_str2torch(scale_dtype)
self.low_cpu_mem_usage = low_cpu_mem_usage
self.cpu_stream_offload_blocks = kwargs.pop("cpu_stream_offload_blocks", False)
self.cpu_stream_loss = kwargs.pop("cpu_stream_loss", False)
self._cpu_offload_tempdir = None
self._offloaded_blocks = {}

if kwargs:
logger.warning(f"unrecognized keys {list(kwargs.keys())} were passed. Please check them.")
Expand Down Expand Up @@ -952,7 +971,6 @@ def quantize_and_save(
output_dir, format=self.formats, inplace=inplace, return_folders=True, **kwargs
)
memory_monitor.log_summary()

return model, folders

def _get_save_folder_name(self, format: OutputFormat) -> str:
Expand Down Expand Up @@ -1534,6 +1552,8 @@ def _quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str])
if not self.is_immediate_saving:
# some modules may have been flushed and set to meta, so we could not move to gpu
mv_module_from_gpu(block)
if self.low_cpu_mem_usage:
offload_block_weights(self, block_name, block)
if block_name == block_names[-1]:
clear_memory(input_ids, device_list=self.device_list)
else:
Expand Down Expand Up @@ -1700,6 +1720,9 @@ def quantize(self) -> tuple[torch.nn.Module, dict[str, Any]]:

self.configure_layer_config(enable_gguf_official_mixed=enable_gguf_official_mixed)

if self.low_cpu_mem_usage:
self._offloaded_blocks = {}

def _should_disable_inplace_due_to_layers_outside_block() -> bool:
return self.has_qlayer_outside_block and (self.iters != 0 or (self.iters == 0 and not self.disable_opt_rtn))

Expand Down Expand Up @@ -1737,9 +1760,14 @@ def _should_disable_inplace_due_to_layers_outside_block() -> bool:
)
else:
logger.info("start to cache block inputs")

all_inputs = self.try_cache_inter_data_gpucpu(all_first_block_names, self.nsamples, layer_names=layer_names)
is_quantized_embedding = self._quantize_embedding_layer()
clear_memory(device_list=self.device_list)
# Log memory breakdown for calibration inputs
if self.low_cpu_mem_usage:
inputs_size_gb = estimate_inputs_size_gb(all_inputs)
logger.info(f"[Memory] calibration inputs size: {inputs_size_gb:.2f} GB")
all_q_inputs = None
if is_quantized_embedding:
all_inputs = copy.deepcopy(self.inputs)
Expand All @@ -1752,6 +1780,12 @@ def _should_disable_inplace_due_to_layers_outside_block() -> bool:
if hasattr(self.model, "hf_device_map") and len(self.model.hf_device_map) > 1:
accelerate.hooks.remove_hook_from_submodules(self.model) # self.model.hf_device_map has not been changed
logger.info("caching done")
# Log memory breakdown for model weights
if self.low_cpu_mem_usage:
model_size_gb = estimate_model_size_gb(self.model)
logger.info(f"[Memory] model weights size: {model_size_gb:.2f} GB")
if self.low_cpu_mem_usage and self.cpu_stream_offload_blocks:
stream_offload_blocks(self, all_blocks)
if len(all_blocks) > 1:
pbar = tqdm(range(0, sum([len(i) for i in all_blocks]), self.nblocks))
else:
Expand Down Expand Up @@ -1797,6 +1831,10 @@ def _should_disable_inplace_due_to_layers_outside_block() -> bool:
if self.is_immediate_saving:
shard_writer(self, is_finalize=True)

if self.low_cpu_mem_usage:
restore_offloaded_blocks(self)
cleanup_cpu_offload_dir(self)

end_time = time.time()
cost_time = end_time - start_time
logger.info(f"quantization tuning time {cost_time}")
Expand Down Expand Up @@ -1877,6 +1915,7 @@ def _quantize_layers(self, layer_names: list, layer_inputs: dict) -> None:
memory_monitor.update()
memory_monitor.log_summary()
return

q_layer_inputs = None
enable_quanted_input = self.enable_quanted_input
has_gguf = False
Expand Down Expand Up @@ -2706,6 +2745,29 @@ def _get_current_output(self, output: list[torch.Tensor], indices: list[int]) ->
current_output = torch.cat(current_output, dim=self.batch_dim)
return current_output

def _get_current_output_stream(
self,
block: torch.nn.Module,
input_ids: list[torch.Tensor],
input_others: dict,
indices: list[int],
device: str,
cache_device: str = "cpu",
) -> torch.Tensor:
current_input_ids, current_input_others = self._sampling_inputs(
input_ids,
input_others,
indices,
seqlen=self.seqlen,
batch_dim=self.batch_dim,
share_cache_keys=self.shared_cache_keys,
)
with torch.no_grad():
output = self.block_forward(
block, current_input_ids, current_input_others, self.amp, self.amp_dtype, device
)
return output.to(cache_device)

def _get_current_q_output(
self,
block: torch.nn.Module,
Expand Down Expand Up @@ -2841,24 +2903,55 @@ def _quantize_block(
hook = AlignDevicesHook(m.tuning_device, io_same_device=True)
add_hook_to_module(m, hook, True)

stream_loss = self.cpu_stream_loss and self.nblocks == 1
if self.cpu_stream_loss and self.nblocks != 1:
logger.warning("cpu_stream_loss only supports nblocks=1; falling back to cached outputs.")
stream_loss = False

if q_input is None:
hook_handles = self._register_act_max_hook(block)

output = self._get_block_outputs(
block, input_ids, input_others, self.batch_size * self.infer_bs_coeff, device, self.cache_device
)
if stream_loss:
output = None
self._get_block_outputs(
block,
input_ids,
input_others,
self.batch_size * self.infer_bs_coeff,
device,
self.cache_device,
save_output=False,
)
else:
output = self._get_block_outputs(
block, input_ids, input_others, self.batch_size * self.infer_bs_coeff, device, self.cache_device
)

# Log output cache size for first block
if (
self.low_cpu_mem_usage
and self.cpu_stream_offload_blocks
and not hasattr(self, "_logged_output_size")
):
output_size = estimate_tensor_size_gb(output)
logger.info(f"[Memory] block output cache size: {output_size:.2f} GB")
self._logged_output_size = True

for handle in hook_handles:
handle.remove()
else:
output = self._get_block_outputs(
block, input_ids, input_others, self.batch_size * self.infer_bs_coeff, device, self.cache_device
)
if stream_loss:
output = None
# Skip pre-computation in stream_loss mode - targets will be computed on-the-fly with frozen_block
else:
output = self._get_block_outputs(
block, input_ids, input_others, self.batch_size * self.infer_bs_coeff, device, self.cache_device
)
hook_handles = self._register_act_max_hook(block)
if hook_handles:
self._get_block_outputs(
block,
q_input,
q_input if q_input is not None else input_ids,
input_others,
self.batch_size * self.infer_bs_coeff,
device,
Expand All @@ -2876,6 +2969,13 @@ def _quantize_block(
clear_memory(device_list=self.device_list)
input_ids = q_input

frozen_block = None
if stream_loss:
frozen_block = copy.deepcopy(block).to(device)
frozen_block.eval()
for p in frozen_block.parameters():
p.requires_grad_(False)

quantized_layer_names, unquantized_layer_names = self.wrapper_block(
block,
self.enable_minmax_tuning,
Expand Down Expand Up @@ -2972,7 +3072,12 @@ def _quantize_block(

for tmp_step in range(self.gradient_accumulate_steps):
indices = global_indices[tmp_step * batch_size : (tmp_step + 1) * batch_size]
current_output = self._get_current_output(output, indices)
if stream_loss:
current_output = self._get_current_output_stream(
frozen_block, input_ids, input_others, indices, loss_device, cache_device=loss_device
)
else:
current_output = self._get_current_output(output, indices)
current_output = to_device(current_output, loss_device)
output_q = self._get_current_q_output(block, input_ids, input_others, indices, device, loss_device)
loss = self._get_loss(output_q, current_output, indices, mse_loss, device)
Expand Down Expand Up @@ -3055,6 +3160,16 @@ def _quantize_block(

return q_outputs, output
else:
# When stream_loss is enabled, output is None, so we need to compute it for the next block
if stream_loss and output is None:
output = self._get_block_outputs(
block,
input_ids,
input_others,
self.batch_size * self.infer_bs_coeff,
device,
cache_device=self.cache_device,
)
if len(self.device_list) > 1 and auto_offload:
accelerate.hooks.remove_hook_from_submodules(block)
if auto_offload:
Expand Down Expand Up @@ -3119,6 +3234,14 @@ def _quantize_blocks(

input_ids, input_others = self._preprocess_block_inputs(inputs)

# Log detailed memory breakdown for first block
if self.low_cpu_mem_usage and self.cpu_stream_offload_blocks:
input_ids_size = estimate_tensor_size_gb(input_ids)
input_others_size = estimate_tensor_size_gb(input_others)
logger.info(
f"[Memory] input_ids size: {input_ids_size:.2f} GB, input_others size: {input_others_size:.2f} GB"
)

if pbar is None:
pbar = tqdm(range(0, len(block_names), nblocks))

Expand All @@ -3135,14 +3258,38 @@ def _quantize_blocks(
modules = [get_module(model, n) for n in names]
m = WrapperMultiblock(modules)

if self.low_cpu_mem_usage and self.cpu_stream_offload_blocks:
if nblocks == 1:
load_offloaded_block_weights(self, n, get_module(model, n))
if i == 0: # Log only for first block
block_size = estimate_block_size_gb(get_module(model, n))
logger.info(f"[Memory] loaded block weights size: {block_size:.2f} GB")
else:
for name in names:
load_offloaded_block_weights(self, name, get_module(model, name))

m.config = model.config if hasattr(model, "config") else None

q_input, input_ids = self._quantize_block(
m,
input_ids,
input_others,
q_input=q_input,
device=device,
)

if self.low_cpu_mem_usage and not self.cpu_stream_offload_blocks:
if nblocks == 1:
offload_block_weights(self, n, get_module(model, n))
else:
for name in names:
offload_block_weights(self, name, get_module(model, name))
if self.low_cpu_mem_usage and self.cpu_stream_offload_blocks:
if nblocks == 1:
discard_offloaded_block(self, n)
else:
for name in names:
discard_offloaded_block(self, name)
if hasattr(model, "config"):
del m.config
if self.is_immediate_packing:
Expand Down
Loading