Skip to content
Merged
2 changes: 1 addition & 1 deletion .github/workflows/test_brevitas.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: [3.8]
python-version: [3.9]
os: [ubuntu-20.04, windows-2019, macos-latest]

runs-on: ${{ matrix.os }}
Expand Down
291 changes: 291 additions & 0 deletions optimum/amd/brevitas/export.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,301 @@
import copy
import logging
import os
import sys
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Union

import onnx
import torch
from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager
from brevitas_examples.llm.llm_quant.export import brevitas_proxy_export_mode
from onnx_tool import Model
from onnx_tool.fusion import FusionPattern
from onnx_tool.graph import Graph
from onnx_tool.node import create_node
from onnx_tool.tensor import Tensor

from optimum.exporters.onnx import onnx_export_from_model
from optimum.exporters.onnx.base import OnnxConfig
from optimum.onnx.graph_transformations import check_and_save_model
from transformers.modeling_utils import PreTrainedModel


LOGGER = logging.getLogger(__name__)

ONNX_FLOAT32_IDENTIFIER = int(1)

## Pattern to find and replace with MatMulInteger
MATMUL = [
{
"name": "deq_linear_0",
"op": "DequantizeLinear",
"attrs": [],
"inport": [],
"outport": [[0, "transpose_0", 0]],
},
{
"name": "transpose_0",
"op": "Transpose",
"attrs": [],
"inport": [[0, "deq_linear_0", 0]],
"outport": [[0, "matmul_0", 1]],
},
{
"name": "quant_linear_1",
"op": "DynamicQuantizeLinear",
"attrs": [],
"inport": [],
"outport": [[0, "deq_linear_1", 0], [1, "deq_linear_1", 1], [2, "deq_linear_1", 2]],
},
{
"name": "deq_linear_1",
"op": "DequantizeLinear",
"attrs": [],
"inport": [
[0, "quant_linear_1", 0],
[1, "quant_linear_1", 1],
[2, "quant_linear_1", 2],
],
"outport": [[0, "matmul_0", 0]],
},
{
"name": "matmul_0",
"op": "MatMul",
"attrs": [],
"inport": [
[0, "deq_linear_1", 0],
[1, "transpose_0", 0],
],
"outport": [],
},
]

GEMM = [
{
"name": "deq_linear_0",
"op": "DequantizeLinear",
"attrs": [],
"inport": [],
"outport": [[0, "gemm_0", 1]],
},
{
"name": "quant_linear_1",
"op": "DynamicQuantizeLinear",
"attrs": [],
"inport": [],
"outport": [[0, "deq_linear_1", 0], [1, "deq_linear_1", 1], [2, "deq_linear_1", 2]],
},
{
"name": "deq_linear_1",
"op": "DequantizeLinear",
"attrs": [],
"inport": [
[0, "quant_linear_1", 0],
[1, "quant_linear_1", 1],
[2, "quant_linear_1", 2],
],
"outport": [[0, "gemm_0", 0]],
},
{
"name": "gemm_0",
"op": "Gemm",
"attrs": [],
"inport": [
[0, "deq_linear_1", 0],
[1, "deq_linear_0", 0],
],
"outport": [],
},
]


def create_nodes(graph: Graph, op: str, name: str, inputs: List[str], outputs: List[str], **kwargs):
newnode = onnx.helper.make_node(op, inputs, outputs, name=name, **kwargs)
newnode = create_node(newnode)
newnode.input = inputs
newnode.output = outputs
for i in inputs:
if i in graph.consumedby:
graph.consumedby[i].append(name)
if i in graph.producedby.keys():
newnode.prevnodes.append(graph.producedby[i])
for o in outputs:
graph.producedby[o] = [name]
if o in graph.consumedby.keys():
newnode.nextnodes.append(graph.consumedby[o])
graph.nodemap[name] = newnode
graph.tensormap[name] = Tensor(name)

return graph


def replace_matmul_to_matmulinteger(graph: Graph, found_nodes: List[List[str]], node_count: int = 0):
for found_pattern in found_nodes:
node_count += 1

deq_linear = graph.nodemap[found_pattern[0]]
dyn_q = graph.nodemap[found_pattern[2]]
dq_weight = deq_linear.prevnodes[0]
graph.add_initial(f"dq_weights_0_{node_count}", dq_weight.value.transpose())
graph.add_initial(f"dq_weights_1_{node_count}", deq_linear.prevnodes[1].value)
graph.add_initial(f"dq_weights_2_{node_count}", deq_linear.prevnodes[2].value)

matmul = graph.nodemap[found_pattern[-1]]
for name in found_pattern:
if "DynamicQuantizeLinear" in name:
continue
graph.remove_node(name)

graph.remove_node(deq_linear.prevnodes[0].name)
graph.remove_node(deq_linear.prevnodes[1].name)
if deq_linear.prevnodes[2].name in graph.nodemap:
graph.remove_node(deq_linear.prevnodes[2].name)

graph = create_nodes(
graph,
"MatMulInteger",
f"matmul_integer_{node_count}",
[dyn_q.output[0], f"dq_weights_0_{node_count}", dyn_q.output[2], f"dq_weights_2_{node_count}"],
[f"matmul_integer_{node_count}"],
)
graph = create_nodes(
graph,
"Cast",
f"cast_{node_count}",
[f"matmul_integer_{node_count}"],
[f"cast_{node_count}"],
to=ONNX_FLOAT32_IDENTIFIER,
)
graph = create_nodes(
graph,
"Mul",
f"mulscales_{node_count}",
[dyn_q.output[1], f"dq_weights_1_{node_count}"],
[f"mulscales_{node_count}"],
)
graph = create_nodes(
graph,
"Mul",
f"mulvalues_{node_count}",
[f"mulscales_{node_count}", f"cast_{node_count}"],
[matmul.output[0]],
)
return graph


def replace_gemm_to_matmulinteger(graph: Graph, found_nodes: List[List[str]], node_count: int = 0):
for found_pattern in found_nodes:
node_count += 1

gemm = graph.nodemap[found_pattern[-1]]
bias = gemm.input[-1]
deq_linear = graph.nodemap[found_pattern[0]]
dyn_q = graph.nodemap[found_pattern[1]]
dq_weight = deq_linear.prevnodes[0]
graph.add_initial(f"dq_weights_0_{node_count}", dq_weight.value.transpose())
graph.add_initial(f"dq_weights_1_{node_count}", deq_linear.prevnodes[1].value)
graph.add_initial(f"dq_weights_2_{node_count}", deq_linear.prevnodes[2].value)

matmul = graph.nodemap[found_pattern[-1]]
for name in found_pattern:
if "DynamicQuantizeLinear" in name:
continue
graph.remove_node(name)
graph.remove_node(deq_linear.prevnodes[0].name)
graph.remove_node(deq_linear.prevnodes[1].name)
if deq_linear.prevnodes[2].name in graph.nodemap:
graph.remove_node(deq_linear.prevnodes[2].name)

graph = create_nodes(
graph,
"MatMulInteger",
f"matmul_integer_{node_count}",
[dyn_q.output[0], f"dq_weights_0_{node_count}", dyn_q.output[2], f"dq_weights_2_{node_count}"],
[f"matmul_integer_{node_count}"],
)
graph = create_nodes(
graph,
"Cast",
f"cast_{node_count}",
[f"matmul_integer_{node_count}"],
[f"cast_{node_count}"],
to=ONNX_FLOAT32_IDENTIFIER,
)
graph = create_nodes(
graph,
"Mul",
f"mulscales_{node_count}",
[dyn_q.output[1], f"dq_weights_1_{node_count}"],
[f"mulscales_{node_count}"],
)
graph = create_nodes(
graph,
"Mul",
f"mulvalues_{node_count}",
[f"mulscales_{node_count}", f"cast_{node_count}"],
[f"mulvalues_{node_count}"],
)
graph = create_nodes(
graph, "Add", f"addbias_{node_count}", [bias, f"mulvalues_{node_count}"], [matmul.output[0]]
)
return graph


def find_and_insert_matmulinteger(model_path: str):
# onnx_tool requires python 3.9+
if sys.version_info[0] == 3 and sys.version_info[1] <= 8:
raise RuntimeError("onnx_tool requires Python 3.9 or higher")

LOGGER.info("Rewriting ONNX Graph with MatMulInteger")
model_path = os.path.join(model_path, "model.onnx")
cfg = {"constant_folding": False, "node_rename": False, "if_fixed_branch": None, "fixed_topk": 0, "verbose": False}

onnx_model = onnx.load(model_path)

# Extract model output
original_output = copy.deepcopy(onnx_model.graph.output)

model = Model(onnx_model, cfg)
graph = model.graph

pattern = FusionPattern(MATMUL)
found_matmul_nodes = pattern.search_pattern(graph)
matmul_node_count = len(found_matmul_nodes)
LOGGER.info(f"Replacing {matmul_node_count} MatMul nodes with MatMulInteger")
graph = replace_matmul_to_matmulinteger(graph, found_matmul_nodes)

pattern = FusionPattern(GEMM)
found_gemm_nodes = pattern.search_pattern(graph)
gemm_node_count = len(found_gemm_nodes)
LOGGER.info(f"Replacing {gemm_node_count} Gemm nodes with MatMulInteger + Add")
graph = replace_gemm_to_matmulinteger(graph, found_gemm_nodes, matmul_node_count)

graph.graph_reorder_nodes()

LOGGER.info("Saving the new ONNX model")
full_path = Path(model_path)

graph = graph.make_graph_onnx(
graph.nodemap.keys(), "graph", graph.input, graph.output, with_initializer=True, with_shape_info=False
)

attr = {"producer_name": "onnx_tool"}
model_to_save = onnx.helper.make_model(graph, **attr)

# onnx_tools might remove the output nodes from the ONNX graph, so we need to restore it.
for out in original_output:
if out not in model_to_save.graph.output:
model_to_save.graph.output.append(out)

model_to_save.ir_version = model.mproto.ir_version
model_to_save.opset_import.pop()
for opset in model.mproto.opset_import:
model_to_save.opset_import.append(opset)

check_and_save_model(model_to_save, full_path)


def onnx_export_from_quantized_model(
quantized_model: Union["PreTrainedModel"],
output: Union[str, Path],
Expand All @@ -26,6 +312,7 @@ def onnx_export_from_quantized_model(
task: str = "text-generation-with-past",
use_subprocess: bool = False,
do_constant_folding: bool = True,
insert_matmulinteger: bool = True,
**kwargs_shapes,
):
with torch.no_grad(), brevitas_proxy_export_mode(quantized_model, export_manager=StdQCDQONNXManager):
Expand All @@ -49,3 +336,7 @@ def onnx_export_from_quantized_model(
no_post_process=True,
**kwargs_shapes,
)

# Replace quantized GEMM and MatMul with MatMulInteger
if insert_matmulinteger:
find_and_insert_matmulinteger(output)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
EXTRAS_REQUIRE = {
"quality": QUALITY_REQUIRE,
"tests": TESTS_REQUIRE,
"brevitas": ["brevitas", "torch>=2.2", "datasets>=2.17", "onnx", "onnxruntime", "accelerate"],
"brevitas": ["brevitas", "torch>=2.2", "datasets>=2.17", "onnx", "onnxruntime", "accelerate", "onnx-tool"],
}

setup(
Expand Down
23 changes: 21 additions & 2 deletions tests/brevitas/test_onnx_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from parameterized import parameterized
from testing_utils import SUPPORTED_MODELS_TINY, VALIDATE_EXPORT_ON_SHAPES, get_quantized_model

from optimum.amd.brevitas.export import find_and_insert_matmulinteger
from optimum.exporters import TasksManager
from optimum.exporters.onnx import (
export_models,
Expand Down Expand Up @@ -145,9 +146,27 @@ def test_dynamic_quantization(
onnx_config_class_constructor=onnx_config_class_constructor,
shapes_to_validate=VALIDATE_EXPORT_ON_SHAPES,
)
original_matmul_gemm_counter = 0
onnx_model = onnx.load(os.path.join(tmpdir, "model.onnx"))

for node in onnx_model.graph.node:
if node.op_type == "Gemm" or node.op_type == "MatMul":
original_matmul_gemm_counter += 1

find_and_insert_matmulinteger(tmpdir)
onnx_model = onnx.load(os.path.join(tmpdir, "model.onnx"))

matmul_gemm_counter = 0
matmulinteger_counter = 0
for node in onnx_model.graph.node:
# Check that we have MatmulInteger, etc.
pass
if node.op_type == "Gemm" or node.op_type == "MatMul":
matmul_gemm_counter += 1

if node.op_type == "MatMulInteger":
matmulinteger_counter += 1

# The number of Matmul+Gemm has to be less compared to the model pre-transformation
# This is not zero since there are matmul that are not linear layers so they are not replaced
# and some linears layers can be excluded from quantization
self.assertTrue(matmul_gemm_counter <= original_matmul_gemm_counter)
self.assertTrue(matmulinteger_counter > 1)