diff --git a/modelopt/onnx/autocast/__main__.py b/modelopt/onnx/autocast/__main__.py index cabeff733..dbd78b187 100644 --- a/modelopt/onnx/autocast/__main__.py +++ b/modelopt/onnx/autocast/__main__.py @@ -66,8 +66,12 @@ def get_parser() -> argparse.ArgumentParser: "--calibration_data", "-d", type=str, - help="File path to inputs for reference runner, either NPZ or Polygraphy JSON file. " - "If not provided, random inputs will be used", + help="File path to inputs for reference runner. Supports: " + "(1) NPZ file for single batch, " + "(2) Directory containing multiple NPZ files for multi-batch calibration, " + "(3) Polygraphy JSON file (supports multiple batches). " + "Multi-batch calibration aggregates statistics across all batches for more robust " + "precision conversion decisions. If not provided, random inputs will be used.", ) parser.add_argument( "--nodes_to_exclude", diff --git a/modelopt/onnx/autocast/nodeclassifier.py b/modelopt/onnx/autocast/nodeclassifier.py index 0a7638429..cbfc00341 100644 --- a/modelopt/onnx/autocast/nodeclassifier.py +++ b/modelopt/onnx/autocast/nodeclassifier.py @@ -150,20 +150,51 @@ def _log_skipped(self, node, **kwargs): class IORangeRule(NodeRuleBase): - """Rule for keeping nodes with out-of-range inputs/outputs in high precision.""" + """Rule for keeping nodes with out-of-range inputs/outputs in high precision. + + Supports both single-batch (raw numpy arrays) and multi-batch (TensorStats objects) + reference data for flexible precision conversion decisions. + """ def __init__(self, data_max, reference_data, node_to_init_map): """Initialize the rule. Args: data_max: Maximum absolute value allowed for node I/O. - reference_data: Reference data for checking I/O ranges. + reference_data: Reference data for checking I/O ranges. Can contain either + raw numpy arrays (single batch) or TensorStats objects (multi-batch aggregated). node_to_init_map: Mapping from node names to their initializers. """ self.data_max = data_max self.reference_data = reference_data self.node_to_init_map = node_to_init_map self.output_data = None + self.output_stats = None # For TensorStats + + def _get_tensor_stats(self, ref_data): + """Extract statistics from reference data (supports both numpy arrays and TensorStats). + + Args: + ref_data: Either a numpy array or a TensorStats object. + + Returns: + tuple: (absmax, min_val, max_val, size) statistics. + """ + # Import here to avoid circular imports + from modelopt.onnx.autocast.referencerunner import TensorStats + + if isinstance(ref_data, TensorStats): + return ref_data.absmax, ref_data.min_val, ref_data.max_val, ref_data.size + else: + # Raw numpy array + if ref_data.size == 0: + return 0, 0, 0, 0 + return ( + np.max(np.abs(ref_data)), + np.min(ref_data), + np.max(ref_data), + ref_data.size, + ) def _check_inner(self, node): def is_io_out_of_range(node, tensor_name): @@ -176,18 +207,25 @@ def is_io_out_of_range(node, tensor_name): f"Node {node.name}: Tensor {tensor_name} not found in reference data." ) return False + ref_data = self.reference_data[tensor_name] - if ref_data.size == 0: + absmax, min_val, max_val, size = self._get_tensor_stats(ref_data) + + if size == 0: logger.debug( f"Node {node.name}: Tensor {tensor_name} has size 0. Skipping I/O range check." ) return False + logger.debug( - f"Node {node.name}: reference data: min={np.min(ref_data)}, max={np.max(ref_data)}" + f"Node {node.name}: reference data: min={min_val}, max={max_val}, absmax={absmax}" ) - if np.any(np.abs(ref_data) > self.data_max): + + if absmax > self.data_max: self.output_data = ref_data + self.output_stats = (absmax, min_val, max_val) return True + return False if node.op_type == "Constant": return False @@ -202,7 +240,13 @@ def is_io_out_of_range(node, tensor_name): def _log_skipped(self, node, **kwargs): """Log information about skipped nodes with I/O range violations.""" - if self.output_data is not None: + if self.output_stats is not None: + absmax, min_val, max_val = self.output_stats + logger.info( + f"Skipping node {node.name}: reference IO out of range: min={min_val}, " + f"max={max_val}, absmax={absmax}, range=[{-self.data_max}, {self.data_max}]" + ) + elif self.output_data is not None: logger.info( f"Skipping node {node.name}: reference IO out of range: min={np.min(self.output_data)}, " f"max={np.max(self.output_data)}, range=[{-self.data_max}, {self.data_max}]" @@ -230,9 +274,18 @@ def __init__(self, max_depth_of_reduction, reference_data, node_to_init_map, ini self.reduction_depth = 0 def _get_tensor_shape(self, tensor_name): - """Get tensor shape from reference data.""" + """Get tensor shape from reference data. + + Supports both raw numpy arrays and TensorStats objects. + """ if tensor_name in self.reference_data: - return self.reference_data[tensor_name].shape + ref_data = self.reference_data[tensor_name] + # Import here to avoid circular imports + from modelopt.onnx.autocast.referencerunner import TensorStats + + if isinstance(ref_data, TensorStats): + return ref_data.shape + return ref_data.shape if tensor_name in self.initializer_map: return self.initializer_map[tensor_name].dims return None diff --git a/modelopt/onnx/autocast/referencerunner.py b/modelopt/onnx/autocast/referencerunner.py index 8dc91ff08..cd944b83c 100644 --- a/modelopt/onnx/autocast/referencerunner.py +++ b/modelopt/onnx/autocast/referencerunner.py @@ -19,12 +19,16 @@ implementation. It supports both random input generation and user-provided inputs through NPZ or Polygraphy JSON files. The runner is used to analyze model behavior and validate outputs during precision conversion. + +When multiple batches of calibration data are provided, the runner aggregates statistics +across all batches to provide more robust range information for precision conversion decisions. """ import copy import io import sys from collections import OrderedDict +from dataclasses import dataclass import numpy as np import onnx @@ -35,6 +39,35 @@ configure_logging() +@dataclass +class TensorStats: + """Statistics for a tensor aggregated across multiple batches. + + Attributes: + absmax: Maximum absolute value across all batches. + min_val: Minimum value across all batches. + max_val: Maximum value across all batches. + shape: Shape of the tensor (from first batch). + """ + + absmax: float + min_val: float + max_val: float + shape: tuple + + def __abs__(self): + """Return the maximum absolute value (for compatibility with np.abs).""" + return self.absmax + + @property + def size(self): + """Return total number of elements.""" + result = 1 + for dim in self.shape: + result *= dim + return result + + class ReferenceRunner: """A class to run ONNX models with ONNXRuntime for reference inference.""" @@ -69,8 +102,29 @@ def _load_inputs_from_json(self, input_data_path): return load_json(input_data_path, description="input data") def _load_inputs_from_npz(self, input_data_path): - """Load inputs from NPZ format.""" - return [np.load(input_data_path)] + """Load inputs from NPZ format. + + Supports both single NPZ file and directory containing multiple NPZ files for multi-batch calibration. + + Args: + input_data_path: Path to NPZ file or directory containing NPZ files. + + Returns: + List of input dictionaries, one per batch. + """ + import os + + if os.path.isdir(input_data_path): + # Load all NPZ files in the directory as multiple batches + npz_files = sorted( + [f for f in os.listdir(input_data_path) if f.endswith(".npz")] + ) + if not npz_files: + raise ValueError(f"No NPZ files found in directory: {input_data_path}") + logger.info(f"Loading {len(npz_files)} NPZ files from directory for multi-batch calibration") + return [np.load(os.path.join(input_data_path, f)) for f in npz_files] + else: + return [np.load(input_data_path)] def _validate_inputs(self, data_loader): """Validate that input names and shapes match the model.""" @@ -96,16 +150,18 @@ def _load_inputs(self, inputs): # If no inputs are provided, use random inputs data_loader = DataLoader(val_range={"": (-1, 1)}) + import os + if inputs is not None: if isinstance(inputs, str): if inputs.endswith(".json"): data_loader = self._load_inputs_from_json(inputs) - elif inputs.endswith(".npz"): + elif inputs.endswith(".npz") or os.path.isdir(inputs): data_loader = self._load_inputs_from_npz(inputs) else: raise ValueError( - f"Invalid input file: {inputs}. Supported input file types: .json (Polygraphy JSON format), " - ".npz (Numpy)" + f"Invalid input file: {inputs}. Supported input types: .json (Polygraphy JSON format), " + ".npz (Numpy), or a directory containing .npz files" ) elif isinstance(inputs, (dict, OrderedDict)): data_loader = [inputs] @@ -118,8 +174,71 @@ def _load_inputs(self, inputs): return data_loader + def _aggregate_tensor_stats(self, all_batch_data: list[OrderedDict]) -> OrderedDict: + """Aggregate tensor statistics across multiple batches. + + Args: + all_batch_data: List of dictionaries containing tensor data for each batch. + + Returns: + OrderedDict mapping tensor names to TensorStats objects. + """ + if len(all_batch_data) == 1: + # Single batch - return raw data for backward compatibility + return all_batch_data[0] + + logger.info(f"Aggregating statistics across {len(all_batch_data)} batches...") + + aggregated = OrderedDict() + tensor_names = all_batch_data[0].keys() + + for name in tensor_names: + absmax = -np.inf + min_val = np.inf + max_val = -np.inf + shape = None + + for batch_data in all_batch_data: + if name not in batch_data: + continue + data = batch_data[name] + if shape is None: + shape = data.shape + + batch_absmax = np.max(np.abs(data)) if data.size > 0 else 0 + batch_min = np.min(data) if data.size > 0 else 0 + batch_max = np.max(data) if data.size > 0 else 0 + + absmax = max(absmax, batch_absmax) + min_val = min(min_val, batch_min) + max_val = max(max_val, batch_max) + + if shape is not None: + aggregated[name] = TensorStats( + absmax=absmax, + min_val=min_val, + max_val=max_val, + shape=shape, + ) + + return aggregated + def run(self, inputs=None): - """Run FP32 inference with provided or random inputs.""" + """Run FP32 inference with provided or random inputs. + + When multiple batches of input data are provided, inference is run for each batch + and statistics are aggregated across all batches for more robust range estimation. + + Args: + inputs: Optional input data. Can be: + - None: Random inputs will be generated + - str: Path to JSON file, NPZ file, or directory containing NPZ files + - dict/OrderedDict: Single batch of input data + + Returns: + OrderedDict: Combined input and output data. For single batch, returns raw arrays. + For multiple batches, returns TensorStats objects with aggregated statistics. + """ import onnxruntime as ort from polygraphy import constants from polygraphy.backend.onnx import BytesFromOnnx @@ -156,15 +275,30 @@ def run(self, inputs=None): logger.error(f"ONNXRuntime execution failed with output:\n{captured_output}") raise Exception("ONNXRuntime failed to run, see logs for details") - # Get the output results - output_dict = OrderedDict(results[0][1][0]) + # Collect all batch data (inputs + outputs) + all_batch_data = [] + runner_results = results[0][1] # Get all iteration results for the first runner + data_loader_iter = iter(data_loader) + + for iter_idx, iter_result in enumerate(runner_results): + output_dict = OrderedDict(iter_result) + + # Get corresponding input data + try: + input_data = next(data_loader_iter) + except StopIteration: + # If data_loader is exhausted, it might be a DataLoader that generates random data + input_data = {} - # Include input data for completeness - input_data = next(iter(data_loader)) + # Combine inputs and outputs for this batch + batch_dict = OrderedDict() + batch_dict.update(input_data) + batch_dict.update(output_dict) + all_batch_data.append(batch_dict) - # Combine inputs and outputs in the returned dictionary - combined_dict = OrderedDict() - combined_dict.update(input_data) - combined_dict.update(output_dict) + num_batches = len(all_batch_data) + if num_batches > 1: + logger.info(f"Processed {num_batches} batches of calibration data") - return combined_dict + # Aggregate statistics across all batches + return self._aggregate_tensor_stats(all_batch_data)