Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
74d4f4e
fix hadamard transform weight dtype, using float64 as default.
lkk12014402 Apr 7, 2026
aa06e43
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 7, 2026
928b155
float32 maybe enough for hadamard transform.
lkk12014402 Apr 7, 2026
c67b95d
in-place weight when auto-round tuning.
lkk12014402 Apr 8, 2026
4700eb2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 8, 2026
43ff2c6
support nvfp4.
lkk12014402 Apr 9, 2026
36d314d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 9, 2026
c558eff
fix import issue.
lkk12014402 Apr 9, 2026
037aad6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 9, 2026
6d69b0e
fix typo.
lkk12014402 Apr 9, 2026
95436ed
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 9, 2026
65092ff
enhance the function `normalize_hadamard_config`
lkk12014402 Apr 9, 2026
506c595
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 9, 2026
47c5aa2
support more scheme format and remove useless custom quantlinear for …
lkk12014402 Apr 13, 2026
ac10fb0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 13, 2026
8aef5d6
fix multi-cards issue when using random_hadamard.
lkk12014402 Apr 13, 2026
e2370a6
Merge branch 'main' into fix_hadamard_transform_dtype
lkk12014402 Apr 13, 2026
5cd659a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 13, 2026
54273c3
provide a notification when applying transform
lkk12014402 Apr 14, 2026
a67ccb7
remove duplicate args.
lkk12014402 Apr 14, 2026
847914a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 14, 2026
80d5e9e
Merge branch 'main' into fix_hadamard_transform_dtype
lkk12014402 Apr 15, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 5 additions & 8 deletions auto_round/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
_handle_special_schemes,
get_gguf_scheme,
preset_name_to_scheme,
scheme_to_preset_name,
)
from auto_round.sign_sgd import SignSGD
from auto_round.special_model_handler import get_predefined_ignore_layers, update_module
Expand Down Expand Up @@ -560,16 +561,12 @@ def __init__(

# apply hadamard transform
if hadamard_config:
Comment thread
lkk12014402 marked this conversation as resolved.
logger.info("Applying Hadamard transform to the model.")
from auto_round.experimental.transform.apply import apply_hadamard_transform
from auto_round.experimental.utils import check_supported_schemes, normalize_hadamard_config
from auto_round.experimental.utils import normalize_hadamard_config

check_supported_schemes(self.scheme)

self.model = apply_hadamard_transform(
self.model, hadamard_config, need_calibration=True if self.iters > 0 else False
)

self.hadamard_config = normalize_hadamard_config(hadamard_config)
self.hadamard_config = normalize_hadamard_config(hadamard_config, self.data_type)
self.model = apply_hadamard_transform(self.model, self.hadamard_config, data_type=self.data_type)

def _gen_auto_scheme(self) -> dict[str, dict]:
if self.mllm:
Expand Down
3 changes: 2 additions & 1 deletion auto_round/experimental/qmodules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.


from auto_round.experimental.qmodules.mx import (
MXFP4QuantLinear,
MXFP8QuantLinear,
MXINT4QuantLinear,
HadamardMXFP4QuantLinear,
)

from auto_round.experimental.qmodules.nvfp4 import NVFP4QuantLinear
from auto_round.experimental.qmodules.fp8_static import WeightFP8ActFP8StaticQuantLinear
32 changes: 13 additions & 19 deletions auto_round/experimental/qmodules/mx.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,19 @@ def __init__(
)
self.register_buffer("weight_scale", init_weight_scale)

hadamard_config = getattr(config, "hadamard_config", None)
# TODO: remove the limit: hadamard_config["hadamard_type"] == "random_hadamard"
if hadamard_config is not None and hadamard_config["hadamard_type"] == "random_hadamard":
self.enable_transform = True
self.register_buffer(
"hadamard_matrix",
torch.empty(
self.group_size,
self.group_size,
dtype=self.dtype,
),
)

def initialize_weights(self, weight: Optional[torch.Tensor]) -> torch.Tensor:
"""
Initialize weights. This method should be overridden by subclasses.
Expand Down Expand Up @@ -238,25 +251,6 @@ def from_original(cls, config: Optional[QuantizationScheme], original_layer: tor
return qdq_linear


class HadamardMXFP4QuantLinear(MXFP4QuantLinear):
"""
Quantized linear layer using the MXFP4 quantization scheme.
"""

def __init__(self, *args, **kwargs):
self.weight_name = "weight_packed"
super().__init__(*args, **kwargs)
self.enable_transform = True
self.register_buffer(
"hadamard_matrix",
torch.empty(
self.group_size,
self.group_size,
dtype=self.dtype,
),
)


class MXFP8QuantLinear(MXQuantLinearBase):
"""
Quantized linear layer using the MXFP8 quantization scheme.
Expand Down
13 changes: 13 additions & 0 deletions auto_round/experimental/qmodules/nvfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,19 @@ def __init__(
),
)

hadamard_config = getattr(config, "hadamard_config", None)
# TODO: remove the limit: hadamard_config["hadamard_type"] == "random_hadamard"
if hadamard_config is not None and hadamard_config["hadamard_type"] == "random_hadamard":
self.enable_transform = True
self.register_buffer(
"hadamard_matrix",
torch.empty(
self.group_size,
self.group_size,
dtype=self.dtype,
),
)

@staticmethod
def _convert_global_scale_to_float32(state_dict: dict[str, torch.Tensor], name: str):
if name not in state_dict or state_dict[name].dtype == torch.float32:
Expand Down
89 changes: 42 additions & 47 deletions auto_round/experimental/transform/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
import tqdm

from auto_round.experimental.qmodules.mx import MXQuantLinearBase
from auto_round.experimental.qmodules.base import QModuleBase
from auto_round.experimental.transform.hadamard_config import HadamardConfig
from auto_round.experimental.transform.hadamards import build_hadamard_transform
from auto_round.experimental.utils import is_triton_kernel_available, normalize_hadamard_config
Expand All @@ -15,10 +15,10 @@
def apply_hadamard_transform(
model: torch.nn.Module,
config: str | dict | HadamardConfig | None,
need_calibration: bool = False,
location: str = "weight",
use_tqdm=True,
desc=None,
data_type="mx_fp",
):
"""
Apply a transform configuration to a model.
Expand All @@ -29,7 +29,6 @@ def apply_hadamard_transform(
:param model: Model to which the transform configuration will be applied.
:param config: Transform configuration to apply. Supported values are:
* ``str``: A named/preset transform configuration. In this case,
``scheme`` is typically required so that the preset can be
resolved to a concrete quantization/transform configuration.
* ``dict``: A raw configuration mapping that will be normalized
(via :func:`normalize_hadamard_config`) and then passed to
Expand All @@ -39,35 +38,31 @@ def apply_hadamard_transform(
normalization.
* ``None``: Uses the default behavior of
:func:`_normalize_hadamard_config` (for example, inferring a
configuration from ``scheme`` or other project defaults), if
configuration from ``data_type`` or other project defaults), if
supported.
:param scheme: Optional quantization/transform scheme identifier used
when ``config`` is a ``str`` (and, if supported, when it is
``None``) to determine which concrete configuration to build.
Ignored when ``config`` is already a ``dict`` or
:class:`TransformConfig`.
:param data_type: quantization data type.
:param use_tqdm: If ``True``, wrap the per-module application in a
tqdm progress bar.
:param desc: Optional description string to show in the tqdm progress
bar. If ``None``, a description will be derived from
``config.transform_type``.
"""

config = normalize_hadamard_config(config)
config = normalize_hadamard_config(config, data_type)
if not isinstance(config, HadamardConfig):
config = HadamardConfig(**config)

modules_config = [
(name, module, config)
for name, module in model.named_modules()
if isinstance(module, torch.nn.Linear) or isinstance(module, MXQuantLinearBase)
if isinstance(module, torch.nn.Linear) or isinstance(module, QModuleBase)
]

desc = f"Applying {config.hadamard_type} transforms" if desc is None else desc
for name, module, config in tqdm.tqdm(modules_config, desc=desc, disable=(not use_tqdm)):
if "lm_head" in name:
continue
_apply_to_module(model, module, config, need_calibration, location)
_apply_to_module(model, module, config, location, data_type)

# attach config to model for compression/serialization
setattr(model, "hadamard_config", config)
Expand All @@ -79,8 +74,8 @@ def _apply_to_module(
model: torch.nn.Module,
module: torch.nn.Module,
config: HadamardConfig,
need_calibration: bool = False,
location: str = "weight",
data_type: str = "mx_fp",
):
"""
Create transforms and apply them to the module
Expand All @@ -96,33 +91,36 @@ def _apply_to_module(

# activation needs transpose
input_hadamard_transform = build_hadamard_transform(
**config.dict(),
**config.model_dump(),
location="input",
inverse=True,
device="cpu",
precision=module.dtype,
precision=module.dtype, # for online activation, the transform dtype maybe bfloat16/float16.
)

if config.hadamard_type != "random_hadamard":
hadamard_weight = input_hadamard_transform.weight
else:
hadamard_weight = None

if is_triton_kernel_available():
if is_triton_kernel_available(data_type):
from auto_round.experimental.transform.triton.mxfp4 import mxfp4_forward_kernel_wrapper

def input_hook(self, args):
input = args[0]
# transform(input)
orig_shape = input.shape
orig_dtype = input.dtype
x_flat = input.contiguous().flatten(end_dim=-2)
qdq_input, _ = mxfp4_forward_kernel_wrapper(
x_flat,
(
hadamard_weight if hadamard_weight is not None else self.hadamard_matrix.T
hadamard_weight.to(orig_dtype)
if hadamard_weight is not None
else self.hadamard_matrix.T.to(orig_dtype)
), # this matrix from w_transform, needs transpose
)
return qdq_input.reshape(orig_shape)
return qdq_input.reshape(orig_shape).to(orig_dtype)

# for fused transform + quantization kernel
module.pre_dequantized_input = True
Expand All @@ -135,13 +133,20 @@ def input_hook(self, args):
input = args[0]

ori_shape = input.shape
orig_dtype = input.dtype

if hadamard_weight is not None:
input = input.view(-1, hadamard_weight.shape[0])
return _multihead_matmul(input, hadamard_weight.to(input.device)).view(ori_shape)
return (
(_multihead_matmul(input, hadamard_weight.to(input.device).to(orig_dtype)))
.view(ori_shape)
.to(orig_dtype)
)
else:
Comment thread
lkk12014402 marked this conversation as resolved.
input = input.view(-1, self.hadamard_matrix.shape[0])
return _multihead_matmul(input, self.hadamard_matrix.T).view(ori_shape)
return (
(_multihead_matmul(input, self.hadamard_matrix.T.to(orig_dtype))).view(ori_shape).to(orig_dtype)
)

# for fused transform + quantization kernel
module.pre_dequantized_input = False
Expand All @@ -153,45 +158,35 @@ def input_hook(self, args):
assert hasattr(module, "weight")

weight_hadamard_transform = build_hadamard_transform(
**config.dict(),
**config.model_dump(),
location="weight",
device=module.weight.device,
precision=module.weight.dtype,
)

# need save random hadamard matrix needed when inference
if config.hadamard_type == "random_hadamard":
module.register_module(config.hadamard_type, weight_hadamard_transform)
# for saving transform weight
from auto_round.experimental.transform.patch_modules import patch_quantlinear

patch_quantlinear(config.hadamard_type)
patch_quantlinear(weight_hadamard_transform)

if need_calibration:
# for training, the weight changes with every forward pass
# for autoround tuning: patch wrapper linear qdq_weight func
from auto_round.experimental.transform.patch_modules import (
patch_wrapperlinear_to_apply_transform,
patch_wrapperwalayer_forward_to_apply_transform,
)

input_hadamard_transform = build_hadamard_transform(
**config.dict(),
location="input",
inverse=True,
device=module.weight.device,
precision=module.weight.dtype,
)
# for autoround tuning: weight not tuning
# for rtn: weight transformed before saving
from auto_round.experimental.transform.patch_modules import (
patch_wrapperlinear_to_apply_transform,
patch_wrapperwalayer_forward_to_apply_transform,
)

patch_wrapperlinear_to_apply_transform(weight_hadamard_transform, input_hadamard_transform)
patch_wrapperwalayer_forward_to_apply_transform(input_hadamard_transform)
input_hadamard_transform = build_hadamard_transform(
**config.model_dump(),
location="input",
inverse=True,
device=module.weight.device,
precision=module.weight.dtype, # for online activation, the transform dtype maybe bfloat16/float16.
)

else:
# transform is no longer needed (unfusing is not supported)
# delattr(module, transform_name)
# fuse transform into weight
with torch.no_grad():
getattr(module, "weight").copy_(weight_hadamard_transform(module.weight).to(module.weight.device))
patch_wrapperlinear_to_apply_transform(weight_hadamard_transform, input_hadamard_transform)
patch_wrapperwalayer_forward_to_apply_transform(input_hadamard_transform)

else:
# TODO: apply transform to output/q/k
Expand Down
36 changes: 31 additions & 5 deletions auto_round/experimental/transform/hadamards.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,25 @@ def __init__(
self,
block_size: int = 32,
device: torch.device = None,
precision: torch.dtype = None,
precision: torch.dtype = torch.float32,
Comment thread
lkk12014402 marked this conversation as resolved.
location: str = "weight",
module_type: type[torch.nn.Module] = torch.nn.Linear,
inverse: bool = False,
):
"""Initialize a Hadamard transform module.

Args:
block_size: Size of each Hadamard block. The input tensor is reshaped
to ``(-1, block_size)`` before applying the transform.
device: Device on which to create the Hadamard matrix.
precision: Data type used for the Hadamard matrix weights, using float32 as default.
location: Target location used by ``apply_transform_weight`` when
applying the transform.
module_type: Module type associated with the transform application,
typically ``torch.nn.Linear``.
inverse: Whether to build the inverse form of the transform.
"""

super().__init__()
self.size = block_size
self.scale = 1 / math.sqrt(self.size)
Expand All @@ -51,7 +65,7 @@ def _create_weight(
self,
size: int,
device: torch.device = None,
precision: torch.dtype = None,
precision: torch.dtype = torch.float32,
) -> torch.nn.Parameter:
data = deterministic_hadamard_matrix(size, precision, device) * self.scale
# TODO: implement SpinQuant, which rotation matrix is learnable
Expand All @@ -78,18 +92,30 @@ def forward(self, x: torch.Tensor):
class RandomHadamardTransform(HadamardTransform):
def __init__(
self,
*args,
block_size: int = 32,
device: torch.device = None,
precision: torch.dtype = None,
location: str = "weight",
module_type: type[torch.nn.Module] = torch.nn.Linear,
inverse: bool = False,
seed: int | None = None,
generator: torch.Generator | None = None,
**kwargs,
):
if generator is not None:
self.generator = generator
else:
self.generator = torch.Generator()
if seed is not None:
self.generator.manual_seed(seed)
super().__init__(*args, **kwargs)

super().__init__(
block_size=block_size,
device=device,
precision=precision,
location=location,
module_type=module_type,
inverse=inverse,
)

def _create_weight(
self,
Expand Down
Loading
Loading