-
Notifications
You must be signed in to change notification settings - Fork 78
support gpt-oss mxfp4 directly loading #1401
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -248,6 +248,27 @@ def _check_accelerate_version(): | |
| ) | ||
|
|
||
|
|
||
| def _is_mxfp4_model(model_path: str) -> bool: | ||
| """Check if the model is quantized with MXFP4.""" | ||
| supported_model_types = ["gpt_oss"] | ||
| try: | ||
| from transformers import AutoConfig | ||
|
|
||
| config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) | ||
| quant_config = getattr(config, "quantization_config", None) | ||
| if quant_config is None: | ||
| return False | ||
| quant_method = ( | ||
| quant_config.get("quant_method", "") | ||
| if isinstance(quant_config, dict) | ||
| else getattr(quant_config, "quant_method", "") | ||
| ) | ||
| model_type = getattr(config, "model_type", "") | ||
| return quant_method == "mxfp4" and model_type in supported_model_types | ||
| except Exception: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why use try catch here?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed, should change with an efficient way |
||
| return False | ||
|
|
||
|
|
||
| def llm_load_model( | ||
| pretrained_model_name_or_path: str, | ||
| platform: str = "hf", | ||
|
|
@@ -289,42 +310,41 @@ def llm_load_model( | |
| if "deepseek" in pretrained_model_name_or_path.lower() and trust_remote_code: | ||
| logger.warning("trust_remote_code is enabled by default, please ensure its correctness.") | ||
|
|
||
| # Build common kwargs for from_pretrained | ||
| load_kwargs = { | ||
| "torch_dtype": torch_dtype, | ||
| "trust_remote_code": trust_remote_code, | ||
| "device_map": "auto" if use_auto_mapping else None, | ||
| } | ||
|
|
||
| # Check if model is MXFP4 quantized and needs dequantization | ||
| # Only set quantization_config when explicitly needed, to avoid overriding model's built-in config | ||
| if _is_mxfp4_model(pretrained_model_name_or_path): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have a small concern that this check might slow down the Auto‑round initialization.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed, thanks! |
||
| try: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In my opinion, I prefer to check version instead of try catch. Using too much try-catch blocks might prevent some bugs from being exposed. |
||
| from transformers import Mxfp4Config | ||
|
|
||
| load_kwargs["quantization_config"] = Mxfp4Config(dequantized=True) | ||
| logger.info("Detected MXFP4 quantized model, using Mxfp4Config(dequantized=True) for loading.") | ||
| except ImportError: | ||
| logger.warning("Mxfp4Config not available in current transformers version, loading without dequantization.") | ||
xin3he marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| if _use_hpu_compile_mode(): | ||
| model = model_cls.from_pretrained( | ||
| pretrained_model_name_or_path, | ||
| torch_dtype=torch_dtype, | ||
| trust_remote_code=trust_remote_code, | ||
| device_map="auto" if use_auto_mapping else None, | ||
| ) | ||
| model = model_cls.from_pretrained(pretrained_model_name_or_path, **load_kwargs) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We currently don’t have enough test coverage for HPU, so please make any changes carefully. If possible, adding more UTs would be really helpful! |
||
| else: | ||
| try: | ||
| model = model_cls.from_pretrained( | ||
| pretrained_model_name_or_path, | ||
| torch_dtype=torch_dtype, | ||
| trust_remote_code=trust_remote_code, | ||
| device_map="auto" if use_auto_mapping else None, | ||
| ) | ||
| model = model_cls.from_pretrained(pretrained_model_name_or_path, **load_kwargs) | ||
| except ValueError as e: | ||
| if "FP8 quantized" in str(e): | ||
| with override_cuda_device_capability(): | ||
| model = model_cls.from_pretrained( | ||
| pretrained_model_name_or_path, | ||
| torch_dtype=torch_dtype, | ||
| trust_remote_code=trust_remote_code, | ||
| device_map="auto" if use_auto_mapping else None, | ||
| ) | ||
| model = model_cls.from_pretrained(pretrained_model_name_or_path, **load_kwargs) | ||
| logger.warning("the support for fp8 model as input is experimental, please use with caution.") | ||
| else: | ||
| raise | ||
|
|
||
| except OSError as e: | ||
| logger.warning(f"fail to load {pretrained_model_name_or_path}, set trust_remote_code to False and retry.") | ||
| model = model_cls.from_pretrained( | ||
| pretrained_model_name_or_path, | ||
| torch_dtype=torch_dtype, | ||
| trust_remote_code=False, | ||
| device_map="auto" if use_auto_mapping else None, | ||
| ) | ||
| load_kwargs["trust_remote_code"] = False | ||
| model = model_cls.from_pretrained(pretrained_model_name_or_path, **load_kwargs) | ||
|
|
||
| model = model.eval() | ||
| check_and_mark_quantized_module(model) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,7 +15,7 @@ | |
| @pytest.fixture | ||
| def setup_gpt_oss(): | ||
| """Fixture to set up the GPT-OSS model and tokenizer.""" | ||
| model_name = "/models/gpt-oss-20b-BF16" | ||
| model_name = "openai/gpt-oss-20b" | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This path is currently used to load the BF16 gpt-oss model, so please keep it as is. |
||
| tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | ||
| config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) | ||
| config.num_hidden_layers = 1 # Reduce layers for testing | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.