Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 44 additions & 24 deletions auto_round/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,27 @@ def _check_accelerate_version():
)


def _is_mxfp4_model(model_path: str) -> bool:
"""Check if the model is quantized with MXFP4."""
supported_model_types = ["gpt_oss"]
try:
from transformers import AutoConfig

config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
quant_config = getattr(config, "quantization_config", None)
if quant_config is None:
return False
quant_method = (
quant_config.get("quant_method", "")
if isinstance(quant_config, dict)
else getattr(quant_config, "quant_method", "")
)
model_type = getattr(config, "model_type", "")
return quant_method == "mxfp4" and model_type in supported_model_types
except Exception:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why use try catch here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, should change with an efficient way

return False


def llm_load_model(
pretrained_model_name_or_path: str,
platform: str = "hf",
Expand Down Expand Up @@ -289,42 +310,41 @@ def llm_load_model(
if "deepseek" in pretrained_model_name_or_path.lower() and trust_remote_code:
logger.warning("trust_remote_code is enabled by default, please ensure its correctness.")

# Build common kwargs for from_pretrained
load_kwargs = {
"torch_dtype": torch_dtype,
"trust_remote_code": trust_remote_code,
"device_map": "auto" if use_auto_mapping else None,
}

# Check if model is MXFP4 quantized and needs dequantization
# Only set quantization_config when explicitly needed, to avoid overriding model's built-in config
if _is_mxfp4_model(pretrained_model_name_or_path):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a small concern that this check might slow down the Auto‑round initialization.
Could you please double‑check it? Thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, thanks!

try:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my opinion, I prefer to check version instead of try catch. Using too much try-catch blocks might prevent some bugs from being exposed.

from transformers import Mxfp4Config

load_kwargs["quantization_config"] = Mxfp4Config(dequantized=True)
logger.info("Detected MXFP4 quantized model, using Mxfp4Config(dequantized=True) for loading.")
except ImportError:
logger.warning("Mxfp4Config not available in current transformers version, loading without dequantization.")

if _use_hpu_compile_mode():
model = model_cls.from_pretrained(
pretrained_model_name_or_path,
torch_dtype=torch_dtype,
trust_remote_code=trust_remote_code,
device_map="auto" if use_auto_mapping else None,
)
model = model_cls.from_pretrained(pretrained_model_name_or_path, **load_kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We currently don’t have enough test coverage for HPU, so please make any changes carefully. If possible, adding more UTs would be really helpful!

else:
try:
model = model_cls.from_pretrained(
pretrained_model_name_or_path,
torch_dtype=torch_dtype,
trust_remote_code=trust_remote_code,
device_map="auto" if use_auto_mapping else None,
)
model = model_cls.from_pretrained(pretrained_model_name_or_path, **load_kwargs)
except ValueError as e:
if "FP8 quantized" in str(e):
with override_cuda_device_capability():
model = model_cls.from_pretrained(
pretrained_model_name_or_path,
torch_dtype=torch_dtype,
trust_remote_code=trust_remote_code,
device_map="auto" if use_auto_mapping else None,
)
model = model_cls.from_pretrained(pretrained_model_name_or_path, **load_kwargs)
logger.warning("the support for fp8 model as input is experimental, please use with caution.")
else:
raise

except OSError as e:
logger.warning(f"fail to load {pretrained_model_name_or_path}, set trust_remote_code to False and retry.")
model = model_cls.from_pretrained(
pretrained_model_name_or_path,
torch_dtype=torch_dtype,
trust_remote_code=False,
device_map="auto" if use_auto_mapping else None,
)
load_kwargs["trust_remote_code"] = False
model = model_cls.from_pretrained(pretrained_model_name_or_path, **load_kwargs)

model = model.eval()
check_and_mark_quantized_module(model)
Expand Down
2 changes: 1 addition & 1 deletion test/test_cpu/models/test_moe_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from ...helpers import get_model_path, transformers_version

gpt_oss_name_or_path = get_model_path("unsloth/gpt-oss-20b-BF16")
gpt_oss_name_or_path = get_model_path("openai/gpt-oss-20b")
llama4_name_or_path = get_model_path("meta-llama/Llama-4-Scout-17B-16E-Instruct")
qwen3_vl_moe_name_or_path = get_model_path("Qwen/Qwen3-VL-30B-A3B-Instruct")
# local path for debug
Expand Down
2 changes: 1 addition & 1 deletion test/test_cuda/models/test_moe_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
@pytest.fixture
def setup_gpt_oss():
"""Fixture to set up the GPT-OSS model and tokenizer."""
model_name = "/models/gpt-oss-20b-BF16"
model_name = "openai/gpt-oss-20b"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This path is currently used to load the BF16 gpt-oss model, so please keep it as is.
You can add a new path specifically for the MXFP4 model.

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
config.num_hidden_layers = 1 # Reduce layers for testing
Expand Down