diff --git a/docs/source/brevitas/usage_guide.mdx b/docs/source/brevitas/usage_guide.mdx index 1c105e6b..bd46bae7 100644 --- a/docs/source/brevitas/usage_guide.mdx +++ b/docs/source/brevitas/usage_guide.mdx @@ -74,14 +74,10 @@ Brevitas models can be exported to ONNX using Optimum: ```python import torch -from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager -from brevitas_examples.llm.llm_quant.export import brevitas_proxy_export_mode +from optimum.amd.brevitas.export import onnx_export_from_quantized_model # Export to ONNX through optimum.exporters. -with torch.no_grad(), brevitas_proxy_export_mode(model, export_manager=StdQCDQONNXManager): - onnx_export_from_model( - model, "llm_quantized_onnx", task="text-generation-with-past", do_validation=False, no_post_process=True - ) +onnx_export_from_quantized_model(model, "llm_quantized_onnx") ``` ## Complete example diff --git a/examples/quantization/brevitas/quantize_llm.py b/examples/quantization/brevitas/quantize_llm.py index 7f4ebcf1..e61fe8a5 100644 --- a/examples/quantization/brevitas/quantize_llm.py +++ b/examples/quantization/brevitas/quantize_llm.py @@ -1,13 +1,9 @@ from argparse import ArgumentParser -import torch -from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager -from brevitas_examples.llm.llm_quant.export import brevitas_proxy_export_mode - from optimum.amd import BrevitasQuantizationConfig, BrevitasQuantizer from optimum.amd.brevitas.accelerate_utils import calc_cpu_device_map, calc_gpu_device_map, offload_model, remove_hooks from optimum.amd.brevitas.data_utils import compute_perplexity, get_dataset_for_model -from optimum.exporters.onnx import onnx_export_from_model +from optimum.amd.brevitas.export import onnx_export_from_quantized_model from transformers import AutoTokenizer @@ -80,16 +76,7 @@ def main(args): quantized_model = quantized_model.to("cpu") # Export to ONNX through optimum.exporters. - export_manager = StdQCDQONNXManager - export_manager.change_weight_export(export_weight_q_node=True) - with torch.no_grad(), brevitas_proxy_export_mode(quantized_model, export_manager=export_manager): - onnx_export_from_model( - quantized_model, - args.onnx_output_path, - task="text-generation-with-past", - do_validation=False, - no_post_process=True, - ) + onnx_export_from_quantized_model(quantized_model, args.onnx_output_path) return return_val diff --git a/optimum/amd/brevitas/export.py b/optimum/amd/brevitas/export.py new file mode 100644 index 00000000..50d1fe2a --- /dev/null +++ b/optimum/amd/brevitas/export.py @@ -0,0 +1,51 @@ +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Union + +import torch +from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager +from brevitas_examples.llm.llm_quant.export import brevitas_proxy_export_mode + +from optimum.exporters.onnx import onnx_export_from_model +from optimum.exporters.onnx.base import OnnxConfig +from transformers.modeling_utils import PreTrainedModel + + +def onnx_export_from_quantized_model( + quantized_model: Union["PreTrainedModel"], + output: Union[str, Path], + opset: Optional[int] = None, + optimize: Optional[str] = None, + monolith: bool = False, + model_kwargs: Optional[Dict[str, Any]] = None, + custom_onnx_configs: Optional[Dict[str, "OnnxConfig"]] = None, + fn_get_submodels: Optional[Callable] = None, + _variant: str = "default", + preprocessors: List = None, + device: str = "cpu", + no_dynamic_axes: bool = False, + task: str = "text-generation-with-past", + use_subprocess: bool = False, + do_constant_folding: bool = True, + **kwargs_shapes, +): + with torch.no_grad(), brevitas_proxy_export_mode(quantized_model, export_manager=StdQCDQONNXManager): + onnx_export_from_model( + quantized_model, + output, + opset=opset, + monolith=monolith, + optimize=optimize, + model_kwargs=model_kwargs, + custom_onnx_configs=custom_onnx_configs, + fn_get_submodels=fn_get_submodels, + _variant=_variant, + preprocessors=preprocessors, + device=device, + no_dynamic_axes=no_dynamic_axes, + use_subprocess=use_subprocess, + do_constant_folding=do_constant_folding, + task=task, + do_validation=False, + no_post_process=True, + **kwargs_shapes, + ) diff --git a/setup.py b/setup.py index 5f87421a..fc5427ab 100644 --- a/setup.py +++ b/setup.py @@ -34,7 +34,7 @@ EXTRAS_REQUIRE = { "quality": QUALITY_REQUIRE, "tests": TESTS_REQUIRE, - "brevitas": ["brevitas", "datasets>=2.17", "onnx", "onnxruntime", "accelerate"], + "brevitas": ["brevitas", "torch>=2.2", "datasets>=2.17", "onnx", "onnxruntime", "accelerate"], } setup(