diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..24b3a8b --- /dev/null +++ b/Dockerfile @@ -0,0 +1,43 @@ +FROM nvcr.io/nvidia/cuda:11.8.0-cudnn8-devel-ubuntu22.04 + +#COPY sources.list /etc/apt/sources.list +ARG DEBIAN_FRONTEND=noninteractive + +# Base tools and repo setup +RUN apt-get update && apt-get install -y --no-install-recommends \ + software-properties-common \ + gnupg \ + build-essential \ + curl \ + ca-certificates \ + cmake \ + vim \ + && add-apt-repository -y ppa:deadsnakes/ppa \ + && apt-get update && apt-get install -y --no-install-recommends \ + python3.10 \ + python3.10-dev \ + python3.10-venv \ + python3.10-distutils \ + && rm -rf /var/lib/apt/lists/* + +# Create a dedicated Python 3.10 venv and make it default on PATH +RUN python3.10 -m venv /opt/py310 \ + && /opt/py310/bin/python -m pip install --upgrade pip setuptools wheel + +ENV VIRTUAL_ENV=/opt/py310 +ENV PATH="${VIRTUAL_ENV}/bin:${PATH}" +ENV PIP_NO_CACHE_DIR=1 + +# Sanity check +RUN python -V && pip -V + +RUN pip config set global.extra-index-url "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple" +RUN pip install torch==2.0.0 --index-url https://download.pytorch.org/whl/cu118 +RUN pip install onnxruntime-gpu==1.16.0 onnx==1.14.1 + +# install app +WORKDIR /workspace +ADD . Dipoorlet +RUN cd Dipoorlet \ + && pip install -r requirements.txt -i https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple \ + && python3 setup.py install diff --git a/dipoorlet/__main__.py b/dipoorlet/__main__.py index b9d9115..b05810c 100644 --- a/dipoorlet/__main__.py +++ b/dipoorlet/__main__.py @@ -52,6 +52,8 @@ parser.add_argument("--optim_transformer", help="Transformer model optimization", default=False, action='store_true') parser.add_argument("--model_type", help="Transformer model type", choices=["unet"], default=None) parser.add_argument("--quant_format", default="QDQ", type=str, choices=["QOP", "QDQ"]) +parser.add_argument("--load_clip", help="Load clip values from a directory or act_clip_val.json path and skip calibration", + default=None) args = parser.parse_args() if args.slurm: @@ -91,6 +93,18 @@ args.optimzed_model_dir = os.path.join(args.output_dir, 'optim_model.onnx') logger.parent = None +def _resolve_clip_dir(path): + if os.path.isdir(path): + return path + if os.path.isfile(path): + base = os.path.basename(path) + if base == 'act_clip_val.json' or base == 'weight_clip_val.json': + return os.path.dirname(path) + raise FileNotFoundError( + "load_clip expects a directory or a path to act_clip_val.json/weight_clip_val.json: {}".format(path) + ) + + start = time.time() if args.optim_transformer: model = onnx.load(args.optimzed_model_dir) @@ -116,16 +130,24 @@ setattr(args, 'world_size', dist.get_world_size()) if dist.get_rank() == 0: logger.info("Do tensor calibration...") -act_clip_val, weight_clip_val = tensor_calibration(onnx_graph, args) -tensor_range = copy.deepcopy(act_clip_val) -save_clip_val(act_clip_val, weight_clip_val, args, - act_fname='act_clip_val.json.rank{}'.format(args.rank), - weight_fname='weight_clip_val.json.rank{}'.format(args.rank)) -dist.barrier() -if dist.get_rank() == 0: - reduce_clip_val(dist.get_world_size(), args) -dist.barrier() -act_clip_val, weight_clip_val = load_clip_val(args) +if args.load_clip: + clip_dir = _resolve_clip_dir(args.load_clip) + if dist.get_rank() == 0: + logger.info("Load clip values from: {}".format(clip_dir)) + act_clip_val, weight_clip_val = load_clip_val(args, base_dir=clip_dir) + tensor_range = copy.deepcopy(act_clip_val) + dist.barrier() +else: + act_clip_val, weight_clip_val = tensor_calibration(onnx_graph, args) + tensor_range = copy.deepcopy(act_clip_val) + save_clip_val(act_clip_val, weight_clip_val, args, + act_fname='act_clip_val.json.rank{}'.format(args.rank), + weight_fname='weight_clip_val.json.rank{}'.format(args.rank)) + dist.barrier() + if dist.get_rank() == 0: + reduce_clip_val(dist.get_world_size(), args) + dist.barrier() + act_clip_val, weight_clip_val = load_clip_val(args) # Weight Transform. if dist.get_rank() == 0: @@ -133,6 +155,8 @@ graph, graph_ori, act_clip_val, weight_clip_val = \ weight_calibration(onnx_graph, act_clip_val, weight_clip_val, args) dist.barrier() +if dist.get_rank() == 0: + save_clip_val(act_clip_val, weight_clip_val, args) # Profiling Distributed. if dist.get_rank() == 0: @@ -158,4 +182,4 @@ if args.quant_format == 'QOP' and args.model_type is None: deploy_QOperator(graph.model, tensor_range, args) end = time.time() - logger.info("Total time cost: {} seconds.".format(int(end - start))) \ No newline at end of file + logger.info("Total time cost: {} seconds.".format(int(end - start))) diff --git a/dipoorlet/deploy/deploy_trt.py b/dipoorlet/deploy/deploy_trt.py index 3d33281..e8e0f2e 100644 --- a/dipoorlet/deploy/deploy_trt.py +++ b/dipoorlet/deploy/deploy_trt.py @@ -1,5 +1,6 @@ import json import os +import numpy as np from .deploy_default import deploy_dispatcher @@ -8,7 +9,9 @@ def gen_trt_range(graph, clip_val, args, **kwargs): for k, v in clip_val.items(): # max(-clip_min, clip_max) - clip_val[k] = max(-clip_val[k][0].astype(float), clip_val[k][1].astype(float)) + v0 = np.min(clip_val[k][0]) + v1 = np.max(clip_val[k][1]) + clip_val[k] = max(-float(v0), float(v1)) tensorrt_blob_json = dict() tensorrt_blob_json['blob_range'] = clip_val diff --git a/dipoorlet/forward_net.py b/dipoorlet/forward_net.py index 513fb2e..dc0e2c5 100644 --- a/dipoorlet/forward_net.py +++ b/dipoorlet/forward_net.py @@ -1,4 +1,5 @@ import copy +import os import time import sys from collections import OrderedDict @@ -19,6 +20,111 @@ ort.set_default_logger_severity(3) sys.setrecursionlimit(2000) +_ORT_SESSION_OPTIONS = None + + +def _parse_int_env(value): + if value is None: + return None + try: + return int(value) + except ValueError: + # Handle values like "4(x2)" by taking the leading integer. + token = value.split('(')[0] + try: + return int(token) + except ValueError: + digits = ''.join(ch for ch in value if ch.isdigit()) + return int(digits) if digits else None + + +def _get_local_world_size(): + for key in ("LOCAL_WORLD_SIZE", "OMPI_COMM_WORLD_LOCAL_SIZE", "SLURM_NTASKS_PER_NODE"): + val = _parse_int_env(os.environ.get(key)) + if val and val > 0: + return val + return 1 + + +def _get_ort_session_options(args): + # Default to auto-allocating threads per process. + global _ORT_SESSION_OPTIONS + if _ORT_SESSION_OPTIONS is not None: + return _ORT_SESSION_OPTIONS + try: + cores = len(os.sched_getaffinity(0)) + except AttributeError: + cores = os.cpu_count() or 1 + local_world = _get_local_world_size() + intra = max(1, cores // max(1, local_world)) + inter = 1 + sess_options = ort.SessionOptions() + sess_options.intra_op_num_threads = intra + sess_options.inter_op_num_threads = inter + _ORT_SESSION_OPTIONS = sess_options + return sess_options + + +def _reshape_input(data, shape): + if len(shape) == 0: + return data + shape = list(shape) + if shape[0] == 0: + shape[0] = 1 + return data.reshape(*shape) + + +def _value_info_shape(shape): + # Keep rank; replace dynamic batch dim 0 with 1 to satisfy ORT shape inference. + shape = list(shape) + if len(shape) > 0 and shape[0] == 0: + shape[0] = 1 + return shape + + +def input_data_generator(input_dir, input_name_list, data_st_idx, data_ed_idx): + for idx in range(data_st_idx, data_ed_idx): + data = {} + for i in input_name_list: + data[i] = np.fromfile(f'{input_dir}/{i}/{idx}.bin', 'float32') + yield data + + +def input_data_batch_generator(input_dir, input_name_list, data_st_idx, data_ed_idx, batch_size, tensor_shapes): + for batch_st in range(data_st_idx, data_ed_idx, batch_size): + batch_ed = min(batch_st + batch_size, data_ed_idx) + batch_idx = range(batch_st, batch_ed) + data = {} + for name in input_name_list: + shape = tensor_shapes[name] + samples = [] + for idx in batch_idx: + raw = np.fromfile(f'{input_dir}/{name}/{idx}.bin', 'float32') + samples.append(_reshape_input(raw, shape)) + if len(samples) == 1: + data[name] = samples[0] + else: + if len(shape) > 0 and shape[0] in (0, 1): + data[name] = np.concatenate(samples, axis=0) + else: + data[name] = np.stack(samples, axis=0) + yield data + + +def _get_calib_batch(onnx_graph, args): + batch = max(1, int(getattr(args, "ada_bs", 1))) + if batch <= 1: + return 1 + for name in onnx_graph.network_inputs: + shape = onnx_graph.get_tensor_shape(name) + if len(shape) > 0 and shape[0] != 0: + logger.warning( + "Calibration batch %d requested but input %s has fixed batch dim %s; fallback to batch=1.", + batch, name, shape[0] + ) + return 1 + return batch + class ActivationCache(object): # We assume get tensor by sequence. @@ -53,13 +159,13 @@ def fetch_input(self, in_tensor=None): for data in input_data_generator(self.args.input_dir, self.graph.network_inputs, self.st, self.ed): for name in self.graph.network_inputs: self.activation_cache[name].append( - data[name][:].reshape(*self.graph.get_tensor_shape(name)).copy()) + _reshape_input(data[name], self.graph.get_tensor_shape(name)).copy()) else: # Means We need specific tensor. self.activation_cache[in_tensor] = [] for data in input_data_generator(self.args.input_dir, self.graph.network_inputs, self.st, self.ed): self.activation_cache[in_tensor].append( - data[in_tensor][:].reshape(*self.graph.get_tensor_shape(in_tensor)).copy()) + _reshape_input(data[in_tensor], self.graph.get_tensor_shape(in_tensor)).copy()) def input_generator(self, tensor_name_list): # TODO batch generator. @@ -95,7 +201,9 @@ def forward_subnet(self, subnet_name, input_list): sub_graph = self.graph_list[self.name_to_graph_id[subnet_name]] sub_net = sub_graph.model ort_inputs = {} - ort_session = ort.InferenceSession(sub_net.SerializeToString(), providers=self.providers) + sess_options = _get_ort_session_options(self.args) + ort_session = ort.InferenceSession(sub_net.SerializeToString(), sess_options=sess_options, + providers=self.providers) if 'CUDAExecutionProvider' not in ort_session.get_provider_options(): logger.warning("CUDA may not used. Please check your ort/cuda/cudnn version.") @@ -147,9 +255,7 @@ def _split_network(self): continue if input not in self.graph.initializer: in_type = self.graph.get_value_type(input) - shape = self.graph.get_tensor_shape(input) - if shape[0] == 0: - shape = [] + shape = _value_info_shape(self.graph.get_tensor_shape(input)) input_value = mtvi(input, in_type, shape) inputs.append(input_value) network_inputs.append(input) @@ -160,9 +266,7 @@ def _split_network(self): if output == '': continue out_type = self.graph.get_value_type(output) - shape = self.graph.get_tensor_shape(output) - if shape[0] == 0: - shape = [] + shape = _value_info_shape(self.graph.get_tensor_shape(output)) output_value = mtvi(output, out_type, shape) outputs.append(output_value) network_outputs.append(output) @@ -197,7 +301,8 @@ def forward_get_minmax(onnx_graph, args): if output_name not in [_o.name for _o in graph.output]: graph.output.insert(0, onnx.ValueInfoProto(name=output_name)) providers = [("CUDAExecutionProvider", {'device_id': args.local_rank})] - ort_session = ort.InferenceSession(net.SerializeToString(), providers=providers) + sess_options = _get_ort_session_options(args) + ort_session = ort.InferenceSession(net.SerializeToString(), sess_options=sess_options, providers=providers) if 'CUDAExecutionProvider' not in ort_session.get_provider_options(): logger.warning("CUDA may not used. Please check your ort/cuda/cudnn version.") # Start activation quantization. @@ -207,10 +312,14 @@ def forward_get_minmax(onnx_graph, args): rank_num = args.data_num // args.world_size data_st_idx = args.rank * rank_num data_ed_idx = min((args.rank + 1) * rank_num, args.data_num) - for data in tqdm(input_data_generator(args.input_dir, onnx_graph.network_inputs, data_st_idx, data_ed_idx), - desc='Minmax update'): + batch_size = _get_calib_batch(onnx_graph, args) + total = (data_ed_idx - data_st_idx + batch_size - 1) // batch_size + tensor_shapes = {name: onnx_graph.get_tensor_shape(name) for name in onnx_graph.network_inputs} + data_iter = input_data_batch_generator(args.input_dir, onnx_graph.network_inputs, + data_st_idx, data_ed_idx, batch_size, tensor_shapes) + for data in tqdm(data_iter, total=total, desc='Minmax update'): for name in onnx_graph.network_inputs: - ort_inputs[name] = data[name][:].reshape(onnx_graph.get_tensor_shape(name)) + ort_inputs[name] = data[name] st = time.time() outputs = [output.name for output in ort_session.get_outputs()] ort_outputs = ort_session.run(outputs, ort_inputs) @@ -245,7 +354,8 @@ def forward_get_hist(onnx_graph, stats_min_max, args): if output_name not in [_o.name for _o in graph.output]: graph.output.insert(0, onnx.ValueInfoProto(name=output_name)) providers = [("CUDAExecutionProvider", {'device_id': args.local_rank})] - ort_session = ort.InferenceSession(net.SerializeToString(), providers=providers) + sess_options = _get_ort_session_options(args) + ort_session = ort.InferenceSession(net.SerializeToString(), sess_options=sess_options, providers=providers) if 'CUDAExecutionProvider' not in ort_session.get_provider_options(): logger.warning("CUDA may not used. Please check your ort/cuda/cudnn version.") # Start activation quantization. @@ -254,10 +364,14 @@ def forward_get_hist(onnx_graph, stats_min_max, args): rank_num = args.data_num // args.world_size data_st_idx = args.rank * rank_num data_ed_idx = min((args.rank + 1) * rank_num, args.data_num) - for data in tqdm(input_data_generator(args.input_dir, onnx_graph.network_inputs, data_st_idx, data_ed_idx), - desc='Hist update: {}'.format(args.rank)): + batch_size = _get_calib_batch(onnx_graph, args) + total = (data_ed_idx - data_st_idx + batch_size - 1) // batch_size + tensor_shapes = {name: onnx_graph.get_tensor_shape(name) for name in onnx_graph.network_inputs} + data_iter = input_data_batch_generator(args.input_dir, onnx_graph.network_inputs, + data_st_idx, data_ed_idx, batch_size, tensor_shapes) + for data in tqdm(data_iter, total=total, desc='Hist update: {}'.format(args.rank)): for name in onnx_graph.network_inputs: - ort_inputs[name] = data[name][:].reshape(onnx_graph.get_tensor_shape(name)) + ort_inputs[name] = data[name] outputs = [output.name for output in ort_session.get_outputs()] ort_outputs = ort_session.run(outputs, ort_inputs) ort_outs = OrderedDict(zip(outputs, ort_outputs)) @@ -290,7 +404,8 @@ def forward_net_octav(onnx_graph, args): if output_name not in [_o.name for _o in graph.output]: graph.output.insert(0, onnx.ValueInfoProto(name=output_name)) providers = [("CUDAExecutionProvider", {'device_id': args.local_rank})] - ort_session = ort.InferenceSession(net.SerializeToString(), providers=providers) + sess_options = _get_ort_session_options(args) + ort_session = ort.InferenceSession(net.SerializeToString(), sess_options=sess_options, providers=providers) if 'CUDAExecutionProvider' not in ort_session.get_provider_options(): logger.warning("CUDA may not used. Please check your ort/cuda/cudnn version.") # Start activation quantization. @@ -456,14 +571,6 @@ def forward_net_octav_transformer(onnx_graph, args): return statistics -def input_data_generator(input_dir, input_name_list, data_st_idx, data_ed_idx): - for idx in range(data_st_idx, data_ed_idx): - data = {} - for i in input_name_list: - data[i] = np.fromfile(f'{input_dir}/{i}/{idx}.bin', 'float32') - yield data - - def forward_get_tensor(graph, net, index, args): for node in graph.graph.node: if node.op_type in QUANT_NODE_NAME_LIST: @@ -474,11 +581,12 @@ def forward_get_tensor(graph, net, index, args): rank = dist.get_rank() device = rank % torch.cuda.device_count() providers = [("CUDAExecutionProvider", {'device_id': device})] - ort_session = ort.InferenceSession(net.SerializeToString(), providers=providers) + sess_options = _get_ort_session_options(args) + ort_session = ort.InferenceSession(net.SerializeToString(), sess_options=sess_options, providers=providers) ort_inputs = {} for data in input_data_generator(args.input_dir, graph.network_inputs, index, index + 1): for name in graph.network_inputs: - ort_inputs[name] = data[name][:].reshape(graph.get_tensor_shape(name)) + ort_inputs[name] = _reshape_input(data[name], graph.get_tensor_shape(name)) outputs = [output.name for output in ort_session.get_outputs()] ort_outputs = ort_session.run(outputs, ort_inputs) ort_outs = OrderedDict(zip(outputs, ort_outputs)) diff --git a/dipoorlet/profiling.py b/dipoorlet/profiling.py index b24b894..a7fb445 100644 --- a/dipoorlet/profiling.py +++ b/dipoorlet/profiling.py @@ -213,15 +213,20 @@ def show_model_ranges(graph, act_clip_val, weight_clip_val, args): ranges_all.update(weight_clip_val) for name, range in ranges_all.items(): tensor_shape = graph.get_tensor_shape(name) - if isinstance(range[0], np.ndarray): + r0, r1 = range[0], range[1] + if isinstance(r0, (list, tuple)): + r0 = np.array(r0) + if isinstance(r1, (list, tuple)): + r1 = np.array(r1) + if isinstance(r0, np.ndarray) or isinstance(r1, np.ndarray): per_channel = "" if 'per_channel' in platform_setting_table[args.deploy]['qw_params'] and \ platform_setting_table[args.deploy]['qw_params']['per_channel']: per_channel = "per channel " logger.info("{:<30} Shape: {:<20} Range: {}[{:<10f} {:<10f}]".format(name, str(tensor_shape), - per_channel, range[0].min(), range[1].max())) + per_channel, np.min(r0), np.max(r1))) else: - logger.info("{:<30} Shape: {:<20} Range: [{:<10f} {:<10f}]".format(name, str(tensor_shape), range[0], range[1])) + logger.info("{:<30} Shape: {:<20} Range: [{:<10f} {:<10f}]".format(name, str(tensor_shape), r0, r1)) def weight_need_perchannel(graph, args): diff --git a/dipoorlet/tensor_cali/basic_algorithm.py b/dipoorlet/tensor_cali/basic_algorithm.py index 487de65..7ab52b2 100644 --- a/dipoorlet/tensor_cali/basic_algorithm.py +++ b/dipoorlet/tensor_cali/basic_algorithm.py @@ -44,7 +44,9 @@ def find_clip_val_hist(onnx_graph, args, store_stats=None, **kwargs): for i in range(len(hist)): accum += hist[i] if accum >= args.threshold: - clip_value = (i + 0.5) * (data_max / args.bins) + bins = int(args.bins) + dmax = float(data_max) + clip_value = np.float32((i + 0.5) * (dmax / bins)) clip_val[name] = [max(-clip_value, np.min(stats_min_max[name]['min'])), min(clip_value, np.max(stats_min_max[name]['max']))] break diff --git a/dipoorlet/utils.py b/dipoorlet/utils.py index 69bd8a4..3d5fdc2 100644 --- a/dipoorlet/utils.py +++ b/dipoorlet/utils.py @@ -311,12 +311,16 @@ def update_model_path(name, args): def save_clip_val(act_clip_val, weight_clip_val, args, act_fname='act_clip_val.json', weight_fname='weight_clip_val.json'): + def _jsonify(val): + # Convert numpy values/arrays to JSON-serializable Python types. + return val.tolist() if hasattr(val, "tolist") else val + for k, v in act_clip_val.items(): - act_clip_val[k][0] = act_clip_val[k][0].tolist() - act_clip_val[k][1] = act_clip_val[k][1].tolist() + act_clip_val[k][0] = _jsonify(act_clip_val[k][0]) + act_clip_val[k][1] = _jsonify(act_clip_val[k][1]) for k, v in weight_clip_val.items(): - weight_clip_val[k][0] = weight_clip_val[k][0].tolist() - weight_clip_val[k][1] = weight_clip_val[k][1].tolist() + weight_clip_val[k][0] = _jsonify(weight_clip_val[k][0]) + weight_clip_val[k][1] = _jsonify(weight_clip_val[k][1]) with open(os.path.join(args.output_dir, act_fname), 'w') as f: json.dump(act_clip_val, f, indent=4) with open(os.path.join(args.output_dir, weight_fname), 'w') as f: @@ -345,16 +349,17 @@ def reduce_clip_val(rank_size, args, act_fname='act_clip_val.json', weight_fname save_clip_val(act_clip_val, weight_clip_val, args) -def load_clip_val(args, act_fname='act_clip_val.json', weight_fname='weight_clip_val.json'): +def load_clip_val(args, act_fname='act_clip_val.json', weight_fname='weight_clip_val.json', base_dir=None): + base_dir = args.output_dir if base_dir is None else base_dir act_clip_val = {} weight_clip_val = {} - with open(os.path.join(args.output_dir, act_fname), 'r') as f: + with open(os.path.join(base_dir, act_fname), 'r') as f: act_clip_val = json.load(f) for k, v in act_clip_val.items(): # We need scalar here. act_clip_val[k][0] = np.float64(act_clip_val[k][0]) act_clip_val[k][1] = np.float64(act_clip_val[k][1]) - with open(os.path.join(args.output_dir, weight_fname), 'r') as f: + with open(os.path.join(base_dir, weight_fname), 'r') as f: per_channel = False if 'per_channel' in platform_setting_table[args.deploy]['qw_params']: per_channel = platform_setting_table[args.deploy]['qw_params']['per_channel'] @@ -432,4 +437,4 @@ def deploy_QOperator(model, tensor_range, args): None, args.skip_layers, op_types_to_quantize) quantizer.quantize_model() model_output = os.path.join(args.output_dir, 'qop_model.onnx') - quantizer.model.save_model_to_file(model_output) \ No newline at end of file + quantizer.model.save_model_to_file(model_output) diff --git a/dipoorlet/weight_transform/ada_quant_layer.py b/dipoorlet/weight_transform/ada_quant_layer.py index 0ab161b..a6c64ea 100644 --- a/dipoorlet/weight_transform/ada_quant_layer.py +++ b/dipoorlet/weight_transform/ada_quant_layer.py @@ -174,7 +174,7 @@ def build_torch_conv(self, node, weight, bias): conv.weight.data = weight.data conv.weight.requires_grad = False if bias is not None: - conv.bias.data = torch.from_numpy(bias).cuda().data + conv.bias.data = torch.from_numpy(bias.copy()).cuda().data conv.bias.requires_grad = False return conv @@ -186,7 +186,7 @@ def build_torch_linear(self, node, weight, bias): linear.weight.data = weight.data linear.weight.requires_grad = False if bias is not None: - linear.bias.data = torch.from_numpy(bias).cuda().data + linear.bias.data = torch.from_numpy(bias.copy()).cuda().data linear.bias.requires_grad = False return linear @@ -208,7 +208,7 @@ def build_torch_deconv(self, node, weight, bias): deconv.weight.data = weight.data deconv.weight.requires_grad = False if bias is not None: - deconv.bias.data = torch.from_numpy(bias).cuda().data + deconv.bias.data = torch.from_numpy(bias.copy()).cuda().data deconv.bias.requires_grad = False return deconv diff --git a/dipoorlet/weight_transform/adaround.py b/dipoorlet/weight_transform/adaround.py index 8ccd471..1892760 100644 --- a/dipoorlet/weight_transform/adaround.py +++ b/dipoorlet/weight_transform/adaround.py @@ -54,7 +54,7 @@ def adaround(graph_ori, graph, act_clip_val, weight_clip_val, args): prev_act_cache = q_act_cache.activation_cache.copy() # Get weight and build torch conv. - weight = numpy_helper.to_array(graph_ada.initializer[node.input[1]][0]) + weight = numpy_helper.to_array(graph_ada.initializer[node.input[1]][0]).copy() bias = None if len(node.input) == 3: bias = numpy_helper.to_array(graph_ada.initializer[node.input[2]][0]) diff --git a/dipoorlet/weight_transform/bias_correction.py b/dipoorlet/weight_transform/bias_correction.py index 2fbcaff..d51f281 100644 --- a/dipoorlet/weight_transform/bias_correction.py +++ b/dipoorlet/weight_transform/bias_correction.py @@ -1,4 +1,7 @@ +import math import numpy as np +import torch +import torch.distributed as dist from onnx import numpy_helper from ..forward_net import ActivationCache @@ -6,11 +9,44 @@ from ..utils import ONNXGraph, logger -def update_conv_node_bias(graph_bc, node, fp_activations, q_activations): - bias_diff = np.stack(fp_activations, axis=0) \ - - np.stack(q_activations, axis=0) - axis = (0, 2, 3) if node.op_type == 'Conv' else (0) - bias_diff = np.squeeze(bias_diff, axis=1).mean(axis=axis) +def _get_bias_shape(graph_bc, node): + if len(node.input) > 2 and node.input[2] in graph_bc.initializer: + return numpy_helper.to_array(graph_bc.initializer[node.input[2]][0]).shape + weight = numpy_helper.to_array(graph_bc.initializer[node.input[1]][0]) + return (weight.shape[0],) + + +def _reduce_bias_diff(fp_activations, q_activations, node, graph_bc, args): + axis = (0, 2, 3) if node.op_type == 'Conv' else (0,) + if len(fp_activations) == 0: + local_sum = np.zeros(_get_bias_shape(graph_bc, node), dtype=np.float64) + count = 0 + else: + bias_diff = np.stack(fp_activations, axis=0) \ + - np.stack(q_activations, axis=0) + bias_diff = np.squeeze(bias_diff, axis=1) + local_sum = bias_diff.sum(axis=axis) + count = 1 + for ax in axis: + count *= bias_diff.shape[ax] + + if dist.is_available() and dist.is_initialized(): + device = torch.device("cuda", args.local_rank) if torch.cuda.is_available() else torch.device("cpu") + sum_tensor = torch.from_numpy(local_sum).to(device) + cnt_tensor = torch.tensor([count], device=device, dtype=torch.float32) + dist.all_reduce(sum_tensor, op=dist.ReduceOp.SUM) + dist.all_reduce(cnt_tensor, op=dist.ReduceOp.SUM) + total = cnt_tensor.item() + if total == 0: + return local_sum + return (sum_tensor / total).cpu().numpy() + if count == 0: + return local_sum + return local_sum / count + + +def update_conv_node_bias(graph_bc, node, fp_activations, q_activations, args): + bias_diff = _reduce_bias_diff(fp_activations, q_activations, node, graph_bc, args) if len(node.input) > 2: ori_bias = numpy_helper.to_array(graph_bc.initializer[node.input[2]][0]) corrected_bias = ori_bias + bias_diff @@ -37,19 +73,25 @@ def bias_correction(graph, act_clip_val, weight_clip_val, args): clip_val.update(weight_clip_val) graph_bc = ONNXGraph() graph_bc.copy_from(graph) - fp_cache = ActivationCache(graph, args) + rank = dist.get_rank() if dist.is_available() and dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_available() and dist.is_initialized() else 1 + rank_num = int(math.ceil(args.data_num / float(world_size))) + rank_st = rank * rank_num + rank_ed = min(rank_st + rank_num, args.data_num) + fp_cache = ActivationCache(graph, args, rank_st, rank_ed) prev_act = None for node in graph.graph.node: if node.op_type in bias_correction_node_type: logger.info("Update bias for node: {}".format(node.name)) # We should do incremental update here. graph_q, _ = quant_graph(graph_bc, clip_val, args) - q_cache = ActivationCache(graph_q, args) + q_cache = ActivationCache(graph_q, args, rank_st, rank_ed) if prev_act is not None: q_cache.activation_cache = prev_act _ = q_cache[node.input[0]] prev_act = q_cache.activation_cache.copy() - update_conv_node_bias(graph_bc, node, fp_cache[node.output[0]], q_cache[node.output[0]]) + update_conv_node_bias(graph_bc, node, fp_cache[node.output[0]], q_cache[node.output[0]], args) graph_bc.update_model() - graph_bc.save_onnx_model('update_bias_model') + if rank == 0: + graph_bc.save_onnx_model('update_bias_model') diff --git a/dipoorlet/weight_transform/brecq.py b/dipoorlet/weight_transform/brecq.py index 24a563e..ebf1ea7 100644 --- a/dipoorlet/weight_transform/brecq.py +++ b/dipoorlet/weight_transform/brecq.py @@ -64,7 +64,7 @@ def brecq(graph_ori, graph, act_clip_val, weight_clip_val, args): ada_layer_list = [] # Get weight and build torch conv. for _node in block_layer_list: - weight = numpy_helper.to_array(graph_brecq.initializer[_node.input[1]][0]) + weight = numpy_helper.to_array(graph_brecq.initializer[_node.input[1]][0]).copy() weight = torch.from_numpy(weight).cuda() bias = None if len(_node.input) == 3: @@ -127,7 +127,7 @@ def brecq(graph_ori, graph, act_clip_val, weight_clip_val, args): ada_block, reg, args.ada_bs, args.ada_epoch * len(block_layer_list), args.drop) # Deploy new weight. for idx, _node in enumerate(block_layer_list): - weight = numpy_helper.to_array(graph_brecq.initializer[_node.input[1]][0]) + weight = numpy_helper.to_array(graph_brecq.initializer[_node.input[1]][0]).copy() weight = torch.from_numpy(weight).cuda() round_mask = round_mask_list[idx] if args.deploy != 'nnie': diff --git a/dipoorlet/weight_transform/sparse_quant.py b/dipoorlet/weight_transform/sparse_quant.py index cd219c9..51f7d39 100644 --- a/dipoorlet/weight_transform/sparse_quant.py +++ b/dipoorlet/weight_transform/sparse_quant.py @@ -55,7 +55,7 @@ def sparse_quant(graph_ori, graph, act_clip_val, weight_clip_val, args): prev_act_cache = q_act_cache.activation_cache.copy() # Get weight and build torch conv. - weight = numpy_helper.to_array(graph_ada.initializer[node.input[1]][0]) + weight = numpy_helper.to_array(graph_ada.initializer[node.input[1]][0]).copy() bias = None if len(node.input) == 3: bias = numpy_helper.to_array(graph_ada.initializer[node.input[2]][0]) diff --git a/dipoorlet/weight_transform/sparse_quant_layer.py b/dipoorlet/weight_transform/sparse_quant_layer.py index b682e99..fb0e68e 100644 --- a/dipoorlet/weight_transform/sparse_quant_layer.py +++ b/dipoorlet/weight_transform/sparse_quant_layer.py @@ -105,7 +105,7 @@ def build_torch_conv(self, node, weight, bias): conv.weight.data = weight.data conv.weight.requires_grad = True if bias is not None: - conv.bias.data = torch.from_numpy(bias).cuda().data + conv.bias.data = torch.from_numpy(bias.copy()).cuda().data conv.bias.requires_grad = True return conv @@ -117,7 +117,7 @@ def build_torch_linear(self, node, weight, bias): linear.weight.data = weight.data linear.weight.requires_grad = True if bias is not None: - linear.bias.data = torch.from_numpy(bias).cuda().data + linear.bias.data = torch.from_numpy(bias.copy()).cuda().data linear.bias.requires_grad = True return linear @@ -139,7 +139,7 @@ def build_torch_deconv(self, node, weight, bias): deconv.weight.data = weight.data deconv.weight.requires_grad = True if bias is not None: - deconv.bias.data = torch.from_numpy(bias).cuda().data + deconv.bias.data = torch.from_numpy(bias.copy()).cuda().data deconv.bias.requires_grad = True return deconv @@ -173,4 +173,4 @@ def forward(self, x): self.layer.dilation) if self.relu_flag: x = F.relu(x) - return x \ No newline at end of file + return x diff --git a/dipoorlet/weight_transform/weight_trans_base.py b/dipoorlet/weight_transform/weight_trans_base.py index fd3912b..ffae7dc 100644 --- a/dipoorlet/weight_transform/weight_trans_base.py +++ b/dipoorlet/weight_transform/weight_trans_base.py @@ -19,12 +19,11 @@ def weight_calibration(onnx_graph, act_clip_val, weight_clip_val, args): graph_after_wt = ONNXGraph() graph_after_wt.copy_from(onnx_graph) if args.bc: - if dist.get_rank() == 0: - bias_correction(graph_after_wt, act_clip_val, weight_clip_val, args) + bias_correction(graph_after_wt, act_clip_val, weight_clip_val, args) dist.barrier() update_model_path('update_bias_model', args) model = onnx.load(args.model) - graph_after_wt = ONNXGraph(model, args.output_dir) + graph_after_wt = ONNXGraph(model, args.output_dir, args.deploy, args.model_type) # Update bias range. weight_clip_val = find_clip_val_minmax_weight(graph_after_wt, args) @@ -34,7 +33,7 @@ def weight_calibration(onnx_graph, act_clip_val, weight_clip_val, args): dist.barrier() update_model_path('weight_equal_model', args) model = onnx.load(args.model) - graph_after_wt = ONNXGraph(model, args.output_dir) + graph_after_wt = ONNXGraph(model, args.output_dir, args.deploy, args.model_type) act_clip_val, weight_clip_val = tensor_calibration(graph_after_wt, args) if args.update_bn: @@ -43,7 +42,7 @@ def weight_calibration(onnx_graph, act_clip_val, weight_clip_val, args): dist.barrier() update_model_path('update_bn_model', args) model = onnx.load(args.model) - graph_after_wt = ONNXGraph(model, args.output_dir) + graph_after_wt = ONNXGraph(model, args.output_dir, args.deploy, args.model_type) if dist.get_rank() == 0: logger.info("Re calibration...") act_clip_val, weight_clip_val = tensor_calibration(graph_after_wt, args) diff --git a/requirements.txt b/requirements.txt index f05f51e..0960c57 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,5 +2,7 @@ torch onnx>=1.10.0 onnxsim onnxruntime-gpu -numpy +numpy==1.26.4 tqdm +termcolor +pyyaml