From 81399284a648223f75e02370701257486212a1c4 Mon Sep 17 00:00:00 2001 From: Xin He Date: Wed, 4 Feb 2026 08:04:05 +0000 Subject: [PATCH 1/2] support gpt-oss mxfp4 directly loading --- auto_round/utils/model.py | 36 +++++++++++++++++++++++++ test/test_cpu/models/test_moe_model.py | 2 +- test/test_cuda/models/test_moe_model.py | 2 +- 3 files changed, 38 insertions(+), 2 deletions(-) diff --git a/auto_round/utils/model.py b/auto_round/utils/model.py index fe416cd97..8a6bdcea8 100644 --- a/auto_round/utils/model.py +++ b/auto_round/utils/model.py @@ -248,6 +248,27 @@ def _check_accelerate_version(): ) +def _is_mxfp4_model(model_path: str) -> bool: + """Check if the model is quantized with MXFP4.""" + supported_model_types = ["gpt_oss"] + try: + from transformers import AutoConfig + + config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + quant_config = getattr(config, "quantization_config", None) + if quant_config is None: + return False + quant_method = ( + quant_config.get("quant_method", "") + if isinstance(quant_config, dict) + else getattr(quant_config, "quant_method", "") + ) + model_type = getattr(config, "model_type", "") + return quant_method == "mxfp4" and model_type in supported_model_types + except Exception: + return False + + def llm_load_model( pretrained_model_name_or_path: str, platform: str = "hf", @@ -289,12 +310,24 @@ def llm_load_model( if "deepseek" in pretrained_model_name_or_path.lower() and trust_remote_code: logger.warning("trust_remote_code is enabled by default, please ensure its correctness.") + # Check if model is MXFP4 quantized and needs dequantization + quantization_config = None + if _is_mxfp4_model(pretrained_model_name_or_path): + try: + from transformers import Mxfp4Config + + quantization_config = Mxfp4Config(dequantized=True) + logger.info("Detected MXFP4 quantized model, using Mxfp4Config(dequantized=True) for loading.") + except ImportError: + logger.warning("Mxfp4Config not available in current transformers version, loading without dequantization.") + if _use_hpu_compile_mode(): model = model_cls.from_pretrained( pretrained_model_name_or_path, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code, device_map="auto" if use_auto_mapping else None, + quantization_config=quantization_config, ) else: try: @@ -303,6 +336,7 @@ def llm_load_model( torch_dtype=torch_dtype, trust_remote_code=trust_remote_code, device_map="auto" if use_auto_mapping else None, + quantization_config=quantization_config, ) except ValueError as e: if "FP8 quantized" in str(e): @@ -312,6 +346,7 @@ def llm_load_model( torch_dtype=torch_dtype, trust_remote_code=trust_remote_code, device_map="auto" if use_auto_mapping else None, + quantization_config=quantization_config, ) logger.warning("the support for fp8 model as input is experimental, please use with caution.") else: @@ -324,6 +359,7 @@ def llm_load_model( torch_dtype=torch_dtype, trust_remote_code=False, device_map="auto" if use_auto_mapping else None, + quantization_config=quantization_config, ) model = model.eval() diff --git a/test/test_cpu/models/test_moe_model.py b/test/test_cpu/models/test_moe_model.py index 1125ab56d..ffadaf4f9 100644 --- a/test/test_cpu/models/test_moe_model.py +++ b/test/test_cpu/models/test_moe_model.py @@ -10,7 +10,7 @@ from ...helpers import get_model_path, transformers_version -gpt_oss_name_or_path = get_model_path("unsloth/gpt-oss-20b-BF16") +gpt_oss_name_or_path = get_model_path("openai/gpt-oss-20b") llama4_name_or_path = get_model_path("meta-llama/Llama-4-Scout-17B-16E-Instruct") qwen3_vl_moe_name_or_path = get_model_path("Qwen/Qwen3-VL-30B-A3B-Instruct") # local path for debug diff --git a/test/test_cuda/models/test_moe_model.py b/test/test_cuda/models/test_moe_model.py index 40c545015..2ef8b5a3d 100644 --- a/test/test_cuda/models/test_moe_model.py +++ b/test/test_cuda/models/test_moe_model.py @@ -15,7 +15,7 @@ @pytest.fixture def setup_gpt_oss(): """Fixture to set up the GPT-OSS model and tokenizer.""" - model_name = "/models/gpt-oss-20b-BF16" + model_name = "openai/gpt-oss-20b" tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) config.num_hidden_layers = 1 # Reduce layers for testing From 65ae337b25488b9838ee60da945b9f744a5e165d Mon Sep 17 00:00:00 2001 From: "He, Xin3" Date: Fri, 6 Feb 2026 11:10:08 +0800 Subject: [PATCH 2/2] update with load_kwargs Signed-off-by: He, Xin3 --- auto_round/utils/model.py | 44 +++++++++++++-------------------------- 1 file changed, 14 insertions(+), 30 deletions(-) diff --git a/auto_round/utils/model.py b/auto_round/utils/model.py index 8a6bdcea8..2f4f4d9dc 100644 --- a/auto_round/utils/model.py +++ b/auto_round/utils/model.py @@ -310,57 +310,41 @@ def llm_load_model( if "deepseek" in pretrained_model_name_or_path.lower() and trust_remote_code: logger.warning("trust_remote_code is enabled by default, please ensure its correctness.") + # Build common kwargs for from_pretrained + load_kwargs = { + "torch_dtype": torch_dtype, + "trust_remote_code": trust_remote_code, + "device_map": "auto" if use_auto_mapping else None, + } + # Check if model is MXFP4 quantized and needs dequantization - quantization_config = None + # Only set quantization_config when explicitly needed, to avoid overriding model's built-in config if _is_mxfp4_model(pretrained_model_name_or_path): try: from transformers import Mxfp4Config - quantization_config = Mxfp4Config(dequantized=True) + load_kwargs["quantization_config"] = Mxfp4Config(dequantized=True) logger.info("Detected MXFP4 quantized model, using Mxfp4Config(dequantized=True) for loading.") except ImportError: logger.warning("Mxfp4Config not available in current transformers version, loading without dequantization.") if _use_hpu_compile_mode(): - model = model_cls.from_pretrained( - pretrained_model_name_or_path, - torch_dtype=torch_dtype, - trust_remote_code=trust_remote_code, - device_map="auto" if use_auto_mapping else None, - quantization_config=quantization_config, - ) + model = model_cls.from_pretrained(pretrained_model_name_or_path, **load_kwargs) else: try: - model = model_cls.from_pretrained( - pretrained_model_name_or_path, - torch_dtype=torch_dtype, - trust_remote_code=trust_remote_code, - device_map="auto" if use_auto_mapping else None, - quantization_config=quantization_config, - ) + model = model_cls.from_pretrained(pretrained_model_name_or_path, **load_kwargs) except ValueError as e: if "FP8 quantized" in str(e): with override_cuda_device_capability(): - model = model_cls.from_pretrained( - pretrained_model_name_or_path, - torch_dtype=torch_dtype, - trust_remote_code=trust_remote_code, - device_map="auto" if use_auto_mapping else None, - quantization_config=quantization_config, - ) + model = model_cls.from_pretrained(pretrained_model_name_or_path, **load_kwargs) logger.warning("the support for fp8 model as input is experimental, please use with caution.") else: raise except OSError as e: logger.warning(f"fail to load {pretrained_model_name_or_path}, set trust_remote_code to False and retry.") - model = model_cls.from_pretrained( - pretrained_model_name_or_path, - torch_dtype=torch_dtype, - trust_remote_code=False, - device_map="auto" if use_auto_mapping else None, - quantization_config=quantization_config, - ) + load_kwargs["trust_remote_code"] = False + model = model_cls.from_pretrained(pretrained_model_name_or_path, **load_kwargs) model = model.eval() check_and_mark_quantized_module(model)