Skip to content
2 changes: 0 additions & 2 deletions auto_round/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,8 +486,6 @@ def __init__(
logger.info(f"using {self.model.dtype} for quantization tuning")

# Some helpers
if "hpu" in str(self.device):
self.inner_supported_types = tuple(x for x in INNER_SUPPORTED_LAYER_TYPES if x != "FP8Linear")
self.batch_dim = None
self.infer_bs_coeff = 1

Expand Down
1 change: 1 addition & 0 deletions auto_round/modeling/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@
# limitations under the License.

from .fp8_quant import *
from .hpu_patch import *
241 changes: 241 additions & 0 deletions auto_round/modeling/finegrained_fp8_patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
# Copyright (c) 2026 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/finegrained_fp8.py
from transformers.core_model_loading import ConversionOps
from transformers.quantizers.quantizers_utils import should_convert_module
from transformers.utils import is_kernels_available, is_torch_accelerator_available, is_torch_available, logging

if is_torch_available():
import torch
import torch.nn as nn

# import triton
# import triton.language as tl
from torch.nn import functional as F


logger = logging.get_logger(__name__)


_FP8_DTYPE = torch.float8_e4m3fn
_FP8_MIN = torch.finfo(_FP8_DTYPE).min
_FP8_MAX = torch.finfo(_FP8_DTYPE).max


class FP8Linear(nn.Linear):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = False,
dtype=torch.float8_e4m3fn,
block_size: tuple[int, int] | None = None,
activation_scheme="dynamic",
):
super().__init__(in_features, out_features)

# If block size is None, it means that we are doing per-tensor quantization
self.block_size = block_size
self.activation_scheme = activation_scheme

self.weight = torch.nn.Parameter(torch.empty(out_features, in_features, dtype=dtype))

if self.block_size is None:
self.weight_scale_inv = nn.Parameter(torch.tensor(1.0, dtype=torch.float32))
else:
scale_out_features = (out_features + self.block_size[0] - 1) // self.block_size[0]
scale_in_features = (in_features + self.block_size[1] - 1) // self.block_size[1]
self.weight_scale_inv = nn.Parameter(
torch.empty(scale_out_features, scale_in_features, dtype=torch.float32)
)

if self.activation_scheme == "static":
self.activation_scale = nn.Parameter(torch.tensor(1.0, dtype=torch.float32))

if bias:
self.bias = nn.Parameter(torch.empty(self.out_features))
else:
self.register_parameter("bias", None)


def _ceil_div(a, b):
return (a + b - 1) // b


def replace_with_fp8_linear(
model, modules_to_not_convert: list[str] | None = None, quantization_config=None, pre_quantized=False
):
"""
A helper function to replace all `torch.nn.Linear` modules by `FP8Linear` modules.

Parameters:
model (`torch.nn.Module`):
Input model or `torch.nn.Module` as the function is run recursively.
modules_to_not_convert (`list[`str`]`, *optional*, defaults to `None`):
Names of the modules to not convert. In practice we keep the `lm_head`
in full precision for numerical stability reasons.
quantization_config (`FbgemmFp8Config`):
The quantization config object that contains the quantization parameters.
pre_quantized (`book`, defaults to `False`):
Whether the model is pre-quantized or not
"""

if quantization_config.dequantize:
return model

has_been_replaced = False
for module_name, module in model.named_modules():
if not should_convert_module(module_name, modules_to_not_convert):
continue
# we need this to correctly materialize the weights during quantization
module_kwargs = {} if pre_quantized else {"dtype": None}
new_module = None
with torch.device("meta"):
if isinstance(module, nn.Linear):
new_module = FP8Linear(
in_features=module.in_features,
out_features=module.out_features,
bias=module.bias is not None,
activation_scheme=quantization_config.activation_scheme,
block_size=quantization_config.weight_block_size,
**module_kwargs,
)
if new_module is not None:
model.set_submodule(module_name, new_module)
has_been_replaced = True

if not has_been_replaced:
logger.warning(
"You are loading your model using fp8 but no linear modules were found in your model."
" Please double check your model architecture."
)
return model


class Fp8Quantize(ConversionOps):
"""
A quantization operation that creates two tensors, weight and scale out of a weight.
"""

def __init__(self, hf_quantizer):
self.hf_quantizer = hf_quantizer

def convert(self, input_dict: torch.Tensor, **kwargs) -> dict[str, torch.Tensor]:
# Unpack single key/value (value may be wrapped in a list)
target_keys, value = tuple(input_dict.items())[0]
value = value[0]

# Resolve block size (support dict-like or attr-like quant_config)
block_size = None
if self.hf_quantizer.quantization_config is not None:
if isinstance(self.hf_quantizer.quantization_config, dict):
block_size = self.hf_quantizer.quantization_config.get("weight_block_size")
else:
block_size = getattr(self.hf_quantizer.quantization_config, "weight_block_size", None)
if block_size is None:
block_size = (value.shape[-2], value.shape[-1])

block_m, block_n = block_size
rows, cols = value.shape[-2], value.shape[-1]

# Enforce exact tiling like your original
if rows % block_m != 0 or cols % block_n != 0:
raise ValueError(
(
f"Matrix dimensions ({rows}, {cols}) must be divisible by block sizes"
f" ({block_m}, {block_n}). for {target_keys}"
)
)

# Leading dims can be empty (2D) or include num_experts/... (3D+)
leading_shape = value.shape[:-2]
rows_tiles = rows // block_m
cols_tiles = cols // block_n

original_shape = value.shape
value_fp32 = value.to(torch.float32)

# Reshape to (..., rows_tiles, block_m, cols_tiles, block_n)
reshaped = value_fp32.reshape(*leading_shape, rows_tiles, block_m, cols_tiles, block_n)

# Per-tile max-abs over the block dims
# dims: block_m is at -3, block_n is at -1 after the reshape
max_abs = reshaped.abs().amax(dim=(-3, -1))
safe_max_abs = torch.where(max_abs > 0, max_abs, torch.ones_like(max_abs))

# Tile scale (we store inverse scale like your Linear: weight_scale_inv)
scales = _FP8_MAX / safe_max_abs
scales = torch.where(max_abs > 0, scales, torch.ones_like(scales)) # keep zeros stable

# Broadcast scales back over the block dims and quantize
# max_abs/scales shape: (..., rows_tiles, cols_tiles)
scales_broadcast = scales.unsqueeze(-1).unsqueeze(-3) # -> (..., rows_tiles, 1, cols_tiles, 1)
scaled = reshaped * scales_broadcast

quantized = torch.clamp(scaled, min=_FP8_MIN, max=_FP8_MAX).to(_FP8_DTYPE)

quantized = quantized.reshape(original_shape)

inv_scales = (1.0 / scales).to(torch.float32) # shape: (*leading, rows_tiles, cols_tiles)
if target_keys.endswith("weight"):
scale_key = target_keys.rsplit(".", 1)[0] + ".weight_scale_inv"
else:
scale_key = target_keys + "_scale_inv"

# Return both quantized weights and per-tile inverse scales (keeps leading dims, e.g., num_experts)
return {
target_keys: quantized,
scale_key: inv_scales,
}


class Fp8Dequantize(ConversionOps):
"""Inverse operation of :class:`Fp8Quantize`. Takes a pair (weight, scale) and reconstructs the fp32 tensor."""

def __init__(self, hf_quantizer):
self.hf_quantizer = hf_quantizer

def convert(
self,
input_dict: dict[str, torch.Tensor],
full_layer_name: str | None = None,
**kwargs,
) -> dict[str, torch.Tensor]:
if len(input_dict) < 2:
# case where we only got weights, need to check for "weight$"
return {full_layer_name: input_dict["weight$"]}

quantized = input_dict["weight$"][0]
scales = input_dict["weight_scale_inv"][0]

rows, cols = quantized.shape[-2:]
block_size = self.hf_quantizer.quantization_config.weight_block_size
if block_size is None:
block_size = (quantized.shape[-2], quantized.shape[-1])

block_m, block_n = block_size

if rows % block_m != 0 or cols % block_n != 0:
raise ValueError(
f"Matrix dimensions ({rows}, {cols}) must be divisible by block sizes ({block_m}, {block_n})."
)
quantized = quantized.to(scales.dtype)
reshaped = quantized.reshape(-1, rows // block_m, block_m, cols // block_n, block_n)
expanded_scales = scales.reshape(-1, rows // block_m, cols // block_n)
expanded_scales = expanded_scales.unsqueeze(-1).unsqueeze(2)
dequantized = reshaped * expanded_scales

return {
full_layer_name: dequantized.reshape(quantized.shape),
}
35 changes: 35 additions & 0 deletions auto_round/modeling/hpu_patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# # Copyright (C) 2026 Intel Corporation
# # SPDX-License-Identifier: Apache-2.0

from auto_round.logger import logger


def patch_finegrained_fp8():
"""Use importlib to replace transformers.integrations.finegrained_fp8 with auto-round's HPU-compatible version."""
try:
from auto_round.utils import is_hpex_available

if not is_hpex_available():
return # No patching needed on non-HPU devices

import importlib
import sys

# Import auto-round's HPU-compatible finegrained_fp8_patch module
finegrained_fp8_patch = importlib.import_module("auto_round.modeling.finegrained_fp8_patch")

# Replace transformers.integrations.finegrained_fp8 in sys.modules
sys.modules["transformers.integrations.finegrained_fp8"] = finegrained_fp8_patch

logger.info(
"✓ Replaced transformers.integrations.finegrained_fp8 with auto_round.modeling.finegrained_fp8_patch"
)

except Exception as e:
import warnings

logger.warning(f"Failed to patch finegrained_fp8: {e}")


# Apply patch on import if HPU is available
patch_finegrained_fp8()
66 changes: 66 additions & 0 deletions auto_round/utils/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import gc
import os
import re
import sys
from contextlib import ContextDecorator, contextmanager
from functools import lru_cache
from itertools import combinations
Expand Down Expand Up @@ -339,6 +340,71 @@ def __exit__(self, exc_type, exc, exc_tb):
return False


Copy link

Copilot AI Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a class but is named like a function (lower_snake_case). For clarity and consistency, consider either renaming it to a CapWords class name (e.g., FakeCudaForHpu) or converting it into a @contextmanager function named fake_cuda_for_hpu.

Copilot uses AI. Check for mistakes.
class fake_cuda_for_hpu(ContextDecorator):
"""Context manager/decorator to fake CUDA availability for HPU devices."""

def __init__(self):
self._orig_is_available = None

def __enter__(self):
if is_hpex_available():
self._orig_is_available = torch.cuda.is_available
Comment on lines +350 to +351
Copy link

Copilot AI Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This mutates a global function (torch.cuda.is_available) process-wide, which can cause surprising behavior if other threads/tasks call CUDA checks while this context is active. If possible, prefer a safer patching approach (e.g., unittest.mock.patch scoped to the smallest block) and keep the patched window as short as possible.

Copilot uses AI. Check for mistakes.
torch.cuda.is_available = lambda: True
return self

def __exit__(self, exc_type, exc, exc_tb):
if is_hpex_available() and hasattr(self, "_orig_is_available"):
torch.cuda.is_available = self._orig_is_available
del self._orig_is_available
return False


class fake_triton_for_hpu(ContextDecorator):
"""Context manager/decorator to fake triton availability for HPU devices."""

def __init__(self):
self._orig_triton = None
self._orig_triton_language = None
self._had_triton = False
self._had_triton_language = False

def __enter__(self):
if is_hpex_available():
# Save original state
self._had_triton = "triton" in sys.modules
self._had_triton_language = "triton.language" in sys.modules

if self._had_triton:
self._orig_triton = sys.modules["triton"]
if self._had_triton_language:
self._orig_triton_language = sys.modules["triton.language"]

# Create and inject fake triton module
class FakeTriton:
def __getattr__(self, name):
return None

fake_triton = FakeTriton()
fake_triton.jit = lambda func: func # Make triton.jit a no-op decorator
sys.modules["triton"] = fake_triton
sys.modules["triton.language"] = FakeTriton()
return self

def __exit__(self, exc_type, exc, exc_tb):
if is_hpex_available():
# Restore original state
if self._had_triton and self._orig_triton is not None:
sys.modules["triton"] = self._orig_triton
elif not self._had_triton and "triton" in sys.modules:
del sys.modules["triton"]

if self._had_triton_language and self._orig_triton_language is not None:
sys.modules["triton.language"] = self._orig_triton_language
elif not self._had_triton_language and "triton.language" in sys.modules:
del sys.modules["triton.language"]
return False


def get_packing_device(device: str | torch.device | None = "auto") -> torch.device:
"""
Selects the packing device.
Expand Down
Loading