diff --git a/auto_round/autoround.py b/auto_round/autoround.py index 33ca1ccd7..069528e5e 100644 --- a/auto_round/autoround.py +++ b/auto_round/autoround.py @@ -87,6 +87,7 @@ def __new__( enable_alg_ext: bool = None, disable_opt_rtn: bool | None = None, low_cpu_mem_usage: bool = True, + transform_config: dict = {}, **kwargs, ) -> BaseCompressor: """Initialize AutoRound with quantization and tuning configuration. @@ -114,6 +115,7 @@ def __new__( disable_opt_rtn (bool, optional): Disable RTN-mode optimization (iters=0) for fast quatnziation with lower accuracy. Defaults to None. low_cpu_mem_usage (bool, optional): Lower CPU memory mode. Defaults to False. + transform_config (dict, optional): transform matrix config like hadamard, like {"transform_class": "hadamard"}. bits (int, optional): Weight quantization bits. Defaults to 4. group_size (int, optional): Weight quantization group size. Defaults to 128. @@ -204,6 +206,7 @@ def __new__( enable_torch_compile=enable_torch_compile, seed=seed, low_cpu_mem_usage=low_cpu_mem_usage, + transform_config=transform_config, **kwargs, ) return ar diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index b1f46cc1e..c49579cdf 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -142,6 +142,7 @@ "super_bits", "super_group_size", "to_quant_block_names", + "transform_config", ) @@ -192,6 +193,7 @@ def __init__( disable_opt_rtn: bool | None = None, seed: int = 42, low_cpu_mem_usage: bool = True, + transform_config: dict = {}, **kwargs, ): """Initialize AutoRound with quantization and tuning configuration. @@ -213,6 +215,7 @@ def __init__( minmax_lr (float, optional): Learning rate for min-max tuning; defaults to `lr`. low_gpu_mem_usage (bool, optional): Lower GPU memory mode. Defaults to False. low_cpu_mem_usage (bool, optional): Lower CPU memory mode. Defaults to False. + transform_config (dict, optional): transform matirx config like hadamard, like {"transform_class": "hadamard"}. iters (int, optional): Optimization iterations. Defaults to 200. seqlen (int, optional): Calibration sequence length. Defaults to 2048. nsamples (int, optional): Number of calibration samples. Defaults to 128. @@ -483,6 +486,8 @@ def __init__( except (ImportError, ModuleNotFoundError): logger.error("algorithm extension import error, fallback to default mode") + self.transform_config = transform_config + def _gen_auto_scheme( self, model: torch.nn.Module, scheme: AutoScheme, dataset: str, device_map: Union[str, int, dict, torch.device] ) -> dict[str, dict]: @@ -1147,6 +1152,7 @@ def _quantize_layer_via_rtn(self, name: str, dtype: torch.dtype = None, to_cpu=T enable_round_tuning=False, enable_torch_compile=self.enable_torch_compile, disable_opt_rtn=disable_opt_rtn, + transform_config=self.transform_config, ) m = m.unwrapper({}) except torch.OutOfMemoryError: @@ -1162,6 +1168,7 @@ def _quantize_layer_via_rtn(self, name: str, dtype: torch.dtype = None, to_cpu=T enable_norm_bias_tuning=False, enable_round_tuning=False, enable_torch_compile=self.enable_torch_compile, + transform_config=self.transform_config, ) m = m.unwrapper({}) except Exception as e: @@ -2817,7 +2824,9 @@ def _quantize_block( self.enable_norm_bias_tuning, enable_torch_compile=self.enable_torch_compile, device=device, + transform_config=self.transform_config, ) + if is_nv_fp(self.data_type): # enable qkv and moe structure global_scale fuse from auto_round.data_type.utils import update_fused_layer_global_scales diff --git a/auto_round/experimental/qmodules/mx.py b/auto_round/experimental/qmodules/mx.py index f06076173..d4a309eb1 100644 --- a/auto_round/experimental/qmodules/mx.py +++ b/auto_round/experimental/qmodules/mx.py @@ -94,6 +94,19 @@ def __init__( ) self.register_buffer("weight_scale", init_weight_scale) + # Rotation matrices buffers + self.enable_transform = False + if self.config.transform_config: + self.enable_transform = True + self.register_buffer( + "forward_hadamard_matrix", + torch.empty( + self.group_size, + self.group_size, + dtype=dtype, + ), + ) + def initialize_weights(self, weight: Optional[torch.Tensor]) -> torch.Tensor: """ Initialize weights. This method should be overridden by subclasses. @@ -145,7 +158,19 @@ def _get_float_scale(cls, scale_e8m0: torch.Tensor) -> torch.Tensor: @torch.inference_mode() def forward(self, input: torch.Tensor) -> torch.Tensor: - qdq_input = self.qdq_input(input) + + if self.enable_transform: + from ..triton.mxfp4 import mxfp4_forward_kernel_wrapper + orig_shape = input.shape + x_flat = input.contiguous().flatten(end_dim=-2) + qdq_input, _ = mxfp4_forward_kernel_wrapper( + x_flat, + self.forward_hadamard_matrix, + ) + qdq_input = qdq_input.reshape(orig_shape) + else: + qdq_input = self.qdq_input(input) + qdq_weight = self.dequant_weight_online() qdq_weight = qdq_weight.to(qdq_input.dtype) out = torch.nn.functional.linear(qdq_input, qdq_weight, self.bias) diff --git a/auto_round/experimental/triton/mxfp4.py b/auto_round/experimental/triton/mxfp4.py new file mode 100644 index 000000000..706290881 --- /dev/null +++ b/auto_round/experimental/triton/mxfp4.py @@ -0,0 +1,188 @@ +# Copyright (c) 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from random import randint + +import torch +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 32 * 32}), + triton.Config({"BLOCK_SIZE": 64 * 32}), + triton.Config({"BLOCK_SIZE": 128 * 32}), + triton.Config({"BLOCK_SIZE": 256 * 32}), + triton.Config({"BLOCK_SIZE": 512 * 32}), + ], + key=[], +) +@triton.jit +def mxfp4_forward_kernel( + x_ptr, + hadamard_matrix_ptr, + output_ptr, + clip_mask_ptr, + n_elements: tl.constexpr, + hadamard_dim: tl.constexpr, + group_size: tl.constexpr, + gaussian_scale: tl.constexpr, + quest: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + offsets_hadamard = tl.arange(0, hadamard_dim * hadamard_dim) + hadamard_matrix = tl.load(hadamard_matrix_ptr + offsets_hadamard).reshape(hadamard_dim, hadamard_dim) + + # load x + pid = tl.program_id(0) + start_idx = pid * BLOCK_SIZE + offsets = start_idx + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x_flat = tl.load(x_ptr + offsets, mask=mask) + + # hadamard transform + x = tl.reshape(x_flat, (BLOCK_SIZE // hadamard_dim, hadamard_dim)) + x_had = tl.dot(x, hadamard_matrix) + + # group + x_had_grouped = tl.reshape(x_had, (BLOCK_SIZE // group_size, group_size)) + + # scale + # quest=True: per-group Gaussian-based scale = gaussian_scale * std + # quest=False: per-group max-abs-based scale, adjusted to FP4 range + if quest: + mean_squared = tl.sum(x_had_grouped * x_had_grouped, axis=-1, keep_dims=True) / group_size + mean = tl.sum(x_had_grouped, axis=-1, keep_dims=True) / group_size + std = tl.sqrt(mean_squared - mean * mean) + scales = gaussian_scale * std + 1e-8 + shared_exps = tl.exp2(tl.floor(tl.log2(scales))) + x_had_scaled = x_had_grouped / shared_exps + else: + scales = tl.max(tl.abs(x_had_grouped), axis=-1, keep_dims=True) + shared_exps = tl.exp2(tl.floor(tl.log2(scales)) - 2) / (3 / 4) + x_had_scaled = x_had_grouped / shared_exps + + # quantize + # Map abs(x) to FP4 levels {0, 0.5, 1, 1.5, 2, 3, 4, 6} + x_had_scaled_abs = tl.abs(x_had_scaled) + x_had_scaled_sign = tl.where( + x_had_scaled > 0, + 1, + -1, + ) + + x_fp4 = ( + tl.where( + x_had_scaled_abs > 5, + 6, + tl.where( + x_had_scaled_abs > 3.5, + 4, + tl.where( + x_had_scaled_abs > 2.5, + 3, + tl.where( + x_had_scaled_abs > 1.75, + 2, + tl.where( + x_had_scaled_abs > 1.25, + 1.5, + tl.where( + x_had_scaled_abs > 0.75, + 1, + tl.where( + x_had_scaled_abs > 0.25, + 0.5, + 0, + ), + ), + ), + ), + ), + ), + ) + * x_had_scaled_sign + ) + if clip_mask_ptr is not None: + tl.store( + clip_mask_ptr + offsets, + tl.reshape(x_had_scaled_abs < 6, (BLOCK_SIZE,)), + mask=mask, + ) + + # dequantize + x_dequantized = x_fp4 * shared_exps + + # Reshape back to flat form for storage + x_dequantized_flat = tl.reshape(x_dequantized, (BLOCK_SIZE,)) + + # store + tl.store(output_ptr + offsets, x_dequantized_flat, mask=mask) + + +@torch.compiler.disable() +def mxfp4_forward_kernel_wrapper( + x, + hadamard_matrix, + return_clip_mask=False, + quest=False, + gaussian_scale=3 / 4, +): + """ + Apply Hadamard transform + group-wise FP4 quantize/dequantize on x. + + Note: + The output is still in the Hadamard-transformed space (no inverse Hadamard is applied). + """ + # Pick a device — we require CUDA + device = x.device + if not device.type == "cuda": + # Either move to cuda or raise, depending on your design + device = torch.device("cuda") + x = x.to(device) + + # Ensure hadamard_matrix is on the same CUDA device + if hadamard_matrix.device != device: + hadamard_matrix = hadamard_matrix.to(device) + + # Make sure inputs are contiguous + x = x.contiguous() + hadamard_matrix = hadamard_matrix.contiguous() + + # Create output tensors on CUDA + output = torch.empty_like(x, device=device) + if return_clip_mask: + clip_mask = torch.empty_like(x, dtype=torch.bool, device=device).contiguous() + else: + clip_mask = None + + # Get total number of elements and calculate grid for launching the kernel + n_elements = x.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) + + # Launch kernel – no need for `with torch.device(...)` + mxfp4_forward_kernel[grid]( + x_ptr=x, + hadamard_matrix_ptr=hadamard_matrix, + output_ptr=output, + clip_mask_ptr=clip_mask, + n_elements=n_elements, + hadamard_dim=hadamard_matrix.shape[-1], + group_size=32, + gaussian_scale=gaussian_scale, + quest=quest, + ) + + return output, clip_mask diff --git a/auto_round/export/export_to_autoround/export_to_nvfp_mxfp.py b/auto_round/export/export_to_autoround/export_to_nvfp_mxfp.py index 7265941d3..d06f540b0 100644 --- a/auto_round/export/export_to_autoround/export_to_nvfp_mxfp.py +++ b/auto_round/export/export_to_autoround/export_to_nvfp_mxfp.py @@ -113,6 +113,11 @@ def pack_layer(name, model, backend, device=None): ## no zeros to handle, as mxfp/nvfp do not support asym quantization # zero = layer.zp qlayer.pack(layer, scale, global_scale=global_scale, input_global_scale=input_global_scale, device=device) + + transform_matrix = getattr(layer, "forward_hadamard_matrix", None) + if transform_matrix is not None: + qlayer.register_buffer("forward_hadamard_matrix", transform_matrix) + qlayer.to(orig_device) set_module(model, name, qlayer) # Note: release weight and bias explicitly, in case they are referenced elsewhere diff --git a/auto_round/inference/convert_model.py b/auto_round/inference/convert_model.py index 3daa3c822..6d11ac50a 100644 --- a/auto_round/inference/convert_model.py +++ b/auto_round/inference/convert_model.py @@ -209,6 +209,8 @@ def get_layer_config(model, quantization_config): act_data_type = getattr(quantization_config, "act_data_type", None) act_dynamic = getattr(quantization_config, "act_dynamic", False) + transform_config = getattr(quantization_config, "transform_config", None) + default_quant_scheme = QuantizationScheme( bits=bits, group_size=group_size, @@ -219,6 +221,7 @@ def get_layer_config(model, quantization_config): act_sym=act_sym, act_data_type=act_data_type, act_dynamic=act_dynamic, + transform_config=transform_config, ) # Determine the quantization block list diff --git a/auto_round/schemes.py b/auto_round/schemes.py index 93110a833..0f34e6df2 100644 --- a/auto_round/schemes.py +++ b/auto_round/schemes.py @@ -34,6 +34,7 @@ class QuantizationScheme: act_dynamic: Optional[bool] = None super_bits: Optional[int] = None super_group_size: Optional[int] = None + transform_config: Optional[dict] = None @classmethod def from_dict(cls, config: dict): diff --git a/auto_round/transforms/transforms.py b/auto_round/transforms/transforms.py new file mode 100644 index 000000000..be8fe66ad --- /dev/null +++ b/auto_round/transforms/transforms.py @@ -0,0 +1,67 @@ +# Copyright (c) 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +from typing import Any, Callable, Dict + +import torch +import torch.nn as nn +from fast_hadamard_transform import hadamard_transform + + +def filter_kwarg_dict(fn_or_method: Callable, kwarg_dict: Dict[str, Any]) -> Dict[str, Any]: + fn_or_method_keys = inspect.signature(fn_or_method).parameters.keys() + return {k: v for k, v in kwarg_dict.items() if k in fn_or_method_keys} + + +class IdentityTransform(nn.Module): + + def __init__(self, *args, **kwargs): + super().__init__() + + def forward(self, x: torch.Tensor): + return x + + def remove_parametrizations(self) -> None: + pass + + +class HadamardTransform(nn.Module): + + def __init__(self, group_size: int = 32): + super().__init__() + self.group_size = group_size + self.scale = 1 / math.sqrt(self.group_size) + + def forward(self, x: torch.Tensor): + # Hadamard transform is it own inverse + x_shape = x.shape + return hadamard_transform(x.view(-1, self.group_size), scale=self.scale).view(x_shape) + + def get_transform_matrix(self, device: torch.device = None, dtype: torch.dtype = None): + return hadamard_transform( + torch.eye(self.group_size, device=device, dtype=dtype), scale=1 / math.sqrt(self.group_size) + ) + + +TRANSFORMS = { + "identity": IdentityTransform, + "hadamard": HadamardTransform, +} + + +def build_transform(transform_class: str, **transform_kwargs): + transform = TRANSFORMS[transform_class] + return transform(**filter_kwarg_dict(transform.__init__, transform_kwargs)) diff --git a/auto_round/wrapper.py b/auto_round/wrapper.py index 24836d85b..9670df83c 100644 --- a/auto_round/wrapper.py +++ b/auto_round/wrapper.py @@ -98,6 +98,7 @@ def __init__( enable_round_tuning=True, enable_torch_compile=False, disable_opt_rtn=True, + transform_config={}, **kwargs, ): """Initializes the WrapperLinear module. @@ -141,6 +142,14 @@ def __init__( else: self.orig_forward = self.linear_forward if type(self.orig_layer) == torch.nn.Linear else self.conv1d_forward + self.enable_transform = False + if transform_config: + from .transforms.transforms import build_transform + + self.in_transform = build_transform(**transform_config) + self.enable_transform = True + self.transform_config = transform_config + def _init_tuning_params_and_quant_func(self): """Initializes tuning parameters and quantization functions. @@ -235,6 +244,9 @@ def _qdq_weight(self, value, min_scale, max_scale): quant_kwargs["super_bits"] = self.orig_layer.super_bits quant_kwargs["super_group_size"] = self.orig_layer.super_group_size + if self.enable_transform: + weight = self.in_transform(weight) + weight_q, scale, zp = self.weight_quant_func( weight.to(self.device), bits=self.orig_layer.bits, @@ -267,6 +279,9 @@ def _qdq_act(self, x, act_max_scale, act_max=None): Returns: tuple: Quantized activation, scale, and zero point. """ + # apply rotate + if self.enable_transform: + x = self.in_transform(x) act_max_scale.data.clamp_(0, 1.0) x, scale, zp = self.act_quant_func( x, @@ -323,6 +338,7 @@ def unwrapper(self, best_params): qdq_weight, scale, zp = self._qdq_weight(v, min_scale, max_scale) # if hasattr(self.orig_layer, "imatrix"): # self.orig_layer.imatrix = None + orig_dtype = self.orig_layer.weight.dtype self.orig_layer.weight.data.copy_(qdq_weight) self.orig_layer.weight.grad = None @@ -347,6 +363,12 @@ def _set_dict_attr(attr_dict, attr_name): else: self.orig_layer.scale = scale.view(-1).to("cpu") + # for saving transform matrix + if self.enable_transform: + transform_matrix = self.in_transform.get_transform_matrix(self.device, orig_dtype).cpu() + self.orig_layer.transform_config = self.transform_config + self.orig_layer.forward_hadamard_matrix = transform_matrix + if zp is not None: if isinstance(zp, dict): _set_dict_attr(zp, "zp") @@ -502,8 +524,17 @@ def __init__(self, orig_layer, enable_torch_compile=False, device="cpu"): self.act_quant_func = compile_func(self.act_quant_func, self.device) self.extra_repr_org = orig_layer.extra_repr + self.enable_transform = False + if getattr(self.orig_layer, "transform_config", False): + from .transforms.transforms import build_transform + + self.in_transform = build_transform(**self.orig_layer.transform_config) + self.enable_transform = True + def forward(self, x): act_max = self.orig_layer.act_max if hasattr(self.orig_layer, "act_max") else None + if self.enable_transform: + x = self.in_transform(x) x, _, _ = self.orig_layer.act_quant_func( x, bits=self.orig_layer.act_bits,