diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index 5877bd389..48aaf1bd6 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -486,8 +486,6 @@ def __init__( logger.info(f"using {self.model.dtype} for quantization tuning") # Some helpers - if "hpu" in str(self.device): - self.inner_supported_types = tuple(x for x in INNER_SUPPORTED_LAYER_TYPES if x != "FP8Linear") self.batch_dim = None self.infer_bs_coeff = 1 diff --git a/auto_round/modeling/__init__.py b/auto_round/modeling/__init__.py index d1bc25269..144e78ca3 100644 --- a/auto_round/modeling/__init__.py +++ b/auto_round/modeling/__init__.py @@ -13,3 +13,4 @@ # limitations under the License. from .fp8_quant import * +from .hpu_patch import * diff --git a/auto_round/modeling/finegrained_fp8_patch.py b/auto_round/modeling/finegrained_fp8_patch.py new file mode 100644 index 000000000..06d85835b --- /dev/null +++ b/auto_round/modeling/finegrained_fp8_patch.py @@ -0,0 +1,241 @@ +# Copyright (c) 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/finegrained_fp8.py +from transformers.core_model_loading import ConversionOps +from transformers.quantizers.quantizers_utils import should_convert_module +from transformers.utils import is_kernels_available, is_torch_accelerator_available, is_torch_available, logging + +if is_torch_available(): + import torch + import torch.nn as nn + + # import triton + # import triton.language as tl + from torch.nn import functional as F + + +logger = logging.get_logger(__name__) + + +_FP8_DTYPE = torch.float8_e4m3fn +_FP8_MIN = torch.finfo(_FP8_DTYPE).min +_FP8_MAX = torch.finfo(_FP8_DTYPE).max + + +class FP8Linear(nn.Linear): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = False, + dtype=torch.float8_e4m3fn, + block_size: tuple[int, int] | None = None, + activation_scheme="dynamic", + ): + super().__init__(in_features, out_features) + + # If block size is None, it means that we are doing per-tensor quantization + self.block_size = block_size + self.activation_scheme = activation_scheme + + self.weight = torch.nn.Parameter(torch.empty(out_features, in_features, dtype=dtype)) + + if self.block_size is None: + self.weight_scale_inv = nn.Parameter(torch.tensor(1.0, dtype=torch.float32)) + else: + scale_out_features = (out_features + self.block_size[0] - 1) // self.block_size[0] + scale_in_features = (in_features + self.block_size[1] - 1) // self.block_size[1] + self.weight_scale_inv = nn.Parameter( + torch.empty(scale_out_features, scale_in_features, dtype=torch.float32) + ) + + if self.activation_scheme == "static": + self.activation_scale = nn.Parameter(torch.tensor(1.0, dtype=torch.float32)) + + if bias: + self.bias = nn.Parameter(torch.empty(self.out_features)) + else: + self.register_parameter("bias", None) + + +def _ceil_div(a, b): + return (a + b - 1) // b + + +def replace_with_fp8_linear( + model, modules_to_not_convert: list[str] | None = None, quantization_config=None, pre_quantized=False +): + """ + A helper function to replace all `torch.nn.Linear` modules by `FP8Linear` modules. + + Parameters: + model (`torch.nn.Module`): + Input model or `torch.nn.Module` as the function is run recursively. + modules_to_not_convert (`list[`str`]`, *optional*, defaults to `None`): + Names of the modules to not convert. In practice we keep the `lm_head` + in full precision for numerical stability reasons. + quantization_config (`FbgemmFp8Config`): + The quantization config object that contains the quantization parameters. + pre_quantized (`book`, defaults to `False`): + Whether the model is pre-quantized or not + """ + + if quantization_config.dequantize: + return model + + has_been_replaced = False + for module_name, module in model.named_modules(): + if not should_convert_module(module_name, modules_to_not_convert): + continue + # we need this to correctly materialize the weights during quantization + module_kwargs = {} if pre_quantized else {"dtype": None} + new_module = None + with torch.device("meta"): + if isinstance(module, nn.Linear): + new_module = FP8Linear( + in_features=module.in_features, + out_features=module.out_features, + bias=module.bias is not None, + activation_scheme=quantization_config.activation_scheme, + block_size=quantization_config.weight_block_size, + **module_kwargs, + ) + if new_module is not None: + model.set_submodule(module_name, new_module) + has_been_replaced = True + + if not has_been_replaced: + logger.warning( + "You are loading your model using fp8 but no linear modules were found in your model." + " Please double check your model architecture." + ) + return model + + +class Fp8Quantize(ConversionOps): + """ + A quantization operation that creates two tensors, weight and scale out of a weight. + """ + + def __init__(self, hf_quantizer): + self.hf_quantizer = hf_quantizer + + def convert(self, input_dict: torch.Tensor, **kwargs) -> dict[str, torch.Tensor]: + # Unpack single key/value (value may be wrapped in a list) + target_keys, value = tuple(input_dict.items())[0] + value = value[0] + + # Resolve block size (support dict-like or attr-like quant_config) + block_size = None + if self.hf_quantizer.quantization_config is not None: + if isinstance(self.hf_quantizer.quantization_config, dict): + block_size = self.hf_quantizer.quantization_config.get("weight_block_size") + else: + block_size = getattr(self.hf_quantizer.quantization_config, "weight_block_size", None) + if block_size is None: + block_size = (value.shape[-2], value.shape[-1]) + + block_m, block_n = block_size + rows, cols = value.shape[-2], value.shape[-1] + + # Enforce exact tiling like your original + if rows % block_m != 0 or cols % block_n != 0: + raise ValueError( + ( + f"Matrix dimensions ({rows}, {cols}) must be divisible by block sizes" + f" ({block_m}, {block_n}). for {target_keys}" + ) + ) + + # Leading dims can be empty (2D) or include num_experts/... (3D+) + leading_shape = value.shape[:-2] + rows_tiles = rows // block_m + cols_tiles = cols // block_n + + original_shape = value.shape + value_fp32 = value.to(torch.float32) + + # Reshape to (..., rows_tiles, block_m, cols_tiles, block_n) + reshaped = value_fp32.reshape(*leading_shape, rows_tiles, block_m, cols_tiles, block_n) + + # Per-tile max-abs over the block dims + # dims: block_m is at -3, block_n is at -1 after the reshape + max_abs = reshaped.abs().amax(dim=(-3, -1)) + safe_max_abs = torch.where(max_abs > 0, max_abs, torch.ones_like(max_abs)) + + # Tile scale (we store inverse scale like your Linear: weight_scale_inv) + scales = _FP8_MAX / safe_max_abs + scales = torch.where(max_abs > 0, scales, torch.ones_like(scales)) # keep zeros stable + + # Broadcast scales back over the block dims and quantize + # max_abs/scales shape: (..., rows_tiles, cols_tiles) + scales_broadcast = scales.unsqueeze(-1).unsqueeze(-3) # -> (..., rows_tiles, 1, cols_tiles, 1) + scaled = reshaped * scales_broadcast + + quantized = torch.clamp(scaled, min=_FP8_MIN, max=_FP8_MAX).to(_FP8_DTYPE) + + quantized = quantized.reshape(original_shape) + + inv_scales = (1.0 / scales).to(torch.float32) # shape: (*leading, rows_tiles, cols_tiles) + if target_keys.endswith("weight"): + scale_key = target_keys.rsplit(".", 1)[0] + ".weight_scale_inv" + else: + scale_key = target_keys + "_scale_inv" + + # Return both quantized weights and per-tile inverse scales (keeps leading dims, e.g., num_experts) + return { + target_keys: quantized, + scale_key: inv_scales, + } + + +class Fp8Dequantize(ConversionOps): + """Inverse operation of :class:`Fp8Quantize`. Takes a pair (weight, scale) and reconstructs the fp32 tensor.""" + + def __init__(self, hf_quantizer): + self.hf_quantizer = hf_quantizer + + def convert( + self, + input_dict: dict[str, torch.Tensor], + full_layer_name: str | None = None, + **kwargs, + ) -> dict[str, torch.Tensor]: + if len(input_dict) < 2: + # case where we only got weights, need to check for "weight$" + return {full_layer_name: input_dict["weight$"]} + + quantized = input_dict["weight$"][0] + scales = input_dict["weight_scale_inv"][0] + + rows, cols = quantized.shape[-2:] + block_size = self.hf_quantizer.quantization_config.weight_block_size + if block_size is None: + block_size = (quantized.shape[-2], quantized.shape[-1]) + + block_m, block_n = block_size + + if rows % block_m != 0 or cols % block_n != 0: + raise ValueError( + f"Matrix dimensions ({rows}, {cols}) must be divisible by block sizes ({block_m}, {block_n})." + ) + quantized = quantized.to(scales.dtype) + reshaped = quantized.reshape(-1, rows // block_m, block_m, cols // block_n, block_n) + expanded_scales = scales.reshape(-1, rows // block_m, cols // block_n) + expanded_scales = expanded_scales.unsqueeze(-1).unsqueeze(2) + dequantized = reshaped * expanded_scales + + return { + full_layer_name: dequantized.reshape(quantized.shape), + } diff --git a/auto_round/modeling/hpu_patch.py b/auto_round/modeling/hpu_patch.py new file mode 100644 index 000000000..521caec4b --- /dev/null +++ b/auto_round/modeling/hpu_patch.py @@ -0,0 +1,35 @@ +# # Copyright (C) 2026 Intel Corporation +# # SPDX-License-Identifier: Apache-2.0 + +from auto_round.logger import logger + + +def patch_finegrained_fp8(): + """Use importlib to replace transformers.integrations.finegrained_fp8 with auto-round's HPU-compatible version.""" + try: + from auto_round.utils import is_hpex_available + + if not is_hpex_available(): + return # No patching needed on non-HPU devices + + import importlib + import sys + + # Import auto-round's HPU-compatible finegrained_fp8_patch module + finegrained_fp8_patch = importlib.import_module("auto_round.modeling.finegrained_fp8_patch") + + # Replace transformers.integrations.finegrained_fp8 in sys.modules + sys.modules["transformers.integrations.finegrained_fp8"] = finegrained_fp8_patch + + logger.info( + "✓ Replaced transformers.integrations.finegrained_fp8 with auto_round.modeling.finegrained_fp8_patch" + ) + + except Exception as e: + import warnings + + logger.warning(f"Failed to patch finegrained_fp8: {e}") + + +# Apply patch on import if HPU is available +patch_finegrained_fp8() diff --git a/auto_round/utils/device.py b/auto_round/utils/device.py index a16f441bf..e9a11923d 100644 --- a/auto_round/utils/device.py +++ b/auto_round/utils/device.py @@ -15,6 +15,7 @@ import gc import os import re +import sys from contextlib import ContextDecorator, contextmanager from functools import lru_cache from itertools import combinations @@ -339,6 +340,71 @@ def __exit__(self, exc_type, exc, exc_tb): return False +class fake_cuda_for_hpu(ContextDecorator): + """Context manager/decorator to fake CUDA availability for HPU devices.""" + + def __init__(self): + self._orig_is_available = None + + def __enter__(self): + if is_hpex_available(): + self._orig_is_available = torch.cuda.is_available + torch.cuda.is_available = lambda: True + return self + + def __exit__(self, exc_type, exc, exc_tb): + if is_hpex_available() and hasattr(self, "_orig_is_available"): + torch.cuda.is_available = self._orig_is_available + del self._orig_is_available + return False + + +class fake_triton_for_hpu(ContextDecorator): + """Context manager/decorator to fake triton availability for HPU devices.""" + + def __init__(self): + self._orig_triton = None + self._orig_triton_language = None + self._had_triton = False + self._had_triton_language = False + + def __enter__(self): + if is_hpex_available(): + # Save original state + self._had_triton = "triton" in sys.modules + self._had_triton_language = "triton.language" in sys.modules + + if self._had_triton: + self._orig_triton = sys.modules["triton"] + if self._had_triton_language: + self._orig_triton_language = sys.modules["triton.language"] + + # Create and inject fake triton module + class FakeTriton: + def __getattr__(self, name): + return None + + fake_triton = FakeTriton() + fake_triton.jit = lambda func: func # Make triton.jit a no-op decorator + sys.modules["triton"] = fake_triton + sys.modules["triton.language"] = FakeTriton() + return self + + def __exit__(self, exc_type, exc, exc_tb): + if is_hpex_available(): + # Restore original state + if self._had_triton and self._orig_triton is not None: + sys.modules["triton"] = self._orig_triton + elif not self._had_triton and "triton" in sys.modules: + del sys.modules["triton"] + + if self._had_triton_language and self._orig_triton_language is not None: + sys.modules["triton.language"] = self._orig_triton_language + elif not self._had_triton_language and "triton.language" in sys.modules: + del sys.modules["triton.language"] + return False + + def get_packing_device(device: str | torch.device | None = "auto") -> torch.device: """ Selects the packing device. diff --git a/auto_round/utils/model.py b/auto_round/utils/model.py index 2536c0a93..1442512c9 100644 --- a/auto_round/utils/model.py +++ b/auto_round/utils/model.py @@ -272,7 +272,10 @@ def llm_load_model( from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer from auto_round.utils.device import ( _use_hpu_compile_mode, + fake_cuda_for_hpu, + fake_triton_for_hpu, get_device_and_parallelism, + is_hpex_available, override_cuda_device_capability, ) @@ -289,13 +292,15 @@ def llm_load_model( if "deepseek" in pretrained_model_name_or_path.lower() and trust_remote_code: logger.warning("trust_remote_code is enabled by default, please ensure its correctness.") - if _use_hpu_compile_mode(): - model = model_cls.from_pretrained( - pretrained_model_name_or_path, - torch_dtype=torch_dtype, - trust_remote_code=trust_remote_code, - device_map="auto" if use_auto_mapping else None, - ) + if is_hpex_available(): + # For loading FP8 model on HPU + with fake_cuda_for_hpu(), fake_triton_for_hpu(), override_cuda_device_capability(): + model = model_cls.from_pretrained( + pretrained_model_name_or_path, + torch_dtype=torch_dtype, + trust_remote_code=trust_remote_code, + device_map="auto" if use_auto_mapping else None, + ) else: try: model = model_cls.from_pretrained( diff --git a/test/test_hpu/test_quant_fp8.py b/test/test_hpu/test_quant_fp8.py new file mode 100644 index 000000000..eaaa61741 --- /dev/null +++ b/test/test_hpu/test_quant_fp8.py @@ -0,0 +1,35 @@ +import os +import shutil + +import pytest +import torch + +from auto_round import AutoRound + +MODEL_LIST = ( + "Qwen/Qwen3-0.6B-FP8", + "Qwen/Qwen3-0.6B", +) + + +class TestAutoRound: + save_dir = "./saved" + + def check_nan_inf_in_tensor(self, tensor, name=""): + return torch.isnan(tensor).any() or torch.isinf(tensor).any() + + @pytest.mark.parametrize("model_name", MODEL_LIST) + def test_small_model_rtn_generation(self, model_name): + ar = AutoRound(model_name, iters=0, scheme="FP8_STATIC", nsamples=16) + model, folder = ar.quantize_and_save(output_dir=self.save_dir, format="llm_compressor") + # all linears except lm_head should be quantized to FP8 + fp8_linear_count = 0 + for name, module in model.named_modules(): + if "FP8QLinear" in type(module).__name__: + assert module.weight.dtype == torch.float8_e4m3fn, f"{name} is not in FP8" + assert not self.check_nan_inf_in_tensor( + module.weight.to(torch.float32) + ), f"{name} has NaN or Inf in weights" + fp8_linear_count += 1 + assert fp8_linear_count > 0, "No FP8 linear layer found in the quantized model" + shutil.rmtree(self.save_dir, ignore_errors=True)