From 6876ee2fff8c109dafc4862ebbb011f2554726cc Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Sun, 26 Sep 2021 15:45:48 +0800 Subject: [PATCH 1/7] add oneflow fx quantization aware training test script --- ci/test.sh | 3 +- .../quantization/test_resnet18.py | 137 ++++++++++++++++++ 2 files changed, 139 insertions(+), 1 deletion(-) create mode 100644 examples/oneflow2onnx/quantization/test_resnet18.py diff --git a/ci/test.sh b/ci/test.sh index e1cd180..ed2fb16 100755 --- a/ci/test.sh +++ b/ci/test.sh @@ -4,4 +4,5 @@ python3 -m pip install --user --upgrade pip if [ -f requirements.txt ]; then python3 -m pip install -r requirements.txt --user; fi python3 -m pip install oneflow --user -U -f https://staging.oneflow.info/branch/master/cu110 python3 setup.py install -python3 -m pytest examples/oneflow2onnx +python3 -m pytest examples/oneflow2onnx/nodes +python3 -m pytest examples/oneflow2onnx/models diff --git a/examples/oneflow2onnx/quantization/test_resnet18.py b/examples/oneflow2onnx/quantization/test_resnet18.py new file mode 100644 index 0000000..51ca5d8 --- /dev/null +++ b/examples/oneflow2onnx/quantization/test_resnet18.py @@ -0,0 +1,137 @@ +import tempfile +import oneflow as flow +import oneflow.nn as nn +import oneflow.nn.functional as F +from oneflow.fx.passes.quantization import quantization_aware_training +from oneflow_onnx.oneflow2onnx.util import convert_to_onnx_and_check + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, in_planes, planes, stride=1): + super(BasicBlock, self).__init__() + self.conv1 = nn.Conv2d( + in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, + stride=1, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion*planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, self.expansion*planes, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(self.expansion*planes) + ) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = self.bn2(self.conv2(out)) + out += self.shortcut(x) + out = F.relu(out) + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, in_planes, planes, stride=1): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, + stride=stride, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, self.expansion * + planes, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(self.expansion*planes) + + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion*planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, self.expansion*planes, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(self.expansion*planes) + ) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = F.relu(self.bn2(self.conv2(out))) + out = self.bn3(self.conv3(out)) + out += self.shortcut(x) + out = F.relu(out) + return out + + +class ResNet(nn.Module): + def __init__(self, block, num_blocks, num_classes=10): + super(ResNet, self).__init__() + self.in_planes = 64 + + self.conv1 = nn.Conv2d(3, 64, kernel_size=3, + stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) + self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) + self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) + self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) + self.linear = nn.Linear(512*block.expansion, num_classes) + + def _make_layer(self, block, planes, num_blocks, stride): + strides = [stride] + [1]*(num_blocks-1) + layers = [] + for stride in strides: + layers.append(block(self.in_planes, planes, stride)) + self.in_planes = planes * block.expansion + return nn.Sequential(*layers) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = self.layer1(out) + out = self.layer2(out) + out = self.layer3(out) + out = self.layer4(out) + out = F.avg_pool2d(out, 4) + out = out.view(out.size(0), -1) + out = self.linear(out) + return out + + +def ResNet18(): + return ResNet(BasicBlock, [2, 2, 2, 2]) + +quantization_resnet18 = ResNet18() + +gm: flow.fx.GraphModule = flow.fx.symbolic_trace(quantization_resnet18) +qconfig = { + 'quantization_bit': 8, + 'quantization_scheme': "symmetric", + 'quantization_formula': "cambricon", + 'per_layer_quantization': True, + 'momentum': 0.95, +} + +quantization_resnet18 = quantization_aware_training(gm, flow.randn(1, 3, 32, 32), qconfig) +quantization_resnet18 = quantization_resnet18.to("cuda") +quantization_resnet18.eval() + +class ResNet18Graph(flow.nn.Graph): + def __init__(self): + super().__init__() + self.m = quantization_resnet18 + + def build(self, x): + out = self.m(x) + return out + +def test_resnet(): + + resnet_graph = ResNet18Graph() + resnet_graph._compile(flow.randn(1, 3, 224, 224).to("cuda")) + + with tempfile.TemporaryDirectory() as tmpdirname: + flow.save(quantization_resnet18.state_dict(), tmpdirname) + convert_to_onnx_and_check(resnet_graph, flow_weight_dir=tmpdirname, onnx_model_path="/tmp", print_outlier=True) + +test_resnet() From 93ba4b41ae97ee2650b20da36b8507e3eb0fc214 Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Sun, 26 Sep 2021 20:24:03 +0800 Subject: [PATCH 2/7] update nn.graph throw eeror script --- examples/oneflow2onnx/quantization/test_resnet18.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/examples/oneflow2onnx/quantization/test_resnet18.py b/examples/oneflow2onnx/quantization/test_resnet18.py index 51ca5d8..e26ae24 100644 --- a/examples/oneflow2onnx/quantization/test_resnet18.py +++ b/examples/oneflow2onnx/quantization/test_resnet18.py @@ -101,9 +101,9 @@ def forward(self, x): def ResNet18(): return ResNet(BasicBlock, [2, 2, 2, 2]) -quantization_resnet18 = ResNet18() +resnet18 = ResNet18() -gm: flow.fx.GraphModule = flow.fx.symbolic_trace(quantization_resnet18) +gm: flow.fx.GraphModule = flow.fx.symbolic_trace(resnet18) qconfig = { 'quantization_bit': 8, 'quantization_scheme': "symmetric", @@ -128,10 +128,11 @@ def build(self, x): def test_resnet(): resnet_graph = ResNet18Graph() - resnet_graph._compile(flow.randn(1, 3, 224, 224).to("cuda")) + resnet_graph.debug() + resnet_graph._compile(flow.randn(1, 3, 32, 32).to("cuda")) - with tempfile.TemporaryDirectory() as tmpdirname: - flow.save(quantization_resnet18.state_dict(), tmpdirname) - convert_to_onnx_and_check(resnet_graph, flow_weight_dir=tmpdirname, onnx_model_path="/tmp", print_outlier=True) + # with tempfile.TemporaryDirectory() as tmpdirname: + # flow.save(quantization_resnet18.state_dict(), tmpdirname) + # convert_to_onnx_and_check(resnet_graph, flow_weight_dir=tmpdirname, onnx_model_path="/tmp", print_outlier=True) test_resnet() From 0d1f9600506fd703d634b7b4850989bdfa5e25f9 Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Mon, 27 Sep 2021 18:53:09 +0800 Subject: [PATCH 3/7] add inceptionv3 init_weights --- .../oneflow2onnx/models/test_inceptionv3.py | 20 +++++++++---------- .../quantization/test_resnet18.py | 2 +- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/examples/oneflow2onnx/models/test_inceptionv3.py b/examples/oneflow2onnx/models/test_inceptionv3.py index cd66059..f790a5a 100644 --- a/examples/oneflow2onnx/models/test_inceptionv3.py +++ b/examples/oneflow2onnx/models/test_inceptionv3.py @@ -120,16 +120,16 @@ def __init__( self.dropout = nn.Dropout() self.fc = nn.Linear(2048, num_classes) - # if init_weights: - # for m in self.modules(): - # if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): - # stddev = float(m.stddev) if hasattr(m, "stddev") else 0.1 # type: ignore - # flow.nn.init.trunc_normal_( - # m.weight, mean=0.0, std=stddev, a=-2, b=2 - # ) - # elif isinstance(m, nn.BatchNorm2d): - # nn.init.constant_(m.weight, 1) - # nn.init.constant_(m.bias, 0) + if init_weights: + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): + stddev = float(m.stddev) if hasattr(m, "stddev") else 0.1 # type: ignore + flow.nn.init.trunc_normal_( + m.weight, mean=0.0, std=stddev, a=-2, b=2 + ) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) def _transform_input(self, x: Tensor) -> Tensor: if self.transform_input: diff --git a/examples/oneflow2onnx/quantization/test_resnet18.py b/examples/oneflow2onnx/quantization/test_resnet18.py index e26ae24..0918cdb 100644 --- a/examples/oneflow2onnx/quantization/test_resnet18.py +++ b/examples/oneflow2onnx/quantization/test_resnet18.py @@ -128,7 +128,7 @@ def build(self, x): def test_resnet(): resnet_graph = ResNet18Graph() - resnet_graph.debug() + # resnet_graph.debug() resnet_graph._compile(flow.randn(1, 3, 32, 32).to("cuda")) # with tempfile.TemporaryDirectory() as tmpdirname: From 020a8d3c8c433dbbbd4e6e8ebf3dcbf6b85a533b Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Tue, 28 Sep 2021 15:21:46 +0800 Subject: [PATCH 4/7] fix graph eval backward bug --- examples/oneflow2onnx/quantization/test_resnet18.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/oneflow2onnx/quantization/test_resnet18.py b/examples/oneflow2onnx/quantization/test_resnet18.py index 0918cdb..bfd5766 100644 --- a/examples/oneflow2onnx/quantization/test_resnet18.py +++ b/examples/oneflow2onnx/quantization/test_resnet18.py @@ -102,6 +102,7 @@ def ResNet18(): return ResNet(BasicBlock, [2, 2, 2, 2]) resnet18 = ResNet18() +resnet18.eval() gm: flow.fx.GraphModule = flow.fx.symbolic_trace(resnet18) qconfig = { @@ -115,6 +116,7 @@ def ResNet18(): quantization_resnet18 = quantization_aware_training(gm, flow.randn(1, 3, 32, 32), qconfig) quantization_resnet18 = quantization_resnet18.to("cuda") quantization_resnet18.eval() +print(quantization_resnet18) class ResNet18Graph(flow.nn.Graph): def __init__(self): @@ -128,7 +130,6 @@ def build(self, x): def test_resnet(): resnet_graph = ResNet18Graph() - # resnet_graph.debug() resnet_graph._compile(flow.randn(1, 3, 32, 32).to("cuda")) # with tempfile.TemporaryDirectory() as tmpdirname: From b84727b3899d151578557273b7fa83892bfedae3 Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Thu, 30 Sep 2021 21:21:02 +0800 Subject: [PATCH 5/7] test resnet18 success --- .../quantization/test_resnet18.py | 22 +++++++++++++------ 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/examples/oneflow2onnx/quantization/test_resnet18.py b/examples/oneflow2onnx/quantization/test_resnet18.py index bfd5766..4cb9ecb 100644 --- a/examples/oneflow2onnx/quantization/test_resnet18.py +++ b/examples/oneflow2onnx/quantization/test_resnet18.py @@ -3,6 +3,7 @@ import oneflow.nn as nn import oneflow.nn.functional as F from oneflow.fx.passes.quantization import quantization_aware_training +from oneflow.fx.passes.dequantization import dequantization_aware_training from oneflow_onnx.oneflow2onnx.util import convert_to_onnx_and_check class BasicBlock(nn.Module): @@ -102,26 +103,32 @@ def ResNet18(): return ResNet(BasicBlock, [2, 2, 2, 2]) resnet18 = ResNet18() +resnet18 = resnet18.to("cuda") resnet18.eval() gm: flow.fx.GraphModule = flow.fx.symbolic_trace(resnet18) qconfig = { 'quantization_bit': 8, 'quantization_scheme': "symmetric", - 'quantization_formula': "cambricon", + 'quantization_formula': "google", 'per_layer_quantization': True, 'momentum': 0.95, } -quantization_resnet18 = quantization_aware_training(gm, flow.randn(1, 3, 32, 32), qconfig) +quantization_resnet18 = quantization_aware_training(gm, flow.randn(1, 3, 32, 32).to("cuda"), qconfig) quantization_resnet18 = quantization_resnet18.to("cuda") quantization_resnet18.eval() -print(quantization_resnet18) + +origin_gm: flow.fx.GraphModule = flow.fx.symbolic_trace(resnet18) +dequantization_resnet18 = dequantization_aware_training(origin_gm, gm, flow.randn(1, 3, 32, 32).to("cuda"), qconfig) +dequantization_resnet18 = dequantization_resnet18.to("cuda") +dequantization_resnet18.eval() +print(dequantization_resnet18) class ResNet18Graph(flow.nn.Graph): def __init__(self): super().__init__() - self.m = quantization_resnet18 + self.m = dequantization_resnet18 def build(self, x): out = self.m(x) @@ -131,9 +138,10 @@ def test_resnet(): resnet_graph = ResNet18Graph() resnet_graph._compile(flow.randn(1, 3, 32, 32).to("cuda")) + # print(resnet_graph) - # with tempfile.TemporaryDirectory() as tmpdirname: - # flow.save(quantization_resnet18.state_dict(), tmpdirname) - # convert_to_onnx_and_check(resnet_graph, flow_weight_dir=tmpdirname, onnx_model_path="/tmp", print_outlier=True) + with tempfile.TemporaryDirectory() as tmpdirname: + flow.save(dequantization_resnet18.state_dict(), tmpdirname) + convert_to_onnx_and_check(resnet_graph, flow_weight_dir=tmpdirname, onnx_model_path="/tmp", print_outlier=True) test_resnet() From 28c8eeb384857b95535410a5613d182992ff1dba Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Fri, 1 Oct 2021 00:16:04 +0800 Subject: [PATCH 6/7] add tensorrt test success --- examples/oneflow2onnx/quantization/common.py | 326 ++++++++++++++++++ .../quantization/test_resnet18.py | 16 +- oneflow_onnx/oneflow2onnx/util.py | 2 +- requirements.txt | 5 + 4 files changed, 341 insertions(+), 8 deletions(-) create mode 100644 examples/oneflow2onnx/quantization/common.py diff --git a/examples/oneflow2onnx/quantization/common.py b/examples/oneflow2onnx/quantization/common.py new file mode 100644 index 0000000..d07de3b --- /dev/null +++ b/examples/oneflow2onnx/quantization/common.py @@ -0,0 +1,326 @@ +# +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from itertools import chain +import argparse +import os + +import pycuda.driver as cuda +import pycuda.autoinit +import numpy as np + +import tensorrt as trt + +try: + # Sometimes python does not understand FileNotFoundError + FileNotFoundError +except NameError: + FileNotFoundError = IOError + +EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) + + +def GiB(val): + return val * 1 << 30 + + +def add_help(description): + parser = argparse.ArgumentParser( + description=description, formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + args, _ = parser.parse_known_args() + + +def find_sample_data( + description="Runs a TensorRT Python sample", subfolder="", find_files=[], err_msg="" +): + """ + Parses sample arguments. + Args: + description (str): Description of the sample. + subfolder (str): The subfolder containing data relevant to this sample + find_files (str): A list of filenames to find. Each filename will be replaced with an absolute path. + Returns: + str: Path of data directory. + """ + + # Standard command-line arguments for all samples. + kDEFAULT_DATA_ROOT = os.path.join(os.sep, "usr", "src", "tensorrt", "data") + parser = argparse.ArgumentParser( + description=description, formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument( + "-d", + "--datadir", + help="Location of the TensorRT sample data directory, and any additional data directories.", + action="append", + default=[kDEFAULT_DATA_ROOT], + ) + args, _ = parser.parse_known_args() + + def get_data_path(data_dir): + # If the subfolder exists, append it to the path, otherwise use the provided path as-is. + data_path = os.path.join(data_dir, subfolder) + if not os.path.exists(data_path): + if data_dir != kDEFAULT_DATA_ROOT: + print( + "WARNING: " + + data_path + + " does not exist. Trying " + + data_dir + + " instead." + ) + data_path = data_dir + # Make sure data directory exists. + if not (os.path.exists(data_path)) and data_dir != kDEFAULT_DATA_ROOT: + print( + "WARNING: {:} does not exist. Please provide the correct data path with the -d option.".format( + data_path + ) + ) + return data_path + + data_paths = [get_data_path(data_dir) for data_dir in args.datadir] + return data_paths, locate_files(data_paths, find_files, err_msg) + + +def locate_files(data_paths, filenames, err_msg=""): + """ + Locates the specified files in the specified data directories. + If a file exists in multiple data directories, the first directory is used. + Args: + data_paths (List[str]): The data directories. + filename (List[str]): The names of the files to find. + Returns: + List[str]: The absolute paths of the files. + Raises: + FileNotFoundError if a file could not be located. + """ + found_files = [None] * len(filenames) + for data_path in data_paths: + # Find all requested files. + for index, (found, filename) in enumerate(zip(found_files, filenames)): + if not found: + file_path = os.path.abspath(os.path.join(data_path, filename)) + if os.path.exists(file_path): + found_files[index] = file_path + + # Check that all files were found + for f, filename in zip(found_files, filenames): + if not f or not os.path.exists(f): + raise FileNotFoundError( + "Could not find {:}. Searched in data paths: {:}\n{:}".format( + filename, data_paths, err_msg + ) + ) + return found_files + + +# Simple helper data class that's a little nicer to use than a 2-tuple. +class HostDeviceMem(object): + def __init__(self, host_mem, device_mem): + self.host = host_mem + self.device = device_mem + + def __str__(self): + return "Host:\n" + str(self.host) + "\nDevice:\n" + str(self.device) + + def __repr__(self): + return self.__str__() + + +# Allocates all buffers required for an engine, i.e. host/device inputs/outputs. +def allocate_buffers(engine): + inputs = [] + outputs = [] + bindings = [] + stream = cuda.Stream() + for binding in engine: + size = trt.volume(engine.get_binding_shape(binding)) * engine.max_batch_size + dtype = trt.nptype(engine.get_binding_dtype(binding)) + # Allocate host and device buffers + host_mem = cuda.pagelocked_empty(size, dtype) + device_mem = cuda.mem_alloc(host_mem.nbytes) + # Append the device buffer to device bindings. + bindings.append(int(device_mem)) + # Append to the appropriate list. + if engine.binding_is_input(binding): + inputs.append(HostDeviceMem(host_mem, device_mem)) + else: + outputs.append(HostDeviceMem(host_mem, device_mem)) + return inputs, outputs, bindings, stream + + +# This function is generalized for multiple inputs/outputs. +# inputs and outputs are expected to be lists of HostDeviceMem objects. +def do_inference(context, bindings, inputs, outputs, stream, batch_size=1): + # Transfer input data to the GPU. + [cuda.memcpy_htod_async(inp.device, inp.host, stream) for inp in inputs] + # Run inference. + context.execute_async( + batch_size=batch_size, bindings=bindings, stream_handle=stream.handle + ) + # Transfer predictions back from the GPU. + [cuda.memcpy_dtoh_async(out.host, out.device, stream) for out in outputs] + # Synchronize the stream + stream.synchronize() + # Return only the host outputs. + return [out.host for out in outputs] + + +# This function is generalized for multiple inputs/outputs for full dimension networks. +# inputs and outputs are expected to be lists of HostDeviceMem objects. +def do_inference_v2(context, bindings, inputs, outputs, stream): + # Transfer input data to the GPU. + [cuda.memcpy_htod_async(inp.device, inp.host, stream) for inp in inputs] + # Run inference. + context.execute_async_v2(bindings=bindings, stream_handle=stream.handle) + # Transfer predictions back from the GPU. + [cuda.memcpy_dtoh_async(out.host, out.device, stream) for out in outputs] + # Synchronize the stream + stream.synchronize() + # Return only the host outputs. + return [out.host for out in outputs] + + +def generate_md5_checksum(local_path): + """Returns the MD5 checksum of a local file. + Keyword argument: + local_path -- path of the file whose checksum shall be generated + """ + with open(local_path, "rb") as local_file: + data = local_file.read() + import hashlib + + return hashlib.md5(data).hexdigest() + + +def download_file(local_path, link, checksum_reference=None): + """Checks if a local file is present and downloads it from the specified path otherwise. + If checksum_reference is specified, the file's md5 checksum is compared against the + expected value. + Keyword arguments: + local_path -- path of the file whose checksum shall be generated + link -- link where the file shall be downloaded from if it is not found locally + checksum_reference -- expected MD5 checksum of the file + """ + if not os.path.exists(local_path): + print("Downloading from %s, this may take a while..." % link) + import wget + + wget.download(link, local_path) + print() + if checksum_reference is not None: + checksum = generate_md5_checksum(local_path) + if checksum != checksum_reference: + raise ValueError( + "The MD5 checksum of local file %s differs from %s, please manually remove \ + the file and try again." + % (local_path, checksum_reference) + ) + return local_path + + +# `retry_call` and `retry` are used to wrap the function we want to try multiple times +def retry_call(func, args=[], kwargs={}, n_retries=3): + """Wrap a function to retry it several times. + Args: + func: function to call + args (List): args parsed to func + kwargs (Dict): kwargs parsed to func + n_retries (int): maximum times of tries + """ + for i_try in range(n_retries): + try: + func(*args, **kwargs) + break + except: + if i_try == n_retries - 1: + raise + print("retry...") + + +# Usage: @retry(n_retries) +def retry(n_retries=3): + """Wrap a function to retry it several times. Decorator version of `retry_call`. + Args: + n_retries (int): maximum times of tries + Usage: + @retry(n_retries) + def func(...): + pass + """ + + def wrapper(func): + def _wrapper(*args, **kwargs): + retry_call(func, args, kwargs, n_retries) + + return _wrapper + + return wrapper + + +def build_qat_engine_from_onnx(model_file, verbose=False): + if verbose: + TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE) + else: + TRT_LOGGER = trt.Logger(trt.Logger.INFO) + + network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) + network_flags = network_flags | ( + 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_PRECISION) + ) + + with trt.Builder(TRT_LOGGER) as builder, builder.create_network( + flags=network_flags + ) as network, trt.OnnxParser(network, TRT_LOGGER) as parser: + with open(model_file, "rb") as model: + if not parser.parse(model.read()): + print("ERROR: Failed to parse the ONNX file.") + for error in range(parser.num_errors): + print(parser.get_error(error)) + return None + config = builder.create_builder_config() + config.max_workspace_size = 1 << 30 + config.flags = config.flags | 1 << int(trt.BuilderFlag.INT8) + return builder.build_engine(network, config) + + +def run_tensorrt(onnx_path, test_case): + with build_qat_engine_from_onnx(onnx_path) as engine: + inputs, outputs, bindings, stream = allocate_buffers(engine) + with engine.create_execution_context() as context: + batch_size = test_case.shape[0] + test_case = test_case.reshape(-1) + np.copyto(inputs[0].host, test_case) + trt_outputs = do_inference_v2( + context, + bindings=bindings, + inputs=inputs, + outputs=outputs, + stream=stream, + ) + data = trt_outputs[0] + return data.reshape(batch_size, -1) + + +def get_onnx_provider(ctx: str = "cpu"): + if ctx == "gpu": + return ["CUDAExecutionProvider"] + elif ctx == "cpu": + return ["CPUExecutionProvider"] + else: + raise NotImplementedError("Not supported device type. ") diff --git a/examples/oneflow2onnx/quantization/test_resnet18.py b/examples/oneflow2onnx/quantization/test_resnet18.py index 4cb9ecb..cc33257 100644 --- a/examples/oneflow2onnx/quantization/test_resnet18.py +++ b/examples/oneflow2onnx/quantization/test_resnet18.py @@ -1,11 +1,13 @@ +import os +import shutil import tempfile import oneflow as flow import oneflow.nn as nn import oneflow.nn.functional as F from oneflow.fx.passes.quantization import quantization_aware_training from oneflow.fx.passes.dequantization import dequantization_aware_training -from oneflow_onnx.oneflow2onnx.util import convert_to_onnx_and_check - +from oneflow_onnx.oneflow2onnx.util import convert_to_onnx_and_check, run_onnx, compare_result +from common import run_tensorrt, get_onnx_provider class BasicBlock(nn.Module): expansion = 1 @@ -123,7 +125,6 @@ def ResNet18(): dequantization_resnet18 = dequantization_aware_training(origin_gm, gm, flow.randn(1, 3, 32, 32).to("cuda"), qconfig) dequantization_resnet18 = dequantization_resnet18.to("cuda") dequantization_resnet18.eval() -print(dequantization_resnet18) class ResNet18Graph(flow.nn.Graph): def __init__(self): @@ -134,14 +135,15 @@ def build(self, x): out = self.m(x) return out -def test_resnet(): - +def test_resnet(): resnet_graph = ResNet18Graph() resnet_graph._compile(flow.randn(1, 3, 32, 32).to("cuda")) - # print(resnet_graph) - with tempfile.TemporaryDirectory() as tmpdirname: flow.save(dequantization_resnet18.state_dict(), tmpdirname) convert_to_onnx_and_check(resnet_graph, flow_weight_dir=tmpdirname, onnx_model_path="/tmp", print_outlier=True) + ipt_dict, onnx_res = run_onnx("/tmp/model.onnx", get_onnx_provider("cpu")) + trt_res = run_tensorrt("/tmp/model.onnx", ipt_dict[list(ipt_dict.keys())[0]]) + compare_result(onnx_res, trt_res, atol=1e-4, print_outlier=True) test_resnet() + diff --git a/oneflow_onnx/oneflow2onnx/util.py b/oneflow_onnx/oneflow2onnx/util.py index 1c7ec15..c1997a1 100644 --- a/oneflow_onnx/oneflow2onnx/util.py +++ b/oneflow_onnx/oneflow2onnx/util.py @@ -46,7 +46,7 @@ def run_onnx( if ipt_dict is None: ipt_dict = OrderedDict() for ipt in sess.get_inputs(): - ipt_data = np.random.uniform(low=-10, high=10, size=ipt.shape).astype( + ipt_data = np.random.uniform(low=-0.5, high=0.5, size=ipt.shape).astype( np.float32 ) ipt_dict[ipt.name] = ipt_data diff --git a/requirements.txt b/requirements.txt index 5d75644..0a4b271 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,7 @@ onnx>=1.8.0 onnxruntime-gpu>=1.8.0 +opencv-python +pytest +nvidia-tensorrt==8.0.0.3 +pycuda +flake8 From 3a4cffed29b01968aa7207d28613c79dc925cbe5 Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Fri, 1 Oct 2021 14:23:46 +0800 Subject: [PATCH 7/7] add resnet18 tensorrt test --- examples/oneflow2onnx/quantization/test_resnet18.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/oneflow2onnx/quantization/test_resnet18.py b/examples/oneflow2onnx/quantization/test_resnet18.py index cc33257..8394d42 100644 --- a/examples/oneflow2onnx/quantization/test_resnet18.py +++ b/examples/oneflow2onnx/quantization/test_resnet18.py @@ -120,6 +120,8 @@ def ResNet18(): quantization_resnet18 = quantization_aware_training(gm, flow.randn(1, 3, 32, 32).to("cuda"), qconfig) quantization_resnet18 = quantization_resnet18.to("cuda") quantization_resnet18.eval() +checkpoint = flow.load('/home/zhangxiaoyu/oneflow-cifar/checkpoint/epoch_11_val_acc_83.280000') +quantization_resnet18.load_state_dict(checkpoint) origin_gm: flow.fx.GraphModule = flow.fx.symbolic_trace(resnet18) dequantization_resnet18 = dequantization_aware_training(origin_gm, gm, flow.randn(1, 3, 32, 32).to("cuda"), qconfig)