Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions auto_round/autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def __new__(
enable_alg_ext: bool = None,
disable_opt_rtn: bool | None = None,
low_cpu_mem_usage: bool = True,
transform_config: dict = {},
**kwargs,
) -> BaseCompressor:
"""Initialize AutoRound with quantization and tuning configuration.
Expand Down Expand Up @@ -114,6 +115,7 @@ def __new__(
disable_opt_rtn (bool, optional): Disable RTN-mode optimization (iters=0) for fast quatnziation
with lower accuracy. Defaults to None.
low_cpu_mem_usage (bool, optional): Lower CPU memory mode. Defaults to False.
transform_config (dict, optional): transform matrix config like hadamard, like {"transform_class": "hadamard"}.
Copy link

Copilot AI Jan 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Corrected spelling of 'matirx' to 'matrix'.

Suggested change
transform_config (dict, optional): transform matrix config like hadamard, like {"transform_class": "hadamard"}.
transform_config (dict, optional): transform matrix config like hadamard, like {"transform_class": "hadamard"}.

Copilot uses AI. Check for mistakes.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please mark it as experimental feature, and clarify the limitation.


bits (int, optional): Weight quantization bits. Defaults to 4.
group_size (int, optional): Weight quantization group size. Defaults to 128.
Expand Down Expand Up @@ -204,6 +206,7 @@ def __new__(
enable_torch_compile=enable_torch_compile,
seed=seed,
low_cpu_mem_usage=low_cpu_mem_usage,
transform_config=transform_config,
**kwargs,
)
return ar
Expand Down
9 changes: 9 additions & 0 deletions auto_round/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@
"super_bits",
"super_group_size",
"to_quant_block_names",
"transform_config",
)


Expand Down Expand Up @@ -192,6 +193,7 @@ def __init__(
disable_opt_rtn: bool | None = None,
seed: int = 42,
low_cpu_mem_usage: bool = True,
transform_config: dict = {},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It’s better not to name it transform_config, as it may be confusing with Transformers.

**kwargs,
):
"""Initialize AutoRound with quantization and tuning configuration.
Expand All @@ -213,6 +215,7 @@ def __init__(
minmax_lr (float, optional): Learning rate for min-max tuning; defaults to `lr`.
low_gpu_mem_usage (bool, optional): Lower GPU memory mode. Defaults to False.
low_cpu_mem_usage (bool, optional): Lower CPU memory mode. Defaults to False.
transform_config (dict, optional): transform matirx config like hadamard, like {"transform_class": "hadamard"}.
iters (int, optional): Optimization iterations. Defaults to 200.
seqlen (int, optional): Calibration sequence length. Defaults to 2048.
nsamples (int, optional): Number of calibration samples. Defaults to 128.
Expand Down Expand Up @@ -483,6 +486,8 @@ def __init__(
except (ImportError, ModuleNotFoundError):
logger.error("algorithm extension import error, fallback to default mode")

self.transform_config = transform_config

def _gen_auto_scheme(
self, model: torch.nn.Module, scheme: AutoScheme, dataset: str, device_map: Union[str, int, dict, torch.device]
) -> dict[str, dict]:
Expand Down Expand Up @@ -1147,6 +1152,7 @@ def _quantize_layer_via_rtn(self, name: str, dtype: torch.dtype = None, to_cpu=T
enable_round_tuning=False,
enable_torch_compile=self.enable_torch_compile,
disable_opt_rtn=disable_opt_rtn,
transform_config=self.transform_config,
)
m = m.unwrapper({})
except torch.OutOfMemoryError:
Expand All @@ -1162,6 +1168,7 @@ def _quantize_layer_via_rtn(self, name: str, dtype: torch.dtype = None, to_cpu=T
enable_norm_bias_tuning=False,
enable_round_tuning=False,
enable_torch_compile=self.enable_torch_compile,
transform_config=self.transform_config,
)
m = m.unwrapper({})
except Exception as e:
Expand Down Expand Up @@ -2817,7 +2824,9 @@ def _quantize_block(
self.enable_norm_bias_tuning,
enable_torch_compile=self.enable_torch_compile,
device=device,
transform_config=self.transform_config,
)

if is_nv_fp(self.data_type): # enable qkv and moe structure global_scale fuse
from auto_round.data_type.utils import update_fused_layer_global_scales

Expand Down
27 changes: 26 additions & 1 deletion auto_round/experimental/qmodules/mx.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,19 @@ def __init__(
)
self.register_buffer("weight_scale", init_weight_scale)

# Rotation matrices buffers
self.enable_transform = False
if self.config.transform_config:
self.enable_transform = True
self.register_buffer(
"forward_hadamard_matrix",
torch.empty(
self.group_size,
self.group_size,
dtype=dtype,
),
)

def initialize_weights(self, weight: Optional[torch.Tensor]) -> torch.Tensor:
"""
Initialize weights. This method should be overridden by subclasses.
Expand Down Expand Up @@ -145,7 +158,19 @@ def _get_float_scale(cls, scale_e8m0: torch.Tensor) -> torch.Tensor:

@torch.inference_mode()
def forward(self, input: torch.Tensor) -> torch.Tensor:
qdq_input = self.qdq_input(input)

if self.enable_transform:
from ..triton.mxfp4 import mxfp4_forward_kernel_wrapper
orig_shape = input.shape
x_flat = input.contiguous().flatten(end_dim=-2)
qdq_input, _ = mxfp4_forward_kernel_wrapper(
x_flat,
self.forward_hadamard_matrix,
)
qdq_input = qdq_input.reshape(orig_shape)
else:
qdq_input = self.qdq_input(input)

qdq_weight = self.dequant_weight_online()
qdq_weight = qdq_weight.to(qdq_input.dtype)
out = torch.nn.functional.linear(qdq_input, qdq_weight, self.bias)
Expand Down
188 changes: 188 additions & 0 deletions auto_round/experimental/triton/mxfp4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
# Copyright (c) 2026 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from random import randint

import torch
import triton
import triton.language as tl


@triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE": 32 * 32}),
triton.Config({"BLOCK_SIZE": 64 * 32}),
triton.Config({"BLOCK_SIZE": 128 * 32}),
triton.Config({"BLOCK_SIZE": 256 * 32}),
triton.Config({"BLOCK_SIZE": 512 * 32}),
],
key=[],
)
@triton.jit
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume xpu does not support this, but it's not a big issue for now

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

xpu support this kernel. but I haven't tested the performance

def mxfp4_forward_kernel(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add the source if the code is copied from another repo.
better add their license at the beginning of this file

x_ptr,
hadamard_matrix_ptr,
output_ptr,
clip_mask_ptr,
n_elements: tl.constexpr,
hadamard_dim: tl.constexpr,
group_size: tl.constexpr,
gaussian_scale: tl.constexpr,
quest: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
offsets_hadamard = tl.arange(0, hadamard_dim * hadamard_dim)
hadamard_matrix = tl.load(hadamard_matrix_ptr + offsets_hadamard).reshape(hadamard_dim, hadamard_dim)

# load x
pid = tl.program_id(0)
start_idx = pid * BLOCK_SIZE
offsets = start_idx + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x_flat = tl.load(x_ptr + offsets, mask=mask)

# hadamard transform
x = tl.reshape(x_flat, (BLOCK_SIZE // hadamard_dim, hadamard_dim))
x_had = tl.dot(x, hadamard_matrix)

# group
x_had_grouped = tl.reshape(x_had, (BLOCK_SIZE // group_size, group_size))

# scale
# quest=True: per-group Gaussian-based scale = gaussian_scale * std
# quest=False: per-group max-abs-based scale, adjusted to FP4 range
if quest:
mean_squared = tl.sum(x_had_grouped * x_had_grouped, axis=-1, keep_dims=True) / group_size
mean = tl.sum(x_had_grouped, axis=-1, keep_dims=True) / group_size
std = tl.sqrt(mean_squared - mean * mean)
scales = gaussian_scale * std + 1e-8
shared_exps = tl.exp2(tl.floor(tl.log2(scales)))
x_had_scaled = x_had_grouped / shared_exps
else:
scales = tl.max(tl.abs(x_had_grouped), axis=-1, keep_dims=True)
shared_exps = tl.exp2(tl.floor(tl.log2(scales)) - 2) / (3 / 4)
x_had_scaled = x_had_grouped / shared_exps

# quantize
# Map abs(x) to FP4 levels {0, 0.5, 1, 1.5, 2, 3, 4, 6}
x_had_scaled_abs = tl.abs(x_had_scaled)
x_had_scaled_sign = tl.where(
x_had_scaled > 0,
1,
-1,
)

x_fp4 = (
tl.where(
x_had_scaled_abs > 5,
6,
tl.where(
x_had_scaled_abs > 3.5,
4,
tl.where(
x_had_scaled_abs > 2.5,
3,
tl.where(
x_had_scaled_abs > 1.75,
2,
tl.where(
x_had_scaled_abs > 1.25,
1.5,
tl.where(
x_had_scaled_abs > 0.75,
1,
tl.where(
x_had_scaled_abs > 0.25,
0.5,
0,
),
),
),
),
),
),
)
* x_had_scaled_sign
)
if clip_mask_ptr is not None:
tl.store(
clip_mask_ptr + offsets,
tl.reshape(x_had_scaled_abs < 6, (BLOCK_SIZE,)),
mask=mask,
)

# dequantize
x_dequantized = x_fp4 * shared_exps

# Reshape back to flat form for storage
x_dequantized_flat = tl.reshape(x_dequantized, (BLOCK_SIZE,))

# store
tl.store(output_ptr + offsets, x_dequantized_flat, mask=mask)


@torch.compiler.disable()
def mxfp4_forward_kernel_wrapper(
x,
hadamard_matrix,
return_clip_mask=False,
quest=False,
gaussian_scale=3 / 4,
):
"""
Apply Hadamard transform + group-wise FP4 quantize/dequantize on x.

Note:
The output is still in the Hadamard-transformed space (no inverse Hadamard is applied).
"""
# Pick a device — we require CUDA
device = x.device
if not device.type == "cuda":
# Either move to cuda or raise, depending on your design
device = torch.device("cuda")
x = x.to(device)

# Ensure hadamard_matrix is on the same CUDA device
if hadamard_matrix.device != device:
hadamard_matrix = hadamard_matrix.to(device)

# Make sure inputs are contiguous
x = x.contiguous()
hadamard_matrix = hadamard_matrix.contiguous()

# Create output tensors on CUDA
Comment on lines +163 to +164
Copy link

Copilot AI Jan 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using lambda for the grid calculation can cause issues with serialization and debugging. Consider using a regular function definition instead.

Suggested change
# Create output tensors on CUDA
def grid(meta):
return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)

Copilot uses AI. Check for mistakes.
output = torch.empty_like(x, device=device)
if return_clip_mask:
clip_mask = torch.empty_like(x, dtype=torch.bool, device=device).contiguous()
else:
clip_mask = None

# Get total number of elements and calculate grid for launching the kernel
n_elements = x.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)

# Launch kernel – no need for `with torch.device(...)`
mxfp4_forward_kernel[grid](
x_ptr=x,
hadamard_matrix_ptr=hadamard_matrix,
output_ptr=output,
clip_mask_ptr=clip_mask,
n_elements=n_elements,
hadamard_dim=hadamard_matrix.shape[-1],
group_size=32,
gaussian_scale=gaussian_scale,
quest=quest,
)

return output, clip_mask
5 changes: 5 additions & 0 deletions auto_round/export/export_to_autoround/export_to_nvfp_mxfp.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,11 @@ def pack_layer(name, model, backend, device=None):
## no zeros to handle, as mxfp/nvfp do not support asym quantization
# zero = layer.zp
qlayer.pack(layer, scale, global_scale=global_scale, input_global_scale=input_global_scale, device=device)

transform_matrix = getattr(layer, "forward_hadamard_matrix", None)
if transform_matrix is not None:
qlayer.register_buffer("forward_hadamard_matrix", transform_matrix)

qlayer.to(orig_device)
set_module(model, name, qlayer)
# Note: release weight and bias explicitly, in case they are referenced elsewhere
Expand Down
3 changes: 3 additions & 0 deletions auto_round/inference/convert_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,8 @@ def get_layer_config(model, quantization_config):
act_data_type = getattr(quantization_config, "act_data_type", None)
act_dynamic = getattr(quantization_config, "act_dynamic", False)

transform_config = getattr(quantization_config, "transform_config", None)

default_quant_scheme = QuantizationScheme(
bits=bits,
group_size=group_size,
Expand All @@ -219,6 +221,7 @@ def get_layer_config(model, quantization_config):
act_sym=act_sym,
act_data_type=act_data_type,
act_dynamic=act_dynamic,
transform_config=transform_config,
)

# Determine the quantization block list
Expand Down
1 change: 1 addition & 0 deletions auto_round/schemes.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class QuantizationScheme:
act_dynamic: Optional[bool] = None
super_bits: Optional[int] = None
super_group_size: Optional[int] = None
transform_config: Optional[dict] = None

@classmethod
def from_dict(cls, config: dict):
Expand Down
Loading