diff --git a/.github/workflows/test_brevitas.yaml b/.github/workflows/test_brevitas.yaml index 72faf513..5db21168 100644 --- a/.github/workflows/test_brevitas.yaml +++ b/.github/workflows/test_brevitas.yaml @@ -23,7 +23,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: [3.8] + python-version: [3.9] os: [ubuntu-20.04, windows-2019, macos-latest] runs-on: ${{ matrix.os }} diff --git a/optimum/amd/brevitas/export.py b/optimum/amd/brevitas/export.py index 50d1fe2a..24525895 100644 --- a/optimum/amd/brevitas/export.py +++ b/optimum/amd/brevitas/export.py @@ -1,15 +1,301 @@ +import copy +import logging +import os +import sys from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Union +import onnx import torch from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager from brevitas_examples.llm.llm_quant.export import brevitas_proxy_export_mode +from onnx_tool import Model +from onnx_tool.fusion import FusionPattern +from onnx_tool.graph import Graph +from onnx_tool.node import create_node +from onnx_tool.tensor import Tensor from optimum.exporters.onnx import onnx_export_from_model from optimum.exporters.onnx.base import OnnxConfig +from optimum.onnx.graph_transformations import check_and_save_model from transformers.modeling_utils import PreTrainedModel +LOGGER = logging.getLogger(__name__) + +ONNX_FLOAT32_IDENTIFIER = int(1) + +## Pattern to find and replace with MatMulInteger +MATMUL = [ + { + "name": "deq_linear_0", + "op": "DequantizeLinear", + "attrs": [], + "inport": [], + "outport": [[0, "transpose_0", 0]], + }, + { + "name": "transpose_0", + "op": "Transpose", + "attrs": [], + "inport": [[0, "deq_linear_0", 0]], + "outport": [[0, "matmul_0", 1]], + }, + { + "name": "quant_linear_1", + "op": "DynamicQuantizeLinear", + "attrs": [], + "inport": [], + "outport": [[0, "deq_linear_1", 0], [1, "deq_linear_1", 1], [2, "deq_linear_1", 2]], + }, + { + "name": "deq_linear_1", + "op": "DequantizeLinear", + "attrs": [], + "inport": [ + [0, "quant_linear_1", 0], + [1, "quant_linear_1", 1], + [2, "quant_linear_1", 2], + ], + "outport": [[0, "matmul_0", 0]], + }, + { + "name": "matmul_0", + "op": "MatMul", + "attrs": [], + "inport": [ + [0, "deq_linear_1", 0], + [1, "transpose_0", 0], + ], + "outport": [], + }, +] + +GEMM = [ + { + "name": "deq_linear_0", + "op": "DequantizeLinear", + "attrs": [], + "inport": [], + "outport": [[0, "gemm_0", 1]], + }, + { + "name": "quant_linear_1", + "op": "DynamicQuantizeLinear", + "attrs": [], + "inport": [], + "outport": [[0, "deq_linear_1", 0], [1, "deq_linear_1", 1], [2, "deq_linear_1", 2]], + }, + { + "name": "deq_linear_1", + "op": "DequantizeLinear", + "attrs": [], + "inport": [ + [0, "quant_linear_1", 0], + [1, "quant_linear_1", 1], + [2, "quant_linear_1", 2], + ], + "outport": [[0, "gemm_0", 0]], + }, + { + "name": "gemm_0", + "op": "Gemm", + "attrs": [], + "inport": [ + [0, "deq_linear_1", 0], + [1, "deq_linear_0", 0], + ], + "outport": [], + }, +] + + +def create_nodes(graph: Graph, op: str, name: str, inputs: List[str], outputs: List[str], **kwargs): + newnode = onnx.helper.make_node(op, inputs, outputs, name=name, **kwargs) + newnode = create_node(newnode) + newnode.input = inputs + newnode.output = outputs + for i in inputs: + if i in graph.consumedby: + graph.consumedby[i].append(name) + if i in graph.producedby.keys(): + newnode.prevnodes.append(graph.producedby[i]) + for o in outputs: + graph.producedby[o] = [name] + if o in graph.consumedby.keys(): + newnode.nextnodes.append(graph.consumedby[o]) + graph.nodemap[name] = newnode + graph.tensormap[name] = Tensor(name) + + return graph + + +def replace_matmul_to_matmulinteger(graph: Graph, found_nodes: List[List[str]], node_count: int = 0): + for found_pattern in found_nodes: + node_count += 1 + + deq_linear = graph.nodemap[found_pattern[0]] + dyn_q = graph.nodemap[found_pattern[2]] + dq_weight = deq_linear.prevnodes[0] + graph.add_initial(f"dq_weights_0_{node_count}", dq_weight.value.transpose()) + graph.add_initial(f"dq_weights_1_{node_count}", deq_linear.prevnodes[1].value) + graph.add_initial(f"dq_weights_2_{node_count}", deq_linear.prevnodes[2].value) + + matmul = graph.nodemap[found_pattern[-1]] + for name in found_pattern: + if "DynamicQuantizeLinear" in name: + continue + graph.remove_node(name) + + graph.remove_node(deq_linear.prevnodes[0].name) + graph.remove_node(deq_linear.prevnodes[1].name) + if deq_linear.prevnodes[2].name in graph.nodemap: + graph.remove_node(deq_linear.prevnodes[2].name) + + graph = create_nodes( + graph, + "MatMulInteger", + f"matmul_integer_{node_count}", + [dyn_q.output[0], f"dq_weights_0_{node_count}", dyn_q.output[2], f"dq_weights_2_{node_count}"], + [f"matmul_integer_{node_count}"], + ) + graph = create_nodes( + graph, + "Cast", + f"cast_{node_count}", + [f"matmul_integer_{node_count}"], + [f"cast_{node_count}"], + to=ONNX_FLOAT32_IDENTIFIER, + ) + graph = create_nodes( + graph, + "Mul", + f"mulscales_{node_count}", + [dyn_q.output[1], f"dq_weights_1_{node_count}"], + [f"mulscales_{node_count}"], + ) + graph = create_nodes( + graph, + "Mul", + f"mulvalues_{node_count}", + [f"mulscales_{node_count}", f"cast_{node_count}"], + [matmul.output[0]], + ) + return graph + + +def replace_gemm_to_matmulinteger(graph: Graph, found_nodes: List[List[str]], node_count: int = 0): + for found_pattern in found_nodes: + node_count += 1 + + gemm = graph.nodemap[found_pattern[-1]] + bias = gemm.input[-1] + deq_linear = graph.nodemap[found_pattern[0]] + dyn_q = graph.nodemap[found_pattern[1]] + dq_weight = deq_linear.prevnodes[0] + graph.add_initial(f"dq_weights_0_{node_count}", dq_weight.value.transpose()) + graph.add_initial(f"dq_weights_1_{node_count}", deq_linear.prevnodes[1].value) + graph.add_initial(f"dq_weights_2_{node_count}", deq_linear.prevnodes[2].value) + + matmul = graph.nodemap[found_pattern[-1]] + for name in found_pattern: + if "DynamicQuantizeLinear" in name: + continue + graph.remove_node(name) + graph.remove_node(deq_linear.prevnodes[0].name) + graph.remove_node(deq_linear.prevnodes[1].name) + if deq_linear.prevnodes[2].name in graph.nodemap: + graph.remove_node(deq_linear.prevnodes[2].name) + + graph = create_nodes( + graph, + "MatMulInteger", + f"matmul_integer_{node_count}", + [dyn_q.output[0], f"dq_weights_0_{node_count}", dyn_q.output[2], f"dq_weights_2_{node_count}"], + [f"matmul_integer_{node_count}"], + ) + graph = create_nodes( + graph, + "Cast", + f"cast_{node_count}", + [f"matmul_integer_{node_count}"], + [f"cast_{node_count}"], + to=ONNX_FLOAT32_IDENTIFIER, + ) + graph = create_nodes( + graph, + "Mul", + f"mulscales_{node_count}", + [dyn_q.output[1], f"dq_weights_1_{node_count}"], + [f"mulscales_{node_count}"], + ) + graph = create_nodes( + graph, + "Mul", + f"mulvalues_{node_count}", + [f"mulscales_{node_count}", f"cast_{node_count}"], + [f"mulvalues_{node_count}"], + ) + graph = create_nodes( + graph, "Add", f"addbias_{node_count}", [bias, f"mulvalues_{node_count}"], [matmul.output[0]] + ) + return graph + + +def find_and_insert_matmulinteger(model_path: str): + # onnx_tool requires python 3.9+ + if sys.version_info[0] == 3 and sys.version_info[1] <= 8: + raise RuntimeError("onnx_tool requires Python 3.9 or higher") + + LOGGER.info("Rewriting ONNX Graph with MatMulInteger") + model_path = os.path.join(model_path, "model.onnx") + cfg = {"constant_folding": False, "node_rename": False, "if_fixed_branch": None, "fixed_topk": 0, "verbose": False} + + onnx_model = onnx.load(model_path) + + # Extract model output + original_output = copy.deepcopy(onnx_model.graph.output) + + model = Model(onnx_model, cfg) + graph = model.graph + + pattern = FusionPattern(MATMUL) + found_matmul_nodes = pattern.search_pattern(graph) + matmul_node_count = len(found_matmul_nodes) + LOGGER.info(f"Replacing {matmul_node_count} MatMul nodes with MatMulInteger") + graph = replace_matmul_to_matmulinteger(graph, found_matmul_nodes) + + pattern = FusionPattern(GEMM) + found_gemm_nodes = pattern.search_pattern(graph) + gemm_node_count = len(found_gemm_nodes) + LOGGER.info(f"Replacing {gemm_node_count} Gemm nodes with MatMulInteger + Add") + graph = replace_gemm_to_matmulinteger(graph, found_gemm_nodes, matmul_node_count) + + graph.graph_reorder_nodes() + + LOGGER.info("Saving the new ONNX model") + full_path = Path(model_path) + + graph = graph.make_graph_onnx( + graph.nodemap.keys(), "graph", graph.input, graph.output, with_initializer=True, with_shape_info=False + ) + + attr = {"producer_name": "onnx_tool"} + model_to_save = onnx.helper.make_model(graph, **attr) + + # onnx_tools might remove the output nodes from the ONNX graph, so we need to restore it. + for out in original_output: + if out not in model_to_save.graph.output: + model_to_save.graph.output.append(out) + + model_to_save.ir_version = model.mproto.ir_version + model_to_save.opset_import.pop() + for opset in model.mproto.opset_import: + model_to_save.opset_import.append(opset) + + check_and_save_model(model_to_save, full_path) + + def onnx_export_from_quantized_model( quantized_model: Union["PreTrainedModel"], output: Union[str, Path], @@ -26,6 +312,7 @@ def onnx_export_from_quantized_model( task: str = "text-generation-with-past", use_subprocess: bool = False, do_constant_folding: bool = True, + insert_matmulinteger: bool = True, **kwargs_shapes, ): with torch.no_grad(), brevitas_proxy_export_mode(quantized_model, export_manager=StdQCDQONNXManager): @@ -49,3 +336,7 @@ def onnx_export_from_quantized_model( no_post_process=True, **kwargs_shapes, ) + + # Replace quantized GEMM and MatMul with MatMulInteger + if insert_matmulinteger: + find_and_insert_matmulinteger(output) diff --git a/setup.py b/setup.py index fc5427ab..ef37a213 100644 --- a/setup.py +++ b/setup.py @@ -34,7 +34,7 @@ EXTRAS_REQUIRE = { "quality": QUALITY_REQUIRE, "tests": TESTS_REQUIRE, - "brevitas": ["brevitas", "torch>=2.2", "datasets>=2.17", "onnx", "onnxruntime", "accelerate"], + "brevitas": ["brevitas", "torch>=2.2", "datasets>=2.17", "onnx", "onnxruntime", "accelerate", "onnx-tool"], } setup( diff --git a/tests/brevitas/test_onnx_export.py b/tests/brevitas/test_onnx_export.py index 260b7ab5..836d21b8 100644 --- a/tests/brevitas/test_onnx_export.py +++ b/tests/brevitas/test_onnx_export.py @@ -14,6 +14,7 @@ from parameterized import parameterized from testing_utils import SUPPORTED_MODELS_TINY, VALIDATE_EXPORT_ON_SHAPES, get_quantized_model +from optimum.amd.brevitas.export import find_and_insert_matmulinteger from optimum.exporters import TasksManager from optimum.exporters.onnx import ( export_models, @@ -145,9 +146,27 @@ def test_dynamic_quantization( onnx_config_class_constructor=onnx_config_class_constructor, shapes_to_validate=VALIDATE_EXPORT_ON_SHAPES, ) + original_matmul_gemm_counter = 0 + onnx_model = onnx.load(os.path.join(tmpdir, "model.onnx")) + + for node in onnx_model.graph.node: + if node.op_type == "Gemm" or node.op_type == "MatMul": + original_matmul_gemm_counter += 1 + find_and_insert_matmulinteger(tmpdir) onnx_model = onnx.load(os.path.join(tmpdir, "model.onnx")) + matmul_gemm_counter = 0 + matmulinteger_counter = 0 for node in onnx_model.graph.node: - # Check that we have MatmulInteger, etc. - pass + if node.op_type == "Gemm" or node.op_type == "MatMul": + matmul_gemm_counter += 1 + + if node.op_type == "MatMulInteger": + matmulinteger_counter += 1 + + # The number of Matmul+Gemm has to be less compared to the model pre-transformation + # This is not zero since there are matmul that are not linear layers so they are not replaced + # and some linears layers can be excluded from quantization + self.assertTrue(matmul_gemm_counter <= original_matmul_gemm_counter) + self.assertTrue(matmulinteger_counter > 1)