From bb4f843b55a655f13c429e73d4c3fb0b3c1558f7 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 14 Mar 2024 14:54:31 +0000 Subject: [PATCH 01/15] Add ONNX rewriter --- .../quantization/brevitas/quantize_llm.py | 1 + examples/quantization/brevitas/rewriter.py | 227 ++++++++++++++++++ optimum/amd/brevitas/export.py | 18 ++ 3 files changed, 246 insertions(+) create mode 100644 examples/quantization/brevitas/rewriter.py diff --git a/examples/quantization/brevitas/quantize_llm.py b/examples/quantization/brevitas/quantize_llm.py index e61fe8a5..2753d154 100644 --- a/examples/quantization/brevitas/quantize_llm.py +++ b/examples/quantization/brevitas/quantize_llm.py @@ -5,6 +5,7 @@ from optimum.amd.brevitas.data_utils import compute_perplexity, get_dataset_for_model from optimum.amd.brevitas.export import onnx_export_from_quantized_model from transformers import AutoTokenizer +from rewriter import rewrite_graph def main(args): diff --git a/examples/quantization/brevitas/rewriter.py b/examples/quantization/brevitas/rewriter.py new file mode 100644 index 00000000..acdfdeb6 --- /dev/null +++ b/examples/quantization/brevitas/rewriter.py @@ -0,0 +1,227 @@ +import onnx +from onnx_tool import Model +from onnx_tool.node import create_node +from onnx_tool.tensor import Tensor + +from onnx_tool.fusion import FusionPattern + +from optimum.onnx.graph_transformations import check_and_save_model + +import pathlib + +import gc + + +MatMul = [ + { + 'name': 'deq_linear_0', + 'op': 'DequantizeLinear', + 'attrs': [ + ], + 'inport': [], + 'outport': [[0, 'transpose_0', 0]], + }, + { + 'name': 'transpose_0', + 'op': 'Transpose', + 'attrs': [ + ], + 'inport': [[0, 'deq_linear_0', 0]], + 'outport': [[0, 'matmul_0', 1]], + }, + { + 'name': 'quant_linear_1', + 'op': 'DynamicQuantizeLinear', + 'attrs': [ + ], + 'inport': [], + 'outport': [[0, 'deq_linear_1', 0], + [1, 'deq_linear_1', 1], + [2, 'deq_linear_1', 2] + ], + }, + { + 'name': 'deq_linear_1', + 'op': 'DequantizeLinear', + 'attrs': [ + ], + 'inport': [[0, 'quant_linear_1', 0], + [1, 'quant_linear_1', 1], + [2, 'quant_linear_1', 2],], + 'outport': [[0, 'matmul_0', 0]], + }, + { + 'name': 'matmul_0', + 'op': 'MatMul', + 'attrs': [ + ], + 'inport': [[0, 'deq_linear_1', 0], + [1, 'transpose_0', 0], + ], + 'outport': [], + }, +] + + + +GEMM = [ + { + 'name': 'deq_linear_0', + 'op': 'DequantizeLinear', + 'attrs': [ + ], + 'inport': [], + 'outport': [[0, 'gemm_0', 1]], + }, + { + 'name': 'quant_linear_1', + 'op': 'DynamicQuantizeLinear', + 'attrs': [ + ], + 'inport': [], + 'outport': [[0, 'deq_linear_1', 0], + [1, 'deq_linear_1', 1], + [2, 'deq_linear_1', 2] + ], + }, + { + 'name': 'deq_linear_1', + 'op': 'DequantizeLinear', + 'attrs': [ + ], + 'inport': [[0, 'quant_linear_1', 0], + [1, 'quant_linear_1', 1], + [2, 'quant_linear_1', 2],], + 'outport': [[0, 'gemm_0', 0]], + }, + { + 'name': 'gemm_0', + 'op': 'Gemm', + 'attrs': [ + ], + 'inport': [[0, 'deq_linear_1', 0], + [1, 'deq_linear_0', 0], + ], + 'outport': [], + }, +] + +def create_nodes(graph, op, name, inputs, outputs, intermediate=None, **kwargs): + + if intermediate is None: + intermediate = [] + newnode = onnx.helper.make_node(op, inputs + intermediate, outputs, name=name, **kwargs) + newnode = create_node(newnode) + newnode.input = inputs + intermediate + newnode.output = outputs + for i in inputs: + if i in graph.consumedby: + graph.consumedby[i].append(name) + if i in graph.producedby.keys(): + newnode.prevnodes.append(graph.producedby[i]) + for o in outputs: + graph.producedby[o] = [name] + if o in graph.consumedby.keys(): + newnode.nextnodes.append(graph.consumedby[o]) + graph.nodemap[name] = newnode + graph.tensormap[name] = Tensor(name) + + return graph + +def replace_matmul_to_matmulinteger(compute_graph, found_nodes): + for i, found_pattern in enumerate(found_nodes): + deq_linear = compute_graph.nodemap[found_pattern[0]] + dyn_q = compute_graph.nodemap[found_pattern[2]] + dq_weight = deq_linear.prevnodes[0] + compute_graph.add_initial(f'dq_weights_0_{i}', dq_weight.value.transpose()) + compute_graph.add_initial(f'dq_weights_1_{i}', deq_linear.prevnodes[1].value) + compute_graph.add_initial(f'dq_weights_2_{i}', deq_linear.prevnodes[2].value) + + + matmul = compute_graph.nodemap[found_pattern[-1]] + for name in found_pattern: + if 'DynamicQuantizeLinear' in name: + continue + compute_graph.remove_node(name) + + compute_graph.remove_node(deq_linear.prevnodes[0].name) + compute_graph.remove_node(deq_linear.prevnodes[1].name) + if deq_linear.prevnodes[2].name in compute_graph.nodemap: + compute_graph.remove_node(deq_linear.prevnodes[2].name) + + compute_graph=create_nodes(compute_graph, 'MatMulInteger', f'matmul_integer_{i}', [dyn_q.output[0], f'dq_weights_0_{i}', dyn_q.output[2], f'dq_weights_2_{i}'], [f'matmul_integer_{i}']) + compute_graph=create_nodes(compute_graph, 'Cast', f'cast_{i}', [f'matmul_integer_{i}'], [f'cast_{i}'], to=int(1)) + compute_graph=create_nodes(compute_graph, 'Mul', f'mulscales_{i}', [dyn_q.output[1], f'dq_weights_1_{i}'], [f'mulscales_{i}']) + compute_graph=create_nodes(compute_graph, 'Mul', f'mulvalues_{i}', [f'mulscales_{i}', f'cast_{i}'], [matmul.output[0]]) + return compute_graph + + +def replace_gemm_to_matmulinteger(compute_graph, found_nodes): + k = 100 + for i, found_pattern in enumerate(found_nodes): + k = i + 100 + gemm = compute_graph.nodemap[found_pattern[-1]] + bias = gemm.input[-1] + deq_linear = compute_graph.nodemap[found_pattern[0]] + dyn_q = compute_graph.nodemap[found_pattern[1]] + dq_weight = deq_linear.prevnodes[0] + compute_graph.add_initial(f'dq_weights_0_{k}', dq_weight.value.transpose()) + compute_graph.add_initial(f'dq_weights_1_{k}', deq_linear.prevnodes[1].value) + compute_graph.add_initial(f'dq_weights_2_{k}', deq_linear.prevnodes[2].value) + + matmul = compute_graph.nodemap[found_pattern[-1]] + for name in found_pattern: + if 'DynamicQuantizeLinear' in name: + continue + compute_graph.remove_node(name) + compute_graph.remove_node(deq_linear.prevnodes[0].name) + compute_graph.remove_node(deq_linear.prevnodes[1].name) + if deq_linear.prevnodes[2].name in compute_graph.nodemap: + compute_graph.remove_node(deq_linear.prevnodes[2].name) + + compute_graph=create_nodes(compute_graph, 'MatMulInteger', f'matmul_integer_{k}', [dyn_q.output[0], f'dq_weights_0_{k}', dyn_q.output[2], f'dq_weights_2_{k}'], [f'matmul_integer_{k}']) + compute_graph=create_nodes(compute_graph, 'Cast', f'cast_{k}', [f'matmul_integer_{k}'], [f'cast_{k}'], to=int(1)) + compute_graph=create_nodes(compute_graph, 'Mul', f'mulscales_{k}', [dyn_q.output[1], f'dq_weights_1_{k}'], [f'mulscales_{k}']) + compute_graph=create_nodes(compute_graph, 'Mul', f'mulvalues_{k}', [f'mulscales_{k}', f'cast_{k}'], [f'mulvalues_{k}']) + compute_graph=create_nodes(compute_graph, 'Add', f'addbias_{k}', [bias, f'mulvalues_{k}'], [matmul.output[0]]) + return compute_graph + +def rewrite_graph(model_path): + print("Rewriting ONNX Graph for Ryzen AI") + + cfg={'constant_folding':False,'node_rename':False,'if_fixed_branch':None,'fixed_topk':0,'verbose':True} + original_output = onnx.load(model_path).graph.output + model = Model(model_path,cfg) + graph = model.graph + + print("Replacing MatMul with MatMulInteger") + pattern = FusionPattern(MatMul) + found_nodes = pattern.search_pattern(graph) + graph = replace_matmul_to_matmulinteger(graph, found_nodes) + + print("Replacing GEMM with MatMulInteger + Add") + pattern = FusionPattern(GEMM) + found_nodes = pattern.search_pattern(graph) + graph = replace_gemm_to_matmulinteger(graph, found_nodes) + + graph.graph_reorder_nodes() + + print("Saving the new ONNX model") + full_path = pathlib.Path(model_path) + + graph = graph.make_graph_onnx(graph.nodemap.keys(), 'graph', graph.input, graph.output, + with_initializer=True, with_shape_info=False) + attr = {} + attr['producer_name'] = 'onnx_tool' + model_to_save = onnx.helper.make_model(graph, **attr) + + for out in original_output: + if out not in model.graph.output: + model.graph.output.append(out) + + model_to_save.ir_version = model.mproto.ir_version + model_to_save.opset_import.pop() + for opset in model.mproto.opset_import: + model_to_save.opset_import.append(opset) + + check_and_save_model(model_to_save, full_path) diff --git a/optimum/amd/brevitas/export.py b/optimum/amd/brevitas/export.py index 50d1fe2a..dea5a8c7 100644 --- a/optimum/amd/brevitas/export.py +++ b/optimum/amd/brevitas/export.py @@ -1,3 +1,4 @@ +<<<<<<< HEAD from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Union @@ -49,3 +50,20 @@ def onnx_export_from_quantized_model( no_post_process=True, **kwargs_shapes, ) +======= + +import torch +from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager +from brevitas_examples.llm.llm_quant.export import brevitas_proxy_export_mode +from optimum.exporters.onnx import onnx_export_from_model + + +def export_quantized_model(quantized_model, path, task="text-generation-with-past"): + with torch.no_grad(), brevitas_proxy_export_mode(quantized_model, export_manager=StdQCDQONNXManager): + onnx_export_from_model( + quantized_model, + path, + task=task, + do_validation=False, + no_post_process=True) +>>>>>>> Feat: export dq only From a2a4862c80aa80ccad25280d107b022ef7952ef1 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 14 Mar 2024 16:34:23 +0000 Subject: [PATCH 02/15] New interface for rewriter --- examples/quantization/brevitas/rewriter.py | 1 - optimum/amd/brevitas/export.py | 247 +++++++++++++++++++-- setup.py | 2 +- 3 files changed, 231 insertions(+), 19 deletions(-) diff --git a/examples/quantization/brevitas/rewriter.py b/examples/quantization/brevitas/rewriter.py index acdfdeb6..2b544fbe 100644 --- a/examples/quantization/brevitas/rewriter.py +++ b/examples/quantization/brevitas/rewriter.py @@ -9,7 +9,6 @@ import pathlib -import gc MatMul = [ diff --git a/optimum/amd/brevitas/export.py b/optimum/amd/brevitas/export.py index dea5a8c7..acfb3c12 100644 --- a/optimum/amd/brevitas/export.py +++ b/optimum/amd/brevitas/export.py @@ -1,4 +1,3 @@ -<<<<<<< HEAD from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Union @@ -10,6 +9,232 @@ from optimum.exporters.onnx.base import OnnxConfig from transformers.modeling_utils import PreTrainedModel +import onnx +from onnx_tool import Model +from onnx_tool.node import create_node +from onnx_tool.tensor import Tensor + +from onnx_tool.fusion import FusionPattern + +from optimum.onnx.graph_transformations import check_and_save_model + +import pathlib + + +## Pattern to find and replace with MatMulInteger +MatMul = [ + { + 'name': 'deq_linear_0', + 'op': 'DequantizeLinear', + 'attrs': [ + ], + 'inport': [], + 'outport': [[0, 'transpose_0', 0]], + }, + { + 'name': 'transpose_0', + 'op': 'Transpose', + 'attrs': [ + ], + 'inport': [[0, 'deq_linear_0', 0]], + 'outport': [[0, 'matmul_0', 1]], + }, + { + 'name': 'quant_linear_1', + 'op': 'DynamicQuantizeLinear', + 'attrs': [ + ], + 'inport': [], + 'outport': [[0, 'deq_linear_1', 0], + [1, 'deq_linear_1', 1], + [2, 'deq_linear_1', 2] + ], + }, + { + 'name': 'deq_linear_1', + 'op': 'DequantizeLinear', + 'attrs': [ + ], + 'inport': [[0, 'quant_linear_1', 0], + [1, 'quant_linear_1', 1], + [2, 'quant_linear_1', 2],], + 'outport': [[0, 'matmul_0', 0]], + }, + { + 'name': 'matmul_0', + 'op': 'MatMul', + 'attrs': [ + ], + 'inport': [[0, 'deq_linear_1', 0], + [1, 'transpose_0', 0], + ], + 'outport': [], + }, +] + +GEMM = [ + { + 'name': 'deq_linear_0', + 'op': 'DequantizeLinear', + 'attrs': [ + ], + 'inport': [], + 'outport': [[0, 'gemm_0', 1]], + }, + { + 'name': 'quant_linear_1', + 'op': 'DynamicQuantizeLinear', + 'attrs': [ + ], + 'inport': [], + 'outport': [[0, 'deq_linear_1', 0], + [1, 'deq_linear_1', 1], + [2, 'deq_linear_1', 2] + ], + }, + { + 'name': 'deq_linear_1', + 'op': 'DequantizeLinear', + 'attrs': [ + ], + 'inport': [[0, 'quant_linear_1', 0], + [1, 'quant_linear_1', 1], + [2, 'quant_linear_1', 2],], + 'outport': [[0, 'gemm_0', 0]], + }, + { + 'name': 'gemm_0', + 'op': 'Gemm', + 'attrs': [ + ], + 'inport': [[0, 'deq_linear_1', 0], + [1, 'deq_linear_0', 0], + ], + 'outport': [], + }, +] + +def create_nodes(graph, op, name, inputs, outputs, intermediate=None, **kwargs): + + if intermediate is None: + intermediate = [] + newnode = onnx.helper.make_node(op, inputs + intermediate, outputs, name=name, **kwargs) + newnode = create_node(newnode) + newnode.input = inputs + intermediate + newnode.output = outputs + for i in inputs: + if i in graph.consumedby: + graph.consumedby[i].append(name) + if i in graph.producedby.keys(): + newnode.prevnodes.append(graph.producedby[i]) + for o in outputs: + graph.producedby[o] = [name] + if o in graph.consumedby.keys(): + newnode.nextnodes.append(graph.consumedby[o]) + graph.nodemap[name] = newnode + graph.tensormap[name] = Tensor(name) + + return graph + +def replace_matmul_to_matmulinteger(compute_graph, found_nodes): + for i, found_pattern in enumerate(found_nodes): + deq_linear = compute_graph.nodemap[found_pattern[0]] + dyn_q = compute_graph.nodemap[found_pattern[2]] + dq_weight = deq_linear.prevnodes[0] + compute_graph.add_initial(f'dq_weights_0_{i}', dq_weight.value.transpose()) + compute_graph.add_initial(f'dq_weights_1_{i}', deq_linear.prevnodes[1].value) + compute_graph.add_initial(f'dq_weights_2_{i}', deq_linear.prevnodes[2].value) + + + matmul = compute_graph.nodemap[found_pattern[-1]] + for name in found_pattern: + if 'DynamicQuantizeLinear' in name: + continue + compute_graph.remove_node(name) + + compute_graph.remove_node(deq_linear.prevnodes[0].name) + compute_graph.remove_node(deq_linear.prevnodes[1].name) + if deq_linear.prevnodes[2].name in compute_graph.nodemap: + compute_graph.remove_node(deq_linear.prevnodes[2].name) + + compute_graph=create_nodes(compute_graph, 'MatMulInteger', f'matmul_integer_{i}', [dyn_q.output[0], f'dq_weights_0_{i}', dyn_q.output[2], f'dq_weights_2_{i}'], [f'matmul_integer_{i}']) + compute_graph=create_nodes(compute_graph, 'Cast', f'cast_{i}', [f'matmul_integer_{i}'], [f'cast_{i}'], to=int(1)) + compute_graph=create_nodes(compute_graph, 'Mul', f'mulscales_{i}', [dyn_q.output[1], f'dq_weights_1_{i}'], [f'mulscales_{i}']) + compute_graph=create_nodes(compute_graph, 'Mul', f'mulvalues_{i}', [f'mulscales_{i}', f'cast_{i}'], [matmul.output[0]]) + return compute_graph + + +def replace_gemm_to_matmulinteger(compute_graph, found_nodes): + k = 100 + for i, found_pattern in enumerate(found_nodes): + k = i + 100 + gemm = compute_graph.nodemap[found_pattern[-1]] + bias = gemm.input[-1] + deq_linear = compute_graph.nodemap[found_pattern[0]] + dyn_q = compute_graph.nodemap[found_pattern[1]] + dq_weight = deq_linear.prevnodes[0] + compute_graph.add_initial(f'dq_weights_0_{k}', dq_weight.value.transpose()) + compute_graph.add_initial(f'dq_weights_1_{k}', deq_linear.prevnodes[1].value) + compute_graph.add_initial(f'dq_weights_2_{k}', deq_linear.prevnodes[2].value) + + matmul = compute_graph.nodemap[found_pattern[-1]] + for name in found_pattern: + if 'DynamicQuantizeLinear' in name: + continue + compute_graph.remove_node(name) + compute_graph.remove_node(deq_linear.prevnodes[0].name) + compute_graph.remove_node(deq_linear.prevnodes[1].name) + if deq_linear.prevnodes[2].name in compute_graph.nodemap: + compute_graph.remove_node(deq_linear.prevnodes[2].name) + + compute_graph=create_nodes(compute_graph, 'MatMulInteger', f'matmul_integer_{k}', [dyn_q.output[0], f'dq_weights_0_{k}', dyn_q.output[2], f'dq_weights_2_{k}'], [f'matmul_integer_{k}']) + compute_graph=create_nodes(compute_graph, 'Cast', f'cast_{k}', [f'matmul_integer_{k}'], [f'cast_{k}'], to=int(1)) + compute_graph=create_nodes(compute_graph, 'Mul', f'mulscales_{k}', [dyn_q.output[1], f'dq_weights_1_{k}'], [f'mulscales_{k}']) + compute_graph=create_nodes(compute_graph, 'Mul', f'mulvalues_{k}', [f'mulscales_{k}', f'cast_{k}'], [f'mulvalues_{k}']) + compute_graph=create_nodes(compute_graph, 'Add', f'addbias_{k}', [bias, f'mulvalues_{k}'], [matmul.output[0]]) + return compute_graph + +def find_and_insert_matmulinteger(model_path): + print("Rewriting ONNX Graph with MatMulInteger ") + + cfg={'constant_folding':False,'node_rename':False,'if_fixed_branch':None,'fixed_topk':0,'verbose':True} + original_output = onnx.load(model_path).graph.output + model = Model(model_path,cfg) + graph = model.graph + + print("Replacing MatMul with MatMulInteger") + pattern = FusionPattern(MatMul) + found_nodes = pattern.search_pattern(graph) + graph = replace_matmul_to_matmulinteger(graph, found_nodes) + + print("Replacing GEMM with MatMulInteger + Add") + pattern = FusionPattern(GEMM) + found_nodes = pattern.search_pattern(graph) + graph = replace_gemm_to_matmulinteger(graph, found_nodes) + + graph.graph_reorder_nodes() + + print("Saving the new ONNX model") + full_path = pathlib.Path(model_path) + + graph = graph.make_graph_onnx(graph.nodemap.keys(), 'graph', graph.input, graph.output, + with_initializer=True, with_shape_info=False) + + attr = {'producer_name': 'onnx_tool'} + model_to_save = onnx.helper.make_model(graph, **attr) + + # onnx_tools might remove the output nodes from the ONNX graph, so we need to restore it. + for out in original_output: + if out not in model.graph.output: + model.graph.output.append(out) + + model_to_save.ir_version = model.mproto.ir_version + model_to_save.opset_import.pop() + for opset in model.mproto.opset_import: + model_to_save.opset_import.append(opset) + + check_and_save_model(model_to_save, full_path) + def onnx_export_from_quantized_model( quantized_model: Union["PreTrainedModel"], @@ -27,6 +252,7 @@ def onnx_export_from_quantized_model( task: str = "text-generation-with-past", use_subprocess: bool = False, do_constant_folding: bool = True, + insert_matmulinteger: bool = True, **kwargs_shapes, ): with torch.no_grad(), brevitas_proxy_export_mode(quantized_model, export_manager=StdQCDQONNXManager): @@ -50,20 +276,7 @@ def onnx_export_from_quantized_model( no_post_process=True, **kwargs_shapes, ) -======= - -import torch -from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager -from brevitas_examples.llm.llm_quant.export import brevitas_proxy_export_mode -from optimum.exporters.onnx import onnx_export_from_model - -def export_quantized_model(quantized_model, path, task="text-generation-with-past"): - with torch.no_grad(), brevitas_proxy_export_mode(quantized_model, export_manager=StdQCDQONNXManager): - onnx_export_from_model( - quantized_model, - path, - task=task, - do_validation=False, - no_post_process=True) ->>>>>>> Feat: export dq only + # Replace quantized GEMM and MatMul with MatMulInteger + if insert_matmulinteger: + find_and_insert_matmulinteger(output) diff --git a/setup.py b/setup.py index fc5427ab..972e3020 100644 --- a/setup.py +++ b/setup.py @@ -34,7 +34,7 @@ EXTRAS_REQUIRE = { "quality": QUALITY_REQUIRE, "tests": TESTS_REQUIRE, - "brevitas": ["brevitas", "torch>=2.2", "datasets>=2.17", "onnx", "onnxruntime", "accelerate"], + "brevitas": ["brevitas", "torch>=2.2", "datasets>=2.17", "onnx", "onnxruntime", "accelerate", "onnx_tools"], } setup( From 0c0cc970c12189d4735b418fae9a0bc2300b12e4 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 14 Mar 2024 16:37:02 +0000 Subject: [PATCH 03/15] remove rewriter file --- examples/quantization/brevitas/rewriter.py | 226 --------------------- 1 file changed, 226 deletions(-) delete mode 100644 examples/quantization/brevitas/rewriter.py diff --git a/examples/quantization/brevitas/rewriter.py b/examples/quantization/brevitas/rewriter.py deleted file mode 100644 index 2b544fbe..00000000 --- a/examples/quantization/brevitas/rewriter.py +++ /dev/null @@ -1,226 +0,0 @@ -import onnx -from onnx_tool import Model -from onnx_tool.node import create_node -from onnx_tool.tensor import Tensor - -from onnx_tool.fusion import FusionPattern - -from optimum.onnx.graph_transformations import check_and_save_model - -import pathlib - - - -MatMul = [ - { - 'name': 'deq_linear_0', - 'op': 'DequantizeLinear', - 'attrs': [ - ], - 'inport': [], - 'outport': [[0, 'transpose_0', 0]], - }, - { - 'name': 'transpose_0', - 'op': 'Transpose', - 'attrs': [ - ], - 'inport': [[0, 'deq_linear_0', 0]], - 'outport': [[0, 'matmul_0', 1]], - }, - { - 'name': 'quant_linear_1', - 'op': 'DynamicQuantizeLinear', - 'attrs': [ - ], - 'inport': [], - 'outport': [[0, 'deq_linear_1', 0], - [1, 'deq_linear_1', 1], - [2, 'deq_linear_1', 2] - ], - }, - { - 'name': 'deq_linear_1', - 'op': 'DequantizeLinear', - 'attrs': [ - ], - 'inport': [[0, 'quant_linear_1', 0], - [1, 'quant_linear_1', 1], - [2, 'quant_linear_1', 2],], - 'outport': [[0, 'matmul_0', 0]], - }, - { - 'name': 'matmul_0', - 'op': 'MatMul', - 'attrs': [ - ], - 'inport': [[0, 'deq_linear_1', 0], - [1, 'transpose_0', 0], - ], - 'outport': [], - }, -] - - - -GEMM = [ - { - 'name': 'deq_linear_0', - 'op': 'DequantizeLinear', - 'attrs': [ - ], - 'inport': [], - 'outport': [[0, 'gemm_0', 1]], - }, - { - 'name': 'quant_linear_1', - 'op': 'DynamicQuantizeLinear', - 'attrs': [ - ], - 'inport': [], - 'outport': [[0, 'deq_linear_1', 0], - [1, 'deq_linear_1', 1], - [2, 'deq_linear_1', 2] - ], - }, - { - 'name': 'deq_linear_1', - 'op': 'DequantizeLinear', - 'attrs': [ - ], - 'inport': [[0, 'quant_linear_1', 0], - [1, 'quant_linear_1', 1], - [2, 'quant_linear_1', 2],], - 'outport': [[0, 'gemm_0', 0]], - }, - { - 'name': 'gemm_0', - 'op': 'Gemm', - 'attrs': [ - ], - 'inport': [[0, 'deq_linear_1', 0], - [1, 'deq_linear_0', 0], - ], - 'outport': [], - }, -] - -def create_nodes(graph, op, name, inputs, outputs, intermediate=None, **kwargs): - - if intermediate is None: - intermediate = [] - newnode = onnx.helper.make_node(op, inputs + intermediate, outputs, name=name, **kwargs) - newnode = create_node(newnode) - newnode.input = inputs + intermediate - newnode.output = outputs - for i in inputs: - if i in graph.consumedby: - graph.consumedby[i].append(name) - if i in graph.producedby.keys(): - newnode.prevnodes.append(graph.producedby[i]) - for o in outputs: - graph.producedby[o] = [name] - if o in graph.consumedby.keys(): - newnode.nextnodes.append(graph.consumedby[o]) - graph.nodemap[name] = newnode - graph.tensormap[name] = Tensor(name) - - return graph - -def replace_matmul_to_matmulinteger(compute_graph, found_nodes): - for i, found_pattern in enumerate(found_nodes): - deq_linear = compute_graph.nodemap[found_pattern[0]] - dyn_q = compute_graph.nodemap[found_pattern[2]] - dq_weight = deq_linear.prevnodes[0] - compute_graph.add_initial(f'dq_weights_0_{i}', dq_weight.value.transpose()) - compute_graph.add_initial(f'dq_weights_1_{i}', deq_linear.prevnodes[1].value) - compute_graph.add_initial(f'dq_weights_2_{i}', deq_linear.prevnodes[2].value) - - - matmul = compute_graph.nodemap[found_pattern[-1]] - for name in found_pattern: - if 'DynamicQuantizeLinear' in name: - continue - compute_graph.remove_node(name) - - compute_graph.remove_node(deq_linear.prevnodes[0].name) - compute_graph.remove_node(deq_linear.prevnodes[1].name) - if deq_linear.prevnodes[2].name in compute_graph.nodemap: - compute_graph.remove_node(deq_linear.prevnodes[2].name) - - compute_graph=create_nodes(compute_graph, 'MatMulInteger', f'matmul_integer_{i}', [dyn_q.output[0], f'dq_weights_0_{i}', dyn_q.output[2], f'dq_weights_2_{i}'], [f'matmul_integer_{i}']) - compute_graph=create_nodes(compute_graph, 'Cast', f'cast_{i}', [f'matmul_integer_{i}'], [f'cast_{i}'], to=int(1)) - compute_graph=create_nodes(compute_graph, 'Mul', f'mulscales_{i}', [dyn_q.output[1], f'dq_weights_1_{i}'], [f'mulscales_{i}']) - compute_graph=create_nodes(compute_graph, 'Mul', f'mulvalues_{i}', [f'mulscales_{i}', f'cast_{i}'], [matmul.output[0]]) - return compute_graph - - -def replace_gemm_to_matmulinteger(compute_graph, found_nodes): - k = 100 - for i, found_pattern in enumerate(found_nodes): - k = i + 100 - gemm = compute_graph.nodemap[found_pattern[-1]] - bias = gemm.input[-1] - deq_linear = compute_graph.nodemap[found_pattern[0]] - dyn_q = compute_graph.nodemap[found_pattern[1]] - dq_weight = deq_linear.prevnodes[0] - compute_graph.add_initial(f'dq_weights_0_{k}', dq_weight.value.transpose()) - compute_graph.add_initial(f'dq_weights_1_{k}', deq_linear.prevnodes[1].value) - compute_graph.add_initial(f'dq_weights_2_{k}', deq_linear.prevnodes[2].value) - - matmul = compute_graph.nodemap[found_pattern[-1]] - for name in found_pattern: - if 'DynamicQuantizeLinear' in name: - continue - compute_graph.remove_node(name) - compute_graph.remove_node(deq_linear.prevnodes[0].name) - compute_graph.remove_node(deq_linear.prevnodes[1].name) - if deq_linear.prevnodes[2].name in compute_graph.nodemap: - compute_graph.remove_node(deq_linear.prevnodes[2].name) - - compute_graph=create_nodes(compute_graph, 'MatMulInteger', f'matmul_integer_{k}', [dyn_q.output[0], f'dq_weights_0_{k}', dyn_q.output[2], f'dq_weights_2_{k}'], [f'matmul_integer_{k}']) - compute_graph=create_nodes(compute_graph, 'Cast', f'cast_{k}', [f'matmul_integer_{k}'], [f'cast_{k}'], to=int(1)) - compute_graph=create_nodes(compute_graph, 'Mul', f'mulscales_{k}', [dyn_q.output[1], f'dq_weights_1_{k}'], [f'mulscales_{k}']) - compute_graph=create_nodes(compute_graph, 'Mul', f'mulvalues_{k}', [f'mulscales_{k}', f'cast_{k}'], [f'mulvalues_{k}']) - compute_graph=create_nodes(compute_graph, 'Add', f'addbias_{k}', [bias, f'mulvalues_{k}'], [matmul.output[0]]) - return compute_graph - -def rewrite_graph(model_path): - print("Rewriting ONNX Graph for Ryzen AI") - - cfg={'constant_folding':False,'node_rename':False,'if_fixed_branch':None,'fixed_topk':0,'verbose':True} - original_output = onnx.load(model_path).graph.output - model = Model(model_path,cfg) - graph = model.graph - - print("Replacing MatMul with MatMulInteger") - pattern = FusionPattern(MatMul) - found_nodes = pattern.search_pattern(graph) - graph = replace_matmul_to_matmulinteger(graph, found_nodes) - - print("Replacing GEMM with MatMulInteger + Add") - pattern = FusionPattern(GEMM) - found_nodes = pattern.search_pattern(graph) - graph = replace_gemm_to_matmulinteger(graph, found_nodes) - - graph.graph_reorder_nodes() - - print("Saving the new ONNX model") - full_path = pathlib.Path(model_path) - - graph = graph.make_graph_onnx(graph.nodemap.keys(), 'graph', graph.input, graph.output, - with_initializer=True, with_shape_info=False) - attr = {} - attr['producer_name'] = 'onnx_tool' - model_to_save = onnx.helper.make_model(graph, **attr) - - for out in original_output: - if out not in model.graph.output: - model.graph.output.append(out) - - model_to_save.ir_version = model.mproto.ir_version - model_to_save.opset_import.pop() - for opset in model.mproto.opset_import: - model_to_save.opset_import.append(opset) - - check_and_save_model(model_to_save, full_path) From 7b10fc3190645fa34c424c539d5979243477764e Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Sat, 16 Mar 2024 23:20:54 +0000 Subject: [PATCH 04/15] Fixes --- examples/quantization/brevitas/quantize_llm.py | 1 - optimum/amd/brevitas/export.py | 7 ++++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/quantization/brevitas/quantize_llm.py b/examples/quantization/brevitas/quantize_llm.py index 2753d154..e61fe8a5 100644 --- a/examples/quantization/brevitas/quantize_llm.py +++ b/examples/quantization/brevitas/quantize_llm.py @@ -5,7 +5,6 @@ from optimum.amd.brevitas.data_utils import compute_perplexity, get_dataset_for_model from optimum.amd.brevitas.export import onnx_export_from_quantized_model from transformers import AutoTokenizer -from rewriter import rewrite_graph def main(args): diff --git a/optimum/amd/brevitas/export.py b/optimum/amd/brevitas/export.py index acfb3c12..126ac146 100644 --- a/optimum/amd/brevitas/export.py +++ b/optimum/amd/brevitas/export.py @@ -19,6 +19,7 @@ from optimum.onnx.graph_transformations import check_and_save_model import pathlib +import os ## Pattern to find and replace with MatMulInteger @@ -196,7 +197,7 @@ def replace_gemm_to_matmulinteger(compute_graph, found_nodes): def find_and_insert_matmulinteger(model_path): print("Rewriting ONNX Graph with MatMulInteger ") - + model_path = os.path.join(model_path, 'model.onnx') cfg={'constant_folding':False,'node_rename':False,'if_fixed_branch':None,'fixed_topk':0,'verbose':True} original_output = onnx.load(model_path).graph.output model = Model(model_path,cfg) @@ -225,8 +226,8 @@ def find_and_insert_matmulinteger(model_path): # onnx_tools might remove the output nodes from the ONNX graph, so we need to restore it. for out in original_output: - if out not in model.graph.output: - model.graph.output.append(out) + if out not in model_to_save.graph.output: + model_to_save.graph.output.append(out) model_to_save.ir_version = model.mproto.ir_version model_to_save.opset_import.pop() From f3c88007ffb3d0d008769a91bb7a7703cabf40d0 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Sat, 16 Mar 2024 23:22:19 +0000 Subject: [PATCH 05/15] Fix and formatting --- optimum/amd/brevitas/export.py | 213 ++++++++++++++++++--------------- 1 file changed, 117 insertions(+), 96 deletions(-) diff --git a/optimum/amd/brevitas/export.py b/optimum/amd/brevitas/export.py index 126ac146..4f63db53 100644 --- a/optimum/amd/brevitas/export.py +++ b/optimum/amd/brevitas/export.py @@ -8,113 +8,106 @@ from optimum.exporters.onnx import onnx_export_from_model from optimum.exporters.onnx.base import OnnxConfig from transformers.modeling_utils import PreTrainedModel +import os import onnx +import torch +from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager +from brevitas_examples.llm.llm_quant.export import brevitas_proxy_export_mode from onnx_tool import Model +from onnx_tool.fusion import FusionPattern from onnx_tool.node import create_node from onnx_tool.tensor import Tensor -from onnx_tool.fusion import FusionPattern - +from optimum.exporters.onnx import onnx_export_from_model from optimum.onnx.graph_transformations import check_and_save_model -import pathlib -import os - ## Pattern to find and replace with MatMulInteger MatMul = [ { - 'name': 'deq_linear_0', - 'op': 'DequantizeLinear', - 'attrs': [ - ], - 'inport': [], - 'outport': [[0, 'transpose_0', 0]], + "name": "deq_linear_0", + "op": "DequantizeLinear", + "attrs": [], + "inport": [], + "outport": [[0, "transpose_0", 0]], }, { - 'name': 'transpose_0', - 'op': 'Transpose', - 'attrs': [ - ], - 'inport': [[0, 'deq_linear_0', 0]], - 'outport': [[0, 'matmul_0', 1]], + "name": "transpose_0", + "op": "Transpose", + "attrs": [], + "inport": [[0, "deq_linear_0", 0]], + "outport": [[0, "matmul_0", 1]], }, { - 'name': 'quant_linear_1', - 'op': 'DynamicQuantizeLinear', - 'attrs': [ - ], - 'inport': [], - 'outport': [[0, 'deq_linear_1', 0], - [1, 'deq_linear_1', 1], - [2, 'deq_linear_1', 2] - ], + "name": "quant_linear_1", + "op": "DynamicQuantizeLinear", + "attrs": [], + "inport": [], + "outport": [[0, "deq_linear_1", 0], [1, "deq_linear_1", 1], [2, "deq_linear_1", 2]], }, { - 'name': 'deq_linear_1', - 'op': 'DequantizeLinear', - 'attrs': [ + "name": "deq_linear_1", + "op": "DequantizeLinear", + "attrs": [], + "inport": [ + [0, "quant_linear_1", 0], + [1, "quant_linear_1", 1], + [2, "quant_linear_1", 2], ], - 'inport': [[0, 'quant_linear_1', 0], - [1, 'quant_linear_1', 1], - [2, 'quant_linear_1', 2],], - 'outport': [[0, 'matmul_0', 0]], + "outport": [[0, "matmul_0", 0]], }, { - 'name': 'matmul_0', - 'op': 'MatMul', - 'attrs': [ + "name": "matmul_0", + "op": "MatMul", + "attrs": [], + "inport": [ + [0, "deq_linear_1", 0], + [1, "transpose_0", 0], ], - 'inport': [[0, 'deq_linear_1', 0], - [1, 'transpose_0', 0], - ], - 'outport': [], + "outport": [], }, ] GEMM = [ { - 'name': 'deq_linear_0', - 'op': 'DequantizeLinear', - 'attrs': [ - ], - 'inport': [], - 'outport': [[0, 'gemm_0', 1]], + "name": "deq_linear_0", + "op": "DequantizeLinear", + "attrs": [], + "inport": [], + "outport": [[0, "gemm_0", 1]], }, { - 'name': 'quant_linear_1', - 'op': 'DynamicQuantizeLinear', - 'attrs': [ - ], - 'inport': [], - 'outport': [[0, 'deq_linear_1', 0], - [1, 'deq_linear_1', 1], - [2, 'deq_linear_1', 2] - ], + "name": "quant_linear_1", + "op": "DynamicQuantizeLinear", + "attrs": [], + "inport": [], + "outport": [[0, "deq_linear_1", 0], [1, "deq_linear_1", 1], [2, "deq_linear_1", 2]], }, { - 'name': 'deq_linear_1', - 'op': 'DequantizeLinear', - 'attrs': [ + "name": "deq_linear_1", + "op": "DequantizeLinear", + "attrs": [], + "inport": [ + [0, "quant_linear_1", 0], + [1, "quant_linear_1", 1], + [2, "quant_linear_1", 2], ], - 'inport': [[0, 'quant_linear_1', 0], - [1, 'quant_linear_1', 1], - [2, 'quant_linear_1', 2],], - 'outport': [[0, 'gemm_0', 0]], + "outport": [[0, "gemm_0", 0]], }, { - 'name': 'gemm_0', - 'op': 'Gemm', - 'attrs': [ + "name": "gemm_0", + "op": "Gemm", + "attrs": [], + "inport": [ + [0, "deq_linear_1", 0], + [1, "deq_linear_0", 0], ], - 'inport': [[0, 'deq_linear_1', 0], - [1, 'deq_linear_0', 0], - ], - 'outport': [], + "outport": [], }, ] + def create_nodes(graph, op, name, inputs, outputs, intermediate=None, **kwargs): if intermediate is None: @@ -137,19 +130,19 @@ def create_nodes(graph, op, name, inputs, outputs, intermediate=None, **kwargs): return graph + def replace_matmul_to_matmulinteger(compute_graph, found_nodes): for i, found_pattern in enumerate(found_nodes): deq_linear = compute_graph.nodemap[found_pattern[0]] dyn_q = compute_graph.nodemap[found_pattern[2]] dq_weight = deq_linear.prevnodes[0] - compute_graph.add_initial(f'dq_weights_0_{i}', dq_weight.value.transpose()) - compute_graph.add_initial(f'dq_weights_1_{i}', deq_linear.prevnodes[1].value) - compute_graph.add_initial(f'dq_weights_2_{i}', deq_linear.prevnodes[2].value) - + compute_graph.add_initial(f"dq_weights_0_{i}", dq_weight.value.transpose()) + compute_graph.add_initial(f"dq_weights_1_{i}", deq_linear.prevnodes[1].value) + compute_graph.add_initial(f"dq_weights_2_{i}", deq_linear.prevnodes[2].value) matmul = compute_graph.nodemap[found_pattern[-1]] for name in found_pattern: - if 'DynamicQuantizeLinear' in name: + if "DynamicQuantizeLinear" in name: continue compute_graph.remove_node(name) @@ -158,10 +151,22 @@ def replace_matmul_to_matmulinteger(compute_graph, found_nodes): if deq_linear.prevnodes[2].name in compute_graph.nodemap: compute_graph.remove_node(deq_linear.prevnodes[2].name) - compute_graph=create_nodes(compute_graph, 'MatMulInteger', f'matmul_integer_{i}', [dyn_q.output[0], f'dq_weights_0_{i}', dyn_q.output[2], f'dq_weights_2_{i}'], [f'matmul_integer_{i}']) - compute_graph=create_nodes(compute_graph, 'Cast', f'cast_{i}', [f'matmul_integer_{i}'], [f'cast_{i}'], to=int(1)) - compute_graph=create_nodes(compute_graph, 'Mul', f'mulscales_{i}', [dyn_q.output[1], f'dq_weights_1_{i}'], [f'mulscales_{i}']) - compute_graph=create_nodes(compute_graph, 'Mul', f'mulvalues_{i}', [f'mulscales_{i}', f'cast_{i}'], [matmul.output[0]]) + compute_graph = create_nodes( + compute_graph, + "MatMulInteger", + f"matmul_integer_{i}", + [dyn_q.output[0], f"dq_weights_0_{i}", dyn_q.output[2], f"dq_weights_2_{i}"], + [f"matmul_integer_{i}"], + ) + compute_graph = create_nodes( + compute_graph, "Cast", f"cast_{i}", [f"matmul_integer_{i}"], [f"cast_{i}"], to=int(1) + ) + compute_graph = create_nodes( + compute_graph, "Mul", f"mulscales_{i}", [dyn_q.output[1], f"dq_weights_1_{i}"], [f"mulscales_{i}"] + ) + compute_graph = create_nodes( + compute_graph, "Mul", f"mulvalues_{i}", [f"mulscales_{i}", f"cast_{i}"], [matmul.output[0]] + ) return compute_graph @@ -174,13 +179,13 @@ def replace_gemm_to_matmulinteger(compute_graph, found_nodes): deq_linear = compute_graph.nodemap[found_pattern[0]] dyn_q = compute_graph.nodemap[found_pattern[1]] dq_weight = deq_linear.prevnodes[0] - compute_graph.add_initial(f'dq_weights_0_{k}', dq_weight.value.transpose()) - compute_graph.add_initial(f'dq_weights_1_{k}', deq_linear.prevnodes[1].value) - compute_graph.add_initial(f'dq_weights_2_{k}', deq_linear.prevnodes[2].value) - + compute_graph.add_initial(f"dq_weights_0_{k}", dq_weight.value.transpose()) + compute_graph.add_initial(f"dq_weights_1_{k}", deq_linear.prevnodes[1].value) + compute_graph.add_initial(f"dq_weights_2_{k}", deq_linear.prevnodes[2].value) + matmul = compute_graph.nodemap[found_pattern[-1]] for name in found_pattern: - if 'DynamicQuantizeLinear' in name: + if "DynamicQuantizeLinear" in name: continue compute_graph.remove_node(name) compute_graph.remove_node(deq_linear.prevnodes[0].name) @@ -188,19 +193,34 @@ def replace_gemm_to_matmulinteger(compute_graph, found_nodes): if deq_linear.prevnodes[2].name in compute_graph.nodemap: compute_graph.remove_node(deq_linear.prevnodes[2].name) - compute_graph=create_nodes(compute_graph, 'MatMulInteger', f'matmul_integer_{k}', [dyn_q.output[0], f'dq_weights_0_{k}', dyn_q.output[2], f'dq_weights_2_{k}'], [f'matmul_integer_{k}']) - compute_graph=create_nodes(compute_graph, 'Cast', f'cast_{k}', [f'matmul_integer_{k}'], [f'cast_{k}'], to=int(1)) - compute_graph=create_nodes(compute_graph, 'Mul', f'mulscales_{k}', [dyn_q.output[1], f'dq_weights_1_{k}'], [f'mulscales_{k}']) - compute_graph=create_nodes(compute_graph, 'Mul', f'mulvalues_{k}', [f'mulscales_{k}', f'cast_{k}'], [f'mulvalues_{k}']) - compute_graph=create_nodes(compute_graph, 'Add', f'addbias_{k}', [bias, f'mulvalues_{k}'], [matmul.output[0]]) + compute_graph = create_nodes( + compute_graph, + "MatMulInteger", + f"matmul_integer_{k}", + [dyn_q.output[0], f"dq_weights_0_{k}", dyn_q.output[2], f"dq_weights_2_{k}"], + [f"matmul_integer_{k}"], + ) + compute_graph = create_nodes( + compute_graph, "Cast", f"cast_{k}", [f"matmul_integer_{k}"], [f"cast_{k}"], to=int(1) + ) + compute_graph = create_nodes( + compute_graph, "Mul", f"mulscales_{k}", [dyn_q.output[1], f"dq_weights_1_{k}"], [f"mulscales_{k}"] + ) + compute_graph = create_nodes( + compute_graph, "Mul", f"mulvalues_{k}", [f"mulscales_{k}", f"cast_{k}"], [f"mulvalues_{k}"] + ) + compute_graph = create_nodes( + compute_graph, "Add", f"addbias_{k}", [bias, f"mulvalues_{k}"], [matmul.output[0]] + ) return compute_graph + def find_and_insert_matmulinteger(model_path): print("Rewriting ONNX Graph with MatMulInteger ") - model_path = os.path.join(model_path, 'model.onnx') - cfg={'constant_folding':False,'node_rename':False,'if_fixed_branch':None,'fixed_topk':0,'verbose':True} + model_path = os.path.join(model_path, "model.onnx") + cfg = {"constant_folding": False, "node_rename": False, "if_fixed_branch": None, "fixed_topk": 0, "verbose": True} original_output = onnx.load(model_path).graph.output - model = Model(model_path,cfg) + model = Model(model_path, cfg) graph = model.graph print("Replacing MatMul with MatMulInteger") @@ -212,16 +232,17 @@ def find_and_insert_matmulinteger(model_path): pattern = FusionPattern(GEMM) found_nodes = pattern.search_pattern(graph) graph = replace_gemm_to_matmulinteger(graph, found_nodes) - + graph.graph_reorder_nodes() print("Saving the new ONNX model") - full_path = pathlib.Path(model_path) + full_path = Path(model_path) - graph = graph.make_graph_onnx(graph.nodemap.keys(), 'graph', graph.input, graph.output, - with_initializer=True, with_shape_info=False) + graph = graph.make_graph_onnx( + graph.nodemap.keys(), "graph", graph.input, graph.output, with_initializer=True, with_shape_info=False + ) - attr = {'producer_name': 'onnx_tool'} + attr = {"producer_name": "onnx_tool"} model_to_save = onnx.helper.make_model(graph, **attr) # onnx_tools might remove the output nodes from the ONNX graph, so we need to restore it. From 333663d02ac388c3b675fa68c0e7e6022378a713 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Sat, 16 Mar 2024 23:29:22 +0000 Subject: [PATCH 06/15] Fix install --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 972e3020..ef37a213 100644 --- a/setup.py +++ b/setup.py @@ -34,7 +34,7 @@ EXTRAS_REQUIRE = { "quality": QUALITY_REQUIRE, "tests": TESTS_REQUIRE, - "brevitas": ["brevitas", "torch>=2.2", "datasets>=2.17", "onnx", "onnxruntime", "accelerate", "onnx_tools"], + "brevitas": ["brevitas", "torch>=2.2", "datasets>=2.17", "onnx", "onnxruntime", "accelerate", "onnx-tool"], } setup( From 37866e7e8dc87bf5578fdaeaf7c3dc3dd3fad947 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 19 Mar 2024 10:55:56 +0000 Subject: [PATCH 07/15] Formatting --- optimum/amd/brevitas/export.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/optimum/amd/brevitas/export.py b/optimum/amd/brevitas/export.py index 4f63db53..db09fd97 100644 --- a/optimum/amd/brevitas/export.py +++ b/optimum/amd/brevitas/export.py @@ -1,15 +1,7 @@ +import os from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Union -import torch -from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager -from brevitas_examples.llm.llm_quant.export import brevitas_proxy_export_mode - -from optimum.exporters.onnx import onnx_export_from_model -from optimum.exporters.onnx.base import OnnxConfig -from transformers.modeling_utils import PreTrainedModel -import os - import onnx import torch from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager @@ -20,7 +12,9 @@ from onnx_tool.tensor import Tensor from optimum.exporters.onnx import onnx_export_from_model +from optimum.exporters.onnx.base import OnnxConfig from optimum.onnx.graph_transformations import check_and_save_model +from transformers.modeling_utils import PreTrainedModel ## Pattern to find and replace with MatMulInteger From fd64a24a77d8ba0182c9ee5255de384d8301cd46 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 19 Mar 2024 12:26:17 +0000 Subject: [PATCH 08/15] Formatting --- optimum/amd/brevitas/export.py | 1 - 1 file changed, 1 deletion(-) diff --git a/optimum/amd/brevitas/export.py b/optimum/amd/brevitas/export.py index db09fd97..a28a7a1c 100644 --- a/optimum/amd/brevitas/export.py +++ b/optimum/amd/brevitas/export.py @@ -103,7 +103,6 @@ def create_nodes(graph, op, name, inputs, outputs, intermediate=None, **kwargs): - if intermediate is None: intermediate = [] newnode = onnx.helper.make_node(op, inputs + intermediate, outputs, name=name, **kwargs) From dc0296778175d6de13fd9e9a017198e0f1c41276 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 19 Mar 2024 13:29:12 +0000 Subject: [PATCH 09/15] Adding tests for MatmulInteger --- tests/brevitas/test_onnx_export.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/tests/brevitas/test_onnx_export.py b/tests/brevitas/test_onnx_export.py index 260b7ab5..97015f3b 100644 --- a/tests/brevitas/test_onnx_export.py +++ b/tests/brevitas/test_onnx_export.py @@ -11,6 +11,7 @@ import torch from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager from brevitas_examples.llm.llm_quant.export import brevitas_proxy_export_mode +from optimum.amd.brevitas.export import find_and_insert_matmulinteger from parameterized import parameterized from testing_utils import SUPPORTED_MODELS_TINY, VALIDATE_EXPORT_ON_SHAPES, get_quantized_model @@ -145,9 +146,27 @@ def test_dynamic_quantization( onnx_config_class_constructor=onnx_config_class_constructor, shapes_to_validate=VALIDATE_EXPORT_ON_SHAPES, ) + original_matmul_gemm_counter = 0 + onnx_model = onnx.load(os.path.join(tmpdir, "model.onnx")) + + for node in onnx_model.graph.node: + if node.op_type == "Gemm" or node.op_type == "MatMul": + original_matmul_gemm_counter += 1 + find_and_insert_matmulinteger(tmpdir) onnx_model = onnx.load(os.path.join(tmpdir, "model.onnx")) + matmul_gemm_counter = 0 + matmulinteger_counter = 0 for node in onnx_model.graph.node: - # Check that we have MatmulInteger, etc. - pass + if node.op_type == "Gemm" or node.op_type == "MatMul": + matmul_gemm_counter += 1 + + if node.op_type == "MatMulInteger": + matmulinteger_counter += 1 + + # The number of Matmul+Gemm has to be less compared to the model pre-transformation + # This is not zero since there are matmul that are not linear layers so they are not replaced + # and some linears layers can be excluded from quantization + assert matmul_gemm_counter <= original_matmul_gemm_counter + assert matmulinteger_counter > 1 From 09e6cb735148c5fa9349a046b2e53bbbf852fa02 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 19 Mar 2024 13:30:12 +0000 Subject: [PATCH 10/15] Formatting --- tests/brevitas/test_onnx_export.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/brevitas/test_onnx_export.py b/tests/brevitas/test_onnx_export.py index 97015f3b..6d9a12aa 100644 --- a/tests/brevitas/test_onnx_export.py +++ b/tests/brevitas/test_onnx_export.py @@ -11,10 +11,10 @@ import torch from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager from brevitas_examples.llm.llm_quant.export import brevitas_proxy_export_mode -from optimum.amd.brevitas.export import find_and_insert_matmulinteger from parameterized import parameterized from testing_utils import SUPPORTED_MODELS_TINY, VALIDATE_EXPORT_ON_SHAPES, get_quantized_model +from optimum.amd.brevitas.export import find_and_insert_matmulinteger from optimum.exporters import TasksManager from optimum.exporters.onnx import ( export_models, @@ -161,7 +161,7 @@ def test_dynamic_quantization( for node in onnx_model.graph.node: if node.op_type == "Gemm" or node.op_type == "MatMul": matmul_gemm_counter += 1 - + if node.op_type == "MatMulInteger": matmulinteger_counter += 1 From c577f97f1774d142b2d248998a26adbafe3ea33b Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 19 Mar 2024 23:38:28 +0000 Subject: [PATCH 11/15] Bump python version to 3.9 for tests --- .github/workflows/test_brevitas.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test_brevitas.yaml b/.github/workflows/test_brevitas.yaml index 72faf513..5db21168 100644 --- a/.github/workflows/test_brevitas.yaml +++ b/.github/workflows/test_brevitas.yaml @@ -23,7 +23,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: [3.8] + python-version: [3.9] os: [ubuntu-20.04, windows-2019, macos-latest] runs-on: ${{ matrix.os }} From d8de4efa057f132826103c0d60b1e99ca76bcfc0 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 20 Mar 2024 10:01:52 +0000 Subject: [PATCH 12/15] Review --- optimum/amd/brevitas/export.py | 179 +++++++++++++++++++-------------- 1 file changed, 106 insertions(+), 73 deletions(-) diff --git a/optimum/amd/brevitas/export.py b/optimum/amd/brevitas/export.py index a28a7a1c..b8519b39 100644 --- a/optimum/amd/brevitas/export.py +++ b/optimum/amd/brevitas/export.py @@ -1,4 +1,7 @@ +import copy +import logging import os +import sys from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Union @@ -8,6 +11,7 @@ from brevitas_examples.llm.llm_quant.export import brevitas_proxy_export_mode from onnx_tool import Model from onnx_tool.fusion import FusionPattern +from onnx_tool.graph import Graph from onnx_tool.node import create_node from onnx_tool.tensor import Tensor @@ -17,8 +21,11 @@ from transformers.modeling_utils import PreTrainedModel +LOGGER = logging.getLogger(__name__) + + ## Pattern to find and replace with MatMulInteger -MatMul = [ +MATMUL = [ { "name": "deq_linear_0", "op": "DequantizeLinear", @@ -102,12 +109,10 @@ ] -def create_nodes(graph, op, name, inputs, outputs, intermediate=None, **kwargs): - if intermediate is None: - intermediate = [] - newnode = onnx.helper.make_node(op, inputs + intermediate, outputs, name=name, **kwargs) +def create_nodes(graph: Graph, op: str, name: str, inputs: List[str], outputs: List[str], **kwargs): + newnode = onnx.helper.make_node(op, inputs, outputs, name=name, **kwargs) newnode = create_node(newnode) - newnode.input = inputs + intermediate + newnode.input = inputs newnode.output = outputs for i in inputs: if i in graph.consumedby: @@ -124,111 +129,139 @@ def create_nodes(graph, op, name, inputs, outputs, intermediate=None, **kwargs): return graph -def replace_matmul_to_matmulinteger(compute_graph, found_nodes): +def replace_matmul_to_matmulinteger(graph: Graph, found_nodes: List[List[str]], node_count: int = 0): for i, found_pattern in enumerate(found_nodes): - deq_linear = compute_graph.nodemap[found_pattern[0]] - dyn_q = compute_graph.nodemap[found_pattern[2]] + node_count += i + deq_linear = graph.nodemap[found_pattern[0]] + dyn_q = graph.nodemap[found_pattern[2]] dq_weight = deq_linear.prevnodes[0] - compute_graph.add_initial(f"dq_weights_0_{i}", dq_weight.value.transpose()) - compute_graph.add_initial(f"dq_weights_1_{i}", deq_linear.prevnodes[1].value) - compute_graph.add_initial(f"dq_weights_2_{i}", deq_linear.prevnodes[2].value) + graph.add_initial(f"dq_weights_0_{node_count}", dq_weight.value.transpose()) + graph.add_initial(f"dq_weights_1_{node_count}", deq_linear.prevnodes[1].value) + graph.add_initial(f"dq_weights_2_{node_count}", deq_linear.prevnodes[2].value) - matmul = compute_graph.nodemap[found_pattern[-1]] + matmul = graph.nodemap[found_pattern[-1]] for name in found_pattern: if "DynamicQuantizeLinear" in name: continue - compute_graph.remove_node(name) + graph.remove_node(name) - compute_graph.remove_node(deq_linear.prevnodes[0].name) - compute_graph.remove_node(deq_linear.prevnodes[1].name) - if deq_linear.prevnodes[2].name in compute_graph.nodemap: - compute_graph.remove_node(deq_linear.prevnodes[2].name) + graph.remove_node(deq_linear.prevnodes[0].name) + graph.remove_node(deq_linear.prevnodes[1].name) + if deq_linear.prevnodes[2].name in graph.nodemap: + graph.remove_node(deq_linear.prevnodes[2].name) - compute_graph = create_nodes( - compute_graph, + graph = create_nodes( + graph, "MatMulInteger", - f"matmul_integer_{i}", - [dyn_q.output[0], f"dq_weights_0_{i}", dyn_q.output[2], f"dq_weights_2_{i}"], - [f"matmul_integer_{i}"], + f"matmul_integer_{node_count}", + [dyn_q.output[0], f"dq_weights_0_{node_count}", dyn_q.output[2], f"dq_weights_2_{node_count}"], + [f"matmul_integer_{node_count}"], ) - compute_graph = create_nodes( - compute_graph, "Cast", f"cast_{i}", [f"matmul_integer_{i}"], [f"cast_{i}"], to=int(1) + graph = create_nodes( + graph, "Cast", f"cast_{node_count}", [f"matmul_integer_{node_count}"], [f"cast_{node_count}"], to=int(1) ) - compute_graph = create_nodes( - compute_graph, "Mul", f"mulscales_{i}", [dyn_q.output[1], f"dq_weights_1_{i}"], [f"mulscales_{i}"] + graph = create_nodes( + graph, + "Mul", + f"mulscales_{node_count}", + [dyn_q.output[1], f"dq_weights_1_{node_count}"], + [f"mulscales_{node_count}"], ) - compute_graph = create_nodes( - compute_graph, "Mul", f"mulvalues_{i}", [f"mulscales_{i}", f"cast_{i}"], [matmul.output[0]] + graph = create_nodes( + graph, + "Mul", + f"mulvalues_{node_count}", + [f"mulscales_{node_count}", f"cast_{node_count}"], + [matmul.output[0]], ) - return compute_graph + return graph -def replace_gemm_to_matmulinteger(compute_graph, found_nodes): - k = 100 +def replace_gemm_to_matmulinteger(graph: Graph, found_nodes: List[List[str]], node_count: int = 0): for i, found_pattern in enumerate(found_nodes): - k = i + 100 - gemm = compute_graph.nodemap[found_pattern[-1]] + node_count += 1 + gemm = graph.nodemap[found_pattern[-1]] bias = gemm.input[-1] - deq_linear = compute_graph.nodemap[found_pattern[0]] - dyn_q = compute_graph.nodemap[found_pattern[1]] + deq_linear = graph.nodemap[found_pattern[0]] + dyn_q = graph.nodemap[found_pattern[1]] dq_weight = deq_linear.prevnodes[0] - compute_graph.add_initial(f"dq_weights_0_{k}", dq_weight.value.transpose()) - compute_graph.add_initial(f"dq_weights_1_{k}", deq_linear.prevnodes[1].value) - compute_graph.add_initial(f"dq_weights_2_{k}", deq_linear.prevnodes[2].value) + graph.add_initial(f"dq_weights_0_{node_count}", dq_weight.value.transpose()) + graph.add_initial(f"dq_weights_1_{node_count}", deq_linear.prevnodes[1].value) + graph.add_initial(f"dq_weights_2_{node_count}", deq_linear.prevnodes[2].value) - matmul = compute_graph.nodemap[found_pattern[-1]] + matmul = graph.nodemap[found_pattern[-1]] for name in found_pattern: if "DynamicQuantizeLinear" in name: continue - compute_graph.remove_node(name) - compute_graph.remove_node(deq_linear.prevnodes[0].name) - compute_graph.remove_node(deq_linear.prevnodes[1].name) - if deq_linear.prevnodes[2].name in compute_graph.nodemap: - compute_graph.remove_node(deq_linear.prevnodes[2].name) - - compute_graph = create_nodes( - compute_graph, + graph.remove_node(name) + graph.remove_node(deq_linear.prevnodes[0].name) + graph.remove_node(deq_linear.prevnodes[1].name) + if deq_linear.prevnodes[2].name in graph.nodemap: + graph.remove_node(deq_linear.prevnodes[2].name) + + graph = create_nodes( + graph, "MatMulInteger", - f"matmul_integer_{k}", - [dyn_q.output[0], f"dq_weights_0_{k}", dyn_q.output[2], f"dq_weights_2_{k}"], - [f"matmul_integer_{k}"], + f"matmul_integer_{node_count}", + [dyn_q.output[0], f"dq_weights_0_{node_count}", dyn_q.output[2], f"dq_weights_2_{node_count}"], + [f"matmul_integer_{node_count}"], ) - compute_graph = create_nodes( - compute_graph, "Cast", f"cast_{k}", [f"matmul_integer_{k}"], [f"cast_{k}"], to=int(1) + graph = create_nodes( + graph, "Cast", f"cast_{node_count}", [f"matmul_integer_{node_count}"], [f"cast_{node_count}"], to=int(1) ) - compute_graph = create_nodes( - compute_graph, "Mul", f"mulscales_{k}", [dyn_q.output[1], f"dq_weights_1_{k}"], [f"mulscales_{k}"] + graph = create_nodes( + graph, + "Mul", + f"mulscales_{node_count}", + [dyn_q.output[1], f"dq_weights_1_{node_count}"], + [f"mulscales_{node_count}"], ) - compute_graph = create_nodes( - compute_graph, "Mul", f"mulvalues_{k}", [f"mulscales_{k}", f"cast_{k}"], [f"mulvalues_{k}"] + graph = create_nodes( + graph, + "Mul", + f"mulvalues_{node_count}", + [f"mulscales_{node_count}", f"cast_{node_count}"], + [f"mulvalues_{node_count}"], ) - compute_graph = create_nodes( - compute_graph, "Add", f"addbias_{k}", [bias, f"mulvalues_{k}"], [matmul.output[0]] + graph = create_nodes( + graph, "Add", f"addbias_{node_count}", [bias, f"mulvalues_{node_count}"], [matmul.output[0]] ) - return compute_graph + return graph -def find_and_insert_matmulinteger(model_path): - print("Rewriting ONNX Graph with MatMulInteger ") +def find_and_insert_matmulinteger(model_path: str): + + # onnx_tool requires python 3.9+ + if sys.version_info[0] == 3 and sys.version_info[1] <= 8: + raise RuntimeError("onnx_tool requires Python 3.9 or higher") + + LOGGER.info("Rewriting ONNX Graph with MatMulInteger") model_path = os.path.join(model_path, "model.onnx") - cfg = {"constant_folding": False, "node_rename": False, "if_fixed_branch": None, "fixed_topk": 0, "verbose": True} - original_output = onnx.load(model_path).graph.output - model = Model(model_path, cfg) + cfg = {"constant_folding": False, "node_rename": False, "if_fixed_branch": None, "fixed_topk": 0, "verbose": False} + + onnx_model = onnx.load(model_path) + + # Extract model output + original_output = copy.deepcopy(onnx_model.graph.output) + + model = Model(onnx_model, cfg) graph = model.graph - print("Replacing MatMul with MatMulInteger") - pattern = FusionPattern(MatMul) - found_nodes = pattern.search_pattern(graph) - graph = replace_matmul_to_matmulinteger(graph, found_nodes) + pattern = FusionPattern(MATMUL) + found_matmul_nodes = pattern.search_pattern(graph) + matmul_node_count = len(found_matmul_nodes) + LOGGER.info(f"Replacing {matmul_node_count} MatMul nodes with MatMulInteger") + graph = replace_matmul_to_matmulinteger(graph, found_matmul_nodes) - print("Replacing GEMM with MatMulInteger + Add") pattern = FusionPattern(GEMM) - found_nodes = pattern.search_pattern(graph) - graph = replace_gemm_to_matmulinteger(graph, found_nodes) + found_gemm_nodes = pattern.search_pattern(graph) + gemm_node_count = len(found_gemm_nodes) + LOGGER.info(f"Replacing {gemm_node_count} Gemm nodes with MatMulInteger + Add") + graph = replace_gemm_to_matmulinteger(graph, found_gemm_nodes, matmul_node_count) graph.graph_reorder_nodes() - print("Saving the new ONNX model") + LOGGER.info("Saving the new ONNX model") full_path = Path(model_path) graph = graph.make_graph_onnx( From 21a05fa9dfb134b80565ca33379a5a45895b430c Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 20 Mar 2024 11:28:22 +0000 Subject: [PATCH 13/15] Fix node names --- optimum/amd/brevitas/export.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/optimum/amd/brevitas/export.py b/optimum/amd/brevitas/export.py index b8519b39..ba1e8cf2 100644 --- a/optimum/amd/brevitas/export.py +++ b/optimum/amd/brevitas/export.py @@ -23,6 +23,7 @@ LOGGER = logging.getLogger(__name__) +ONNX_FLOAT32_IDENTIFIER = int(1) ## Pattern to find and replace with MatMulInteger MATMUL = [ @@ -130,8 +131,9 @@ def create_nodes(graph: Graph, op: str, name: str, inputs: List[str], outputs: L def replace_matmul_to_matmulinteger(graph: Graph, found_nodes: List[List[str]], node_count: int = 0): - for i, found_pattern in enumerate(found_nodes): - node_count += i + for found_pattern in found_nodes: + node_count += 1 + deq_linear = graph.nodemap[found_pattern[0]] dyn_q = graph.nodemap[found_pattern[2]] dq_weight = deq_linear.prevnodes[0] @@ -158,7 +160,7 @@ def replace_matmul_to_matmulinteger(graph: Graph, found_nodes: List[List[str]], [f"matmul_integer_{node_count}"], ) graph = create_nodes( - graph, "Cast", f"cast_{node_count}", [f"matmul_integer_{node_count}"], [f"cast_{node_count}"], to=int(1) + graph, "Cast", f"cast_{node_count}", [f"matmul_integer_{node_count}"], [f"cast_{node_count}"], to=ONNX_FLOAT32_IDENTIFIER ) graph = create_nodes( graph, @@ -178,8 +180,9 @@ def replace_matmul_to_matmulinteger(graph: Graph, found_nodes: List[List[str]], def replace_gemm_to_matmulinteger(graph: Graph, found_nodes: List[List[str]], node_count: int = 0): - for i, found_pattern in enumerate(found_nodes): + for found_pattern in found_nodes: node_count += 1 + gemm = graph.nodemap[found_pattern[-1]] bias = gemm.input[-1] deq_linear = graph.nodemap[found_pattern[0]] @@ -207,7 +210,7 @@ def replace_gemm_to_matmulinteger(graph: Graph, found_nodes: List[List[str]], no [f"matmul_integer_{node_count}"], ) graph = create_nodes( - graph, "Cast", f"cast_{node_count}", [f"matmul_integer_{node_count}"], [f"cast_{node_count}"], to=int(1) + graph, "Cast", f"cast_{node_count}", [f"matmul_integer_{node_count}"], [f"cast_{node_count}"], to=ONNX_FLOAT32_IDENTIFIER ) graph = create_nodes( graph, From fac256f45d1d4ff6b0c3a91f34a1917e95404883 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 20 Mar 2024 11:29:10 +0000 Subject: [PATCH 14/15] Formatting --- optimum/amd/brevitas/export.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/optimum/amd/brevitas/export.py b/optimum/amd/brevitas/export.py index ba1e8cf2..24525895 100644 --- a/optimum/amd/brevitas/export.py +++ b/optimum/amd/brevitas/export.py @@ -160,7 +160,12 @@ def replace_matmul_to_matmulinteger(graph: Graph, found_nodes: List[List[str]], [f"matmul_integer_{node_count}"], ) graph = create_nodes( - graph, "Cast", f"cast_{node_count}", [f"matmul_integer_{node_count}"], [f"cast_{node_count}"], to=ONNX_FLOAT32_IDENTIFIER + graph, + "Cast", + f"cast_{node_count}", + [f"matmul_integer_{node_count}"], + [f"cast_{node_count}"], + to=ONNX_FLOAT32_IDENTIFIER, ) graph = create_nodes( graph, @@ -210,7 +215,12 @@ def replace_gemm_to_matmulinteger(graph: Graph, found_nodes: List[List[str]], no [f"matmul_integer_{node_count}"], ) graph = create_nodes( - graph, "Cast", f"cast_{node_count}", [f"matmul_integer_{node_count}"], [f"cast_{node_count}"], to=ONNX_FLOAT32_IDENTIFIER + graph, + "Cast", + f"cast_{node_count}", + [f"matmul_integer_{node_count}"], + [f"cast_{node_count}"], + to=ONNX_FLOAT32_IDENTIFIER, ) graph = create_nodes( graph, @@ -233,7 +243,6 @@ def replace_gemm_to_matmulinteger(graph: Graph, found_nodes: List[List[str]], no def find_and_insert_matmulinteger(model_path: str): - # onnx_tool requires python 3.9+ if sys.version_info[0] == 3 and sys.version_info[1] <= 8: raise RuntimeError("onnx_tool requires Python 3.9 or higher") From 4e2523d0bca99f8c9aeb660dff8560949f99f808 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 22 Mar 2024 10:17:41 +0100 Subject: [PATCH 15/15] Change assert type --- tests/brevitas/test_onnx_export.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/brevitas/test_onnx_export.py b/tests/brevitas/test_onnx_export.py index 6d9a12aa..836d21b8 100644 --- a/tests/brevitas/test_onnx_export.py +++ b/tests/brevitas/test_onnx_export.py @@ -168,5 +168,5 @@ def test_dynamic_quantization( # The number of Matmul+Gemm has to be less compared to the model pre-transformation # This is not zero since there are matmul that are not linear layers so they are not replaced # and some linears layers can be excluded from quantization - assert matmul_gemm_counter <= original_matmul_gemm_counter - assert matmulinteger_counter > 1 + self.assertTrue(matmul_gemm_counter <= original_matmul_gemm_counter) + self.assertTrue(matmulinteger_counter > 1)