Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion ci/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ python3 -m pip install --user --upgrade pip
if [ -f requirements.txt ]; then python3 -m pip install -r requirements.txt --user; fi
python3 -m pip install oneflow --user -U -f https://staging.oneflow.info/branch/master/cu110
python3 setup.py install
python3 -m pytest examples/oneflow2onnx
python3 -m pytest examples/oneflow2onnx/nodes
python3 -m pytest examples/oneflow2onnx/models
20 changes: 10 additions & 10 deletions examples/oneflow2onnx/models/test_inceptionv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,16 +120,16 @@ def __init__(
self.dropout = nn.Dropout()
self.fc = nn.Linear(2048, num_classes)

# if init_weights:
# for m in self.modules():
# if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
# stddev = float(m.stddev) if hasattr(m, "stddev") else 0.1 # type: ignore
# flow.nn.init.trunc_normal_(
# m.weight, mean=0.0, std=stddev, a=-2, b=2
# )
# elif isinstance(m, nn.BatchNorm2d):
# nn.init.constant_(m.weight, 1)
# nn.init.constant_(m.bias, 0)
if init_weights:
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
stddev = float(m.stddev) if hasattr(m, "stddev") else 0.1 # type: ignore
flow.nn.init.trunc_normal_(
m.weight, mean=0.0, std=stddev, a=-2, b=2
)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)

def _transform_input(self, x: Tensor) -> Tensor:
if self.transform_input:
Expand Down
326 changes: 326 additions & 0 deletions examples/oneflow2onnx/quantization/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,326 @@
#
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from itertools import chain
import argparse
import os

import pycuda.driver as cuda
import pycuda.autoinit
import numpy as np

import tensorrt as trt

try:
# Sometimes python does not understand FileNotFoundError
FileNotFoundError
except NameError:
FileNotFoundError = IOError

EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)


def GiB(val):
return val * 1 << 30


def add_help(description):
parser = argparse.ArgumentParser(
description=description, formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
args, _ = parser.parse_known_args()


def find_sample_data(
description="Runs a TensorRT Python sample", subfolder="", find_files=[], err_msg=""
):
"""
Parses sample arguments.
Args:
description (str): Description of the sample.
subfolder (str): The subfolder containing data relevant to this sample
find_files (str): A list of filenames to find. Each filename will be replaced with an absolute path.
Returns:
str: Path of data directory.
"""

# Standard command-line arguments for all samples.
kDEFAULT_DATA_ROOT = os.path.join(os.sep, "usr", "src", "tensorrt", "data")
parser = argparse.ArgumentParser(
description=description, formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"-d",
"--datadir",
help="Location of the TensorRT sample data directory, and any additional data directories.",
action="append",
default=[kDEFAULT_DATA_ROOT],
)
args, _ = parser.parse_known_args()

def get_data_path(data_dir):
# If the subfolder exists, append it to the path, otherwise use the provided path as-is.
data_path = os.path.join(data_dir, subfolder)
if not os.path.exists(data_path):
if data_dir != kDEFAULT_DATA_ROOT:
print(
"WARNING: "
+ data_path
+ " does not exist. Trying "
+ data_dir
+ " instead."
)
data_path = data_dir
# Make sure data directory exists.
if not (os.path.exists(data_path)) and data_dir != kDEFAULT_DATA_ROOT:
print(
"WARNING: {:} does not exist. Please provide the correct data path with the -d option.".format(
data_path
)
)
return data_path

data_paths = [get_data_path(data_dir) for data_dir in args.datadir]
return data_paths, locate_files(data_paths, find_files, err_msg)


def locate_files(data_paths, filenames, err_msg=""):
"""
Locates the specified files in the specified data directories.
If a file exists in multiple data directories, the first directory is used.
Args:
data_paths (List[str]): The data directories.
filename (List[str]): The names of the files to find.
Returns:
List[str]: The absolute paths of the files.
Raises:
FileNotFoundError if a file could not be located.
"""
found_files = [None] * len(filenames)
for data_path in data_paths:
# Find all requested files.
for index, (found, filename) in enumerate(zip(found_files, filenames)):
if not found:
file_path = os.path.abspath(os.path.join(data_path, filename))
if os.path.exists(file_path):
found_files[index] = file_path

# Check that all files were found
for f, filename in zip(found_files, filenames):
if not f or not os.path.exists(f):
raise FileNotFoundError(
"Could not find {:}. Searched in data paths: {:}\n{:}".format(
filename, data_paths, err_msg
)
)
return found_files


# Simple helper data class that's a little nicer to use than a 2-tuple.
class HostDeviceMem(object):
def __init__(self, host_mem, device_mem):
self.host = host_mem
self.device = device_mem

def __str__(self):
return "Host:\n" + str(self.host) + "\nDevice:\n" + str(self.device)

def __repr__(self):
return self.__str__()


# Allocates all buffers required for an engine, i.e. host/device inputs/outputs.
def allocate_buffers(engine):
inputs = []
outputs = []
bindings = []
stream = cuda.Stream()
for binding in engine:
size = trt.volume(engine.get_binding_shape(binding)) * engine.max_batch_size
dtype = trt.nptype(engine.get_binding_dtype(binding))
# Allocate host and device buffers
host_mem = cuda.pagelocked_empty(size, dtype)
device_mem = cuda.mem_alloc(host_mem.nbytes)
# Append the device buffer to device bindings.
bindings.append(int(device_mem))
# Append to the appropriate list.
if engine.binding_is_input(binding):
inputs.append(HostDeviceMem(host_mem, device_mem))
else:
outputs.append(HostDeviceMem(host_mem, device_mem))
return inputs, outputs, bindings, stream


# This function is generalized for multiple inputs/outputs.
# inputs and outputs are expected to be lists of HostDeviceMem objects.
def do_inference(context, bindings, inputs, outputs, stream, batch_size=1):
# Transfer input data to the GPU.
[cuda.memcpy_htod_async(inp.device, inp.host, stream) for inp in inputs]
# Run inference.
context.execute_async(
batch_size=batch_size, bindings=bindings, stream_handle=stream.handle
)
# Transfer predictions back from the GPU.
[cuda.memcpy_dtoh_async(out.host, out.device, stream) for out in outputs]
# Synchronize the stream
stream.synchronize()
# Return only the host outputs.
return [out.host for out in outputs]


# This function is generalized for multiple inputs/outputs for full dimension networks.
# inputs and outputs are expected to be lists of HostDeviceMem objects.
def do_inference_v2(context, bindings, inputs, outputs, stream):
# Transfer input data to the GPU.
[cuda.memcpy_htod_async(inp.device, inp.host, stream) for inp in inputs]
# Run inference.
context.execute_async_v2(bindings=bindings, stream_handle=stream.handle)
# Transfer predictions back from the GPU.
[cuda.memcpy_dtoh_async(out.host, out.device, stream) for out in outputs]
# Synchronize the stream
stream.synchronize()
# Return only the host outputs.
return [out.host for out in outputs]


def generate_md5_checksum(local_path):
"""Returns the MD5 checksum of a local file.
Keyword argument:
local_path -- path of the file whose checksum shall be generated
"""
with open(local_path, "rb") as local_file:
data = local_file.read()
import hashlib

return hashlib.md5(data).hexdigest()


def download_file(local_path, link, checksum_reference=None):
"""Checks if a local file is present and downloads it from the specified path otherwise.
If checksum_reference is specified, the file's md5 checksum is compared against the
expected value.
Keyword arguments:
local_path -- path of the file whose checksum shall be generated
link -- link where the file shall be downloaded from if it is not found locally
checksum_reference -- expected MD5 checksum of the file
"""
if not os.path.exists(local_path):
print("Downloading from %s, this may take a while..." % link)
import wget

wget.download(link, local_path)
print()
if checksum_reference is not None:
checksum = generate_md5_checksum(local_path)
if checksum != checksum_reference:
raise ValueError(
"The MD5 checksum of local file %s differs from %s, please manually remove \
the file and try again."
% (local_path, checksum_reference)
)
return local_path


# `retry_call` and `retry` are used to wrap the function we want to try multiple times
def retry_call(func, args=[], kwargs={}, n_retries=3):
"""Wrap a function to retry it several times.
Args:
func: function to call
args (List): args parsed to func
kwargs (Dict): kwargs parsed to func
n_retries (int): maximum times of tries
"""
for i_try in range(n_retries):
try:
func(*args, **kwargs)
break
except:
if i_try == n_retries - 1:
raise
print("retry...")


# Usage: @retry(n_retries)
def retry(n_retries=3):
"""Wrap a function to retry it several times. Decorator version of `retry_call`.
Args:
n_retries (int): maximum times of tries
Usage:
@retry(n_retries)
def func(...):
pass
"""

def wrapper(func):
def _wrapper(*args, **kwargs):
retry_call(func, args, kwargs, n_retries)

return _wrapper

return wrapper


def build_qat_engine_from_onnx(model_file, verbose=False):
if verbose:
TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE)
else:
TRT_LOGGER = trt.Logger(trt.Logger.INFO)

network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
network_flags = network_flags | (
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_PRECISION)
)

with trt.Builder(TRT_LOGGER) as builder, builder.create_network(
flags=network_flags
) as network, trt.OnnxParser(network, TRT_LOGGER) as parser:
with open(model_file, "rb") as model:
if not parser.parse(model.read()):
print("ERROR: Failed to parse the ONNX file.")
for error in range(parser.num_errors):
print(parser.get_error(error))
return None
config = builder.create_builder_config()
config.max_workspace_size = 1 << 30
config.flags = config.flags | 1 << int(trt.BuilderFlag.INT8)
return builder.build_engine(network, config)


def run_tensorrt(onnx_path, test_case):
with build_qat_engine_from_onnx(onnx_path) as engine:
inputs, outputs, bindings, stream = allocate_buffers(engine)
with engine.create_execution_context() as context:
batch_size = test_case.shape[0]
test_case = test_case.reshape(-1)
np.copyto(inputs[0].host, test_case)
trt_outputs = do_inference_v2(
context,
bindings=bindings,
inputs=inputs,
outputs=outputs,
stream=stream,
)
data = trt_outputs[0]
return data.reshape(batch_size, -1)


def get_onnx_provider(ctx: str = "cpu"):
if ctx == "gpu":
return ["CUDAExecutionProvider"]
elif ctx == "cpu":
return ["CPUExecutionProvider"]
else:
raise NotImplementedError("Not supported device type. ")
Loading