Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions auto_round/export/export_to_autoround/export_to_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def save_quantized_as_autoround(
tokenizer: Callable = None,
layer_config: dict = None,
inplace: bool = True,
backend: str = None,
device: Union[str, torch.device] = "cpu",
serialization_dict: dict = None,
**kwargs,
Expand All @@ -165,6 +166,8 @@ def save_quantized_as_autoround(
quantization_config = serialization_dict
quantization_config["block_name_to_quantize"] = quantization_config.pop("to_quant_block_names", None)
quantization_config["quant_method"] = "auto-round"
if backend:
quantization_config["packing_format"] = backend
if "e5m2" in serialization_dict.get("data_type", "fp8"):
quantization_config["fmt"] = "e5m2"
else:
Expand Down
2 changes: 1 addition & 1 deletion auto_round/formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -1086,7 +1086,7 @@ def save_quantized(
elif serialization_dict.get("data_type", "int") == "fp" and serialization_dict.get("bits", 16) == 8:
from auto_round.export.export_to_autoround.export_to_fp8 import save_quantized_as_autoround

backend = "auto_round"
backend = "auto_round:fp8_static" if serialization_dict.get("act_bits", 16) == 8 else None
export_func = save_quantized_as_autoround
else:
from auto_round.export.export_to_autoround.export import save_quantized_as_autoround
Expand Down
27 changes: 19 additions & 8 deletions test/test_cuda/schemes/test_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import shutil

import pytest
import transformers
from packaging import version

from auto_round import AutoRound
Expand Down Expand Up @@ -36,20 +37,26 @@ def test_gguf(self, tiny_qwen_model_path):
def test_w4a16(self, tiny_opt_model_path):
ar = AutoRound(tiny_opt_model_path, scheme="W4A16", nsamples=1, iters=1)
assert ar.bits == 4
ar.quantize()
ar.quantize_and_save()
model = transformers.AutoModelForCausalLM.from_pretrained("tmp_autoround", trust_remote_code=True)
assert model is not None, "Model loading failed after quantization with W4A16 scheme"

def test_w2a16(self, tiny_opt_model_path):
ar = AutoRound(tiny_opt_model_path, scheme="W2A16", nsamples=1, iters=1)
assert ar.bits == 2
ar.quantize()
ar.quantize_and_save()
model = transformers.AutoModelForCausalLM.from_pretrained("tmp_autoround", trust_remote_code=True)
assert model is not None, "Model loading failed after quantization with W2A16 scheme"

def test_mxfp4(self, tiny_opt_model_path):
ar = AutoRound(tiny_opt_model_path, scheme="MXFP4_RCEIL", nsamples=1, iters=1)
assert ar.bits == 4
assert ar.act_bits == 4
assert ar.data_type == "mx_fp"
assert ar.act_data_type == "mx_fp_rceil"
ar.quantize()
ar.quantize_and_save()
model = transformers.AutoModelForCausalLM.from_pretrained("tmp_autoround", trust_remote_code=True)
assert model is not None, "Model loading failed after quantization with MXFP4 scheme"

def test_fp8_static(self, tiny_opt_model_path):
ar = AutoRound(tiny_opt_model_path, scheme="FP8_STATIC", nsamples=1, iters=1)
Expand All @@ -59,21 +66,23 @@ def test_fp8_static(self, tiny_opt_model_path):
assert ar.act_data_type == "fp"
assert ar.group_size == -1
assert ar.act_dynamic is False
ar.quantize()
ar.quantize_and_save()
model = transformers.AutoModelForCausalLM.from_pretrained("tmp_autoround", trust_remote_code=True)
assert model is not None, "Model loading failed after quantization with FP8_STATIC scheme"

## RTN tests
def test_w2a16_rtn(self, tiny_opt_model_path):
ar = AutoRound(tiny_opt_model_path, scheme="W2A16", nsamples=1, iters=0)
assert ar.bits == 2
ar.quantize()
ar.quantize_and_save()

def test_mxfp4_rtn(self, tiny_opt_model_path):
ar = AutoRound(tiny_opt_model_path, scheme="MXFP4", nsamples=1, iters=0)
assert ar.bits == 4
assert ar.act_bits == 4
assert ar.data_type == "mx_fp"
assert ar.act_data_type == "mx_fp"
ar.quantize()
ar.quantize_and_save()

def test_fp8_static_rtn(self, tiny_opt_model_path):
ar = AutoRound(tiny_opt_model_path, scheme="FP8_STATIC", nsamples=1, iters=0)
Expand All @@ -83,7 +92,7 @@ def test_fp8_static_rtn(self, tiny_opt_model_path):
assert ar.act_data_type == "fp"
assert ar.group_size == -1
assert ar.act_dynamic is False
ar.quantize()
ar.quantize_and_save()

def test_scheme_in_layer_config(self):
model_path = get_model_path("facebook/opt-125m")
Expand All @@ -94,7 +103,9 @@ def test_scheme_in_layer_config(self):
}
ar = AutoRound(model_path, scheme="W3A16", nsamples=1, iters=1, layer_config=layer_config)

ar.quantize()
ar.quantize_and_save()
model = transformers.AutoModelForCausalLM.from_pretrained("tmp_autoround", trust_remote_code=True)
assert model is not None, "Model loading failed after quantization with layer-specific schemes"
for n, m in ar.model.named_modules():
if n == "model.decoder.layers.2.self_attn.q_proj":
assert m.bits == 2
Expand Down
Loading