Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions modelopt/onnx/autocast/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,12 @@ def get_parser() -> argparse.ArgumentParser:
"--calibration_data",
"-d",
type=str,
help="File path to inputs for reference runner, either NPZ or Polygraphy JSON file. "
"If not provided, random inputs will be used",
help="File path to inputs for reference runner. Supports: "
"(1) NPZ file for single batch, "
"(2) Directory containing multiple NPZ files for multi-batch calibration, "
"(3) Polygraphy JSON file (supports multiple batches). "
"Multi-batch calibration aggregates statistics across all batches for more robust "
"precision conversion decisions. If not provided, random inputs will be used.",
)
parser.add_argument(
"--nodes_to_exclude",
Expand Down
69 changes: 61 additions & 8 deletions modelopt/onnx/autocast/nodeclassifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,20 +150,51 @@ def _log_skipped(self, node, **kwargs):


class IORangeRule(NodeRuleBase):
"""Rule for keeping nodes with out-of-range inputs/outputs in high precision."""
"""Rule for keeping nodes with out-of-range inputs/outputs in high precision.

Supports both single-batch (raw numpy arrays) and multi-batch (TensorStats objects)
reference data for flexible precision conversion decisions.
"""

def __init__(self, data_max, reference_data, node_to_init_map):
"""Initialize the rule.

Args:
data_max: Maximum absolute value allowed for node I/O.
reference_data: Reference data for checking I/O ranges.
reference_data: Reference data for checking I/O ranges. Can contain either
raw numpy arrays (single batch) or TensorStats objects (multi-batch aggregated).
node_to_init_map: Mapping from node names to their initializers.
"""
self.data_max = data_max
self.reference_data = reference_data
self.node_to_init_map = node_to_init_map
self.output_data = None
self.output_stats = None # For TensorStats

def _get_tensor_stats(self, ref_data):
"""Extract statistics from reference data (supports both numpy arrays and TensorStats).

Args:
ref_data: Either a numpy array or a TensorStats object.

Returns:
tuple: (absmax, min_val, max_val, size) statistics.
"""
# Import here to avoid circular imports
from modelopt.onnx.autocast.referencerunner import TensorStats

if isinstance(ref_data, TensorStats):
return ref_data.absmax, ref_data.min_val, ref_data.max_val, ref_data.size
else:
# Raw numpy array
if ref_data.size == 0:
return 0, 0, 0, 0
return (
np.max(np.abs(ref_data)),
np.min(ref_data),
np.max(ref_data),
ref_data.size,
)

def _check_inner(self, node):
def is_io_out_of_range(node, tensor_name):
Expand All @@ -176,18 +207,25 @@ def is_io_out_of_range(node, tensor_name):
f"Node {node.name}: Tensor {tensor_name} not found in reference data."
)
return False

ref_data = self.reference_data[tensor_name]
if ref_data.size == 0:
absmax, min_val, max_val, size = self._get_tensor_stats(ref_data)

if size == 0:
logger.debug(
f"Node {node.name}: Tensor {tensor_name} has size 0. Skipping I/O range check."
)
return False

logger.debug(
f"Node {node.name}: reference data: min={np.min(ref_data)}, max={np.max(ref_data)}"
f"Node {node.name}: reference data: min={min_val}, max={max_val}, absmax={absmax}"
)
if np.any(np.abs(ref_data) > self.data_max):

if absmax > self.data_max:
self.output_data = ref_data
self.output_stats = (absmax, min_val, max_val)
return True
return False

if node.op_type == "Constant":
return False
Expand All @@ -202,7 +240,13 @@ def is_io_out_of_range(node, tensor_name):

def _log_skipped(self, node, **kwargs):
"""Log information about skipped nodes with I/O range violations."""
if self.output_data is not None:
if self.output_stats is not None:
absmax, min_val, max_val = self.output_stats
logger.info(
f"Skipping node {node.name}: reference IO out of range: min={min_val}, "
f"max={max_val}, absmax={absmax}, range=[{-self.data_max}, {self.data_max}]"
)
elif self.output_data is not None:
logger.info(
f"Skipping node {node.name}: reference IO out of range: min={np.min(self.output_data)}, "
f"max={np.max(self.output_data)}, range=[{-self.data_max}, {self.data_max}]"
Expand Down Expand Up @@ -230,9 +274,18 @@ def __init__(self, max_depth_of_reduction, reference_data, node_to_init_map, ini
self.reduction_depth = 0

def _get_tensor_shape(self, tensor_name):
"""Get tensor shape from reference data."""
"""Get tensor shape from reference data.

Supports both raw numpy arrays and TensorStats objects.
"""
if tensor_name in self.reference_data:
return self.reference_data[tensor_name].shape
ref_data = self.reference_data[tensor_name]
# Import here to avoid circular imports
from modelopt.onnx.autocast.referencerunner import TensorStats

if isinstance(ref_data, TensorStats):
return ref_data.shape
return ref_data.shape
if tensor_name in self.initializer_map:
return self.initializer_map[tensor_name].dims
return None
Expand Down
164 changes: 149 additions & 15 deletions modelopt/onnx/autocast/referencerunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,16 @@
implementation. It supports both random input generation and user-provided inputs through
NPZ or Polygraphy JSON files. The runner is used to analyze model behavior and validate
outputs during precision conversion.

When multiple batches of calibration data are provided, the runner aggregates statistics
across all batches to provide more robust range information for precision conversion decisions.
"""

import copy
import io
import sys
from collections import OrderedDict
from dataclasses import dataclass

import numpy as np
import onnx
Expand All @@ -35,6 +39,35 @@
configure_logging()


@dataclass
class TensorStats:
"""Statistics for a tensor aggregated across multiple batches.

Attributes:
absmax: Maximum absolute value across all batches.
min_val: Minimum value across all batches.
max_val: Maximum value across all batches.
shape: Shape of the tensor (from first batch).
"""

absmax: float
min_val: float
max_val: float
shape: tuple

def __abs__(self):
"""Return the maximum absolute value (for compatibility with np.abs)."""
return self.absmax

@property
def size(self):
"""Return total number of elements."""
result = 1
for dim in self.shape:
result *= dim
return result


class ReferenceRunner:
"""A class to run ONNX models with ONNXRuntime for reference inference."""

Expand Down Expand Up @@ -69,8 +102,29 @@ def _load_inputs_from_json(self, input_data_path):
return load_json(input_data_path, description="input data")

def _load_inputs_from_npz(self, input_data_path):
"""Load inputs from NPZ format."""
return [np.load(input_data_path)]
"""Load inputs from NPZ format.

Supports both single NPZ file and directory containing multiple NPZ files for multi-batch calibration.

Args:
input_data_path: Path to NPZ file or directory containing NPZ files.

Returns:
List of input dictionaries, one per batch.
"""
import os

if os.path.isdir(input_data_path):
# Load all NPZ files in the directory as multiple batches
npz_files = sorted(
[f for f in os.listdir(input_data_path) if f.endswith(".npz")]
)
if not npz_files:
raise ValueError(f"No NPZ files found in directory: {input_data_path}")
logger.info(f"Loading {len(npz_files)} NPZ files from directory for multi-batch calibration")
return [np.load(os.path.join(input_data_path, f)) for f in npz_files]
else:
return [np.load(input_data_path)]

def _validate_inputs(self, data_loader):
"""Validate that input names and shapes match the model."""
Expand All @@ -96,16 +150,18 @@ def _load_inputs(self, inputs):
# If no inputs are provided, use random inputs
data_loader = DataLoader(val_range={"": (-1, 1)})

import os

if inputs is not None:
if isinstance(inputs, str):
if inputs.endswith(".json"):
data_loader = self._load_inputs_from_json(inputs)
elif inputs.endswith(".npz"):
elif inputs.endswith(".npz") or os.path.isdir(inputs):
data_loader = self._load_inputs_from_npz(inputs)
else:
raise ValueError(
f"Invalid input file: {inputs}. Supported input file types: .json (Polygraphy JSON format), "
".npz (Numpy)"
f"Invalid input file: {inputs}. Supported input types: .json (Polygraphy JSON format), "
".npz (Numpy), or a directory containing .npz files"
)
elif isinstance(inputs, (dict, OrderedDict)):
data_loader = [inputs]
Expand All @@ -118,8 +174,71 @@ def _load_inputs(self, inputs):

return data_loader

def _aggregate_tensor_stats(self, all_batch_data: list[OrderedDict]) -> OrderedDict:
"""Aggregate tensor statistics across multiple batches.

Args:
all_batch_data: List of dictionaries containing tensor data for each batch.

Returns:
OrderedDict mapping tensor names to TensorStats objects.
"""
if len(all_batch_data) == 1:
# Single batch - return raw data for backward compatibility
return all_batch_data[0]

logger.info(f"Aggregating statistics across {len(all_batch_data)} batches...")

aggregated = OrderedDict()
tensor_names = all_batch_data[0].keys()

for name in tensor_names:
absmax = -np.inf
min_val = np.inf
max_val = -np.inf
shape = None

for batch_data in all_batch_data:
if name not in batch_data:
continue
data = batch_data[name]
if shape is None:
shape = data.shape

batch_absmax = np.max(np.abs(data)) if data.size > 0 else 0
batch_min = np.min(data) if data.size > 0 else 0
batch_max = np.max(data) if data.size > 0 else 0

absmax = max(absmax, batch_absmax)
min_val = min(min_val, batch_min)
max_val = max(max_val, batch_max)

if shape is not None:
aggregated[name] = TensorStats(
absmax=absmax,
min_val=min_val,
max_val=max_val,
shape=shape,
)

return aggregated

def run(self, inputs=None):
"""Run FP32 inference with provided or random inputs."""
"""Run FP32 inference with provided or random inputs.

When multiple batches of input data are provided, inference is run for each batch
and statistics are aggregated across all batches for more robust range estimation.

Args:
inputs: Optional input data. Can be:
- None: Random inputs will be generated
- str: Path to JSON file, NPZ file, or directory containing NPZ files
- dict/OrderedDict: Single batch of input data

Returns:
OrderedDict: Combined input and output data. For single batch, returns raw arrays.
For multiple batches, returns TensorStats objects with aggregated statistics.
"""
import onnxruntime as ort
from polygraphy import constants
from polygraphy.backend.onnx import BytesFromOnnx
Expand Down Expand Up @@ -156,15 +275,30 @@ def run(self, inputs=None):
logger.error(f"ONNXRuntime execution failed with output:\n{captured_output}")
raise Exception("ONNXRuntime failed to run, see logs for details")

# Get the output results
output_dict = OrderedDict(results[0][1][0])
# Collect all batch data (inputs + outputs)
all_batch_data = []
runner_results = results[0][1] # Get all iteration results for the first runner
data_loader_iter = iter(data_loader)

for iter_idx, iter_result in enumerate(runner_results):
output_dict = OrderedDict(iter_result)

# Get corresponding input data
try:
input_data = next(data_loader_iter)
except StopIteration:
# If data_loader is exhausted, it might be a DataLoader that generates random data
input_data = {}

# Include input data for completeness
input_data = next(iter(data_loader))
# Combine inputs and outputs for this batch
batch_dict = OrderedDict()
batch_dict.update(input_data)
batch_dict.update(output_dict)
all_batch_data.append(batch_dict)

# Combine inputs and outputs in the returned dictionary
combined_dict = OrderedDict()
combined_dict.update(input_data)
combined_dict.update(output_dict)
num_batches = len(all_batch_data)
if num_batches > 1:
logger.info(f"Processed {num_batches} batches of calibration data")

return combined_dict
# Aggregate statistics across all batches
return self._aggregate_tensor_stats(all_batch_data)