diff --git a/examples/quantization/brevitas/quantize_llm.py b/examples/quantization/brevitas/quantize_llm.py index 7f4ebcf1..734f344e 100644 --- a/examples/quantization/brevitas/quantize_llm.py +++ b/examples/quantization/brevitas/quantize_llm.py @@ -81,7 +81,9 @@ def main(args): # Export to ONNX through optimum.exporters. export_manager = StdQCDQONNXManager - export_manager.change_weight_export(export_weight_q_node=True) + if args.qdq_weights: + export_manager.change_weight_export(export_weight_q_node=True) + with torch.no_grad(), brevitas_proxy_export_mode(quantized_model, export_manager=export_manager): onnx_export_from_model( quantized_model, @@ -154,11 +156,17 @@ def main(args): default="auto", help='Device to run the example on (e.q., "cpu", "cuda:0", "auto"). "auto" will automatically select the device using HuggingFace Accelerate (choices: [%(choices)s], default: %(default)s).', ) + parser.add_argument( + "--qdq-weights", + action="store_true", + default=False, + help="In the ONNX export, save quantized weights as float32 and insert an additional QuantizeLinear node, TensorRT style (default: %(default)s).", + ) parser.add_argument( "--onnx-output-path", type=str, default="llm_quantized_onnx", - help="Location to store the output ONNX model (default: %(default)s)", + help="Location to store the output ONNX model (default: %(default)s).", ) args = parser.parse_args() diff --git a/tests/brevitas/test_onnx_export.py b/tests/brevitas/test_onnx_export.py index 260b7ab5..b0ffcd8d 100644 --- a/tests/brevitas/test_onnx_export.py +++ b/tests/brevitas/test_onnx_export.py @@ -7,10 +7,12 @@ from pathlib import Path from typing import Dict +import numpy as np import onnx import torch from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager from brevitas_examples.llm.llm_quant.export import brevitas_proxy_export_mode +from onnx import numpy_helper from parameterized import parameterized from testing_utils import SUPPORTED_MODELS_TINY, VALIDATE_EXPORT_ON_SHAPES, get_quantized_model @@ -49,14 +51,28 @@ def _get_models_to_test(export_models_dict: Dict, library_name: str = "transform library_name=library_name, ) - models_to_test.append((f"{model_type}_{task}", model_type, model_name, task, onnx_config_constructor)) + models_to_test.append( + (f"{model_type}_{task}_DQ", model_type, model_name, task, onnx_config_constructor, False) + ) + models_to_test.append( + (f"{model_type}_{task}_QDQ", model_type, model_name, task, onnx_config_constructor, True) + ) return sorted(models_to_test) def export_and_validate( - model: torch.nn.Module, task: str, export_output_dir: str, onnx_config_class_constructor, shapes_to_validate: Dict + model: torch.nn.Module, + task: str, + export_output_dir: str, + onnx_config_class_constructor, + shapes_to_validate: Dict, + qdq_weights: bool, ): - with torch.no_grad(), brevitas_proxy_export_mode(model, export_manager=StdQCDQONNXManager): + export_manager = StdQCDQONNXManager + if qdq_weights: + export_manager.change_weight_export(export_weight_q_node=True) + + with torch.no_grad(), brevitas_proxy_export_mode(model, export_manager=export_manager): library_name = TasksManager._infer_library_from_model(model) framework = "pt" dtype = get_parameter_dtype(model) if framework == "pt" else model.dtype @@ -121,12 +137,7 @@ def export_and_validate( class TestOnnxExport(unittest.TestCase): @parameterized.expand(_get_models_to_test(SUPPORTED_MODELS_TINY)) def test_dynamic_quantization( - self, - test_name, - model_type, - model_name, - task, - onnx_config_class_constructor, + self, test_name, model_type, model_name, task, onnx_config_class_constructor, qdq_weights: bool ): model = get_quantized_model( model_name, @@ -144,10 +155,24 @@ def test_dynamic_quantization( export_output_dir=tmpdir, onnx_config_class_constructor=onnx_config_class_constructor, shapes_to_validate=VALIDATE_EXPORT_ON_SHAPES, + qdq_weights=qdq_weights, ) onnx_model = onnx.load(os.path.join(tmpdir, "model.onnx")) - for node in onnx_model.graph.node: - # Check that we have MatmulInteger, etc. - pass + if qdq_weights: + for node in onnx_model.graph.node: + if node.op_type == "Constant": + for attrib in node.attribute: + new_array = numpy_helper.to_array(attrib.t) + if len(new_array.shape) >= 2 and new_array.dtype in [np.uint8, np.int8]: + break + else: + self.assertTrue(False, "Did not found an int8/uint8 serialized weight") + else: + for node in onnx_model.graph.node: + if node.op_type == "Constant": + for attrib in node.attribute: + new_array = numpy_helper.to_array(attrib.t) + if len(new_array.shape) >= 2 and new_array.dtype in [np.uint8, np.int8]: + self.assertTrue(False, "Found uint8/int8 serialized weights while we should not")