From f831f9e75d69bcb28937152ae3b359a85a6c8e23 Mon Sep 17 00:00:00 2001 From: lkk12014402 Date: Tue, 27 Jan 2026 05:17:20 +0000 Subject: [PATCH 1/3] support hadamard transform for mxfp4 with rtn or autoround method. Signed-off-by: lkk12014402 --- auto_round/autoround.py | 3 + auto_round/compressors/base.py | 7 + auto_round/experimental/qmodules/mx.py | 27 ++- auto_round/experimental/triton/mxfp4.py | 178 ++++++++++++++++++ .../export_to_nvfp_mxfp.py | 5 + auto_round/inference/convert_model.py | 3 + auto_round/schemes.py | 1 + auto_round/transforms/transforms.py | 47 +++++ auto_round/wrapper.py | 29 +++ 9 files changed, 299 insertions(+), 1 deletion(-) create mode 100644 auto_round/experimental/triton/mxfp4.py create mode 100644 auto_round/transforms/transforms.py diff --git a/auto_round/autoround.py b/auto_round/autoround.py index 33ca1ccd7..baa3705a7 100644 --- a/auto_round/autoround.py +++ b/auto_round/autoround.py @@ -87,6 +87,7 @@ def __new__( enable_alg_ext: bool = None, disable_opt_rtn: bool | None = None, low_cpu_mem_usage: bool = True, + transform_config: dict = {}, **kwargs, ) -> BaseCompressor: """Initialize AutoRound with quantization and tuning configuration. @@ -114,6 +115,7 @@ def __new__( disable_opt_rtn (bool, optional): Disable RTN-mode optimization (iters=0) for fast quatnziation with lower accuracy. Defaults to None. low_cpu_mem_usage (bool, optional): Lower CPU memory mode. Defaults to False. + transform_config (dict, optional): transform matirx config like hadamard, like {"transform_class": "hadamard"}. bits (int, optional): Weight quantization bits. Defaults to 4. group_size (int, optional): Weight quantization group size. Defaults to 128. @@ -204,6 +206,7 @@ def __new__( enable_torch_compile=enable_torch_compile, seed=seed, low_cpu_mem_usage=low_cpu_mem_usage, + transform_config=transform_config, **kwargs, ) return ar diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index b1f46cc1e..17cac2ea4 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -142,6 +142,7 @@ "super_bits", "super_group_size", "to_quant_block_names", + "transform_config", ) @@ -192,6 +193,7 @@ def __init__( disable_opt_rtn: bool | None = None, seed: int = 42, low_cpu_mem_usage: bool = True, + transform_config: dict = {}, **kwargs, ): """Initialize AutoRound with quantization and tuning configuration. @@ -213,6 +215,7 @@ def __init__( minmax_lr (float, optional): Learning rate for min-max tuning; defaults to `lr`. low_gpu_mem_usage (bool, optional): Lower GPU memory mode. Defaults to False. low_cpu_mem_usage (bool, optional): Lower CPU memory mode. Defaults to False. + transform_config (dict, optional): transform matirx config like hadamard, like {"transform_class": "hadamard"}. iters (int, optional): Optimization iterations. Defaults to 200. seqlen (int, optional): Calibration sequence length. Defaults to 2048. nsamples (int, optional): Number of calibration samples. Defaults to 128. @@ -483,6 +486,8 @@ def __init__( except (ImportError, ModuleNotFoundError): logger.error("algorithm extension import error, fallback to default mode") + self.transform_config = transform_config + def _gen_auto_scheme( self, model: torch.nn.Module, scheme: AutoScheme, dataset: str, device_map: Union[str, int, dict, torch.device] ) -> dict[str, dict]: @@ -1147,6 +1152,7 @@ def _quantize_layer_via_rtn(self, name: str, dtype: torch.dtype = None, to_cpu=T enable_round_tuning=False, enable_torch_compile=self.enable_torch_compile, disable_opt_rtn=disable_opt_rtn, + transform_config=self.transform_config, ) m = m.unwrapper({}) except torch.OutOfMemoryError: @@ -1162,6 +1168,7 @@ def _quantize_layer_via_rtn(self, name: str, dtype: torch.dtype = None, to_cpu=T enable_norm_bias_tuning=False, enable_round_tuning=False, enable_torch_compile=self.enable_torch_compile, + transform_config=self.transform_config, ) m = m.unwrapper({}) except Exception as e: diff --git a/auto_round/experimental/qmodules/mx.py b/auto_round/experimental/qmodules/mx.py index f06076173..57ecb9b17 100644 --- a/auto_round/experimental/qmodules/mx.py +++ b/auto_round/experimental/qmodules/mx.py @@ -94,6 +94,19 @@ def __init__( ) self.register_buffer("weight_scale", init_weight_scale) + # Rotation matrices buffers + self.enable_transform = False + if self.config.transform_config is not None: + self.enable_transform = True + self.register_buffer( + "forward_hadamard_matrix", + torch.empty( + self.group_size, + self.group_size, + dtype=dtype, + ), + ) + def initialize_weights(self, weight: Optional[torch.Tensor]) -> torch.Tensor: """ Initialize weights. This method should be overridden by subclasses. @@ -145,7 +158,19 @@ def _get_float_scale(cls, scale_e8m0: torch.Tensor) -> torch.Tensor: @torch.inference_mode() def forward(self, input: torch.Tensor) -> torch.Tensor: - qdq_input = self.qdq_input(input) + + if self.enable_transform: + from ..triton.mxfp4 import mxfp4_forward_kernel_wrapper + orig_shape = input.shape + x_flat = input.contiguous().flatten(end_dim=-2) + qdq_input, _ = mxfp4_forward_kernel_wrapper( + x_flat, + self.forward_hadamard_matrix, + ) + qdq_input = qdq_input.reshape(orig_shape) + else: + qdq_input = self.qdq_input(input) + qdq_weight = self.dequant_weight_online() qdq_weight = qdq_weight.to(qdq_input.dtype) out = torch.nn.functional.linear(qdq_input, qdq_weight, self.bias) diff --git a/auto_round/experimental/triton/mxfp4.py b/auto_round/experimental/triton/mxfp4.py new file mode 100644 index 000000000..65648b43f --- /dev/null +++ b/auto_round/experimental/triton/mxfp4.py @@ -0,0 +1,178 @@ +from random import randint + +import torch +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 32 * 32}), + triton.Config({"BLOCK_SIZE": 64 * 32}), + triton.Config({"BLOCK_SIZE": 128 * 32}), + triton.Config({"BLOCK_SIZE": 256 * 32}), + triton.Config({"BLOCK_SIZE": 512 * 32}), + ], + key=[], +) +@triton.jit +def mxfp4_forward_kernel( + x_ptr, + hadamard_matrix_ptr, + output_ptr, + clip_mask_ptr, + n_elements: tl.constexpr, + hadamard_dim: tl.constexpr, + group_size: tl.constexpr, + gaussian_scale: tl.constexpr, + quest: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + offsets_hadamard = tl.arange(0, hadamard_dim * hadamard_dim) + hadamard_matrix = tl.load(hadamard_matrix_ptr + offsets_hadamard).reshape( + hadamard_dim, hadamard_dim + ) + + # load x + pid = tl.program_id(0) + start_idx = pid * BLOCK_SIZE + offsets = start_idx + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x_flat = tl.load(x_ptr + offsets, mask=mask) + + # hadamard transform + x = tl.reshape(x_flat, (BLOCK_SIZE // hadamard_dim, hadamard_dim)) + x_had = tl.dot(x, hadamard_matrix) + + # group + x_had_grouped = tl.reshape(x_had, (BLOCK_SIZE // group_size, group_size)) + + # scale + # quest=True: per-group Gaussian-based scale = gaussian_scale * std + # quest=False: per-group max-abs-based scale, adjusted to FP4 range + if quest: + mean_squared = ( + tl.sum(x_had_grouped * x_had_grouped, axis=-1, keep_dims=True) / group_size + ) + mean = tl.sum(x_had_grouped, axis=-1, keep_dims=True) / group_size + std = tl.sqrt(mean_squared - mean * mean) + scales = gaussian_scale * std + 1e-8 + shared_exps = tl.exp2(tl.floor(tl.log2(scales))) + x_had_scaled = x_had_grouped / shared_exps + else: + scales = tl.max(tl.abs(x_had_grouped), axis=-1, keep_dims=True) + shared_exps = tl.exp2(tl.floor(tl.log2(scales)) - 2) / (3 / 4) + x_had_scaled = x_had_grouped / shared_exps + + # quantize + # Map abs(x) to FP4 levels {0, 0.5, 1, 1.5, 2, 3, 4, 6} + x_had_scaled_abs = tl.abs(x_had_scaled) + x_had_scaled_sign = tl.where( + x_had_scaled > 0, + 1, + -1, + ) + + x_fp4 = ( + tl.where( + x_had_scaled_abs > 5, + 6, + tl.where( + x_had_scaled_abs > 3.5, + 4, + tl.where( + x_had_scaled_abs > 2.5, + 3, + tl.where( + x_had_scaled_abs > 1.75, + 2, + tl.where( + x_had_scaled_abs > 1.25, + 1.5, + tl.where( + x_had_scaled_abs > 0.75, + 1, + tl.where( + x_had_scaled_abs > 0.25, + 0.5, + 0, + ), + ), + ), + ), + ), + ), + ) + * x_had_scaled_sign + ) + if clip_mask_ptr is not None: + tl.store( + clip_mask_ptr + offsets, + tl.reshape(x_had_scaled_abs < 6, (BLOCK_SIZE,)), + mask=mask, + ) + + # dequantize + x_dequantized = x_fp4 * shared_exps + + # Reshape back to flat form for storage + x_dequantized_flat = tl.reshape(x_dequantized, (BLOCK_SIZE,)) + + # store + tl.store(output_ptr + offsets, x_dequantized_flat, mask=mask) + + +@torch.compiler.disable() +def mxfp4_forward_kernel_wrapper( + x, + hadamard_matrix, + return_clip_mask=False, + quest=False, + gaussian_scale=3 / 4, +): + """ + Apply Hadamard transform + group-wise FP4 quantize/dequantize on x. + + Note: + The output is still in the Hadamard-transformed space (no inverse Hadamard is applied). + """ + # Pick a device — we require CUDA + device = x.device + if not device.type == "cuda": + # Either move to cuda or raise, depending on your design + device = torch.device("cuda") + x = x.to(device) + + # Ensure hadamard_matrix is on the same CUDA device + if hadamard_matrix.device != device: + hadamard_matrix = hadamard_matrix.to(device) + + # Make sure inputs are contiguous + x = x.contiguous() + hadamard_matrix = hadamard_matrix.contiguous() + + # Create output tensors on CUDA + output = torch.empty_like(x, device=device) + if return_clip_mask: + clip_mask = torch.empty_like(x, dtype=torch.bool, device=device).contiguous() + else: + clip_mask = None + + # Get total number of elements and calculate grid for launching the kernel + n_elements = x.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) + + # Launch kernel – no need for `with torch.device(...)` + mxfp4_forward_kernel[grid]( + x_ptr=x, + hadamard_matrix_ptr=hadamard_matrix, + output_ptr=output, + clip_mask_ptr=clip_mask, + n_elements=n_elements, + hadamard_dim=hadamard_matrix.shape[-1], + group_size=32, + gaussian_scale=gaussian_scale, + quest=quest, + ) + + return output, clip_mask diff --git a/auto_round/export/export_to_autoround/export_to_nvfp_mxfp.py b/auto_round/export/export_to_autoround/export_to_nvfp_mxfp.py index 7265941d3..d06f540b0 100644 --- a/auto_round/export/export_to_autoround/export_to_nvfp_mxfp.py +++ b/auto_round/export/export_to_autoround/export_to_nvfp_mxfp.py @@ -113,6 +113,11 @@ def pack_layer(name, model, backend, device=None): ## no zeros to handle, as mxfp/nvfp do not support asym quantization # zero = layer.zp qlayer.pack(layer, scale, global_scale=global_scale, input_global_scale=input_global_scale, device=device) + + transform_matrix = getattr(layer, "forward_hadamard_matrix", None) + if transform_matrix is not None: + qlayer.register_buffer("forward_hadamard_matrix", transform_matrix) + qlayer.to(orig_device) set_module(model, name, qlayer) # Note: release weight and bias explicitly, in case they are referenced elsewhere diff --git a/auto_round/inference/convert_model.py b/auto_round/inference/convert_model.py index 3daa3c822..6d11ac50a 100644 --- a/auto_round/inference/convert_model.py +++ b/auto_round/inference/convert_model.py @@ -209,6 +209,8 @@ def get_layer_config(model, quantization_config): act_data_type = getattr(quantization_config, "act_data_type", None) act_dynamic = getattr(quantization_config, "act_dynamic", False) + transform_config = getattr(quantization_config, "transform_config", None) + default_quant_scheme = QuantizationScheme( bits=bits, group_size=group_size, @@ -219,6 +221,7 @@ def get_layer_config(model, quantization_config): act_sym=act_sym, act_data_type=act_data_type, act_dynamic=act_dynamic, + transform_config=transform_config, ) # Determine the quantization block list diff --git a/auto_round/schemes.py b/auto_round/schemes.py index 93110a833..0f34e6df2 100644 --- a/auto_round/schemes.py +++ b/auto_round/schemes.py @@ -34,6 +34,7 @@ class QuantizationScheme: act_dynamic: Optional[bool] = None super_bits: Optional[int] = None super_group_size: Optional[int] = None + transform_config: Optional[dict] = None @classmethod def from_dict(cls, config: dict): diff --git a/auto_round/transforms/transforms.py b/auto_round/transforms/transforms.py new file mode 100644 index 000000000..6e02bb6c0 --- /dev/null +++ b/auto_round/transforms/transforms.py @@ -0,0 +1,47 @@ +import math +import torch +import torch.nn as nn +from fast_hadamard_transform import hadamard_transform + +import inspect +from typing import Any, Callable, Dict + +def filter_kwarg_dict(fn_or_method: Callable, kwarg_dict: Dict[str, Any]) -> Dict[str, Any]: + fn_or_method_keys = inspect.signature(fn_or_method).parameters.keys() + return {k: v for k, v in kwarg_dict.items() if k in fn_or_method_keys} + +class IdentityTransform(nn.Module): + + def __init__(self, *args, **kwargs): + super().__init__() + + def forward(self, x: torch.Tensor): + return x + + def remove_parametrizations(self) -> None: + pass + +class HadamardTransform(nn.Module): + + def __init__(self, group_size: int = 32): + super().__init__() + self.group_size = group_size + self.scale = 1 / math.sqrt(self.group_size) + + def forward(self, x: torch.Tensor): + # Hadamard transform is it own inverse + x_shape = x.shape + return hadamard_transform(x.view(-1, self.group_size), scale=self.scale).view(x_shape) + + def get_transform_matrix(self, device: torch.device = None, dtype: torch.dtype = None): + return hadamard_transform(torch.eye(self.group_size, device=device, dtype=dtype), scale=1 / math.sqrt(self.group_size)) + + +TRANSFORMS = { + "identity": IdentityTransform, + "hadamard": HadamardTransform, +} + +def build_transform(transform_class: str, **transform_kwargs): + transform = TRANSFORMS[transform_class] + return transform(**filter_kwarg_dict(transform.__init__, transform_kwargs)) diff --git a/auto_round/wrapper.py b/auto_round/wrapper.py index 24836d85b..6909e37ab 100644 --- a/auto_round/wrapper.py +++ b/auto_round/wrapper.py @@ -98,6 +98,7 @@ def __init__( enable_round_tuning=True, enable_torch_compile=False, disable_opt_rtn=True, + transform_config={}, **kwargs, ): """Initializes the WrapperLinear module. @@ -141,6 +142,13 @@ def __init__( else: self.orig_forward = self.linear_forward if type(self.orig_layer) == torch.nn.Linear else self.conv1d_forward + self.enable_transform = False + if transform_config: + from .transforms.transforms import build_transform + self.in_transform = build_transform(**transform_config) + self.enable_transform = True + self.transform_config = transform_config + def _init_tuning_params_and_quant_func(self): """Initializes tuning parameters and quantization functions. @@ -235,6 +243,9 @@ def _qdq_weight(self, value, min_scale, max_scale): quant_kwargs["super_bits"] = self.orig_layer.super_bits quant_kwargs["super_group_size"] = self.orig_layer.super_group_size + if self.enable_transform: + weight = self.in_transform(weight) + weight_q, scale, zp = self.weight_quant_func( weight.to(self.device), bits=self.orig_layer.bits, @@ -267,6 +278,9 @@ def _qdq_act(self, x, act_max_scale, act_max=None): Returns: tuple: Quantized activation, scale, and zero point. """ + # apply rotate + if self.enable_transform: + x = self.in_transform(x) act_max_scale.data.clamp_(0, 1.0) x, scale, zp = self.act_quant_func( x, @@ -323,6 +337,7 @@ def unwrapper(self, best_params): qdq_weight, scale, zp = self._qdq_weight(v, min_scale, max_scale) # if hasattr(self.orig_layer, "imatrix"): # self.orig_layer.imatrix = None + orig_dtype = self.orig_layer.weight.dtype self.orig_layer.weight.data.copy_(qdq_weight) self.orig_layer.weight.grad = None @@ -347,6 +362,12 @@ def _set_dict_attr(attr_dict, attr_name): else: self.orig_layer.scale = scale.view(-1).to("cpu") + # for saving transform matrix + if self.enable_transform: + transform_matrix = self.in_transform.get_transform_matrix(self.device, orig_dtype).cpu() + self.orig_layer.transform_config = self.transform_config + self.orig_layer.forward_hadamard_matrix = transform_matrix + if zp is not None: if isinstance(zp, dict): _set_dict_attr(zp, "zp") @@ -502,8 +523,16 @@ def __init__(self, orig_layer, enable_torch_compile=False, device="cpu"): self.act_quant_func = compile_func(self.act_quant_func, self.device) self.extra_repr_org = orig_layer.extra_repr + self.enable_transform = False + if getattr(self.orig_layer, "transform_config", False): + from .transforms.transforms import build_transform + self.in_transform = build_transform(**self.orig_layer.transform_config) + self.enable_transform = True + def forward(self, x): act_max = self.orig_layer.act_max if hasattr(self.orig_layer, "act_max") else None + if self.enable_transform: + x = self.in_transform(x) x, _, _ = self.orig_layer.act_quant_func( x, bits=self.orig_layer.act_bits, From 795bb357f7b687ee3ba19b326c426018b2acec35 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 27 Jan 2026 05:21:13 +0000 Subject: [PATCH 2/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- auto_round/autoround.py | 2 +- auto_round/compressors/base.py | 2 +- auto_round/experimental/qmodules/mx.py | 1 + auto_round/experimental/triton/mxfp4.py | 22 +++++++++++++++------ auto_round/transforms/transforms.py | 26 ++++++++++++++++++++++--- auto_round/wrapper.py | 2 ++ 6 files changed, 44 insertions(+), 11 deletions(-) diff --git a/auto_round/autoround.py b/auto_round/autoround.py index baa3705a7..069528e5e 100644 --- a/auto_round/autoround.py +++ b/auto_round/autoround.py @@ -115,7 +115,7 @@ def __new__( disable_opt_rtn (bool, optional): Disable RTN-mode optimization (iters=0) for fast quatnziation with lower accuracy. Defaults to None. low_cpu_mem_usage (bool, optional): Lower CPU memory mode. Defaults to False. - transform_config (dict, optional): transform matirx config like hadamard, like {"transform_class": "hadamard"}. + transform_config (dict, optional): transform matrix config like hadamard, like {"transform_class": "hadamard"}. bits (int, optional): Weight quantization bits. Defaults to 4. group_size (int, optional): Weight quantization group size. Defaults to 128. diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index 17cac2ea4..d7d0f657a 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -215,7 +215,7 @@ def __init__( minmax_lr (float, optional): Learning rate for min-max tuning; defaults to `lr`. low_gpu_mem_usage (bool, optional): Lower GPU memory mode. Defaults to False. low_cpu_mem_usage (bool, optional): Lower CPU memory mode. Defaults to False. - transform_config (dict, optional): transform matirx config like hadamard, like {"transform_class": "hadamard"}. + transform_config (dict, optional): transform matrix config like hadamard, like {"transform_class": "hadamard"}. iters (int, optional): Optimization iterations. Defaults to 200. seqlen (int, optional): Calibration sequence length. Defaults to 2048. nsamples (int, optional): Number of calibration samples. Defaults to 128. diff --git a/auto_round/experimental/qmodules/mx.py b/auto_round/experimental/qmodules/mx.py index 57ecb9b17..af0b0b48b 100644 --- a/auto_round/experimental/qmodules/mx.py +++ b/auto_round/experimental/qmodules/mx.py @@ -161,6 +161,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: if self.enable_transform: from ..triton.mxfp4 import mxfp4_forward_kernel_wrapper + orig_shape = input.shape x_flat = input.contiguous().flatten(end_dim=-2) qdq_input, _ = mxfp4_forward_kernel_wrapper( diff --git a/auto_round/experimental/triton/mxfp4.py b/auto_round/experimental/triton/mxfp4.py index 65648b43f..706290881 100644 --- a/auto_round/experimental/triton/mxfp4.py +++ b/auto_round/experimental/triton/mxfp4.py @@ -1,3 +1,17 @@ +# Copyright (c) 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from random import randint import torch @@ -29,9 +43,7 @@ def mxfp4_forward_kernel( BLOCK_SIZE: tl.constexpr, ): offsets_hadamard = tl.arange(0, hadamard_dim * hadamard_dim) - hadamard_matrix = tl.load(hadamard_matrix_ptr + offsets_hadamard).reshape( - hadamard_dim, hadamard_dim - ) + hadamard_matrix = tl.load(hadamard_matrix_ptr + offsets_hadamard).reshape(hadamard_dim, hadamard_dim) # load x pid = tl.program_id(0) @@ -51,9 +63,7 @@ def mxfp4_forward_kernel( # quest=True: per-group Gaussian-based scale = gaussian_scale * std # quest=False: per-group max-abs-based scale, adjusted to FP4 range if quest: - mean_squared = ( - tl.sum(x_had_grouped * x_had_grouped, axis=-1, keep_dims=True) / group_size - ) + mean_squared = tl.sum(x_had_grouped * x_had_grouped, axis=-1, keep_dims=True) / group_size mean = tl.sum(x_had_grouped, axis=-1, keep_dims=True) / group_size std = tl.sqrt(mean_squared - mean * mean) scales = gaussian_scale * std + 1e-8 diff --git a/auto_round/transforms/transforms.py b/auto_round/transforms/transforms.py index 6e02bb6c0..be8fe66ad 100644 --- a/auto_round/transforms/transforms.py +++ b/auto_round/transforms/transforms.py @@ -1,15 +1,31 @@ +# Copyright (c) 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect import math +from typing import Any, Callable, Dict + import torch import torch.nn as nn from fast_hadamard_transform import hadamard_transform -import inspect -from typing import Any, Callable, Dict def filter_kwarg_dict(fn_or_method: Callable, kwarg_dict: Dict[str, Any]) -> Dict[str, Any]: fn_or_method_keys = inspect.signature(fn_or_method).parameters.keys() return {k: v for k, v in kwarg_dict.items() if k in fn_or_method_keys} + class IdentityTransform(nn.Module): def __init__(self, *args, **kwargs): @@ -21,6 +37,7 @@ def forward(self, x: torch.Tensor): def remove_parametrizations(self) -> None: pass + class HadamardTransform(nn.Module): def __init__(self, group_size: int = 32): @@ -34,7 +51,9 @@ def forward(self, x: torch.Tensor): return hadamard_transform(x.view(-1, self.group_size), scale=self.scale).view(x_shape) def get_transform_matrix(self, device: torch.device = None, dtype: torch.dtype = None): - return hadamard_transform(torch.eye(self.group_size, device=device, dtype=dtype), scale=1 / math.sqrt(self.group_size)) + return hadamard_transform( + torch.eye(self.group_size, device=device, dtype=dtype), scale=1 / math.sqrt(self.group_size) + ) TRANSFORMS = { @@ -42,6 +61,7 @@ def get_transform_matrix(self, device: torch.device = None, dtype: torch.dtype = "hadamard": HadamardTransform, } + def build_transform(transform_class: str, **transform_kwargs): transform = TRANSFORMS[transform_class] return transform(**filter_kwarg_dict(transform.__init__, transform_kwargs)) diff --git a/auto_round/wrapper.py b/auto_round/wrapper.py index 6909e37ab..9670df83c 100644 --- a/auto_round/wrapper.py +++ b/auto_round/wrapper.py @@ -145,6 +145,7 @@ def __init__( self.enable_transform = False if transform_config: from .transforms.transforms import build_transform + self.in_transform = build_transform(**transform_config) self.enable_transform = True self.transform_config = transform_config @@ -526,6 +527,7 @@ def __init__(self, orig_layer, enable_torch_compile=False, device="cpu"): self.enable_transform = False if getattr(self.orig_layer, "transform_config", False): from .transforms.transforms import build_transform + self.in_transform = build_transform(**self.orig_layer.transform_config) self.enable_transform = True From d171df6da91efbe3c1f376f1854a826afd03c357 Mon Sep 17 00:00:00 2001 From: lkk12014402 Date: Sun, 1 Feb 2026 10:10:16 +0000 Subject: [PATCH 3/3] fix bugs. Signed-off-by: lkk12014402 --- auto_round/compressors/base.py | 4 +++- auto_round/experimental/qmodules/mx.py | 3 +-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index d7d0f657a..c49579cdf 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -215,7 +215,7 @@ def __init__( minmax_lr (float, optional): Learning rate for min-max tuning; defaults to `lr`. low_gpu_mem_usage (bool, optional): Lower GPU memory mode. Defaults to False. low_cpu_mem_usage (bool, optional): Lower CPU memory mode. Defaults to False. - transform_config (dict, optional): transform matrix config like hadamard, like {"transform_class": "hadamard"}. + transform_config (dict, optional): transform matirx config like hadamard, like {"transform_class": "hadamard"}. iters (int, optional): Optimization iterations. Defaults to 200. seqlen (int, optional): Calibration sequence length. Defaults to 2048. nsamples (int, optional): Number of calibration samples. Defaults to 128. @@ -2824,7 +2824,9 @@ def _quantize_block( self.enable_norm_bias_tuning, enable_torch_compile=self.enable_torch_compile, device=device, + transform_config=self.transform_config, ) + if is_nv_fp(self.data_type): # enable qkv and moe structure global_scale fuse from auto_round.data_type.utils import update_fused_layer_global_scales diff --git a/auto_round/experimental/qmodules/mx.py b/auto_round/experimental/qmodules/mx.py index af0b0b48b..d4a309eb1 100644 --- a/auto_round/experimental/qmodules/mx.py +++ b/auto_round/experimental/qmodules/mx.py @@ -96,7 +96,7 @@ def __init__( # Rotation matrices buffers self.enable_transform = False - if self.config.transform_config is not None: + if self.config.transform_config: self.enable_transform = True self.register_buffer( "forward_hadamard_matrix", @@ -161,7 +161,6 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: if self.enable_transform: from ..triton.mxfp4 import mxfp4_forward_kernel_wrapper - orig_shape = input.shape x_flat = input.contiguous().flatten(end_dim=-2) qdq_input, _ = mxfp4_forward_kernel_wrapper(