-
Notifications
You must be signed in to change notification settings - Fork 78
Support load FP8 model on HPU #1449
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
b5f7489
d8a9849
6332241
b6b6222
0918cb4
251be94
a892fc6
73f93b2
581c993
84c1f1e
906353d
5650385
ca24a68
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,3 +13,4 @@ | |
| # limitations under the License. | ||
|
|
||
| from .fp8_quant import * | ||
| from .hpu_patch import * | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,241 @@ | ||
| # Copyright (c) 2026 Intel Corporation | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/finegrained_fp8.py | ||
| from transformers.core_model_loading import ConversionOps | ||
| from transformers.quantizers.quantizers_utils import should_convert_module | ||
| from transformers.utils import is_kernels_available, is_torch_accelerator_available, is_torch_available, logging | ||
|
|
||
| if is_torch_available(): | ||
| import torch | ||
| import torch.nn as nn | ||
|
|
||
| # import triton | ||
| # import triton.language as tl | ||
| from torch.nn import functional as F | ||
|
|
||
|
|
||
| logger = logging.get_logger(__name__) | ||
|
|
||
|
|
||
| _FP8_DTYPE = torch.float8_e4m3fn | ||
| _FP8_MIN = torch.finfo(_FP8_DTYPE).min | ||
| _FP8_MAX = torch.finfo(_FP8_DTYPE).max | ||
|
|
||
|
|
||
| class FP8Linear(nn.Linear): | ||
| def __init__( | ||
| self, | ||
| in_features: int, | ||
| out_features: int, | ||
| bias: bool = False, | ||
| dtype=torch.float8_e4m3fn, | ||
| block_size: tuple[int, int] | None = None, | ||
| activation_scheme="dynamic", | ||
| ): | ||
| super().__init__(in_features, out_features) | ||
|
|
||
| # If block size is None, it means that we are doing per-tensor quantization | ||
| self.block_size = block_size | ||
| self.activation_scheme = activation_scheme | ||
|
|
||
| self.weight = torch.nn.Parameter(torch.empty(out_features, in_features, dtype=dtype)) | ||
|
|
||
| if self.block_size is None: | ||
| self.weight_scale_inv = nn.Parameter(torch.tensor(1.0, dtype=torch.float32)) | ||
| else: | ||
| scale_out_features = (out_features + self.block_size[0] - 1) // self.block_size[0] | ||
| scale_in_features = (in_features + self.block_size[1] - 1) // self.block_size[1] | ||
| self.weight_scale_inv = nn.Parameter( | ||
| torch.empty(scale_out_features, scale_in_features, dtype=torch.float32) | ||
| ) | ||
|
|
||
| if self.activation_scheme == "static": | ||
| self.activation_scale = nn.Parameter(torch.tensor(1.0, dtype=torch.float32)) | ||
|
|
||
| if bias: | ||
| self.bias = nn.Parameter(torch.empty(self.out_features)) | ||
| else: | ||
| self.register_parameter("bias", None) | ||
|
|
||
|
|
||
| def _ceil_div(a, b): | ||
| return (a + b - 1) // b | ||
|
|
||
|
|
||
| def replace_with_fp8_linear( | ||
| model, modules_to_not_convert: list[str] | None = None, quantization_config=None, pre_quantized=False | ||
| ): | ||
| """ | ||
| A helper function to replace all `torch.nn.Linear` modules by `FP8Linear` modules. | ||
|
|
||
| Parameters: | ||
| model (`torch.nn.Module`): | ||
| Input model or `torch.nn.Module` as the function is run recursively. | ||
| modules_to_not_convert (`list[`str`]`, *optional*, defaults to `None`): | ||
| Names of the modules to not convert. In practice we keep the `lm_head` | ||
| in full precision for numerical stability reasons. | ||
| quantization_config (`FbgemmFp8Config`): | ||
| The quantization config object that contains the quantization parameters. | ||
| pre_quantized (`book`, defaults to `False`): | ||
| Whether the model is pre-quantized or not | ||
| """ | ||
|
|
||
| if quantization_config.dequantize: | ||
| return model | ||
|
|
||
| has_been_replaced = False | ||
| for module_name, module in model.named_modules(): | ||
| if not should_convert_module(module_name, modules_to_not_convert): | ||
| continue | ||
| # we need this to correctly materialize the weights during quantization | ||
| module_kwargs = {} if pre_quantized else {"dtype": None} | ||
| new_module = None | ||
| with torch.device("meta"): | ||
| if isinstance(module, nn.Linear): | ||
| new_module = FP8Linear( | ||
| in_features=module.in_features, | ||
| out_features=module.out_features, | ||
| bias=module.bias is not None, | ||
| activation_scheme=quantization_config.activation_scheme, | ||
| block_size=quantization_config.weight_block_size, | ||
| **module_kwargs, | ||
| ) | ||
| if new_module is not None: | ||
| model.set_submodule(module_name, new_module) | ||
| has_been_replaced = True | ||
|
|
||
| if not has_been_replaced: | ||
| logger.warning( | ||
| "You are loading your model using fp8 but no linear modules were found in your model." | ||
| " Please double check your model architecture." | ||
| ) | ||
| return model | ||
|
|
||
|
|
||
| class Fp8Quantize(ConversionOps): | ||
| """ | ||
| A quantization operation that creates two tensors, weight and scale out of a weight. | ||
| """ | ||
|
|
||
| def __init__(self, hf_quantizer): | ||
| self.hf_quantizer = hf_quantizer | ||
|
|
||
| def convert(self, input_dict: torch.Tensor, **kwargs) -> dict[str, torch.Tensor]: | ||
| # Unpack single key/value (value may be wrapped in a list) | ||
| target_keys, value = tuple(input_dict.items())[0] | ||
| value = value[0] | ||
|
|
||
| # Resolve block size (support dict-like or attr-like quant_config) | ||
| block_size = None | ||
| if self.hf_quantizer.quantization_config is not None: | ||
| if isinstance(self.hf_quantizer.quantization_config, dict): | ||
| block_size = self.hf_quantizer.quantization_config.get("weight_block_size") | ||
| else: | ||
| block_size = getattr(self.hf_quantizer.quantization_config, "weight_block_size", None) | ||
| if block_size is None: | ||
| block_size = (value.shape[-2], value.shape[-1]) | ||
|
|
||
| block_m, block_n = block_size | ||
| rows, cols = value.shape[-2], value.shape[-1] | ||
|
|
||
| # Enforce exact tiling like your original | ||
| if rows % block_m != 0 or cols % block_n != 0: | ||
| raise ValueError( | ||
| ( | ||
| f"Matrix dimensions ({rows}, {cols}) must be divisible by block sizes" | ||
| f" ({block_m}, {block_n}). for {target_keys}" | ||
| ) | ||
| ) | ||
|
|
||
| # Leading dims can be empty (2D) or include num_experts/... (3D+) | ||
| leading_shape = value.shape[:-2] | ||
| rows_tiles = rows // block_m | ||
| cols_tiles = cols // block_n | ||
|
|
||
| original_shape = value.shape | ||
| value_fp32 = value.to(torch.float32) | ||
|
|
||
| # Reshape to (..., rows_tiles, block_m, cols_tiles, block_n) | ||
| reshaped = value_fp32.reshape(*leading_shape, rows_tiles, block_m, cols_tiles, block_n) | ||
|
|
||
| # Per-tile max-abs over the block dims | ||
| # dims: block_m is at -3, block_n is at -1 after the reshape | ||
| max_abs = reshaped.abs().amax(dim=(-3, -1)) | ||
| safe_max_abs = torch.where(max_abs > 0, max_abs, torch.ones_like(max_abs)) | ||
|
|
||
| # Tile scale (we store inverse scale like your Linear: weight_scale_inv) | ||
| scales = _FP8_MAX / safe_max_abs | ||
| scales = torch.where(max_abs > 0, scales, torch.ones_like(scales)) # keep zeros stable | ||
|
|
||
| # Broadcast scales back over the block dims and quantize | ||
| # max_abs/scales shape: (..., rows_tiles, cols_tiles) | ||
| scales_broadcast = scales.unsqueeze(-1).unsqueeze(-3) # -> (..., rows_tiles, 1, cols_tiles, 1) | ||
| scaled = reshaped * scales_broadcast | ||
|
|
||
| quantized = torch.clamp(scaled, min=_FP8_MIN, max=_FP8_MAX).to(_FP8_DTYPE) | ||
|
|
||
| quantized = quantized.reshape(original_shape) | ||
|
|
||
| inv_scales = (1.0 / scales).to(torch.float32) # shape: (*leading, rows_tiles, cols_tiles) | ||
| if target_keys.endswith("weight"): | ||
| scale_key = target_keys.rsplit(".", 1)[0] + ".weight_scale_inv" | ||
| else: | ||
| scale_key = target_keys + "_scale_inv" | ||
|
|
||
| # Return both quantized weights and per-tile inverse scales (keeps leading dims, e.g., num_experts) | ||
| return { | ||
| target_keys: quantized, | ||
| scale_key: inv_scales, | ||
| } | ||
|
|
||
|
|
||
| class Fp8Dequantize(ConversionOps): | ||
| """Inverse operation of :class:`Fp8Quantize`. Takes a pair (weight, scale) and reconstructs the fp32 tensor.""" | ||
|
|
||
| def __init__(self, hf_quantizer): | ||
| self.hf_quantizer = hf_quantizer | ||
|
|
||
| def convert( | ||
| self, | ||
| input_dict: dict[str, torch.Tensor], | ||
| full_layer_name: str | None = None, | ||
| **kwargs, | ||
| ) -> dict[str, torch.Tensor]: | ||
| if len(input_dict) < 2: | ||
| # case where we only got weights, need to check for "weight$" | ||
| return {full_layer_name: input_dict["weight$"]} | ||
|
|
||
| quantized = input_dict["weight$"][0] | ||
| scales = input_dict["weight_scale_inv"][0] | ||
|
|
||
| rows, cols = quantized.shape[-2:] | ||
| block_size = self.hf_quantizer.quantization_config.weight_block_size | ||
| if block_size is None: | ||
| block_size = (quantized.shape[-2], quantized.shape[-1]) | ||
|
|
||
| block_m, block_n = block_size | ||
|
|
||
| if rows % block_m != 0 or cols % block_n != 0: | ||
| raise ValueError( | ||
| f"Matrix dimensions ({rows}, {cols}) must be divisible by block sizes ({block_m}, {block_n})." | ||
| ) | ||
| quantized = quantized.to(scales.dtype) | ||
| reshaped = quantized.reshape(-1, rows // block_m, block_m, cols // block_n, block_n) | ||
| expanded_scales = scales.reshape(-1, rows // block_m, cols // block_n) | ||
| expanded_scales = expanded_scales.unsqueeze(-1).unsqueeze(2) | ||
| dequantized = reshaped * expanded_scales | ||
|
|
||
| return { | ||
| full_layer_name: dequantized.reshape(quantized.shape), | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,35 @@ | ||
| # # Copyright (C) 2026 Intel Corporation | ||
| # # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| from auto_round.logger import logger | ||
|
|
||
|
|
||
| def patch_finegrained_fp8(): | ||
| """Use importlib to replace transformers.integrations.finegrained_fp8 with auto-round's HPU-compatible version.""" | ||
| try: | ||
| from auto_round.utils import is_hpex_available | ||
|
|
||
| if not is_hpex_available(): | ||
| return # No patching needed on non-HPU devices | ||
|
|
||
| import importlib | ||
| import sys | ||
|
|
||
| # Import auto-round's HPU-compatible finegrained_fp8_patch module | ||
| finegrained_fp8_patch = importlib.import_module("auto_round.modeling.finegrained_fp8_patch") | ||
|
|
||
| # Replace transformers.integrations.finegrained_fp8 in sys.modules | ||
| sys.modules["transformers.integrations.finegrained_fp8"] = finegrained_fp8_patch | ||
|
|
||
| logger.info( | ||
| "✓ Replaced transformers.integrations.finegrained_fp8 with auto_round.modeling.finegrained_fp8_patch" | ||
| ) | ||
|
|
||
| except Exception as e: | ||
| import warnings | ||
|
|
||
| logger.warning(f"Failed to patch finegrained_fp8: {e}") | ||
|
|
||
|
|
||
| # Apply patch on import if HPU is available | ||
| patch_finegrained_fp8() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,6 +15,7 @@ | |
| import gc | ||
| import os | ||
| import re | ||
| import sys | ||
| from contextlib import ContextDecorator, contextmanager | ||
| from functools import lru_cache | ||
| from itertools import combinations | ||
|
|
@@ -339,6 +340,71 @@ def __exit__(self, exc_type, exc, exc_tb): | |
| return False | ||
|
|
||
|
|
||
| class fake_cuda_for_hpu(ContextDecorator): | ||
| """Context manager/decorator to fake CUDA availability for HPU devices.""" | ||
|
|
||
| def __init__(self): | ||
| self._orig_is_available = None | ||
|
|
||
| def __enter__(self): | ||
| if is_hpex_available(): | ||
| self._orig_is_available = torch.cuda.is_available | ||
|
Comment on lines
+350
to
+351
|
||
| torch.cuda.is_available = lambda: True | ||
| return self | ||
|
|
||
| def __exit__(self, exc_type, exc, exc_tb): | ||
| if is_hpex_available() and hasattr(self, "_orig_is_available"): | ||
| torch.cuda.is_available = self._orig_is_available | ||
| del self._orig_is_available | ||
| return False | ||
|
|
||
|
|
||
| class fake_triton_for_hpu(ContextDecorator): | ||
| """Context manager/decorator to fake triton availability for HPU devices.""" | ||
|
|
||
| def __init__(self): | ||
| self._orig_triton = None | ||
| self._orig_triton_language = None | ||
| self._had_triton = False | ||
| self._had_triton_language = False | ||
|
|
||
| def __enter__(self): | ||
| if is_hpex_available(): | ||
| # Save original state | ||
| self._had_triton = "triton" in sys.modules | ||
| self._had_triton_language = "triton.language" in sys.modules | ||
|
|
||
| if self._had_triton: | ||
| self._orig_triton = sys.modules["triton"] | ||
| if self._had_triton_language: | ||
| self._orig_triton_language = sys.modules["triton.language"] | ||
|
|
||
| # Create and inject fake triton module | ||
| class FakeTriton: | ||
| def __getattr__(self, name): | ||
| return None | ||
|
|
||
| fake_triton = FakeTriton() | ||
| fake_triton.jit = lambda func: func # Make triton.jit a no-op decorator | ||
| sys.modules["triton"] = fake_triton | ||
| sys.modules["triton.language"] = FakeTriton() | ||
| return self | ||
|
|
||
| def __exit__(self, exc_type, exc, exc_tb): | ||
| if is_hpex_available(): | ||
| # Restore original state | ||
| if self._had_triton and self._orig_triton is not None: | ||
| sys.modules["triton"] = self._orig_triton | ||
| elif not self._had_triton and "triton" in sys.modules: | ||
| del sys.modules["triton"] | ||
|
|
||
| if self._had_triton_language and self._orig_triton_language is not None: | ||
| sys.modules["triton.language"] = self._orig_triton_language | ||
| elif not self._had_triton_language and "triton.language" in sys.modules: | ||
| del sys.modules["triton.language"] | ||
| return False | ||
|
|
||
|
|
||
| def get_packing_device(device: str | torch.device | None = "auto") -> torch.device: | ||
| """ | ||
| Selects the packing device. | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a class but is named like a function (lower_snake_case). For clarity and consistency, consider either renaming it to a CapWords class name (e.g.,
FakeCudaForHpu) or converting it into a@contextmanagerfunction namedfake_cuda_for_hpu.