From 38a9db7e0f8907201c4a65afa3ae07de468d8c85 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 14 Mar 2024 14:54:31 +0000 Subject: [PATCH 1/7] Feat: export dq only --- examples/quantization/brevitas/quantize_llm.py | 16 ++-------------- optimum/amd/brevitas/export.py | 15 +++++++++++++++ setup.py | 2 +- 3 files changed, 18 insertions(+), 15 deletions(-) create mode 100644 optimum/amd/brevitas/export.py diff --git a/examples/quantization/brevitas/quantize_llm.py b/examples/quantization/brevitas/quantize_llm.py index 7f4ebcf1..4276789d 100644 --- a/examples/quantization/brevitas/quantize_llm.py +++ b/examples/quantization/brevitas/quantize_llm.py @@ -1,13 +1,10 @@ from argparse import ArgumentParser -import torch -from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager -from brevitas_examples.llm.llm_quant.export import brevitas_proxy_export_mode +from optimum.amd.brevitas.export import export_quantized_model from optimum.amd import BrevitasQuantizationConfig, BrevitasQuantizer from optimum.amd.brevitas.accelerate_utils import calc_cpu_device_map, calc_gpu_device_map, offload_model, remove_hooks from optimum.amd.brevitas.data_utils import compute_perplexity, get_dataset_for_model -from optimum.exporters.onnx import onnx_export_from_model from transformers import AutoTokenizer @@ -80,16 +77,7 @@ def main(args): quantized_model = quantized_model.to("cpu") # Export to ONNX through optimum.exporters. - export_manager = StdQCDQONNXManager - export_manager.change_weight_export(export_weight_q_node=True) - with torch.no_grad(), brevitas_proxy_export_mode(quantized_model, export_manager=export_manager): - onnx_export_from_model( - quantized_model, - args.onnx_output_path, - task="text-generation-with-past", - do_validation=False, - no_post_process=True, - ) + export_quantized_model(quantized_model, args.onnx_output_path) return return_val diff --git a/optimum/amd/brevitas/export.py b/optimum/amd/brevitas/export.py new file mode 100644 index 00000000..507cb827 --- /dev/null +++ b/optimum/amd/brevitas/export.py @@ -0,0 +1,15 @@ + +import torch +from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager +from brevitas_examples.llm.llm_quant.export import brevitas_proxy_export_mode +from optimum.exporters.onnx import onnx_export_from_model + + +def export_quantized_model(quantized_model, path, task="text-generation-with-past"): + with torch.no_grad(), brevitas_proxy_export_mode(quantized_model, export_manager=StdQCDQONNXManager): + onnx_export_from_model( + quantized_model, + path, + task=task, + do_validation=False, + no_post_process=True) \ No newline at end of file diff --git a/setup.py b/setup.py index 5f87421a..fc5427ab 100644 --- a/setup.py +++ b/setup.py @@ -34,7 +34,7 @@ EXTRAS_REQUIRE = { "quality": QUALITY_REQUIRE, "tests": TESTS_REQUIRE, - "brevitas": ["brevitas", "datasets>=2.17", "onnx", "onnxruntime", "accelerate"], + "brevitas": ["brevitas", "torch>=2.2", "datasets>=2.17", "onnx", "onnxruntime", "accelerate"], } setup( From 624baa31ecd51bada2f9c1ba087a336fd9a7b4ff Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 14 Mar 2024 16:43:06 +0000 Subject: [PATCH 2/7] fix --- optimum/amd/brevitas/export.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/optimum/amd/brevitas/export.py b/optimum/amd/brevitas/export.py index 507cb827..ce71b887 100644 --- a/optimum/amd/brevitas/export.py +++ b/optimum/amd/brevitas/export.py @@ -1,15 +1,9 @@ - import torch from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager -from brevitas_examples.llm.llm_quant.export import brevitas_proxy_export_mode +from brevitas_examples.llm.llm_quant.export import brevitas_proxy_export_mode from optimum.exporters.onnx import onnx_export_from_model def export_quantized_model(quantized_model, path, task="text-generation-with-past"): with torch.no_grad(), brevitas_proxy_export_mode(quantized_model, export_manager=StdQCDQONNXManager): - onnx_export_from_model( - quantized_model, - path, - task=task, - do_validation=False, - no_post_process=True) \ No newline at end of file + onnx_export_from_model(quantized_model, path, task=task, do_validation=False, no_post_process=True) From 28965dfdc41ee2b2b10d2e8be31c991de937120a Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 14 Mar 2024 16:49:21 +0000 Subject: [PATCH 3/7] fix --- examples/quantization/brevitas/quantize_llm.py | 3 +-- optimum/amd/brevitas/export.py | 1 + 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/quantization/brevitas/quantize_llm.py b/examples/quantization/brevitas/quantize_llm.py index 4276789d..d2cec92a 100644 --- a/examples/quantization/brevitas/quantize_llm.py +++ b/examples/quantization/brevitas/quantize_llm.py @@ -1,10 +1,9 @@ from argparse import ArgumentParser - -from optimum.amd.brevitas.export import export_quantized_model from optimum.amd import BrevitasQuantizationConfig, BrevitasQuantizer from optimum.amd.brevitas.accelerate_utils import calc_cpu_device_map, calc_gpu_device_map, offload_model, remove_hooks from optimum.amd.brevitas.data_utils import compute_perplexity, get_dataset_for_model +from optimum.amd.brevitas.export import export_quantized_model from transformers import AutoTokenizer diff --git a/optimum/amd/brevitas/export.py b/optimum/amd/brevitas/export.py index ce71b887..e7e2de3f 100644 --- a/optimum/amd/brevitas/export.py +++ b/optimum/amd/brevitas/export.py @@ -1,6 +1,7 @@ import torch from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager from brevitas_examples.llm.llm_quant.export import brevitas_proxy_export_mode + from optimum.exporters.onnx import onnx_export_from_model From 3afeddff0cf1450c3ca77e563320ac3298eb7ae9 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 18 Mar 2024 14:01:45 +0000 Subject: [PATCH 4/7] Code review --- .../quantization/brevitas/quantize_llm.py | 4 +- optimum/amd/brevitas/export.py | 41 ++++++++++++++++++- 2 files changed, 41 insertions(+), 4 deletions(-) diff --git a/examples/quantization/brevitas/quantize_llm.py b/examples/quantization/brevitas/quantize_llm.py index d2cec92a..e61fe8a5 100644 --- a/examples/quantization/brevitas/quantize_llm.py +++ b/examples/quantization/brevitas/quantize_llm.py @@ -3,7 +3,7 @@ from optimum.amd import BrevitasQuantizationConfig, BrevitasQuantizer from optimum.amd.brevitas.accelerate_utils import calc_cpu_device_map, calc_gpu_device_map, offload_model, remove_hooks from optimum.amd.brevitas.data_utils import compute_perplexity, get_dataset_for_model -from optimum.amd.brevitas.export import export_quantized_model +from optimum.amd.brevitas.export import onnx_export_from_quantized_model from transformers import AutoTokenizer @@ -76,7 +76,7 @@ def main(args): quantized_model = quantized_model.to("cpu") # Export to ONNX through optimum.exporters. - export_quantized_model(quantized_model, args.onnx_output_path) + onnx_export_from_quantized_model(quantized_model, args.onnx_output_path) return return_val diff --git a/optimum/amd/brevitas/export.py b/optimum/amd/brevitas/export.py index e7e2de3f..10f5defb 100644 --- a/optimum/amd/brevitas/export.py +++ b/optimum/amd/brevitas/export.py @@ -3,8 +3,45 @@ from brevitas_examples.llm.llm_quant.export import brevitas_proxy_export_mode from optimum.exporters.onnx import onnx_export_from_model +from transformers.modeling_utils import PreTrainedModel +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union +from pathlib import Path +from optimum.exporters.onnx.base import OnnxConfig -def export_quantized_model(quantized_model, path, task="text-generation-with-past"): +def onnx_export_from_quantized_model( + quantized_model: Union["PreTrainedModel"], + output: Union[str, Path], + opset: Optional[int] = None, + optimize: Optional[str] = None, + monolith: bool = False, + model_kwargs: Optional[Dict[str, Any]] = None, + custom_onnx_configs: Optional[Dict[str, "OnnxConfig"]] = None, + fn_get_submodels: Optional[Callable] = None, + _variant: str = "default", + preprocessors: List = None, + device: str = "cpu", + no_dynamic_axes: bool = False, + use_subprocess: bool = False, + do_constant_folding: bool = True, + **kwargs_shapes): with torch.no_grad(), brevitas_proxy_export_mode(quantized_model, export_manager=StdQCDQONNXManager): - onnx_export_from_model(quantized_model, path, task=task, do_validation=False, no_post_process=True) + onnx_export_from_model( + quantized_model, + output, + opset=opset, + monolith=monolith, + optimize=optimize, + model_kwargs=model_kwargs, + custom_onnx_configs=custom_onnx_configs, + fn_get_submodels=fn_get_submodels, + _variant=_variant, + preprocessors=preprocessors, + device=device, + no_dynamic_axes=no_dynamic_axes, + use_subprocess=use_subprocess, + do_constant_folding=do_constant_folding, + task="text-generation-with-past", + do_validation=False, + no_post_process=True, + **kwargs_shapes) From 122f95da3245913924cf5d33379dd2a0c7fafc00 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 18 Mar 2024 14:14:22 +0000 Subject: [PATCH 5/7] Docs: update documentation --- docs/source/brevitas/usage_guide.mdx | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/docs/source/brevitas/usage_guide.mdx b/docs/source/brevitas/usage_guide.mdx index 1c105e6b..bd46bae7 100644 --- a/docs/source/brevitas/usage_guide.mdx +++ b/docs/source/brevitas/usage_guide.mdx @@ -74,14 +74,10 @@ Brevitas models can be exported to ONNX using Optimum: ```python import torch -from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager -from brevitas_examples.llm.llm_quant.export import brevitas_proxy_export_mode +from optimum.amd.brevitas.export import onnx_export_from_quantized_model # Export to ONNX through optimum.exporters. -with torch.no_grad(), brevitas_proxy_export_mode(model, export_manager=StdQCDQONNXManager): - onnx_export_from_model( - model, "llm_quantized_onnx", task="text-generation-with-past", do_validation=False, no_post_process=True - ) +onnx_export_from_quantized_model(model, "llm_quantized_onnx") ``` ## Complete example From 434533ae8338be8b0a7e7b62b42a349ba5e17da5 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 18 Mar 2024 16:10:03 +0000 Subject: [PATCH 6/7] Formatting --- optimum/amd/brevitas/export.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/optimum/amd/brevitas/export.py b/optimum/amd/brevitas/export.py index 10f5defb..56bafd77 100644 --- a/optimum/amd/brevitas/export.py +++ b/optimum/amd/brevitas/export.py @@ -1,12 +1,13 @@ +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Union + import torch from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager from brevitas_examples.llm.llm_quant.export import brevitas_proxy_export_mode from optimum.exporters.onnx import onnx_export_from_model -from transformers.modeling_utils import PreTrainedModel -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union -from pathlib import Path from optimum.exporters.onnx.base import OnnxConfig +from transformers.modeling_utils import PreTrainedModel def onnx_export_from_quantized_model( @@ -24,10 +25,11 @@ def onnx_export_from_quantized_model( no_dynamic_axes: bool = False, use_subprocess: bool = False, do_constant_folding: bool = True, - **kwargs_shapes): + **kwargs_shapes, +): with torch.no_grad(), brevitas_proxy_export_mode(quantized_model, export_manager=StdQCDQONNXManager): onnx_export_from_model( - quantized_model, + quantized_model, output, opset=opset, monolith=monolith, @@ -41,7 +43,8 @@ def onnx_export_from_quantized_model( no_dynamic_axes=no_dynamic_axes, use_subprocess=use_subprocess, do_constant_folding=do_constant_folding, - task="text-generation-with-past", - do_validation=False, + task="text-generation-with-past", + do_validation=False, no_post_process=True, - **kwargs_shapes) + **kwargs_shapes, + ) From 848eaa17bda20ae297c0a857647e88fd951057ce Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 19 Mar 2024 10:43:11 +0100 Subject: [PATCH 7/7] Apply suggestions from code review Co-authored-by: Mohit Sharma --- optimum/amd/brevitas/export.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/optimum/amd/brevitas/export.py b/optimum/amd/brevitas/export.py index 56bafd77..50d1fe2a 100644 --- a/optimum/amd/brevitas/export.py +++ b/optimum/amd/brevitas/export.py @@ -23,6 +23,7 @@ def onnx_export_from_quantized_model( preprocessors: List = None, device: str = "cpu", no_dynamic_axes: bool = False, + task: str = "text-generation-with-past", use_subprocess: bool = False, do_constant_folding: bool = True, **kwargs_shapes, @@ -43,7 +44,7 @@ def onnx_export_from_quantized_model( no_dynamic_axes=no_dynamic_axes, use_subprocess=use_subprocess, do_constant_folding=do_constant_folding, - task="text-generation-with-past", + task=task, do_validation=False, no_post_process=True, **kwargs_shapes,