diff --git a/README.md b/README.md index 82f712023..8ae1c220e 100644 --- a/README.md +++ b/README.md @@ -119,7 +119,7 @@ pip install auto-round-hpu ## Model Quantization (CPU/Intel GPU/Gaudi/CUDA) ->If you encounter issues during quantization, try using pure RTN mode with iters=0, disable_opt_rtn=True. Additionally, using group_size=32 or mixed bits is recommended for better results.. +>If you encounter issues during quantization, try using pure RTN mode with iters=0, disable_opt_rtn=True. Additionally, using group_size=32 or mixed bits is recommended for better results. ### CLI Usage The full list of supported arguments is provided by calling `auto-round -h` on the terminal. diff --git a/auto_round/auto_scheme/utils.py b/auto_round/auto_scheme/utils.py index 3c19acc42..aa1e185a5 100644 --- a/auto_round/auto_scheme/utils.py +++ b/auto_round/auto_scheme/utils.py @@ -246,7 +246,7 @@ def dispatch_model_by_all_available_devices( else: raise ValueError(f"Unsupported device {device} in device_map: {device_map}") new_max_memory[device] = max_memory[device] - + model.tie_weights() device_map = infer_auto_device_map(model, max_memory=max_memory, no_split_module_classes=no_split_modules) model = dispatch_model(model, device_map=device_map) return model diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index 22149ddbc..9f75b1ad9 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -1633,7 +1633,7 @@ def _adjust_immediate_packing_and_saving(self): self.is_immediate_saving = True if self.low_cpu_mem_usage and not self.is_immediate_packing: - logger.warning( + logger.info( "`low_cpu_mem_usage` is only supported when `immediate_packing` is True. " "Setting `low_cpu_mem_usage` to False." ) @@ -2205,6 +2205,7 @@ def try_cache_inter_data_gpucpu(self, block_names, nsamples, layer_names=None, l max_memory=new_max_memory, no_split_module_classes=no_split_modules, ) + self.model.tie_weights() device_map = infer_auto_device_map( self.model, max_memory=new_max_memory, no_split_module_classes=no_split_modules ) diff --git a/auto_round/eval/eval_cli.py b/auto_round/eval/eval_cli.py index 938be51c6..c1ccf2b40 100644 --- a/auto_round/eval/eval_cli.py +++ b/auto_round/eval/eval_cli.py @@ -16,9 +16,11 @@ import os import time +import torch.nn from transformers.utils.versions import require_version from auto_round.utils import ( + dispatch_model_block_wise, get_device_and_parallelism, get_device_str, get_model_dtype, @@ -286,7 +288,11 @@ def eval_task_by_task( if batch_size is None: batch_size = "auto:8" - if not isinstance(model, str): + if not isinstance(model, str) and parallelism: + from accelerate import dispatch_model, infer_auto_device_map + + device_map = infer_auto_device_map(model) + model = dispatch_model(model, device_map=device_map) parallelism = False is_gguf_file = False gguf_file = None @@ -294,6 +300,8 @@ def eval_task_by_task( model, tokenizer, is_gguf_file, gguf_file = _load_gguf_model_if_needed(model, eval_model_dtype) if is_gguf_file: parallelism = False + if isinstance(model, torch.nn.Module): + dispatch_model_block_wise(model, device_map="auto") # As we set visible device before, so explcits eval_model_dtype = get_model_dtype(eval_model_dtype) if mllm: diff --git a/auto_round/eval/evaluation.py b/auto_round/eval/evaluation.py index cff49b371..27a1fe746 100644 --- a/auto_round/eval/evaluation.py +++ b/auto_round/eval/evaluation.py @@ -16,6 +16,7 @@ from typing import Optional, Union from auto_round.logger import logger +from auto_round.utils import dispatch_model_block_wise os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -199,13 +200,13 @@ def load_gguf_model_for_eval(eval_folder, formats, args): return model, tokenizer -def prepare_model_for_eval(model, device_str, eval_model_dtype): +def prepare_model_for_eval(model, device_map, eval_model_dtype): """ Prepare model for evaluation. Args: model: Quantized model - device_str: Device string + device_map: Device string eval_model_dtype: Evaluation data type Returns: @@ -221,9 +222,7 @@ def prepare_model_for_eval(model, device_str, eval_model_dtype): dispatch_model(model, model.hf_device_map) else: - # Single device model - device_str = detect_device(device_str) - model = model.to(device_str) + dispatch_model_block_wise(model, device_map) # Convert dtype if model.dtype != eval_model_dtype and eval_model_dtype != "auto": @@ -427,7 +426,7 @@ def run_model_evaluation(model, tokenizer, autoround, folders, formats, device_s model, tokenizer = load_gguf_model_for_eval(eval_folder, formats, args) else: eval_model_dtype = get_model_dtype(args.eval_model_dtype, "auto") - model = prepare_model_for_eval(model, device_str, eval_model_dtype) + model = prepare_model_for_eval(model, args.device_map, eval_model_dtype) # Evaluate with model instance evaluate_with_model_instance(model, tokenizer, device_str, args) diff --git a/auto_round/utils/device.py b/auto_round/utils/device.py index a16f441bf..abbd332e7 100644 --- a/auto_round/utils/device.py +++ b/auto_round/utils/device.py @@ -24,6 +24,8 @@ import cpuinfo import psutil import torch +from accelerate import dispatch_model, infer_auto_device_map +from accelerate.utils import get_balanced_memory, get_max_memory from auto_round.logger import logger from auto_round.utils.model import check_to_quantized, get_block_names, get_layer_features, get_module @@ -273,7 +275,13 @@ def is_valid_digit(s): def get_device_and_parallelism(device: Union[str, torch.device, int]) -> tuple[str, bool]: if isinstance(device, str): - devices = device.replace(" ", "").split(",") + if device in ["cuda", "xpu", "hpu"]: + device = detect_device(device) + parallelism = False + return device, parallelism + else: + device = re.sub("xpu:|hpu:|cuda:", "", device) + devices = device.replace(" ", "").split(",") elif isinstance(device, int): devices = [str(device)] else: @@ -294,8 +302,14 @@ def get_device_and_parallelism(device: Union[str, torch.device, int]) -> tuple[s return device, parallelism -def set_cuda_visible_devices(device): - devices = device.replace(" ", "").split(",") +def set_cuda_visible_devices(device: str): + if device == "cuda": + devices = ["0"] + elif device == "auto": + return + else: + devices = device.replace(" ", "").split(",") + devices = [device.split(":")[-1] for device in devices] if all(s.isdigit() for s in devices): if "CUDA_VISIBLE_DEVICES" in os.environ: current_visible_devices = os.environ["CUDA_VISIBLE_DEVICES"] @@ -1195,6 +1209,52 @@ def find_optimal_subset(arr, target): return result +def dispatch_model_block_wise(model: torch.nn.Module, device_map: str, max_mem_ratio=0.9): + if hasattr(model, "hf_device_map") and len(model.hf_device_map) > 1: + import accelerate + + accelerate.hooks.remove_hook_from_submodules(model) + no_split_modules = getattr(model, "_no_split_modules", []) + devices = parse_available_devices(device_map) + if len(devices) == 1: + model.to(devices[0]) + return model + + max_memory = get_max_memory() + new_max_memory = {} + if "cpu" not in devices: + devices.append("cpu") + for device in devices: + if ":" in device: + device = int(device.split(":")[-1]) + elif device == "cpu": + device = "cpu" + elif isinstance(device, str): + device = 0 + else: + raise ValueError(f"Unsupported device {device} in device_map: {device_map}") + # Use 90% of the reported max memory to leave headroom for activations, + # temporary tensors, other processes, and allocator fragmentation, reducing + # the chance of runtime OOM while still utilizing most available memory. + new_max_memory[device] = max_memory[device] * max_mem_ratio + new_max_memory = get_balanced_memory( + model, + max_memory=new_max_memory, + no_split_module_classes=no_split_modules, + ) + model.tie_weights() + device_map = infer_auto_device_map(model, max_memory=new_max_memory, no_split_module_classes=no_split_modules) + if len(devices) > 1 and "cpu" in device_map.values(): + logger.warning( + "Some layers are offloaded to cpu, which may severely impact calibration speed." + " Please consider using more cards." + ) + + model = dispatch_model(model, device_map=device_map) + + return model + + def set_avg_auto_device_map(model: torch.nn.Module, device_map): block_name_list = get_block_names(model) device_list = parse_available_devices(device_map)