From 2a26ca2651fc57e6675072d2b58dc4fca056dede Mon Sep 17 00:00:00 2001 From: Gal Hubara Agam <96368689+galagam@users.noreply.github.com> Date: Mon, 22 Dec 2025 15:38:13 +0200 Subject: [PATCH 01/11] Draft: AutoCast local implementation for type inference Signed-off-by: Gal Hubara Agam <96368689+galagam@users.noreply.github.com> --- docs/source/guides/8_autocast.rst | 18 + modelopt/onnx/autocast/__main__.py | 11 + modelopt/onnx/autocast/convert.py | 20 +- modelopt/onnx/autocast/precisionconverter.py | 63 ++- modelopt/onnx/utils.py | 371 ++++++++++++++++++ .../onnx/autocast/test_precisionconverter.py | 261 ++++++++---- 6 files changed, 658 insertions(+), 86 deletions(-) diff --git a/docs/source/guides/8_autocast.rst b/docs/source/guides/8_autocast.rst index 4ad39e969..75f94b57d 100644 --- a/docs/source/guides/8_autocast.rst +++ b/docs/source/guides/8_autocast.rst @@ -42,6 +42,7 @@ AutoCast can also be used programmatically through its Python API: trt_plugins=[], # list of TensorRT plugin library paths in .so format max_depth_of_reduction=None, # maximum depth of reduction allowed in low precision opset=None, # optional target ONNX opset version (default: 13 for fp16, 22 for bf16) + use_local_type_inference=False, # use local type inference instead of ONNX's infer_shapes (WAR) ) # Save the converted model @@ -82,6 +83,9 @@ AutoCast follows these steps to convert a model: - Converts eligible nodes to lower precision - Automatically inserts necessary cast operations - Automatically replaces initializers with lower precision values + - Performs type inference to propagate types through the graph + - By default, uses ONNX's ``infer_shapes`` which performs both shape and type inference using the ONNX infer_shapes API. + - Use ``use_local_type_inference=True`` to use a local type-only inference implementation (experimental). #. **Validation and Export**: @@ -145,6 +149,14 @@ Best Practices - A warning will be issued if you specify an opset lower than the original model's opset, as downgrading opset versions may cause compatibility issues. - The opset may be automatically increased beyond your specified value if certain operations require it (e.g., quantization nodes require opset >= 19). +#. **Type Inference Control** + + - By default, AutoCast uses ONNX's ``infer_shapes`` which performs both shape and type inference. + - Use ``--use_local_type_inference`` to enable a local type-only inference implementation. + - This is a workaround for cases where shape inference fails for any reason, which allows us to bypass the dependency in ONNX's shape inference logic. + - The local implementation uses graphsurgeon for topological sorting and handles special operators like Cast, QuantizeLinear, DequantizeLinear, Constant and ConstantOfShape. + - Note: The local type inference may be less robust than ONNX's implementation for edge cases, but avoids unnecessary shape inference overhead and possible failures. + Limitations and Restrictions ---------------------------- - AutoCast does not yet support quantized models. @@ -198,3 +210,9 @@ Convert to BF16 with a specific opset: .. code-block:: bash python -m modelopt.onnx.autocast --onnx_path model.onnx --low_precision_type bf16 --opset 22 + +Use local type inference instead of ONNX's infer_shapes: + +.. code-block:: bash + + python -m modelopt.onnx.autocast --onnx_path model.onnx --use_local_type_inference diff --git a/modelopt/onnx/autocast/__main__.py b/modelopt/onnx/autocast/__main__.py index cabeff733..d2dd74566 100644 --- a/modelopt/onnx/autocast/__main__.py +++ b/modelopt/onnx/autocast/__main__.py @@ -185,6 +185,16 @@ def get_parser() -> argparse.ArgumentParser: "higher version." ), ) + parser.add_argument( + "--use_local_type_inference", + action="store_true", + default=False, + help=( + "Use local type inference implementation instead of ONNX's infer_shapes (experimental)." + "This is a workaround for cases where shape inference fails for any reason." + "Default: False (uses ONNX's infer_shapes which does both shape and type inference)." + ), + ) return parser @@ -218,6 +228,7 @@ def main(argv=None): trt_plugins_precision=args.trt_plugins_precision, max_depth_of_reduction=args.max_depth_of_reduction, opset=args.opset, + use_local_type_inference=args.use_local_type_inference, ) output_path = args.output_path diff --git a/modelopt/onnx/autocast/convert.py b/modelopt/onnx/autocast/convert.py index 4328c9fc2..0e8775341 100644 --- a/modelopt/onnx/autocast/convert.py +++ b/modelopt/onnx/autocast/convert.py @@ -61,6 +61,7 @@ def convert_to_mixed_precision( trt_plugins_precision: list[str] = [], max_depth_of_reduction: int | None = None, opset: int | None = None, + use_local_type_inference: bool = False, ) -> onnx.ModelProto: """Convert model to mixed precision. @@ -85,6 +86,9 @@ def convert_to_mixed_precision( opset: Target ONNX opset version. If None, uses default minimum opset based on low_precision_type (22 for bf16, 13 for fp16). The opset may be automatically increased if certain operations require a higher version. + use_local_type_inference: If True, use local type inference implementation instead of ONNX's + infer_shapes. This is a workaround (WAR) when only type inference is + needed without shape inference. Default: False. Returns: onnx.ModelProto: The converted mixed precision model. @@ -132,7 +136,10 @@ def convert_to_mixed_precision( model = graph_sanitizer.model # Setup internal mappings - model = onnx_utils.infer_shapes(model) + if use_local_type_inference: + model = onnx_utils.infer_types(model) + else: + model = onnx_utils.infer_shapes(model) value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) # Automatically add 'trt' to list of providers if custom ops are detected @@ -164,6 +171,7 @@ def convert_to_mixed_precision( low_precision_type=low_precision_type, init_conversion_max_bytes=init_conversion_max_bytes, custom_ops=graph_sanitizer.custom_ops, + use_local_type_inference=use_local_type_inference, ) # Obtain reference data @@ -196,6 +204,7 @@ def convert_to_f16( op_block_list: list[str] = [], tensor_block_dict: dict[str, dict[str, list[int]]] = {}, trt_plugins: list[str] | None = [], + use_local_type_inference: bool = False, ) -> onnx.ModelProto: """Convert model to mixed precision, using PrecisionConverter. @@ -208,6 +217,9 @@ def convert_to_f16( op_block_list: List of operation types that should remain in FP32. tensor_block_dict: Dictionary of tensors (operation type and I/O indices) that should remain in FP32. trt_plugins: List of TensorRT plugin library paths in .so format (compiled shared library). + use_local_type_inference: If True, use local type inference implementation instead of ONNX's + infer_shapes. This is a workaround (WAR) when only type inference is + needed without shape inference. Default: False. """ assert low_precision_type in ["fp16", "bf16"], "low_precision_type must be either fp16 or bf16" @@ -225,7 +237,10 @@ def convert_to_f16( model = sanitizer.model # Setup internal mappings - model = onnx_utils.infer_shapes(model) + if use_local_type_inference: + model = onnx_utils.infer_types(model) + else: + model = onnx_utils.infer_shapes(model) value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) precision_converter = PrecisionConverter( @@ -237,6 +252,7 @@ def convert_to_f16( low_precision_type=low_precision_type, custom_ops=sanitizer.custom_ops, tensor_block_dict=tensor_block_dict, + use_local_type_inference=use_local_type_inference, ) high_precision_nodes = [node.name for node in model.graph.node if node.op_type in op_block_list] low_precision_nodes = [ diff --git a/modelopt/onnx/autocast/precisionconverter.py b/modelopt/onnx/autocast/precisionconverter.py index eae589d8a..f351c4fdb 100644 --- a/modelopt/onnx/autocast/precisionconverter.py +++ b/modelopt/onnx/autocast/precisionconverter.py @@ -97,6 +97,7 @@ def __init__( max_ir_version: int | None = None, trt_plugins: list[str] | None = [], tensor_block_dict: dict[str, dict[str, list[int]]] = {}, + use_local_type_inference: bool = False, ) -> None: """Initialize PrecisionConverter. @@ -140,6 +141,7 @@ def __init__( self.min_opset = min_opset self.max_ir_version = max_ir_version self.trt_plugins = trt_plugins + self.use_local_type_inference = use_local_type_inference # Detect additional ops not supported in low precision according to the model's opset version self.op_types_not_supported_in_low_precision = OP_TYPES_NOT_SUPPORTED_IN_LOW_PRECISION + ( @@ -157,6 +159,20 @@ def __init__( self._warned_values_clamp_max = False self._warned_values_clamp_min = False + def infer_types(self, **kwargs): + """Infers types (and optionally shapes) based on the use_local_type_inference flag. + + Args: + **kwargs: Additional arguments passed to infer_shapes when not using local type inference. + + Returns: + onnx.ModelProto: Model with inferred types (and shapes if not using local type inference). + """ + if self.use_local_type_inference: + return onnx_utils.infer_types(self.model) + else: + return onnx_utils.infer_shapes(self.model, **kwargs) + def convert( self, high_precision_nodes: list[str], @@ -251,13 +267,39 @@ def convert( # Populate type information with inferred types self.model = self._propagate_types_shapes_custom_ops(self.model) else: + # Preserve original output types (before clearing) + # Store in instance variable so we can restore after cleanup + # Always preserve non-float types, and preserve all types if keep_io_types is True + self._original_output_types = {} + for out in self.model.graph.output: + if out.type.HasField("tensor_type"): + original_type = out.type.tensor_type.elem_type + # Always preserve non-float types (INT64, INT32, etc.) + # Also preserve all types if keep_io_types is True + if original_type not in ONNX_TYPES or self.keep_io_types: + self._original_output_types[out.name] = original_type + # Clear type/shape information for intermediates and outputs (including subgraphs) self._clear_types_and_shapes_recursive(self.model.graph) + # Populate type information with inferred types - self.model = onnx_utils.infer_shapes(self.model, strict_mode=True, check_type=False) + self.model = self.infer_types(strict_mode=True, check_type=False) self._ensure_types_are_defined() - # Sanity check: Verify type correctness - self.model = onnx_utils.infer_shapes(self.model, strict_mode=True, check_type=True) + + # Restore original output types (non-float types and all types if keep_io_types is True) + if hasattr(self, "_original_output_types"): + for out in self.model.graph.output: + if out.name in self._original_output_types: + out.type.tensor_type.elem_type = self._original_output_types[out.name] + + # Sanity check: Verify type correctness (only when using ONNX's infer_shapes) + if not self.use_local_type_inference: + self.model = self.infer_types(strict_mode=True, check_type=True) + # Restore output types again after second inference + if hasattr(self, "_original_output_types"): + for out in self.model.graph.output: + if out.name in self._original_output_types: + out.type.tensor_type.elem_type = self._original_output_types[out.name] # Update value_info_map and initializer_map with casts we added self.value_info_map, self.initializer_map, self.node_to_init_map = utils.setup_mappings( @@ -267,6 +309,14 @@ def convert( # Remove redundant casts self._cleanup() + # Restore original output types after cleanup + # (cleanup may have modified outputs, so we need to restore types again) + # Always restore non-float types, and all types if keep_io_types is True + if hasattr(self, "_original_output_types"): + for out in self.model.graph.output: + if out.name in self._original_output_types: + out.type.tensor_type.elem_type = self._original_output_types[out.name] + self._sanity_check() return self.model @@ -1175,8 +1225,9 @@ def _remove_redundant_casts(self): if self.custom_ops: self.model = self._propagate_types_shapes_custom_ops(self.model) else: - self.model = onnx_utils.infer_shapes(self.model, strict_mode=True) - self.model = onnx_utils.infer_shapes(self.model, strict_mode=True, check_type=True) + self.model = self.infer_types(strict_mode=True) + if not self.use_local_type_inference: + self.model = self.infer_types(strict_mode=True, check_type=True) nodes_to_remove = [] for node in self.model.graph.node: @@ -1261,7 +1312,7 @@ def _fix_network_output_names(self): if self.custom_ops: self.model = self._propagate_types_shapes_custom_ops(self.model) else: - self.model = onnx_utils.infer_shapes(self.model, strict_mode=True, check_type=True) + self.model = self.infer_types(strict_mode=True, check_type=True) self.value_info_map, self.initializer_map, self.node_to_init_map = utils.setup_mappings( self.model ) diff --git a/modelopt/onnx/utils.py b/modelopt/onnx/utils.py index a6b37758e..ef08bd5c0 100644 --- a/modelopt/onnx/utils.py +++ b/modelopt/onnx/utils.py @@ -728,6 +728,377 @@ def get_attribute(node: onnx.NodeProto, attr_name: str) -> Any: raise ValueError(f"Attribute {attr_name} not found in node {node.name}") +def infer_types(model: onnx.ModelProto) -> onnx.ModelProto: + """Infers types (but not shapes) of the onnx graph using local implementation. + + This is a workaround for cases when ONNX's shape inference fails. + ONNX's infer_shapes performs both shape and type inference together, but for AutoCast, we only + need type inference. + + Args: + model: ONNX model to infer types for. + + Returns: + onnx.ModelProto: Model with inferred types updated in value_info and outputs. + """ + from modelopt.onnx.autocast import utils as autocast_utils + + # Get opset version + opset = get_opset_version(model) + + # Process each graph (main graph and all subgraphs) recursively + def infer_types_for_graph( + graph: onnx.GraphProto, parent_node: onnx.NodeProto = None, is_subgraph: bool = False + ) -> None: + """Infer types for a single graph (main or subgraph). + + Args: + graph: The graph to infer types for. + parent_node: The parent node containing this subgraph (None for main graph). + is_subgraph: Whether this is a subgraph (True) or the main graph (False). + """ + # Use graphsurgeon to topologically sort nodes for efficient single-pass traversal + # Create a temporary model with just this graph for graphsurgeon + temp_model = onnx.ModelProto() + temp_model.graph.CopyFrom(graph) + temp_model.opset_import.add().version = opset + temp_model.ir_version = model.ir_version + + try: + gs_graph = gs.import_onnx(temp_model) + gs_graph.toposort() + except Exception as e: + logger.debug( + f"Graphsurgeon toposort failed for {'subgraph' if is_subgraph else 'main graph'}," + f"using original order: {e!s}" + ) + # Fallback: process nodes in original order + gs_graph = None + + # Create mappings for quick lookup for this graph + initializer_map = {init.name: init for init in graph.initializer} + value_info_map = {vi.name: vi for vi in graph.value_info} + output_names = {out.name for out in graph.output} + + # Map tensor names to their inferred types (scoped to this graph) + tensor_types = {} + + # Initialize types from inputs and initializers + for inp in graph.input: + if inp.type.HasField("tensor_type"): + tensor_types[inp.name] = inp.type.tensor_type.elem_type + + for init_name, init in initializer_map.items(): + tensor_types[init_name] = init.data_type + + # Helper function to get tensor type + def get_tensor_type(tensor_name: str) -> int | None: + if tensor_name in tensor_types: + return tensor_types[tensor_name] + if tensor_name in value_info_map: + vi = value_info_map[tensor_name] + if vi.type.HasField("tensor_type"): + return vi.type.tensor_type.elem_type + return None + + def str_to_tensor_dtype(dtype_str: str) -> onnx.TensorProto.DataType: + """Converts a string representation of a tensor dtype to an onnx.TensorProto.DataType.""" + _str_to_tensor_dtype = { + "float": onnx.TensorProto.FLOAT, + "uint8": onnx.TensorProto.UINT8, + "int8": onnx.TensorProto.INT8, + "uint16": onnx.TensorProto.UINT16, + "int16": onnx.TensorProto.INT16, + "int32": onnx.TensorProto.INT32, + "int64": onnx.TensorProto.INT64, + "string": onnx.TensorProto.STRING, + "bool": onnx.TensorProto.BOOL, + "float16": onnx.TensorProto.FLOAT16, + "double": onnx.TensorProto.DOUBLE, + "uint32": onnx.TensorProto.UINT32, + "uint64": onnx.TensorProto.UINT64, + "complex64": onnx.TensorProto.COMPLEX64, + "complex128": onnx.TensorProto.COMPLEX128, + "bfloat16": onnx.TensorProto.BFLOAT16, + "float8e4m3fn": onnx.TensorProto.FLOAT8E4M3FN, + "float8e4m3fnuz": onnx.TensorProto.FLOAT8E4M3FNUZ, + "float8e5m2": onnx.TensorProto.FLOAT8E5M2, + "float8e5m2fnuz": onnx.TensorProto.FLOAT8E5M2FNUZ, + "uint4": onnx.TensorProto.UINT4, + "int4": onnx.TensorProto.INT4, + "float4e2m1": onnx.TensorProto.FLOAT4E2M1, + "float8e8m0": onnx.TensorProto.FLOAT8E8M0, + } + try: + str_sanitized = dtype_str.replace("tensor(", "").replace(")", "") + return _str_to_tensor_dtype[str_sanitized] + except KeyError: + raise ValueError(f"Invalid tensor dtype string: {str_sanitized}") + + # Create mapping from node name to ONNX node for efficient lookup + node_name_to_onnx = {node.name: node for node in graph.node} + + # Get nodes to process (from graphsurgeon if available, otherwise from graph directly) + if gs_graph is not None: + nodes_to_process = gs_graph.nodes + else: + nodes_to_process = graph.node + + # Process nodes in topological order (single pass) + for gs_node_or_onnx_node in nodes_to_process: + # Get corresponding ONNX node + if gs_graph is not None: + # From graphsurgeon + node = node_name_to_onnx.get(gs_node_or_onnx_node.name) + else: + # Direct from graph + node = gs_node_or_onnx_node + + if node is None: + if gs_graph is not None: + logger.debug( + f"Could not find ONNX node for graphsurgeon node: {gs_node_or_onnx_node.name}" + ) + continue + + # Get input types for this node + input_types = [] + for inp_name in node.input: + if not inp_name: + continue + inp_type = get_tensor_type(inp_name) + if inp_type is not None: + input_types.append(inp_type) + else: + # In topologically sorted order, this shouldn't happen unless + # the input is from an initializer/input we missed or there's a cycle + logger.debug(f"Warning: Input {inp_name} of node {node.name} has unknown type") + # Use FLOAT as fallback + input_types.append(onnx.TensorProto.FLOAT) + + # Infer output types for this node + output_types = [] + + if node.op_type == "Cast": + # Cast node: output type is the 'to' attribute + cast_to_type = None + for attr in node.attribute: + if attr.name == "to": + cast_to_type = attr.i + break + if cast_to_type is not None: + output_types = [cast_to_type] * len(node.output) + else: + # Fallback: use input type if cast target unknown + output_types = ( + input_types[: len(node.output)] + if input_types + else [onnx.TensorProto.FLOAT] * len(node.output) + ) + elif node.op_type == "DequantizeLinear": + # DequantizeLinear: output type is determined by output_dtype attribute if present, + # otherwise use the scale type (input[1]) + # inputs: [data, scale, zero_point (optional)] + output_dtype = None + for attr in node.attribute: + if attr.name == "output_dtype": + output_dtype = attr.i + break + + if output_dtype is not None: + output_types = [output_dtype] * len(node.output) + elif len(node.input) >= 2 and node.input[1]: + scale_type = get_tensor_type(node.input[1]) + if scale_type is not None: + output_types = [scale_type] * len(node.output) + else: + # Fallback: use first input type or FLOAT + output_types = [ + input_types[0] if input_types else onnx.TensorProto.FLOAT + ] * len(node.output) + else: + output_types = [ + input_types[0] if input_types else onnx.TensorProto.FLOAT + ] * len(node.output) + elif node.op_type == "QuantizeLinear": + # QuantizeLinear: output type is determined by output_dtype attribute if present, + # otherwise use the zero_point type (input[2]) + # inputs: [data, scale, zero_point] + output_dtype = None + for attr in node.attribute: + if attr.name == "output_dtype": + output_dtype = attr.i + break + + if output_dtype is not None: + output_types = [output_dtype] * len(node.output) + elif len(node.input) >= 3 and node.input[2]: + zero_point_type = get_tensor_type(node.input[2]) + if zero_point_type is not None: + output_types = [zero_point_type] * len(node.output) + else: + # Fallback: typically UINT8 or INT8 for quantized types + output_types = [onnx.TensorProto.UINT8] * len(node.output) + else: + # Fallback: typically UINT8 or INT8 for quantized types + output_types = [onnx.TensorProto.UINT8] * len(node.output) + elif node.op_type == "Constant": + # Constant: output type is from the value attribute's tensor data_type + const_type = None + for attr in node.attribute: + if attr.name == "value" and attr.type == onnx.AttributeProto.TENSOR: + if attr.t.HasField("data_type"): + const_type = attr.t.data_type + break + if const_type is not None: + output_types = [const_type] * len(node.output) + else: + # Fallback: use FLOAT if type cannot be determined + output_types = [onnx.TensorProto.FLOAT] * len(node.output) + elif node.op_type == "ConstantOfShape": + # ConstantOfShape: output type is from the value attribute's tensor data_type + # If no value attribute, defaults to FLOAT + # Note: Schema allows multiple types, so we need to check the value attribute + const_type = onnx.TensorProto.FLOAT # default + for attr in node.attribute: + if attr.name == "value" and attr.type == onnx.AttributeProto.TENSOR: + if attr.t.HasField("data_type"): + const_type = attr.t.data_type + break + output_types = [const_type] * len(node.output) + elif node.op_type == "Split": + # Split schema allows multiple outputs, but the schema only specifies one output type + output_types = [input_types[0]] * len(node.output) + else: + default_type = input_types[0] if input_types else onnx.TensorProto.FLOAT + # Use ONNX operator schema to determine output types + try: + schema = onnx.defs.get_schema(node.op_type, opset, domain=node.domain or "") + assert schema.outputs and len(schema.outputs) >= len(node.output) + except Exception as e: + # Fallback: if schema lookup fails, propagate first input type + logger.debug( + f"Node {node.name}: Failed to get schema for {node.op_type}: {e}, propagate first input type" + ) + default_type = input_types[0] if input_types else onnx.TensorProto.FLOAT + output_types = [default_type] * len(node.output) + + # Try to infer from schema + input_schemas = [schema.inputs[i].type_str for i in range(len(schema.inputs))] + output_schemas = [schema.outputs[i].type_str for i in range(len(schema.outputs))] + output_types = [None] * len(node.output) + + for output_idx in range(len(node.output)): + # explicit type is set in schema, use it + if "tensor" in output_schemas[output_idx]: + found_type = str_to_tensor_dtype(output_schemas[output_idx]) + output_types[output_idx] = found_type + continue + # sometimes output type is set with a placeholder name despite supporting a single type + # e.g. Shape operator is constrained to int64, but the type_str is "T1" + for constraint in schema.type_constraints: + # If output type constraint has only one allowed type, use it directly + if constraint.type_param_str == output_schemas[output_idx]: + if len(constraint.allowed_type_strs) == 1: + found_type = str_to_tensor_dtype(constraint.allowed_type_strs[0]) + output_types[output_idx] = found_type + break + else: # we have a placeholder name "T", "T1", "T2", etc that should match one of the input types + try: + input_match_idx = input_schemas.index(output_schemas[output_idx]) + except ValueError: + input_match_idx = None + if input_match_idx is not None: + found_type = input_types[input_match_idx] + else: + found_type = default_type + logger.debug( + f"Node {node.name}: Failed to infer type for output #{output_idx}, propagate first " + "input type" + ) + output_types[output_idx] = found_type + + # Update output tensor types + for out_idx, out_name in enumerate(node.output): + if not out_name or out_idx >= len(output_types): + continue + + output_type = output_types[out_idx] + tensor_types[out_name] = output_type + + # Update value_info if it exists + if out_name in value_info_map: + value_info_map[out_name].type.tensor_type.elem_type = output_type + elif out_name not in output_names: + # Create new value_info for intermediate tensor + new_vi = graph.value_info.add() + new_vi.name = out_name + new_vi.type.tensor_type.elem_type = output_type + value_info_map[out_name] = new_vi + + # Update output types for this graph + for out in graph.output: + if out.name in tensor_types: + out.type.tensor_type.elem_type = tensor_types[out.name] + + # Process main graph and all subgraphs recursively + autocast_utils.walk_subgraphs_recursive(model.graph, infer_types_for_graph, is_subgraph=False) + infer_types_verification(model) + return model + + +def infer_types_verification(model: onnx.ModelProto) -> onnx.ModelProto: + """Verify that all reachable tensors have a defined type. + + This is necessary because some nodes may be removed during the inference process, + leaving unreachable value_info entries. + """ + reachable_tensors = set() + + # Add graph inputs as reachable + for inp in model.graph.input: + reachable_tensors.add(inp.name) + + # Add initializers as reachable + for init in model.graph.initializer: + reachable_tensors.add(init.name) + + # Traverse nodes to find all reachable tensor outputs + for node in model.graph.node: + # A node is reachable if any of its inputs are reachable + # (or if it has no inputs - rare but possible) + node_is_reachable = not node.input or any( + inp in reachable_tensors for inp in node.input if inp + ) + + if node_is_reachable: + # All outputs of a reachable node are reachable + for out in node.output: + if out: # Skip empty output names + reachable_tensors.add(out) + + is_undefined = False + # Check value_info for reachable tensors + for vi in model.graph.value_info: + if vi.name in reachable_tensors: + if vi.type.tensor_type.elem_type == onnx.TensorProto.UNDEFINED: + logger.error( + f"Infer types verification failed. Value info {vi.name} has undefined type" + ) + is_undefined = True + + # Graph outputs should always be reachable + for out in model.graph.output: + if out.type.tensor_type.elem_type == onnx.TensorProto.UNDEFINED: + logger.error(f"Infer types verification failed. Output {out.name} has undefined type") + is_undefined = True + if is_undefined: + raise ValueError( + "Infer types verification failed. Undefined types found in the model - see logs for details." + ) + return model + + def infer_shapes(model: onnx.ModelProto, **kwargs): """Infers shapes of the onnx graph, handles large models.""" if model.ByteSize() > (2 * (1024**3)): # 2GB limit diff --git a/tests/unit/onnx/autocast/test_precisionconverter.py b/tests/unit/onnx/autocast/test_precisionconverter.py index 4fb02a230..c6f9d5627 100644 --- a/tests/unit/onnx/autocast/test_precisionconverter.py +++ b/tests/unit/onnx/autocast/test_precisionconverter.py @@ -32,6 +32,18 @@ def low_precision_onnx_type(low_precision_type_str): return TensorProto.FLOAT16 if low_precision_type_str == "fp16" else TensorProto.BFLOAT16 +def setup_mappings( + model: onnx.ModelProto, use_local_type_inference: bool = False +) -> tuple[onnx.ModelProto, dict, dict, dict]: + # Setup internal mappings + if use_local_type_inference: + model = onnx_utils.infer_types(model) + else: + model = onnx_utils.infer_shapes(model) + value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) + return model, value_info_map, initializer_map, node_to_init_map + + #################################################################################################### # Testing with a basic GEMM->Add->Relu graph #################################################################################################### @@ -56,16 +68,21 @@ def simple_model(): model.ir_version = 10 onnx.checker.check_model(model) - model = onnx_utils.infer_shapes(model) - value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) + model, value_info_map, initializer_map, node_to_init_map = setup_mappings(model) return model, value_info_map, initializer_map, node_to_init_map -def test_graph_converter_init(simple_model): +@pytest.mark.parametrize("use_local_type_inference", [True, False]) +def test_graph_converter_init(simple_model, use_local_type_inference): model, value_info_map, initializer_map, node_to_init_map = simple_model converter = PrecisionConverter( - model, value_info_map, initializer_map, node_to_init_map, keep_io_types=True + model, + value_info_map, + initializer_map, + node_to_init_map, + keep_io_types=True, + use_local_type_inference=use_local_type_inference, ) assert converter.model == model assert converter.value_info_map == value_info_map @@ -75,7 +92,8 @@ def test_graph_converter_init(simple_model): @pytest.mark.parametrize("keep_io_types", [True, False]) @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) -def test_simple_convert(simple_model, keep_io_types, low_precision_type): +@pytest.mark.parametrize("use_local_type_inference", [True, False]) +def test_simple_convert(simple_model, keep_io_types, low_precision_type, use_local_type_inference): model, value_info_map, initializer_map, node_to_init_map = simple_model converter = PrecisionConverter( model, @@ -84,6 +102,7 @@ def test_simple_convert(simple_model, keep_io_types, low_precision_type): node_to_init_map, keep_io_types=keep_io_types, low_precision_type=low_precision_type, + use_local_type_inference=use_local_type_inference, ) # Convert add node to fp16, keep mul in fp32 @@ -133,7 +152,10 @@ def test_unsupported_precision_type(simple_model, low_precision_type): @pytest.mark.parametrize("keep_io_types", [True, False]) @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) -def test_convert_no_disabled_nodes(simple_model, keep_io_types, low_precision_type): +@pytest.mark.parametrize("use_local_type_inference", [True, False]) +def test_convert_no_disabled_nodes( + simple_model, keep_io_types, low_precision_type, use_local_type_inference +): model, value_info_map, initializer_map, node_to_init_map = simple_model converter = PrecisionConverter( model, @@ -142,6 +164,7 @@ def test_convert_no_disabled_nodes(simple_model, keep_io_types, low_precision_ty node_to_init_map, keep_io_types=keep_io_types, low_precision_type=low_precision_type, + use_local_type_inference=use_local_type_inference, ) # Convert all nodes to fp16 @@ -167,7 +190,10 @@ def test_convert_no_disabled_nodes(simple_model, keep_io_types, low_precision_ty @pytest.mark.parametrize("keep_io_types", [True, False]) @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) -def test_get_tensors_to_cast(simple_model, keep_io_types, low_precision_type): +@pytest.mark.parametrize("use_local_type_inference", [True, False]) +def test_get_tensors_to_cast( + simple_model, keep_io_types, low_precision_type, use_local_type_inference +): model, value_info_map, initializer_map, node_to_init_map = simple_model converter = PrecisionConverter( model, @@ -176,6 +202,7 @@ def test_get_tensors_to_cast(simple_model, keep_io_types, low_precision_type): node_to_init_map, keep_io_types=keep_io_types, low_precision_type=low_precision_type, + use_local_type_inference=use_local_type_inference, ) # Test when relu node is in low precision @@ -196,7 +223,8 @@ def test_get_tensors_to_cast(simple_model, keep_io_types, low_precision_type): @pytest.mark.parametrize("keep_io_types", [True, False]) @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) -def test_keep_io_names(simple_model, keep_io_types, low_precision_type): +@pytest.mark.parametrize("use_local_type_inference", [True, False]) +def test_keep_io_names(simple_model, keep_io_types, low_precision_type, use_local_type_inference): model, value_info_map, initializer_map, node_to_init_map = simple_model converter = PrecisionConverter( model, @@ -205,6 +233,7 @@ def test_keep_io_names(simple_model, keep_io_types, low_precision_type): node_to_init_map, keep_io_types=keep_io_types, low_precision_type=low_precision_type, + use_local_type_inference=use_local_type_inference, ) # Convert all nodes to low precision @@ -258,16 +287,16 @@ def model_with_multiple_consumers(): model.ir_version = 10 onnx.checker.check_model(model) - model = onnx_utils.infer_shapes(model) - value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) + model, value_info_map, initializer_map, node_to_init_map = setup_mappings(model) return model, value_info_map, initializer_map, node_to_init_map @pytest.mark.parametrize("keep_io_types", [True, False]) @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) +@pytest.mark.parametrize("use_local_type_inference", [True, False]) def test_convert_with_multiple_consumers( - model_with_multiple_consumers, keep_io_types, low_precision_type + model_with_multiple_consumers, keep_io_types, low_precision_type, use_local_type_inference ): model, value_info_map, initializer_map, node_to_init_map = model_with_multiple_consumers converter = PrecisionConverter( @@ -277,6 +306,7 @@ def test_convert_with_multiple_consumers( node_to_init_map, keep_io_types=keep_io_types, low_precision_type=low_precision_type, + use_local_type_inference=use_local_type_inference, ) # Only gemm1 and add1 are converted to fp32, gemm2 and add2 are fp16 @@ -300,8 +330,9 @@ def test_convert_with_multiple_consumers( @pytest.mark.parametrize("keep_io_types", [True, False]) @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) +@pytest.mark.parametrize("use_local_type_inference", [True, False]) def test_get_tensors_to_cast_multiple_consumers( - model_with_multiple_consumers, keep_io_types, low_precision_type + model_with_multiple_consumers, keep_io_types, low_precision_type, use_local_type_inference ): model, value_info_map, initializer_map, node_to_init_map = model_with_multiple_consumers converter = PrecisionConverter( @@ -311,6 +342,7 @@ def test_get_tensors_to_cast_multiple_consumers( node_to_init_map, keep_io_types=keep_io_types, low_precision_type=low_precision_type, + use_local_type_inference=use_local_type_inference, ) # Test when gemm2 and add1 nodes are in low precision @@ -327,7 +359,10 @@ def test_get_tensors_to_cast_multiple_consumers( @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) -def test_convert_initializers(model_with_multiple_consumers, low_precision_type): +@pytest.mark.parametrize("use_local_type_inference", [True, False]) +def test_convert_initializers( + model_with_multiple_consumers, low_precision_type, use_local_type_inference +): model, value_info_map, initializer_map, node_to_init_map = model_with_multiple_consumers converter = PrecisionConverter( model, @@ -335,6 +370,7 @@ def test_convert_initializers(model_with_multiple_consumers, low_precision_type) initializer_map, node_to_init_map, low_precision_type=low_precision_type, + use_local_type_inference=use_local_type_inference, ) # Test successful cast, add1 and add2 share add_init and operate in different precisions @@ -361,6 +397,7 @@ def test_convert_initializers(model_with_multiple_consumers, low_precision_type) initializer_map, node_to_init_map, low_precision_type=low_precision_type, + use_local_type_inference=use_local_type_inference, ) add1_node = next(n for n in converter2.model.graph.node if n.name == "add1") add2_node = next(n for n in converter2.model.graph.node if n.name == "add2") @@ -384,6 +421,7 @@ def test_convert_initializers(model_with_multiple_consumers, low_precision_type) initializer_map, node_to_init_map, low_precision_type=low_precision_type, + use_local_type_inference=use_local_type_inference, ) add1_node = next(n for n in converter3.model.graph.node if n.name == "add1") add2_node = next(n for n in converter3.model.graph.node if n.name == "add2") @@ -404,7 +442,10 @@ def test_convert_initializers(model_with_multiple_consumers, low_precision_type) assert f"add_init_{low_precision_type}" in init_names -def test_clamping_fp16_initializers_out_of_range(model_with_multiple_consumers): +@pytest.mark.parametrize("use_local_type_inference", [True, False]) +def test_clamping_fp16_initializers_out_of_range( + model_with_multiple_consumers, use_local_type_inference +): model, value_info_map, initializer_map, node_to_init_map = model_with_multiple_consumers # Initializer is out of FP16 range, node is converted to FP16 @@ -412,7 +453,13 @@ def test_clamping_fp16_initializers_out_of_range(model_with_multiple_consumers): add_init = numpy_helper.from_array(add_init_out_of_range, name="add_init") model.graph.initializer[1].CopyFrom(add_init) - converter = PrecisionConverter(model, value_info_map, initializer_map, node_to_init_map) + converter = PrecisionConverter( + model, + value_info_map, + initializer_map, + node_to_init_map, + use_local_type_inference=use_local_type_inference, + ) converter._convert_initializers(low_precision_nodes=["add1", "add2"], high_precision_nodes=[]) # Verify initializer is clamped @@ -427,7 +474,13 @@ def test_clamping_fp16_initializers_out_of_range(model_with_multiple_consumers): assert add_init_converted_array[0, 1] == np.finfo(np.float16).max # Initializer is out of FP16 range, node is kept in FP32 - converter2 = PrecisionConverter(model, value_info_map, initializer_map, node_to_init_map) + converter2 = PrecisionConverter( + model, + value_info_map, + initializer_map, + node_to_init_map, + use_local_type_inference=use_local_type_inference, + ) converter2._convert_initializers(low_precision_nodes=[], high_precision_nodes=["add1", "add2"]) # Verify initializer is not clamped @@ -441,7 +494,13 @@ def test_clamping_fp16_initializers_out_of_range(model_with_multiple_consumers): assert np.all(add_init_converted_array == add_init_out_of_range) # Initializer is out of FP16 range, one consumer is converted to FP16, the other is kept in FP32 - converter3 = PrecisionConverter(model, value_info_map, initializer_map, node_to_init_map) + converter3 = PrecisionConverter( + model, + value_info_map, + initializer_map, + node_to_init_map, + use_local_type_inference=use_local_type_inference, + ) converter3._convert_initializers(low_precision_nodes=["add1"], high_precision_nodes=["add2"]) # Verify initializer is duplicated, and the FP16 copy is clamped @@ -462,7 +521,10 @@ def test_clamping_fp16_initializers_out_of_range(model_with_multiple_consumers): assert np.all(add_init_fp32_array == add_init_out_of_range) -def test_bf16_no_clamping_initializers_out_of_range(model_with_multiple_consumers): +@pytest.mark.parametrize("use_local_type_inference", [True, False]) +def test_bf16_no_clamping_initializers_out_of_range( + model_with_multiple_consumers, use_local_type_inference +): model, value_info_map, initializer_map, node_to_init_map = model_with_multiple_consumers # Initializer is out of FP16 range, but that does not affect BF16 conversion @@ -476,6 +538,7 @@ def test_bf16_no_clamping_initializers_out_of_range(model_with_multiple_consumer initializer_map, node_to_init_map, low_precision_type="bf16", + use_local_type_inference=use_local_type_inference, ) converter._convert_initializers(low_precision_nodes=["add1", "add2"], high_precision_nodes=[]) @@ -511,13 +574,13 @@ def model_with_dynamic_shapes(): matmul_node = helper.make_node("MatMul", ["X", "weight"], ["matmul_out"], name="matmul") transpose_node = helper.make_node("Transpose", ["Y"], ["transpose_out"], name="transpose") concat_node = helper.make_node( - "Concat", ["matmul_out", "transpose_out"], ["concat_out"], name="concat", axis=0 + "Concat", ["matmul_out", "transpose_out"], ["concat_out"], name="concat1", axis=0 ) size_y = helper.make_node("Size", ["concat_out"], ["total_size"], name="size") const_4 = numpy_helper.from_array(np.array([4], dtype=np.int64), name="const_4") first_dim = helper.make_node("Div", ["total_size", "const_4"], ["first_dim"], name="div") concat_dims_node = helper.make_node( - "Concat", ["first_dim", "const_4"], ["final_shape"], name="concat", axis=0 + "Concat", ["first_dim", "const_4"], ["final_shape"], name="concat2", axis=0 ) reshape_node = helper.make_node("Reshape", ["concat_out", "final_shape"], ["Z"], name="reshape") @@ -540,18 +603,23 @@ def model_with_dynamic_shapes(): model = helper.make_model(graph, producer_name="model_dynamic") model.opset_import[0].version = 20 model.ir_version = 10 - model = onnx_utils.infer_shapes(model) - - value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) + model, value_info_map, initializer_map, node_to_init_map = setup_mappings(model) return model, value_info_map, initializer_map, node_to_init_map -def test_dynamic_model_conversion(model_with_dynamic_shapes): +@pytest.mark.parametrize("use_local_type_inference", [True, False]) +def test_dynamic_model_conversion(model_with_dynamic_shapes, use_local_type_inference): model, value_info_map, initializer_map, node_to_init_map = model_with_dynamic_shapes # Test mixed precision conversion - converter2 = PrecisionConverter(model, value_info_map, initializer_map, node_to_init_map) + converter2 = PrecisionConverter( + model, + value_info_map, + initializer_map, + node_to_init_map, + use_local_type_inference=use_local_type_inference, + ) high_precision_nodes = ["matmul"] low_precision_nodes = ["transpose", "concat", "size", "div", "concat_dims", "reshape"] @@ -563,7 +631,8 @@ def test_dynamic_model_conversion(model_with_dynamic_shapes): #################################################################################################### # Cast cleanup logic #################################################################################################### -def test_cast_output_pattern(): +@pytest.mark.parametrize("use_local_type_inference", [True, False]) +def test_cast_output_pattern(use_local_type_inference): x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [3, 4]) y1 = helper.make_tensor_value_info("Y1", TensorProto.FLOAT, [3, 4]) y2 = helper.make_tensor_value_info("Y2", TensorProto.FLOAT, [3, 4]) @@ -583,10 +652,14 @@ def test_cast_output_pattern(): model = helper.make_model(graph, producer_name="model_double_cast") model.opset_import[0].version = 20 model.ir_version = 10 - model = onnx_utils.infer_shapes(model) - - value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) - converter = PrecisionConverter(model, value_info_map, initializer_map, node_to_init_map) + model, value_info_map, initializer_map, node_to_init_map = setup_mappings(model) + converter = PrecisionConverter( + model, + value_info_map, + initializer_map, + node_to_init_map, + use_local_type_inference=use_local_type_inference, + ) # Setting all nodes to FP16 means that the final graph should have no cast nodes converted_model = converter.convert( @@ -602,7 +675,8 @@ def test_cast_output_pattern(): assert converted_model.graph.output[i].name == model.graph.output[i].name -def test_cast_output_pattern_mixed_precision(): +@pytest.mark.parametrize("use_local_type_inference", [True, False]) +def test_cast_output_pattern_mixed_precision(use_local_type_inference): x1 = helper.make_tensor_value_info("X1", TensorProto.FLOAT, [3, 4]) x2 = helper.make_tensor_value_info("X2", TensorProto.FLOAT, [3, 4]) y0 = helper.make_tensor_value_info("Y0", TensorProto.FLOAT, [3, 4]) @@ -625,10 +699,14 @@ def test_cast_output_pattern_mixed_precision(): model = helper.make_model(graph, producer_name="model_double_cast") model.opset_import[0].version = 20 model.ir_version = 10 - model = onnx_utils.infer_shapes(model) - - value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) - converter = PrecisionConverter(model, value_info_map, initializer_map, node_to_init_map) + model, value_info_map, initializer_map, node_to_init_map = setup_mappings(model) + converter = PrecisionConverter( + model, + value_info_map, + initializer_map, + node_to_init_map, + use_local_type_inference=use_local_type_inference, + ) # Network output Y0 has two consumers, one is FP16 and the other is FP32 converted_model = converter.convert( @@ -641,7 +719,8 @@ def test_cast_output_pattern_mixed_precision(): @pytest.mark.parametrize("keep_io_types", [True, False]) -def test_chain_of_casts_pattern(keep_io_types): +@pytest.mark.parametrize("use_local_type_inference", [True, False]) +def test_chain_of_casts_pattern(keep_io_types, use_local_type_inference): x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [3, 4]) y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [3, 4]) @@ -690,11 +769,14 @@ def test_chain_of_casts_pattern(keep_io_types): model = helper.make_model(graph, producer_name="model_cast_chain") model.opset_import[0].version = 20 model.ir_version = 10 - model = onnx_utils.infer_shapes(model) - - value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) + model, value_info_map, initializer_map, node_to_init_map = setup_mappings(model) converter = PrecisionConverter( - model, value_info_map, initializer_map, node_to_init_map, keep_io_types=keep_io_types + model, + value_info_map, + initializer_map, + node_to_init_map, + keep_io_types=keep_io_types, + use_local_type_inference=use_local_type_inference, ) converter.convert(high_precision_nodes=["add"], low_precision_nodes=[]) @@ -705,7 +787,8 @@ def test_chain_of_casts_pattern(keep_io_types): @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) -def test_existing_low_precision_output(low_precision_type): +@pytest.mark.parametrize("use_local_type_inference", [True, False]) +def test_existing_low_precision_output(low_precision_type, use_local_type_inference): # Create a simple model with FP16 output x = helper.make_tensor_value_info("X", low_precision_onnx_type(low_precision_type), [3, 4]) y = helper.make_tensor_value_info("Y", low_precision_onnx_type(low_precision_type), [3, 4]) @@ -715,8 +798,7 @@ def test_existing_low_precision_output(low_precision_type): model.opset_import[0].version = 20 model.ir_version = 10 - model = onnx_utils.infer_shapes(model) - value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) + model, value_info_map, initializer_map, node_to_init_map = setup_mappings(model) converter = PrecisionConverter( model, @@ -725,6 +807,7 @@ def test_existing_low_precision_output(low_precision_type): node_to_init_map, keep_io_types=True, low_precision_type=low_precision_type, + use_local_type_inference=use_local_type_inference, ) converter.convert(high_precision_nodes=["add"], low_precision_nodes=[]) @@ -743,7 +826,8 @@ def test_existing_low_precision_output(low_precision_type): @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) -def test_output_cast_output_pattern(low_precision_type): +@pytest.mark.parametrize("use_local_type_inference", [True, False]) +def test_output_cast_output_pattern(low_precision_type, use_local_type_inference): x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [3, 4]) y1 = helper.make_tensor_value_info("Y1", TensorProto.FLOAT, [3, 4]) y2 = helper.make_tensor_value_info("Y2", low_precision_onnx_type(low_precision_type), [3, 4]) @@ -764,9 +848,8 @@ def test_output_cast_output_pattern(low_precision_type): model = helper.make_model(graph, producer_name="model_output_cast_output") model.opset_import[0].version = 20 model.ir_version = 10 - model = onnx_utils.infer_shapes(model) + model, value_info_map, initializer_map, node_to_init_map = setup_mappings(model) - value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) converter = PrecisionConverter( model, value_info_map, @@ -774,6 +857,7 @@ def test_output_cast_output_pattern(low_precision_type): node_to_init_map, keep_io_types=True, low_precision_type=low_precision_type, + use_local_type_inference=use_local_type_inference, ) # Setting nodes precision to match I/O type means that the final graph should have no cast nodes @@ -790,7 +874,8 @@ def test_output_cast_output_pattern(low_precision_type): @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) -def test_cast_output_keep_io_types_pattern(low_precision_type): +@pytest.mark.parametrize("use_local_type_inference", [True, False]) +def test_cast_output_keep_io_types_pattern(low_precision_type, use_local_type_inference): x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [3, 4]) y1 = helper.make_tensor_value_info("Y1", TensorProto.FLOAT, [3, 4]) y2 = helper.make_tensor_value_info("Y2", TensorProto.FLOAT, [3, 4]) @@ -809,9 +894,7 @@ def test_cast_output_keep_io_types_pattern(low_precision_type): model = helper.make_model(graph, producer_name="model_cast_output_keep_io_types") model.opset_import[0].version = 20 model.ir_version = 10 - model = onnx_utils.infer_shapes(model) - - value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) + model, value_info_map, initializer_map, node_to_init_map = setup_mappings(model) converter = PrecisionConverter( model, value_info_map, @@ -819,6 +902,7 @@ def test_cast_output_keep_io_types_pattern(low_precision_type): node_to_init_map, keep_io_types=True, low_precision_type=low_precision_type, + use_local_type_inference=use_local_type_inference, ) converter.convert(high_precision_nodes=[], low_precision_nodes=["add1", "add2"]) @@ -827,7 +911,8 @@ def test_cast_output_keep_io_types_pattern(low_precision_type): assert converter.model.graph.output[1].type.tensor_type.elem_type == TensorProto.FLOAT -def test_unsupported_op_types_model(): +@pytest.mark.parametrize("use_local_type_inference", [True, False]) +def test_unsupported_op_types_model(use_local_type_inference): x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [3, 4]) roi = helper.make_tensor_value_info("roi", TensorProto.FLOAT, [3, 4]) scales = helper.make_tensor_value_info("scales", TensorProto.FLOAT, [4]) @@ -848,17 +933,22 @@ def test_unsupported_op_types_model(): [], ) model = helper.make_model(graph, producer_name="model_celu") - model = onnx.shape_inference.infer_shapes(model) - - value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) - converter = PrecisionConverter(model, value_info_map, initializer_map, node_to_init_map) + model, value_info_map, initializer_map, node_to_init_map = setup_mappings(model) + converter = PrecisionConverter( + model, + value_info_map, + initializer_map, + node_to_init_map, + use_local_type_inference=use_local_type_inference, + ) converter.convert(high_precision_nodes=[], low_precision_nodes=["celu", "resize", "nms"]) onnx.checker.check_model(converter.model) @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) @pytest.mark.parametrize("empty_tensor_target", ["low_precision", "high_precision"]) -def test_empty_tensor_handling(low_precision_type, empty_tensor_target): +@pytest.mark.parametrize("use_local_type_inference", [True, False]) +def test_empty_tensor_handling(low_precision_type, empty_tensor_target, use_local_type_inference): """Test empty tensor handling for both low and high precision node targets.""" # Create model with empty float tensor from Constant layer x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2]) @@ -888,8 +978,7 @@ def test_empty_tensor_handling(low_precision_type, empty_tensor_target): model.ir_version = 10 onnx.checker.check_model(model) - model = onnx_utils.infer_shapes(model) - value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) + model, value_info_map, initializer_map, node_to_init_map = setup_mappings(model) converter = PrecisionConverter( model, value_info_map, @@ -897,6 +986,7 @@ def test_empty_tensor_handling(low_precision_type, empty_tensor_target): node_to_init_map, keep_io_types=True, low_precision_type=low_precision_type, + use_local_type_inference=use_local_type_inference, ) # Test empty tensor detection @@ -979,14 +1069,16 @@ def model_with_constant_cast_patterns(): model.ir_version = 10 onnx.checker.check_model(model) - model = onnx_utils.infer_shapes(model) - value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) + model, value_info_map, initializer_map, node_to_init_map = setup_mappings(model) return model, value_info_map, initializer_map, node_to_init_map @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) -def test_constant_cast_folding(model_with_constant_cast_patterns, low_precision_type): +@pytest.mark.parametrize("use_local_type_inference", [True, False]) +def test_constant_cast_folding( + model_with_constant_cast_patterns, low_precision_type, use_local_type_inference +): """Test constant->cast folding as part of the full conversion process.""" model, value_info_map, initializer_map, node_to_init_map = model_with_constant_cast_patterns @@ -997,6 +1089,7 @@ def test_constant_cast_folding(model_with_constant_cast_patterns, low_precision_ node_to_init_map, keep_io_types=True, low_precision_type=low_precision_type, + use_local_type_inference=use_local_type_inference, ) # Convert with some nodes in low precision to trigger cast insertion @@ -1077,15 +1170,15 @@ def model_with_multiple_output_node_casted_to_output(): model.ir_version = 10 onnx.checker.check_model(model) - model = onnx_utils.infer_shapes(model) - value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) + model, value_info_map, initializer_map, node_to_init_map = setup_mappings(model) return model, value_info_map, initializer_map, node_to_init_map @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) +@pytest.mark.parametrize("use_local_type_inference", [True, False]) def test_multiple_output_node_casted_to_output( - model_with_multiple_output_node_casted_to_output, low_precision_type + model_with_multiple_output_node_casted_to_output, low_precision_type, use_local_type_inference ): model, value_info_map, initializer_map, node_to_init_map = ( model_with_multiple_output_node_casted_to_output @@ -1098,6 +1191,7 @@ def test_multiple_output_node_casted_to_output( node_to_init_map, keep_io_types=True, low_precision_type=low_precision_type, + use_local_type_inference=use_local_type_inference, ) converted_model = converter.convert( high_precision_nodes=[], low_precision_nodes=["concat_1", "concat_2"] @@ -1145,16 +1239,16 @@ def model_with_casted_input_to_output(): model.ir_version = 10 onnx.checker.check_model(model) - model = onnx_utils.infer_shapes(model) - value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) + model, value_info_map, initializer_map, node_to_init_map = setup_mappings(model) return model, value_info_map, initializer_map, node_to_init_map @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) @pytest.mark.parametrize("keep_io_types", [True, False]) +@pytest.mark.parametrize("use_local_type_inference", [True, False]) def test_casted_input_to_output_model( - model_with_casted_input_to_output, low_precision_type, keep_io_types + model_with_casted_input_to_output, low_precision_type, keep_io_types, use_local_type_inference ): model, value_info_map, initializer_map, node_to_init_map = model_with_casted_input_to_output @@ -1168,6 +1262,7 @@ def test_casted_input_to_output_model( min_opset=22 if low_precision_type == "bf16" else 13, max_ir_version=LATEST_IR_VERSION_SUPPORTED_BY_ORT, trt_plugins=[], + use_local_type_inference=use_local_type_inference, ) converted_model = converter.convert( high_precision_nodes=["cast_input"], low_precision_nodes=["add1", "add2"] @@ -1218,8 +1313,7 @@ def create_model_with_resize_op(): model.ir_version = 10 onnx.checker.check_model(model) - model = onnx_utils.infer_shapes(model) - value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) + model, value_info_map, initializer_map, node_to_init_map = setup_mappings(model) return model, value_info_map, initializer_map, node_to_init_map @@ -1276,16 +1370,16 @@ def create_model_with_resize_op_tensor_scales(): model.ir_version = 10 onnx.checker.check_model(model) - model = onnx_utils.infer_shapes(model) - value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) + model, value_info_map, initializer_map, node_to_init_map = setup_mappings(model) return model, value_info_map, initializer_map, node_to_init_map @pytest.mark.parametrize("keep_io_types", [True, False]) @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) +@pytest.mark.parametrize("use_local_type_inference", [True, False]) def test_resize_op_initializer_conversion( - create_model_with_resize_op, keep_io_types, low_precision_type + create_model_with_resize_op, keep_io_types, low_precision_type, use_local_type_inference ): model, value_info_map, initializer_map, node_to_init_map = create_model_with_resize_op @@ -1296,6 +1390,7 @@ def test_resize_op_initializer_conversion( node_to_init_map, keep_io_types=keep_io_types, low_precision_type=low_precision_type, + use_local_type_inference=use_local_type_inference, ) converted_model = converter.convert( high_precision_nodes=[], low_precision_nodes=[node.name for node in model.graph.node] @@ -1305,8 +1400,12 @@ def test_resize_op_initializer_conversion( @pytest.mark.parametrize("keep_io_types", [True, False]) @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) +@pytest.mark.parametrize("use_local_type_inference", [True, False]) def test_resize_op_tensor_scales_conversion( - create_model_with_resize_op_tensor_scales, keep_io_types, low_precision_type + create_model_with_resize_op_tensor_scales, + keep_io_types, + low_precision_type, + use_local_type_inference, ): model, value_info_map, initializer_map, node_to_init_map = ( create_model_with_resize_op_tensor_scales @@ -1319,6 +1418,7 @@ def test_resize_op_tensor_scales_conversion( node_to_init_map, keep_io_types=keep_io_types, low_precision_type=low_precision_type, + use_local_type_inference=use_local_type_inference, ) converted_model = converter.convert( high_precision_nodes=[], low_precision_nodes=[node.name for node in model.graph.node] @@ -1409,15 +1509,15 @@ def model_with_if_subgraph(): model.ir_version = 10 onnx.checker.check_model(model) - model = onnx_utils.infer_shapes(model) - value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) + model, value_info_map, initializer_map, node_to_init_map = setup_mappings(model) return model, value_info_map, initializer_map, node_to_init_map @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) @pytest.mark.parametrize("if_precision", ["low", "high"]) +@pytest.mark.parametrize("use_local_type_inference", [True, False]) def test_if_subgraph_initializer_conversion( - model_with_if_subgraph, low_precision_type, if_precision + model_with_if_subgraph, low_precision_type, if_precision, use_local_type_inference ): """Test that initializers in If subgraphs are converted based on parent node precision.""" model, value_info_map, initializer_map, node_to_init_map = model_with_if_subgraph @@ -1429,6 +1529,7 @@ def test_if_subgraph_initializer_conversion( node_to_init_map, keep_io_types=True, low_precision_type=low_precision_type, + use_local_type_inference=use_local_type_inference, ) # Classify the If node based on test parameter @@ -1482,7 +1583,10 @@ def test_if_subgraph_initializer_conversion( @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) -def test_if_subgraph_mixed_precision_boundary(model_with_if_subgraph, low_precision_type): +@pytest.mark.parametrize("use_local_type_inference", [True, False]) +def test_if_subgraph_mixed_precision_boundary( + model_with_if_subgraph, low_precision_type, use_local_type_inference +): """Test that types are correctly handled at If subgraph boundaries in mixed precision.""" model, value_info_map, initializer_map, node_to_init_map = model_with_if_subgraph @@ -1498,7 +1602,7 @@ def test_if_subgraph_mixed_precision_boundary(model_with_if_subgraph, low_precis model.graph.output.append(output_tensor) # Refresh mappings - value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) + model, value_info_map, initializer_map, node_to_init_map = setup_mappings(model) converter = PrecisionConverter( model, @@ -1507,6 +1611,7 @@ def test_if_subgraph_mixed_precision_boundary(model_with_if_subgraph, low_precis node_to_init_map, keep_io_types=True, low_precision_type=low_precision_type, + use_local_type_inference=use_local_type_inference, ) # If in low precision, Add in high precision From ea9f7e86a18a279759430b25c63bf436778fbc45 Mon Sep 17 00:00:00 2001 From: Gal Hubara Agam <96368689+galagam@users.noreply.github.com> Date: Mon, 22 Dec 2025 17:22:31 +0200 Subject: [PATCH 02/11] revert precisionconverter type infer logic change created by cursor and committed by mistake Signed-off-by: Gal Hubara Agam <96368689+galagam@users.noreply.github.com> --- modelopt/onnx/autocast/precisionconverter.py | 39 ++------------------ 1 file changed, 3 insertions(+), 36 deletions(-) diff --git a/modelopt/onnx/autocast/precisionconverter.py b/modelopt/onnx/autocast/precisionconverter.py index f351c4fdb..72ea28cd0 100644 --- a/modelopt/onnx/autocast/precisionconverter.py +++ b/modelopt/onnx/autocast/precisionconverter.py @@ -115,6 +115,7 @@ def __init__( max_ir_version: Max IR version for conversion. trt_plugins: List of custom TensorRT plugin library paths in .so format (compiled shared library). tensor_block_dict: Dictionary of tensors (operation type and I/O indices) that should remain in FP32. + use_local_type_inference: Use local type inference instead of ONNX's infer_shapes. """ self.model = deepcopy(model) self.value_info_map = value_info_map @@ -267,39 +268,13 @@ def convert( # Populate type information with inferred types self.model = self._propagate_types_shapes_custom_ops(self.model) else: - # Preserve original output types (before clearing) - # Store in instance variable so we can restore after cleanup - # Always preserve non-float types, and preserve all types if keep_io_types is True - self._original_output_types = {} - for out in self.model.graph.output: - if out.type.HasField("tensor_type"): - original_type = out.type.tensor_type.elem_type - # Always preserve non-float types (INT64, INT32, etc.) - # Also preserve all types if keep_io_types is True - if original_type not in ONNX_TYPES or self.keep_io_types: - self._original_output_types[out.name] = original_type - # Clear type/shape information for intermediates and outputs (including subgraphs) self._clear_types_and_shapes_recursive(self.model.graph) - # Populate type information with inferred types self.model = self.infer_types(strict_mode=True, check_type=False) self._ensure_types_are_defined() - - # Restore original output types (non-float types and all types if keep_io_types is True) - if hasattr(self, "_original_output_types"): - for out in self.model.graph.output: - if out.name in self._original_output_types: - out.type.tensor_type.elem_type = self._original_output_types[out.name] - - # Sanity check: Verify type correctness (only when using ONNX's infer_shapes) - if not self.use_local_type_inference: - self.model = self.infer_types(strict_mode=True, check_type=True) - # Restore output types again after second inference - if hasattr(self, "_original_output_types"): - for out in self.model.graph.output: - if out.name in self._original_output_types: - out.type.tensor_type.elem_type = self._original_output_types[out.name] + # Sanity check: Verify type correctness + self.model = self.infer_types(strict_mode=True, check_type=True) # Update value_info_map and initializer_map with casts we added self.value_info_map, self.initializer_map, self.node_to_init_map = utils.setup_mappings( @@ -309,14 +284,6 @@ def convert( # Remove redundant casts self._cleanup() - # Restore original output types after cleanup - # (cleanup may have modified outputs, so we need to restore types again) - # Always restore non-float types, and all types if keep_io_types is True - if hasattr(self, "_original_output_types"): - for out in self.model.graph.output: - if out.name in self._original_output_types: - out.type.tensor_type.elem_type = self._original_output_types[out.name] - self._sanity_check() return self.model From 80d66f850af6c8f9d5fbf3ad3f9162fb4e859a6d Mon Sep 17 00:00:00 2001 From: Gal Hubara Agam <96368689+galagam@users.noreply.github.com> Date: Mon, 22 Dec 2025 17:43:45 +0200 Subject: [PATCH 03/11] fix logic for subgraphs Signed-off-by: Gal Hubara Agam <96368689+galagam@users.noreply.github.com> --- modelopt/onnx/utils.py | 148 ++++++++++++++++++++++++++++------------- 1 file changed, 103 insertions(+), 45 deletions(-) diff --git a/modelopt/onnx/utils.py b/modelopt/onnx/utils.py index ef08bd5c0..fec0545de 100644 --- a/modelopt/onnx/utils.py +++ b/modelopt/onnx/utils.py @@ -970,53 +970,111 @@ def str_to_tensor_dtype(dtype_str: str) -> onnx.TensorProto.DataType: # Split schema allows multiple outputs, but the schema only specifies one output type output_types = [input_types[0]] * len(node.output) else: - default_type = input_types[0] if input_types else onnx.TensorProto.FLOAT - # Use ONNX operator schema to determine output types - try: - schema = onnx.defs.get_schema(node.op_type, opset, domain=node.domain or "") - assert schema.outputs and len(schema.outputs) >= len(node.output) - except Exception as e: - # Fallback: if schema lookup fails, propagate first input type - logger.debug( - f"Node {node.name}: Failed to get schema for {node.op_type}: {e}, propagate first input type" - ) + # Check if this node has subgraphs (GRAPH or GRAPHS attributes) + # Common nodes with subgraphs: If, Loop, Scan + subgraphs = [] + for attr in node.attribute: + if attr.type == onnx.AttributeProto.GRAPH: + subgraphs.append(attr.g) + elif attr.type == onnx.AttributeProto.GRAPHS: + subgraphs.extend(attr.graphs) + + # If node has subgraphs, infer types for them first + if subgraphs: + for subgraph in subgraphs: + infer_types_for_graph(subgraph, parent_node=node, is_subgraph=True) + + # For nodes with subgraphs, try to infer output types from subgraph outputs + # This avoids incorrectly matching to control inputs (e.g., condition for If, trip_count for Loop) + output_types = [] + if len(node.output) > 0: + # Use the first subgraph as reference (works for If, Loop, Scan) + first_subgraph = subgraphs[0] + for out_idx, out_name in enumerate(node.output): + if out_idx < len(first_subgraph.output): + subgraph_out = first_subgraph.output[out_idx] + # Typically we only have one subgraph, but If nodes have two subgraphs + # (then_branch and else_branch). In any case, the output types of the + # subgraphs must be identical, so we check just the first one + if ( + subgraph_out.type.HasField("tensor_type") + and subgraph_out.type.tensor_type.elem_type + != onnx.TensorProto.UNDEFINED + ): + output_types.append(subgraph_out.type.tensor_type.elem_type) + else: + output_types.append(onnx.TensorProto.FLOAT) + else: + # Fallback if we can't infer from subgraphs + output_types = None + + # If we couldn't infer from subgraphs, fall through to schema-based inference + if output_types is None or len(output_types) != len(node.output): + output_types = None + else: + # No subgraphs, proceed with normal inference + output_types = None + + # If output_types not set yet, use schema-based inference + if output_types is None: default_type = input_types[0] if input_types else onnx.TensorProto.FLOAT - output_types = [default_type] * len(node.output) - - # Try to infer from schema - input_schemas = [schema.inputs[i].type_str for i in range(len(schema.inputs))] - output_schemas = [schema.outputs[i].type_str for i in range(len(schema.outputs))] - output_types = [None] * len(node.output) - - for output_idx in range(len(node.output)): - # explicit type is set in schema, use it - if "tensor" in output_schemas[output_idx]: - found_type = str_to_tensor_dtype(output_schemas[output_idx]) - output_types[output_idx] = found_type - continue - # sometimes output type is set with a placeholder name despite supporting a single type - # e.g. Shape operator is constrained to int64, but the type_str is "T1" - for constraint in schema.type_constraints: - # If output type constraint has only one allowed type, use it directly - if constraint.type_param_str == output_schemas[output_idx]: - if len(constraint.allowed_type_strs) == 1: - found_type = str_to_tensor_dtype(constraint.allowed_type_strs[0]) + # Use ONNX operator schema to determine output types + try: + schema = onnx.defs.get_schema(node.op_type, opset, domain=node.domain or "") + assert schema.outputs and len(schema.outputs) >= len(node.output) + except Exception as e: + # Fallback: if schema lookup fails, propagate first input type + logger.debug( + f"Node {node.name}: Failed to get schema for {node.op_type}: {e}, " + "propagate first input type" + ) + default_type = input_types[0] if input_types else onnx.TensorProto.FLOAT + output_types = [default_type] * len(node.output) + else: + # Try to infer from schema + input_schemas = [ + schema.inputs[i].type_str for i in range(len(schema.inputs)) + ] + output_schemas = [ + schema.outputs[i].type_str for i in range(len(schema.outputs)) + ] + output_types = [None] * len(node.output) + + for output_idx in range(len(node.output)): + # explicit type is set in schema, use it + if "tensor" in output_schemas[output_idx]: + found_type = str_to_tensor_dtype(output_schemas[output_idx]) + output_types[output_idx] = found_type + continue + # sometimes output type is set with a placeholder name despite supporting a single type + # e.g. Shape operator is constrained to int64, but the type_str is "T1" + for constraint in schema.type_constraints: + # If output type constraint has only one allowed type, use it directly + if constraint.type_param_str == output_schemas[output_idx]: + if len(constraint.allowed_type_strs) == 1: + found_type = str_to_tensor_dtype( + constraint.allowed_type_strs[0] + ) + output_types[output_idx] = found_type + break + else: + # We have a placeholder name "T", "T1", "T2", etc that should + # match one of the input types + try: + input_match_idx = input_schemas.index( + output_schemas[output_idx] + ) + except ValueError: + input_match_idx = None + if input_match_idx is not None: + found_type = input_types[input_match_idx] + else: + found_type = default_type + logger.debug( + f"Node {node.name}: Failed to infer type for output " + f"#{output_idx}, propagate first input type" + ) output_types[output_idx] = found_type - break - else: # we have a placeholder name "T", "T1", "T2", etc that should match one of the input types - try: - input_match_idx = input_schemas.index(output_schemas[output_idx]) - except ValueError: - input_match_idx = None - if input_match_idx is not None: - found_type = input_types[input_match_idx] - else: - found_type = default_type - logger.debug( - f"Node {node.name}: Failed to infer type for output #{output_idx}, propagate first " - "input type" - ) - output_types[output_idx] = found_type # Update output tensor types for out_idx, out_name in enumerate(node.output): From 18d04fe4851aa63eabcc9fe219e711f4b1e50ede Mon Sep 17 00:00:00 2001 From: Gal Hubara Agam <96368689+galagam@users.noreply.github.com> Date: Wed, 24 Dec 2025 17:15:17 +0200 Subject: [PATCH 04/11] fixes Signed-off-by: Gal Hubara Agam <96368689+galagam@users.noreply.github.com> --- modelopt/onnx/utils.py | 58 +++++++++++++++--------------------------- 1 file changed, 21 insertions(+), 37 deletions(-) diff --git a/modelopt/onnx/utils.py b/modelopt/onnx/utils.py index fec0545de..d7597e46b 100644 --- a/modelopt/onnx/utils.py +++ b/modelopt/onnx/utils.py @@ -867,14 +867,9 @@ def str_to_tensor_dtype(dtype_str: str) -> onnx.TensorProto.DataType: if not inp_name: continue inp_type = get_tensor_type(inp_name) - if inp_type is not None: - input_types.append(inp_type) - else: - # In topologically sorted order, this shouldn't happen unless - # the input is from an initializer/input we missed or there's a cycle - logger.debug(f"Warning: Input {inp_name} of node {node.name} has unknown type") - # Use FLOAT as fallback - input_types.append(onnx.TensorProto.FLOAT) + if inp_type is None: + raise ValueError(f"Input {inp_name} of node {node.name} has unknown type") + input_types.append(inp_type) # Infer output types for this node output_types = [] @@ -886,15 +881,9 @@ def str_to_tensor_dtype(dtype_str: str) -> onnx.TensorProto.DataType: if attr.name == "to": cast_to_type = attr.i break - if cast_to_type is not None: - output_types = [cast_to_type] * len(node.output) - else: - # Fallback: use input type if cast target unknown - output_types = ( - input_types[: len(node.output)] - if input_types - else [onnx.TensorProto.FLOAT] * len(node.output) - ) + if cast_to_type is None: + raise ValueError(f"Cast node {node.name} has unknown target type") + output_types = [cast_to_type] elif node.op_type == "DequantizeLinear": # DequantizeLinear: output type is determined by output_dtype attribute if present, # otherwise use the scale type (input[1]) @@ -906,20 +895,17 @@ def str_to_tensor_dtype(dtype_str: str) -> onnx.TensorProto.DataType: break if output_dtype is not None: - output_types = [output_dtype] * len(node.output) + output_types = [output_dtype] elif len(node.input) >= 2 and node.input[1]: scale_type = get_tensor_type(node.input[1]) if scale_type is not None: - output_types = [scale_type] * len(node.output) + output_types = [scale_type] else: # Fallback: use first input type or FLOAT - output_types = [ - input_types[0] if input_types else onnx.TensorProto.FLOAT - ] * len(node.output) + output_types = [input_types[0] if input_types else onnx.TensorProto.FLOAT] else: - output_types = [ - input_types[0] if input_types else onnx.TensorProto.FLOAT - ] * len(node.output) + # Fallback: use first input type or FLOAT + output_types = [input_types[0] if input_types else onnx.TensorProto.FLOAT] elif node.op_type == "QuantizeLinear": # QuantizeLinear: output type is determined by output_dtype attribute if present, # otherwise use the zero_point type (input[2]) @@ -935,13 +921,13 @@ def str_to_tensor_dtype(dtype_str: str) -> onnx.TensorProto.DataType: elif len(node.input) >= 3 and node.input[2]: zero_point_type = get_tensor_type(node.input[2]) if zero_point_type is not None: - output_types = [zero_point_type] * len(node.output) + output_types = [zero_point_type] else: - # Fallback: typically UINT8 or INT8 for quantized types - output_types = [onnx.TensorProto.UINT8] * len(node.output) + # Fallback: use INT8 as fallback, since TRT doesn't support UINT8 + output_types = [onnx.TensorProto.INT8] else: - # Fallback: typically UINT8 or INT8 for quantized types - output_types = [onnx.TensorProto.UINT8] * len(node.output) + # Fallback: use INT8 as fallback, since TRT doesn't support UINT8 + output_types = [onnx.TensorProto.INT8] elif node.op_type == "Constant": # Constant: output type is from the value attribute's tensor data_type const_type = None @@ -950,22 +936,20 @@ def str_to_tensor_dtype(dtype_str: str) -> onnx.TensorProto.DataType: if attr.t.HasField("data_type"): const_type = attr.t.data_type break - if const_type is not None: - output_types = [const_type] * len(node.output) - else: - # Fallback: use FLOAT if type cannot be determined - output_types = [onnx.TensorProto.FLOAT] * len(node.output) + assert const_type is not None + output_types = [const_type] elif node.op_type == "ConstantOfShape": # ConstantOfShape: output type is from the value attribute's tensor data_type # If no value attribute, defaults to FLOAT # Note: Schema allows multiple types, so we need to check the value attribute - const_type = onnx.TensorProto.FLOAT # default + const_type = None for attr in node.attribute: if attr.name == "value" and attr.type == onnx.AttributeProto.TENSOR: if attr.t.HasField("data_type"): const_type = attr.t.data_type break - output_types = [const_type] * len(node.output) + assert const_type is not None + output_types = [const_type] elif node.op_type == "Split": # Split schema allows multiple outputs, but the schema only specifies one output type output_types = [input_types[0]] * len(node.output) From 2ff3a6d09e2ddce18e29a8df958a89f76f3e17fa Mon Sep 17 00:00:00 2001 From: Gal Hubara Agam <96368689+galagam@users.noreply.github.com> Date: Thu, 25 Dec 2025 09:38:05 +0200 Subject: [PATCH 05/11] do not clear shapes when inferring only types Signed-off-by: Gal Hubara Agam <96368689+galagam@users.noreply.github.com> --- modelopt/onnx/autocast/precisionconverter.py | 27 +++++++++++++------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/modelopt/onnx/autocast/precisionconverter.py b/modelopt/onnx/autocast/precisionconverter.py index 72ea28cd0..732ec2c9a 100644 --- a/modelopt/onnx/autocast/precisionconverter.py +++ b/modelopt/onnx/autocast/precisionconverter.py @@ -299,9 +299,9 @@ def _clear_types_and_shapes_recursive( ) -> None: """Recursively clear type/shape information for a graph and all its subgraphs. - This is necessary for control flow operators (Scan, If, Loop) which have subgraphs. - For subgraphs, preserve value_info for outer scope variables (not produced by nodes in subgraph). - For main graph, clear all value_info. + If use_local_type_inference is True, we clear only types, not shapes. + For subgraphs, input types/shapes are cleared, so that the input types/shapes are propagated + from the main graph. Args: graph: The ONNX graph to clear types and shapes for. @@ -318,15 +318,23 @@ def _clear_callback(g: onnx.GraphProto, parent: onnx.NodeProto, is_sub: bool) -> for inp in g.input: if inp.type.HasField("tensor_type"): inp.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED - for idx, d in enumerate(inp.type.tensor_type.shape.dim): - if d.dim_value: - inp.type.tensor_type.shape.dim[idx].dim_param = "unk" + if not self.use_local_type_inference: + for idx, d in enumerate(inp.type.tensor_type.shape.dim): + if d.dim_value: + inp.type.tensor_type.shape.dim[idx].dim_param = "unk" if is_sub: # Identify which tensors are produced by nodes in this subgraph subgraph_outputs = set() for node in g.node: subgraph_outputs.update(node.output) + # Clear type/shape information for intermediates and outputs + for vi in g.value_info: + vi.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED + if not self.use_local_type_inference: + for idx, d in enumerate(vi.type.tensor_type.shape.dim): + if d.dim_value: + vi.type.tensor_type.shape.dim[idx].dim_param = "unk" # Clear value_info only for intermediates produced by nodes in this subgraph for vi in g.value_info: @@ -345,9 +353,10 @@ def _clear_callback(g: onnx.GraphProto, parent: onnx.NodeProto, is_sub: bool) -> # Clear outputs for both main graph and subgraphs for out in g.output: out.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED - for idx, d in enumerate(out.type.tensor_type.shape.dim): - if d.dim_value: - out.type.tensor_type.shape.dim[idx].dim_param = "unk" + if not self.use_local_type_inference: + for idx, d in enumerate(out.type.tensor_type.shape.dim): + if d.dim_value: + out.type.tensor_type.shape.dim[idx].dim_param = "unk" utils.walk_subgraphs_recursive(graph, _clear_callback, is_subgraph=is_subgraph) From bfc762a518d00dbd51adfa7529555c0c32514d74 Mon Sep 17 00:00:00 2001 From: Gal Hubara Agam <96368689+galagam@users.noreply.github.com> Date: Thu, 25 Dec 2025 09:50:23 +0200 Subject: [PATCH 06/11] rename use_local_type_inference to use_standalone_type_inference and add to changelog Signed-off-by: Gal Hubara Agam <96368689+galagam@users.noreply.github.com> --- CHANGELOG.rst | 1 + docs/source/guides/8_autocast.rst | 14 +- modelopt/onnx/autocast/__main__.py | 4 +- modelopt/onnx/autocast/convert.py | 16 +- modelopt/onnx/autocast/precisionconverter.py | 20 +- .../onnx/autocast/test_precisionconverter.py | 179 ++++++++++-------- 6 files changed, 123 insertions(+), 111 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 978ac209d..75840754d 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -18,6 +18,7 @@ NVIDIA Model Optimizer Changelog (Linux) - Add support for parallel draft heads in Eagle speculative decoding. - Add support to enable custom emulated quantization backend. See :meth:`register_quant_backend `` for more details. See an example in ``tests/unit/torch/quantization/test_custom_backend.py``. - Add ``examples/llm_qad`` for QAD training with Megatron-LM. +- Add standalone type inference option (``--use_standalone_type_inference``) in ONNX AutoCast as an alternative to ONNX's ``infer_shapes``. This experimental feature performs type-only inference without shape inference, useful as a workaround when shape inference fails or to avoid unnecessary shape inference overhead. **Deprecations** diff --git a/docs/source/guides/8_autocast.rst b/docs/source/guides/8_autocast.rst index 75f94b57d..0701f2f1f 100644 --- a/docs/source/guides/8_autocast.rst +++ b/docs/source/guides/8_autocast.rst @@ -42,7 +42,7 @@ AutoCast can also be used programmatically through its Python API: trt_plugins=[], # list of TensorRT plugin library paths in .so format max_depth_of_reduction=None, # maximum depth of reduction allowed in low precision opset=None, # optional target ONNX opset version (default: 13 for fp16, 22 for bf16) - use_local_type_inference=False, # use local type inference instead of ONNX's infer_shapes (WAR) + use_standalone_type_inference=False, # use standalone type inference instead of ONNX's infer_shapes (WAR) ) # Save the converted model @@ -85,7 +85,7 @@ AutoCast follows these steps to convert a model: - Automatically replaces initializers with lower precision values - Performs type inference to propagate types through the graph - By default, uses ONNX's ``infer_shapes`` which performs both shape and type inference using the ONNX infer_shapes API. - - Use ``use_local_type_inference=True`` to use a local type-only inference implementation (experimental). + - Use ``use_standalone_type_inference=True`` to use a standalone type-only inference implementation (experimental). #. **Validation and Export**: @@ -152,10 +152,10 @@ Best Practices #. **Type Inference Control** - By default, AutoCast uses ONNX's ``infer_shapes`` which performs both shape and type inference. - - Use ``--use_local_type_inference`` to enable a local type-only inference implementation. + - Use ``--use_standalone_type_inference`` to enable a standalone type-only inference implementation. - This is a workaround for cases where shape inference fails for any reason, which allows us to bypass the dependency in ONNX's shape inference logic. - - The local implementation uses graphsurgeon for topological sorting and handles special operators like Cast, QuantizeLinear, DequantizeLinear, Constant and ConstantOfShape. - - Note: The local type inference may be less robust than ONNX's implementation for edge cases, but avoids unnecessary shape inference overhead and possible failures. + - The standalone implementation uses graphsurgeon for topological sorting and handles special operators like Cast, QuantizeLinear, DequantizeLinear, Constant and ConstantOfShape. + - Note: The standalone type inference may be less robust than ONNX's implementation for edge cases, but avoids unnecessary shape inference overhead and possible failures. Limitations and Restrictions ---------------------------- @@ -211,8 +211,8 @@ Convert to BF16 with a specific opset: python -m modelopt.onnx.autocast --onnx_path model.onnx --low_precision_type bf16 --opset 22 -Use local type inference instead of ONNX's infer_shapes: +Use standalone type inference instead of ONNX's infer_shapes: .. code-block:: bash - python -m modelopt.onnx.autocast --onnx_path model.onnx --use_local_type_inference + python -m modelopt.onnx.autocast --onnx_path model.onnx --use_standalone_type_inference diff --git a/modelopt/onnx/autocast/__main__.py b/modelopt/onnx/autocast/__main__.py index d2dd74566..da521d524 100644 --- a/modelopt/onnx/autocast/__main__.py +++ b/modelopt/onnx/autocast/__main__.py @@ -186,7 +186,7 @@ def get_parser() -> argparse.ArgumentParser: ), ) parser.add_argument( - "--use_local_type_inference", + "--use_standalone_type_inference", action="store_true", default=False, help=( @@ -228,7 +228,7 @@ def main(argv=None): trt_plugins_precision=args.trt_plugins_precision, max_depth_of_reduction=args.max_depth_of_reduction, opset=args.opset, - use_local_type_inference=args.use_local_type_inference, + use_standalone_type_inference=args.use_standalone_type_inference, ) output_path = args.output_path diff --git a/modelopt/onnx/autocast/convert.py b/modelopt/onnx/autocast/convert.py index 0e8775341..50f2a4348 100644 --- a/modelopt/onnx/autocast/convert.py +++ b/modelopt/onnx/autocast/convert.py @@ -61,7 +61,7 @@ def convert_to_mixed_precision( trt_plugins_precision: list[str] = [], max_depth_of_reduction: int | None = None, opset: int | None = None, - use_local_type_inference: bool = False, + use_standalone_type_inference: bool = False, ) -> onnx.ModelProto: """Convert model to mixed precision. @@ -86,7 +86,7 @@ def convert_to_mixed_precision( opset: Target ONNX opset version. If None, uses default minimum opset based on low_precision_type (22 for bf16, 13 for fp16). The opset may be automatically increased if certain operations require a higher version. - use_local_type_inference: If True, use local type inference implementation instead of ONNX's + use_standalone_type_inference: If True, use standalone type inference implementation instead of ONNX's infer_shapes. This is a workaround (WAR) when only type inference is needed without shape inference. Default: False. @@ -136,7 +136,7 @@ def convert_to_mixed_precision( model = graph_sanitizer.model # Setup internal mappings - if use_local_type_inference: + if use_standalone_type_inference: model = onnx_utils.infer_types(model) else: model = onnx_utils.infer_shapes(model) @@ -171,7 +171,7 @@ def convert_to_mixed_precision( low_precision_type=low_precision_type, init_conversion_max_bytes=init_conversion_max_bytes, custom_ops=graph_sanitizer.custom_ops, - use_local_type_inference=use_local_type_inference, + use_standalone_type_inference=use_standalone_type_inference, ) # Obtain reference data @@ -204,7 +204,7 @@ def convert_to_f16( op_block_list: list[str] = [], tensor_block_dict: dict[str, dict[str, list[int]]] = {}, trt_plugins: list[str] | None = [], - use_local_type_inference: bool = False, + use_standalone_type_inference: bool = False, ) -> onnx.ModelProto: """Convert model to mixed precision, using PrecisionConverter. @@ -217,7 +217,7 @@ def convert_to_f16( op_block_list: List of operation types that should remain in FP32. tensor_block_dict: Dictionary of tensors (operation type and I/O indices) that should remain in FP32. trt_plugins: List of TensorRT plugin library paths in .so format (compiled shared library). - use_local_type_inference: If True, use local type inference implementation instead of ONNX's + use_standalone_type_inference: If True, use standalone type inference implementation instead of ONNX's infer_shapes. This is a workaround (WAR) when only type inference is needed without shape inference. Default: False. """ @@ -237,7 +237,7 @@ def convert_to_f16( model = sanitizer.model # Setup internal mappings - if use_local_type_inference: + if use_standalone_type_inference: model = onnx_utils.infer_types(model) else: model = onnx_utils.infer_shapes(model) @@ -252,7 +252,7 @@ def convert_to_f16( low_precision_type=low_precision_type, custom_ops=sanitizer.custom_ops, tensor_block_dict=tensor_block_dict, - use_local_type_inference=use_local_type_inference, + use_standalone_type_inference=use_standalone_type_inference, ) high_precision_nodes = [node.name for node in model.graph.node if node.op_type in op_block_list] low_precision_nodes = [ diff --git a/modelopt/onnx/autocast/precisionconverter.py b/modelopt/onnx/autocast/precisionconverter.py index 732ec2c9a..0c6d97553 100644 --- a/modelopt/onnx/autocast/precisionconverter.py +++ b/modelopt/onnx/autocast/precisionconverter.py @@ -97,7 +97,7 @@ def __init__( max_ir_version: int | None = None, trt_plugins: list[str] | None = [], tensor_block_dict: dict[str, dict[str, list[int]]] = {}, - use_local_type_inference: bool = False, + use_standalone_type_inference: bool = False, ) -> None: """Initialize PrecisionConverter. @@ -115,7 +115,7 @@ def __init__( max_ir_version: Max IR version for conversion. trt_plugins: List of custom TensorRT plugin library paths in .so format (compiled shared library). tensor_block_dict: Dictionary of tensors (operation type and I/O indices) that should remain in FP32. - use_local_type_inference: Use local type inference instead of ONNX's infer_shapes. + use_standalone_type_inference: Use standalone type inference instead of ONNX's infer_shapes. """ self.model = deepcopy(model) self.value_info_map = value_info_map @@ -142,7 +142,7 @@ def __init__( self.min_opset = min_opset self.max_ir_version = max_ir_version self.trt_plugins = trt_plugins - self.use_local_type_inference = use_local_type_inference + self.use_standalone_type_inference = use_standalone_type_inference # Detect additional ops not supported in low precision according to the model's opset version self.op_types_not_supported_in_low_precision = OP_TYPES_NOT_SUPPORTED_IN_LOW_PRECISION + ( @@ -161,7 +161,7 @@ def __init__( self._warned_values_clamp_min = False def infer_types(self, **kwargs): - """Infers types (and optionally shapes) based on the use_local_type_inference flag. + """Infers types (and optionally shapes) based on the use_standalone_type_inference flag. Args: **kwargs: Additional arguments passed to infer_shapes when not using local type inference. @@ -169,7 +169,7 @@ def infer_types(self, **kwargs): Returns: onnx.ModelProto: Model with inferred types (and shapes if not using local type inference). """ - if self.use_local_type_inference: + if self.use_standalone_type_inference: return onnx_utils.infer_types(self.model) else: return onnx_utils.infer_shapes(self.model, **kwargs) @@ -299,7 +299,7 @@ def _clear_types_and_shapes_recursive( ) -> None: """Recursively clear type/shape information for a graph and all its subgraphs. - If use_local_type_inference is True, we clear only types, not shapes. + If use_standalone_type_inference is True, we clear only types, not shapes. For subgraphs, input types/shapes are cleared, so that the input types/shapes are propagated from the main graph. @@ -318,7 +318,7 @@ def _clear_callback(g: onnx.GraphProto, parent: onnx.NodeProto, is_sub: bool) -> for inp in g.input: if inp.type.HasField("tensor_type"): inp.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED - if not self.use_local_type_inference: + if not self.use_standalone_type_inference: for idx, d in enumerate(inp.type.tensor_type.shape.dim): if d.dim_value: inp.type.tensor_type.shape.dim[idx].dim_param = "unk" @@ -331,7 +331,7 @@ def _clear_callback(g: onnx.GraphProto, parent: onnx.NodeProto, is_sub: bool) -> # Clear type/shape information for intermediates and outputs for vi in g.value_info: vi.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED - if not self.use_local_type_inference: + if not self.use_standalone_type_inference: for idx, d in enumerate(vi.type.tensor_type.shape.dim): if d.dim_value: vi.type.tensor_type.shape.dim[idx].dim_param = "unk" @@ -353,7 +353,7 @@ def _clear_callback(g: onnx.GraphProto, parent: onnx.NodeProto, is_sub: bool) -> # Clear outputs for both main graph and subgraphs for out in g.output: out.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED - if not self.use_local_type_inference: + if not self.use_standalone_type_inference: for idx, d in enumerate(out.type.tensor_type.shape.dim): if d.dim_value: out.type.tensor_type.shape.dim[idx].dim_param = "unk" @@ -1202,7 +1202,7 @@ def _remove_redundant_casts(self): self.model = self._propagate_types_shapes_custom_ops(self.model) else: self.model = self.infer_types(strict_mode=True) - if not self.use_local_type_inference: + if not self.use_standalone_type_inference: self.model = self.infer_types(strict_mode=True, check_type=True) nodes_to_remove = [] diff --git a/tests/unit/onnx/autocast/test_precisionconverter.py b/tests/unit/onnx/autocast/test_precisionconverter.py index c6f9d5627..bf6fcd4f7 100644 --- a/tests/unit/onnx/autocast/test_precisionconverter.py +++ b/tests/unit/onnx/autocast/test_precisionconverter.py @@ -33,10 +33,10 @@ def low_precision_onnx_type(low_precision_type_str): def setup_mappings( - model: onnx.ModelProto, use_local_type_inference: bool = False + model: onnx.ModelProto, use_standalone_type_inference: bool = False ) -> tuple[onnx.ModelProto, dict, dict, dict]: # Setup internal mappings - if use_local_type_inference: + if use_standalone_type_inference: model = onnx_utils.infer_types(model) else: model = onnx_utils.infer_shapes(model) @@ -73,8 +73,8 @@ def simple_model(): return model, value_info_map, initializer_map, node_to_init_map -@pytest.mark.parametrize("use_local_type_inference", [True, False]) -def test_graph_converter_init(simple_model, use_local_type_inference): +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) +def test_graph_converter_init(simple_model, use_standalone_type_inference): model, value_info_map, initializer_map, node_to_init_map = simple_model converter = PrecisionConverter( model, @@ -82,7 +82,7 @@ def test_graph_converter_init(simple_model, use_local_type_inference): initializer_map, node_to_init_map, keep_io_types=True, - use_local_type_inference=use_local_type_inference, + use_standalone_type_inference=use_standalone_type_inference, ) assert converter.model == model assert converter.value_info_map == value_info_map @@ -92,8 +92,10 @@ def test_graph_converter_init(simple_model, use_local_type_inference): @pytest.mark.parametrize("keep_io_types", [True, False]) @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) -@pytest.mark.parametrize("use_local_type_inference", [True, False]) -def test_simple_convert(simple_model, keep_io_types, low_precision_type, use_local_type_inference): +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) +def test_simple_convert( + simple_model, keep_io_types, low_precision_type, use_standalone_type_inference +): model, value_info_map, initializer_map, node_to_init_map = simple_model converter = PrecisionConverter( model, @@ -102,7 +104,7 @@ def test_simple_convert(simple_model, keep_io_types, low_precision_type, use_loc node_to_init_map, keep_io_types=keep_io_types, low_precision_type=low_precision_type, - use_local_type_inference=use_local_type_inference, + use_standalone_type_inference=use_standalone_type_inference, ) # Convert add node to fp16, keep mul in fp32 @@ -152,9 +154,9 @@ def test_unsupported_precision_type(simple_model, low_precision_type): @pytest.mark.parametrize("keep_io_types", [True, False]) @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) -@pytest.mark.parametrize("use_local_type_inference", [True, False]) +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) def test_convert_no_disabled_nodes( - simple_model, keep_io_types, low_precision_type, use_local_type_inference + simple_model, keep_io_types, low_precision_type, use_standalone_type_inference ): model, value_info_map, initializer_map, node_to_init_map = simple_model converter = PrecisionConverter( @@ -164,7 +166,7 @@ def test_convert_no_disabled_nodes( node_to_init_map, keep_io_types=keep_io_types, low_precision_type=low_precision_type, - use_local_type_inference=use_local_type_inference, + use_standalone_type_inference=use_standalone_type_inference, ) # Convert all nodes to fp16 @@ -190,9 +192,9 @@ def test_convert_no_disabled_nodes( @pytest.mark.parametrize("keep_io_types", [True, False]) @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) -@pytest.mark.parametrize("use_local_type_inference", [True, False]) +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) def test_get_tensors_to_cast( - simple_model, keep_io_types, low_precision_type, use_local_type_inference + simple_model, keep_io_types, low_precision_type, use_standalone_type_inference ): model, value_info_map, initializer_map, node_to_init_map = simple_model converter = PrecisionConverter( @@ -202,7 +204,7 @@ def test_get_tensors_to_cast( node_to_init_map, keep_io_types=keep_io_types, low_precision_type=low_precision_type, - use_local_type_inference=use_local_type_inference, + use_standalone_type_inference=use_standalone_type_inference, ) # Test when relu node is in low precision @@ -223,8 +225,10 @@ def test_get_tensors_to_cast( @pytest.mark.parametrize("keep_io_types", [True, False]) @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) -@pytest.mark.parametrize("use_local_type_inference", [True, False]) -def test_keep_io_names(simple_model, keep_io_types, low_precision_type, use_local_type_inference): +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) +def test_keep_io_names( + simple_model, keep_io_types, low_precision_type, use_standalone_type_inference +): model, value_info_map, initializer_map, node_to_init_map = simple_model converter = PrecisionConverter( model, @@ -233,7 +237,7 @@ def test_keep_io_names(simple_model, keep_io_types, low_precision_type, use_loca node_to_init_map, keep_io_types=keep_io_types, low_precision_type=low_precision_type, - use_local_type_inference=use_local_type_inference, + use_standalone_type_inference=use_standalone_type_inference, ) # Convert all nodes to low precision @@ -294,9 +298,9 @@ def model_with_multiple_consumers(): @pytest.mark.parametrize("keep_io_types", [True, False]) @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) -@pytest.mark.parametrize("use_local_type_inference", [True, False]) +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) def test_convert_with_multiple_consumers( - model_with_multiple_consumers, keep_io_types, low_precision_type, use_local_type_inference + model_with_multiple_consumers, keep_io_types, low_precision_type, use_standalone_type_inference ): model, value_info_map, initializer_map, node_to_init_map = model_with_multiple_consumers converter = PrecisionConverter( @@ -306,7 +310,7 @@ def test_convert_with_multiple_consumers( node_to_init_map, keep_io_types=keep_io_types, low_precision_type=low_precision_type, - use_local_type_inference=use_local_type_inference, + use_standalone_type_inference=use_standalone_type_inference, ) # Only gemm1 and add1 are converted to fp32, gemm2 and add2 are fp16 @@ -330,9 +334,9 @@ def test_convert_with_multiple_consumers( @pytest.mark.parametrize("keep_io_types", [True, False]) @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) -@pytest.mark.parametrize("use_local_type_inference", [True, False]) +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) def test_get_tensors_to_cast_multiple_consumers( - model_with_multiple_consumers, keep_io_types, low_precision_type, use_local_type_inference + model_with_multiple_consumers, keep_io_types, low_precision_type, use_standalone_type_inference ): model, value_info_map, initializer_map, node_to_init_map = model_with_multiple_consumers converter = PrecisionConverter( @@ -342,7 +346,7 @@ def test_get_tensors_to_cast_multiple_consumers( node_to_init_map, keep_io_types=keep_io_types, low_precision_type=low_precision_type, - use_local_type_inference=use_local_type_inference, + use_standalone_type_inference=use_standalone_type_inference, ) # Test when gemm2 and add1 nodes are in low precision @@ -359,9 +363,9 @@ def test_get_tensors_to_cast_multiple_consumers( @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) -@pytest.mark.parametrize("use_local_type_inference", [True, False]) +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) def test_convert_initializers( - model_with_multiple_consumers, low_precision_type, use_local_type_inference + model_with_multiple_consumers, low_precision_type, use_standalone_type_inference ): model, value_info_map, initializer_map, node_to_init_map = model_with_multiple_consumers converter = PrecisionConverter( @@ -370,7 +374,7 @@ def test_convert_initializers( initializer_map, node_to_init_map, low_precision_type=low_precision_type, - use_local_type_inference=use_local_type_inference, + use_standalone_type_inference=use_standalone_type_inference, ) # Test successful cast, add1 and add2 share add_init and operate in different precisions @@ -397,7 +401,7 @@ def test_convert_initializers( initializer_map, node_to_init_map, low_precision_type=low_precision_type, - use_local_type_inference=use_local_type_inference, + use_standalone_type_inference=use_standalone_type_inference, ) add1_node = next(n for n in converter2.model.graph.node if n.name == "add1") add2_node = next(n for n in converter2.model.graph.node if n.name == "add2") @@ -421,7 +425,7 @@ def test_convert_initializers( initializer_map, node_to_init_map, low_precision_type=low_precision_type, - use_local_type_inference=use_local_type_inference, + use_standalone_type_inference=use_standalone_type_inference, ) add1_node = next(n for n in converter3.model.graph.node if n.name == "add1") add2_node = next(n for n in converter3.model.graph.node if n.name == "add2") @@ -442,9 +446,9 @@ def test_convert_initializers( assert f"add_init_{low_precision_type}" in init_names -@pytest.mark.parametrize("use_local_type_inference", [True, False]) +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) def test_clamping_fp16_initializers_out_of_range( - model_with_multiple_consumers, use_local_type_inference + model_with_multiple_consumers, use_standalone_type_inference ): model, value_info_map, initializer_map, node_to_init_map = model_with_multiple_consumers @@ -458,7 +462,7 @@ def test_clamping_fp16_initializers_out_of_range( value_info_map, initializer_map, node_to_init_map, - use_local_type_inference=use_local_type_inference, + use_standalone_type_inference=use_standalone_type_inference, ) converter._convert_initializers(low_precision_nodes=["add1", "add2"], high_precision_nodes=[]) @@ -479,7 +483,7 @@ def test_clamping_fp16_initializers_out_of_range( value_info_map, initializer_map, node_to_init_map, - use_local_type_inference=use_local_type_inference, + use_standalone_type_inference=use_standalone_type_inference, ) converter2._convert_initializers(low_precision_nodes=[], high_precision_nodes=["add1", "add2"]) @@ -499,7 +503,7 @@ def test_clamping_fp16_initializers_out_of_range( value_info_map, initializer_map, node_to_init_map, - use_local_type_inference=use_local_type_inference, + use_standalone_type_inference=use_standalone_type_inference, ) converter3._convert_initializers(low_precision_nodes=["add1"], high_precision_nodes=["add2"]) @@ -521,9 +525,9 @@ def test_clamping_fp16_initializers_out_of_range( assert np.all(add_init_fp32_array == add_init_out_of_range) -@pytest.mark.parametrize("use_local_type_inference", [True, False]) +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) def test_bf16_no_clamping_initializers_out_of_range( - model_with_multiple_consumers, use_local_type_inference + model_with_multiple_consumers, use_standalone_type_inference ): model, value_info_map, initializer_map, node_to_init_map = model_with_multiple_consumers @@ -538,7 +542,7 @@ def test_bf16_no_clamping_initializers_out_of_range( initializer_map, node_to_init_map, low_precision_type="bf16", - use_local_type_inference=use_local_type_inference, + use_standalone_type_inference=use_standalone_type_inference, ) converter._convert_initializers(low_precision_nodes=["add1", "add2"], high_precision_nodes=[]) @@ -608,8 +612,8 @@ def model_with_dynamic_shapes(): return model, value_info_map, initializer_map, node_to_init_map -@pytest.mark.parametrize("use_local_type_inference", [True, False]) -def test_dynamic_model_conversion(model_with_dynamic_shapes, use_local_type_inference): +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) +def test_dynamic_model_conversion(model_with_dynamic_shapes, use_standalone_type_inference): model, value_info_map, initializer_map, node_to_init_map = model_with_dynamic_shapes # Test mixed precision conversion @@ -618,7 +622,7 @@ def test_dynamic_model_conversion(model_with_dynamic_shapes, use_local_type_infe value_info_map, initializer_map, node_to_init_map, - use_local_type_inference=use_local_type_inference, + use_standalone_type_inference=use_standalone_type_inference, ) high_precision_nodes = ["matmul"] low_precision_nodes = ["transpose", "concat", "size", "div", "concat_dims", "reshape"] @@ -631,8 +635,8 @@ def test_dynamic_model_conversion(model_with_dynamic_shapes, use_local_type_infe #################################################################################################### # Cast cleanup logic #################################################################################################### -@pytest.mark.parametrize("use_local_type_inference", [True, False]) -def test_cast_output_pattern(use_local_type_inference): +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) +def test_cast_output_pattern(use_standalone_type_inference): x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [3, 4]) y1 = helper.make_tensor_value_info("Y1", TensorProto.FLOAT, [3, 4]) y2 = helper.make_tensor_value_info("Y2", TensorProto.FLOAT, [3, 4]) @@ -658,7 +662,7 @@ def test_cast_output_pattern(use_local_type_inference): value_info_map, initializer_map, node_to_init_map, - use_local_type_inference=use_local_type_inference, + use_standalone_type_inference=use_standalone_type_inference, ) # Setting all nodes to FP16 means that the final graph should have no cast nodes @@ -675,8 +679,8 @@ def test_cast_output_pattern(use_local_type_inference): assert converted_model.graph.output[i].name == model.graph.output[i].name -@pytest.mark.parametrize("use_local_type_inference", [True, False]) -def test_cast_output_pattern_mixed_precision(use_local_type_inference): +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) +def test_cast_output_pattern_mixed_precision(use_standalone_type_inference): x1 = helper.make_tensor_value_info("X1", TensorProto.FLOAT, [3, 4]) x2 = helper.make_tensor_value_info("X2", TensorProto.FLOAT, [3, 4]) y0 = helper.make_tensor_value_info("Y0", TensorProto.FLOAT, [3, 4]) @@ -705,7 +709,7 @@ def test_cast_output_pattern_mixed_precision(use_local_type_inference): value_info_map, initializer_map, node_to_init_map, - use_local_type_inference=use_local_type_inference, + use_standalone_type_inference=use_standalone_type_inference, ) # Network output Y0 has two consumers, one is FP16 and the other is FP32 @@ -719,8 +723,8 @@ def test_cast_output_pattern_mixed_precision(use_local_type_inference): @pytest.mark.parametrize("keep_io_types", [True, False]) -@pytest.mark.parametrize("use_local_type_inference", [True, False]) -def test_chain_of_casts_pattern(keep_io_types, use_local_type_inference): +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) +def test_chain_of_casts_pattern(keep_io_types, use_standalone_type_inference): x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [3, 4]) y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [3, 4]) @@ -776,7 +780,7 @@ def test_chain_of_casts_pattern(keep_io_types, use_local_type_inference): initializer_map, node_to_init_map, keep_io_types=keep_io_types, - use_local_type_inference=use_local_type_inference, + use_standalone_type_inference=use_standalone_type_inference, ) converter.convert(high_precision_nodes=["add"], low_precision_nodes=[]) @@ -787,8 +791,8 @@ def test_chain_of_casts_pattern(keep_io_types, use_local_type_inference): @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) -@pytest.mark.parametrize("use_local_type_inference", [True, False]) -def test_existing_low_precision_output(low_precision_type, use_local_type_inference): +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) +def test_existing_low_precision_output(low_precision_type, use_standalone_type_inference): # Create a simple model with FP16 output x = helper.make_tensor_value_info("X", low_precision_onnx_type(low_precision_type), [3, 4]) y = helper.make_tensor_value_info("Y", low_precision_onnx_type(low_precision_type), [3, 4]) @@ -807,7 +811,7 @@ def test_existing_low_precision_output(low_precision_type, use_local_type_infere node_to_init_map, keep_io_types=True, low_precision_type=low_precision_type, - use_local_type_inference=use_local_type_inference, + use_standalone_type_inference=use_standalone_type_inference, ) converter.convert(high_precision_nodes=["add"], low_precision_nodes=[]) @@ -826,8 +830,8 @@ def test_existing_low_precision_output(low_precision_type, use_local_type_infere @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) -@pytest.mark.parametrize("use_local_type_inference", [True, False]) -def test_output_cast_output_pattern(low_precision_type, use_local_type_inference): +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) +def test_output_cast_output_pattern(low_precision_type, use_standalone_type_inference): x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [3, 4]) y1 = helper.make_tensor_value_info("Y1", TensorProto.FLOAT, [3, 4]) y2 = helper.make_tensor_value_info("Y2", low_precision_onnx_type(low_precision_type), [3, 4]) @@ -857,7 +861,7 @@ def test_output_cast_output_pattern(low_precision_type, use_local_type_inference node_to_init_map, keep_io_types=True, low_precision_type=low_precision_type, - use_local_type_inference=use_local_type_inference, + use_standalone_type_inference=use_standalone_type_inference, ) # Setting nodes precision to match I/O type means that the final graph should have no cast nodes @@ -874,8 +878,8 @@ def test_output_cast_output_pattern(low_precision_type, use_local_type_inference @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) -@pytest.mark.parametrize("use_local_type_inference", [True, False]) -def test_cast_output_keep_io_types_pattern(low_precision_type, use_local_type_inference): +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) +def test_cast_output_keep_io_types_pattern(low_precision_type, use_standalone_type_inference): x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [3, 4]) y1 = helper.make_tensor_value_info("Y1", TensorProto.FLOAT, [3, 4]) y2 = helper.make_tensor_value_info("Y2", TensorProto.FLOAT, [3, 4]) @@ -902,7 +906,7 @@ def test_cast_output_keep_io_types_pattern(low_precision_type, use_local_type_in node_to_init_map, keep_io_types=True, low_precision_type=low_precision_type, - use_local_type_inference=use_local_type_inference, + use_standalone_type_inference=use_standalone_type_inference, ) converter.convert(high_precision_nodes=[], low_precision_nodes=["add1", "add2"]) @@ -911,8 +915,8 @@ def test_cast_output_keep_io_types_pattern(low_precision_type, use_local_type_in assert converter.model.graph.output[1].type.tensor_type.elem_type == TensorProto.FLOAT -@pytest.mark.parametrize("use_local_type_inference", [True, False]) -def test_unsupported_op_types_model(use_local_type_inference): +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) +def test_unsupported_op_types_model(use_standalone_type_inference): x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [3, 4]) roi = helper.make_tensor_value_info("roi", TensorProto.FLOAT, [3, 4]) scales = helper.make_tensor_value_info("scales", TensorProto.FLOAT, [4]) @@ -939,7 +943,7 @@ def test_unsupported_op_types_model(use_local_type_inference): value_info_map, initializer_map, node_to_init_map, - use_local_type_inference=use_local_type_inference, + use_standalone_type_inference=use_standalone_type_inference, ) converter.convert(high_precision_nodes=[], low_precision_nodes=["celu", "resize", "nms"]) onnx.checker.check_model(converter.model) @@ -947,8 +951,10 @@ def test_unsupported_op_types_model(use_local_type_inference): @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) @pytest.mark.parametrize("empty_tensor_target", ["low_precision", "high_precision"]) -@pytest.mark.parametrize("use_local_type_inference", [True, False]) -def test_empty_tensor_handling(low_precision_type, empty_tensor_target, use_local_type_inference): +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) +def test_empty_tensor_handling( + low_precision_type, empty_tensor_target, use_standalone_type_inference +): """Test empty tensor handling for both low and high precision node targets.""" # Create model with empty float tensor from Constant layer x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2]) @@ -986,7 +992,7 @@ def test_empty_tensor_handling(low_precision_type, empty_tensor_target, use_loca node_to_init_map, keep_io_types=True, low_precision_type=low_precision_type, - use_local_type_inference=use_local_type_inference, + use_standalone_type_inference=use_standalone_type_inference, ) # Test empty tensor detection @@ -1075,9 +1081,9 @@ def model_with_constant_cast_patterns(): @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) -@pytest.mark.parametrize("use_local_type_inference", [True, False]) +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) def test_constant_cast_folding( - model_with_constant_cast_patterns, low_precision_type, use_local_type_inference + model_with_constant_cast_patterns, low_precision_type, use_standalone_type_inference ): """Test constant->cast folding as part of the full conversion process.""" model, value_info_map, initializer_map, node_to_init_map = model_with_constant_cast_patterns @@ -1089,7 +1095,7 @@ def test_constant_cast_folding( node_to_init_map, keep_io_types=True, low_precision_type=low_precision_type, - use_local_type_inference=use_local_type_inference, + use_standalone_type_inference=use_standalone_type_inference, ) # Convert with some nodes in low precision to trigger cast insertion @@ -1176,9 +1182,11 @@ def model_with_multiple_output_node_casted_to_output(): @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) -@pytest.mark.parametrize("use_local_type_inference", [True, False]) +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) def test_multiple_output_node_casted_to_output( - model_with_multiple_output_node_casted_to_output, low_precision_type, use_local_type_inference + model_with_multiple_output_node_casted_to_output, + low_precision_type, + use_standalone_type_inference, ): model, value_info_map, initializer_map, node_to_init_map = ( model_with_multiple_output_node_casted_to_output @@ -1191,7 +1199,7 @@ def test_multiple_output_node_casted_to_output( node_to_init_map, keep_io_types=True, low_precision_type=low_precision_type, - use_local_type_inference=use_local_type_inference, + use_standalone_type_inference=use_standalone_type_inference, ) converted_model = converter.convert( high_precision_nodes=[], low_precision_nodes=["concat_1", "concat_2"] @@ -1246,9 +1254,12 @@ def model_with_casted_input_to_output(): @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) @pytest.mark.parametrize("keep_io_types", [True, False]) -@pytest.mark.parametrize("use_local_type_inference", [True, False]) +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) def test_casted_input_to_output_model( - model_with_casted_input_to_output, low_precision_type, keep_io_types, use_local_type_inference + model_with_casted_input_to_output, + low_precision_type, + keep_io_types, + use_standalone_type_inference, ): model, value_info_map, initializer_map, node_to_init_map = model_with_casted_input_to_output @@ -1262,7 +1273,7 @@ def test_casted_input_to_output_model( min_opset=22 if low_precision_type == "bf16" else 13, max_ir_version=LATEST_IR_VERSION_SUPPORTED_BY_ORT, trt_plugins=[], - use_local_type_inference=use_local_type_inference, + use_standalone_type_inference=use_standalone_type_inference, ) converted_model = converter.convert( high_precision_nodes=["cast_input"], low_precision_nodes=["add1", "add2"] @@ -1377,9 +1388,9 @@ def create_model_with_resize_op_tensor_scales(): @pytest.mark.parametrize("keep_io_types", [True, False]) @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) -@pytest.mark.parametrize("use_local_type_inference", [True, False]) +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) def test_resize_op_initializer_conversion( - create_model_with_resize_op, keep_io_types, low_precision_type, use_local_type_inference + create_model_with_resize_op, keep_io_types, low_precision_type, use_standalone_type_inference ): model, value_info_map, initializer_map, node_to_init_map = create_model_with_resize_op @@ -1390,7 +1401,7 @@ def test_resize_op_initializer_conversion( node_to_init_map, keep_io_types=keep_io_types, low_precision_type=low_precision_type, - use_local_type_inference=use_local_type_inference, + use_standalone_type_inference=use_standalone_type_inference, ) converted_model = converter.convert( high_precision_nodes=[], low_precision_nodes=[node.name for node in model.graph.node] @@ -1400,12 +1411,12 @@ def test_resize_op_initializer_conversion( @pytest.mark.parametrize("keep_io_types", [True, False]) @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) -@pytest.mark.parametrize("use_local_type_inference", [True, False]) +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) def test_resize_op_tensor_scales_conversion( create_model_with_resize_op_tensor_scales, keep_io_types, low_precision_type, - use_local_type_inference, + use_standalone_type_inference, ): model, value_info_map, initializer_map, node_to_init_map = ( create_model_with_resize_op_tensor_scales @@ -1418,7 +1429,7 @@ def test_resize_op_tensor_scales_conversion( node_to_init_map, keep_io_types=keep_io_types, low_precision_type=low_precision_type, - use_local_type_inference=use_local_type_inference, + use_standalone_type_inference=use_standalone_type_inference, ) converted_model = converter.convert( high_precision_nodes=[], low_precision_nodes=[node.name for node in model.graph.node] @@ -1515,9 +1526,9 @@ def model_with_if_subgraph(): @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) @pytest.mark.parametrize("if_precision", ["low", "high"]) -@pytest.mark.parametrize("use_local_type_inference", [True, False]) +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) def test_if_subgraph_initializer_conversion( - model_with_if_subgraph, low_precision_type, if_precision, use_local_type_inference + model_with_if_subgraph, low_precision_type, if_precision, use_standalone_type_inference ): """Test that initializers in If subgraphs are converted based on parent node precision.""" model, value_info_map, initializer_map, node_to_init_map = model_with_if_subgraph @@ -1529,7 +1540,7 @@ def test_if_subgraph_initializer_conversion( node_to_init_map, keep_io_types=True, low_precision_type=low_precision_type, - use_local_type_inference=use_local_type_inference, + use_standalone_type_inference=use_standalone_type_inference, ) # Classify the If node based on test parameter @@ -1583,9 +1594,9 @@ def test_if_subgraph_initializer_conversion( @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) -@pytest.mark.parametrize("use_local_type_inference", [True, False]) +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) def test_if_subgraph_mixed_precision_boundary( - model_with_if_subgraph, low_precision_type, use_local_type_inference + model_with_if_subgraph, low_precision_type, use_standalone_type_inference ): """Test that types are correctly handled at If subgraph boundaries in mixed precision.""" model, value_info_map, initializer_map, node_to_init_map = model_with_if_subgraph @@ -1611,7 +1622,7 @@ def test_if_subgraph_mixed_precision_boundary( node_to_init_map, keep_io_types=True, low_precision_type=low_precision_type, - use_local_type_inference=use_local_type_inference, + use_standalone_type_inference=use_standalone_type_inference, ) # If in low precision, Add in high precision From 0b8e5d91b13496d7d2fe8c12bc72ed7d8e0e8c9a Mon Sep 17 00:00:00 2001 From: Gal Hubara Agam <96368689+galagam@users.noreply.github.com> Date: Sun, 11 Jan 2026 13:03:45 +0200 Subject: [PATCH 07/11] fix rebase Signed-off-by: Gal Hubara Agam <96368689+galagam@users.noreply.github.com> --- modelopt/onnx/autocast/precisionconverter.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/modelopt/onnx/autocast/precisionconverter.py b/modelopt/onnx/autocast/precisionconverter.py index 0c6d97553..4eddbcd64 100644 --- a/modelopt/onnx/autocast/precisionconverter.py +++ b/modelopt/onnx/autocast/precisionconverter.py @@ -328,21 +328,15 @@ def _clear_callback(g: onnx.GraphProto, parent: onnx.NodeProto, is_sub: bool) -> subgraph_outputs = set() for node in g.node: subgraph_outputs.update(node.output) - # Clear type/shape information for intermediates and outputs - for vi in g.value_info: - vi.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED - if not self.use_standalone_type_inference: - for idx, d in enumerate(vi.type.tensor_type.shape.dim): - if d.dim_value: - vi.type.tensor_type.shape.dim[idx].dim_param = "unk" # Clear value_info only for intermediates produced by nodes in this subgraph for vi in g.value_info: if vi.name in subgraph_outputs: vi.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED - for idx, d in enumerate(vi.type.tensor_type.shape.dim): - if d.dim_value: - vi.type.tensor_type.shape.dim[idx].dim_param = "unk" + if not self.use_standalone_type_inference: + for idx, d in enumerate(vi.type.tensor_type.shape.dim): + if d.dim_value: + vi.type.tensor_type.shape.dim[idx].dim_param = "unk" else: for vi in g.value_info: vi.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED From 023cde9861cf02b077056ccf3c007ada3243f638 Mon Sep 17 00:00:00 2001 From: Gal Hubara Agam <96368689+galagam@users.noreply.github.com> Date: Sun, 11 Jan 2026 13:50:18 +0200 Subject: [PATCH 08/11] utility function infer_types Signed-off-by: Gal Hubara Agam <96368689+galagam@users.noreply.github.com> --- modelopt/onnx/autocast/convert.py | 10 +---- modelopt/onnx/autocast/precisionconverter.py | 40 ++++++++++--------- modelopt/onnx/utils.py | 27 ++++++++++++- .../onnx/autocast/test_precisionconverter.py | 5 +-- 4 files changed, 50 insertions(+), 32 deletions(-) diff --git a/modelopt/onnx/autocast/convert.py b/modelopt/onnx/autocast/convert.py index 50f2a4348..73d2bea4d 100644 --- a/modelopt/onnx/autocast/convert.py +++ b/modelopt/onnx/autocast/convert.py @@ -136,10 +136,7 @@ def convert_to_mixed_precision( model = graph_sanitizer.model # Setup internal mappings - if use_standalone_type_inference: - model = onnx_utils.infer_types(model) - else: - model = onnx_utils.infer_shapes(model) + model = onnx_utils.infer_types(model, use_standalone_type_inference) value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) # Automatically add 'trt' to list of providers if custom ops are detected @@ -237,10 +234,7 @@ def convert_to_f16( model = sanitizer.model # Setup internal mappings - if use_standalone_type_inference: - model = onnx_utils.infer_types(model) - else: - model = onnx_utils.infer_shapes(model) + model = onnx_utils.infer_types(model, use_standalone_type_inference) value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) precision_converter = PrecisionConverter( diff --git a/modelopt/onnx/autocast/precisionconverter.py b/modelopt/onnx/autocast/precisionconverter.py index 4eddbcd64..3a97874a2 100644 --- a/modelopt/onnx/autocast/precisionconverter.py +++ b/modelopt/onnx/autocast/precisionconverter.py @@ -160,20 +160,6 @@ def __init__( self._warned_values_clamp_max = False self._warned_values_clamp_min = False - def infer_types(self, **kwargs): - """Infers types (and optionally shapes) based on the use_standalone_type_inference flag. - - Args: - **kwargs: Additional arguments passed to infer_shapes when not using local type inference. - - Returns: - onnx.ModelProto: Model with inferred types (and shapes if not using local type inference). - """ - if self.use_standalone_type_inference: - return onnx_utils.infer_types(self.model) - else: - return onnx_utils.infer_shapes(self.model, **kwargs) - def convert( self, high_precision_nodes: list[str], @@ -271,10 +257,14 @@ def convert( # Clear type/shape information for intermediates and outputs (including subgraphs) self._clear_types_and_shapes_recursive(self.model.graph) # Populate type information with inferred types - self.model = self.infer_types(strict_mode=True, check_type=False) + self.model = onnx_utils.infer_types( + self.model, self.use_standalone_type_inference, strict_mode=True, check_type=False + ) self._ensure_types_are_defined() # Sanity check: Verify type correctness - self.model = self.infer_types(strict_mode=True, check_type=True) + self.model = onnx_utils.infer_types( + self.model, self.use_standalone_type_inference, strict_mode=True, check_type=True + ) # Update value_info_map and initializer_map with casts we added self.value_info_map, self.initializer_map, self.node_to_init_map = utils.setup_mappings( @@ -1195,9 +1185,16 @@ def _remove_redundant_casts(self): if self.custom_ops: self.model = self._propagate_types_shapes_custom_ops(self.model) else: - self.model = self.infer_types(strict_mode=True) + self.model = onnx_utils.infer_types( + self.model, self.use_standalone_type_inference, strict_mode=True + ) if not self.use_standalone_type_inference: - self.model = self.infer_types(strict_mode=True, check_type=True) + self.model = onnx_utils.infer_types( + self.model, + self.use_standalone_type_inference, + strict_mode=True, + check_type=True, + ) nodes_to_remove = [] for node in self.model.graph.node: @@ -1282,7 +1279,12 @@ def _fix_network_output_names(self): if self.custom_ops: self.model = self._propagate_types_shapes_custom_ops(self.model) else: - self.model = self.infer_types(strict_mode=True, check_type=True) + self.model = onnx_utils.infer_types( + self.model, + self.use_standalone_type_inference, + strict_mode=True, + check_type=True, + ) self.value_info_map, self.initializer_map, self.node_to_init_map = utils.setup_mappings( self.model ) diff --git a/modelopt/onnx/utils.py b/modelopt/onnx/utils.py index d7597e46b..49970b302 100644 --- a/modelopt/onnx/utils.py +++ b/modelopt/onnx/utils.py @@ -728,9 +728,11 @@ def get_attribute(node: onnx.NodeProto, attr_name: str) -> Any: raise ValueError(f"Attribute {attr_name} not found in node {node.name}") -def infer_types(model: onnx.ModelProto) -> onnx.ModelProto: +def _infer_types_only(model: onnx.ModelProto) -> onnx.ModelProto: """Infers types (but not shapes) of the onnx graph using local implementation. + This is an internal function. Use infer_types() as the public API. + This is a workaround for cases when ONNX's shape inference fails. ONNX's infer_shapes performs both shape and type inference together, but for AutoCast, we only need type inference. @@ -1157,6 +1159,29 @@ def infer_shapes(model: onnx.ModelProto, **kwargs): return onnx.shape_inference.infer_shapes(model, **kwargs) +def infer_types( + model: onnx.ModelProto, use_standalone_type_inference: bool = False, **kwargs +) -> onnx.ModelProto: + """Infers types (and optionally shapes) based on the use_standalone_type_inference flag. + + When use_standalone_type_inference is True, uses a standalone type inference implementation + that only infers types. Otherwise, uses ONNX's infer_shapes which infers both types and shapes. + + Args: + model: ONNX model to infer types/shapes for. + use_standalone_type_inference: If True, use standalone type inference (_infer_types_only). + If False, use ONNX's shape inference (infer_shapes). + **kwargs: Additional arguments passed to infer_shapes when not using standalone type inference. + + Returns: + onnx.ModelProto: Model with inferred types (and shapes if not using standalone type inference). + """ + if use_standalone_type_inference: + return _infer_types_only(model) + else: + return infer_shapes(model, **kwargs) + + def onnx_type_str_to_enum(dtype: str) -> int: """Converts ONNX type in string format to onnx.TensorProto format. diff --git a/tests/unit/onnx/autocast/test_precisionconverter.py b/tests/unit/onnx/autocast/test_precisionconverter.py index bf6fcd4f7..e8777b438 100644 --- a/tests/unit/onnx/autocast/test_precisionconverter.py +++ b/tests/unit/onnx/autocast/test_precisionconverter.py @@ -36,10 +36,7 @@ def setup_mappings( model: onnx.ModelProto, use_standalone_type_inference: bool = False ) -> tuple[onnx.ModelProto, dict, dict, dict]: # Setup internal mappings - if use_standalone_type_inference: - model = onnx_utils.infer_types(model) - else: - model = onnx_utils.infer_shapes(model) + model = onnx_utils.infer_types(model, use_standalone_type_inference) value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) return model, value_info_map, initializer_map, node_to_init_map From d8472b6a5f8a75925754089d868e64f9bed49b39 Mon Sep 17 00:00:00 2001 From: Gal Hubara Agam <96368689+galagam@users.noreply.github.com> Date: Sun, 11 Jan 2026 14:53:45 +0200 Subject: [PATCH 09/11] address CR for code cleanup Signed-off-by: Gal Hubara Agam <96368689+galagam@users.noreply.github.com> --- modelopt/onnx/utils.py | 69 +++++------------------------------------- 1 file changed, 7 insertions(+), 62 deletions(-) diff --git a/modelopt/onnx/utils.py b/modelopt/onnx/utils.py index 49970b302..ca14251f7 100644 --- a/modelopt/onnx/utils.py +++ b/modelopt/onnx/utils.py @@ -769,13 +769,16 @@ def infer_types_for_graph( try: gs_graph = gs.import_onnx(temp_model) gs_graph.toposort() + # Convert back to ONNX to get topologically sorted nodes + sorted_model = gs.export_onnx(gs_graph) + sorted_graph = sorted_model.graph except Exception as e: logger.debug( f"Graphsurgeon toposort failed for {'subgraph' if is_subgraph else 'main graph'}," f"using original order: {e!s}" ) # Fallback: process nodes in original order - gs_graph = None + sorted_graph = graph # Create mappings for quick lookup for this graph initializer_map = {init.name: init for init in graph.initializer} @@ -803,66 +806,8 @@ def get_tensor_type(tensor_name: str) -> int | None: return vi.type.tensor_type.elem_type return None - def str_to_tensor_dtype(dtype_str: str) -> onnx.TensorProto.DataType: - """Converts a string representation of a tensor dtype to an onnx.TensorProto.DataType.""" - _str_to_tensor_dtype = { - "float": onnx.TensorProto.FLOAT, - "uint8": onnx.TensorProto.UINT8, - "int8": onnx.TensorProto.INT8, - "uint16": onnx.TensorProto.UINT16, - "int16": onnx.TensorProto.INT16, - "int32": onnx.TensorProto.INT32, - "int64": onnx.TensorProto.INT64, - "string": onnx.TensorProto.STRING, - "bool": onnx.TensorProto.BOOL, - "float16": onnx.TensorProto.FLOAT16, - "double": onnx.TensorProto.DOUBLE, - "uint32": onnx.TensorProto.UINT32, - "uint64": onnx.TensorProto.UINT64, - "complex64": onnx.TensorProto.COMPLEX64, - "complex128": onnx.TensorProto.COMPLEX128, - "bfloat16": onnx.TensorProto.BFLOAT16, - "float8e4m3fn": onnx.TensorProto.FLOAT8E4M3FN, - "float8e4m3fnuz": onnx.TensorProto.FLOAT8E4M3FNUZ, - "float8e5m2": onnx.TensorProto.FLOAT8E5M2, - "float8e5m2fnuz": onnx.TensorProto.FLOAT8E5M2FNUZ, - "uint4": onnx.TensorProto.UINT4, - "int4": onnx.TensorProto.INT4, - "float4e2m1": onnx.TensorProto.FLOAT4E2M1, - "float8e8m0": onnx.TensorProto.FLOAT8E8M0, - } - try: - str_sanitized = dtype_str.replace("tensor(", "").replace(")", "") - return _str_to_tensor_dtype[str_sanitized] - except KeyError: - raise ValueError(f"Invalid tensor dtype string: {str_sanitized}") - - # Create mapping from node name to ONNX node for efficient lookup - node_name_to_onnx = {node.name: node for node in graph.node} - - # Get nodes to process (from graphsurgeon if available, otherwise from graph directly) - if gs_graph is not None: - nodes_to_process = gs_graph.nodes - else: - nodes_to_process = graph.node - # Process nodes in topological order (single pass) - for gs_node_or_onnx_node in nodes_to_process: - # Get corresponding ONNX node - if gs_graph is not None: - # From graphsurgeon - node = node_name_to_onnx.get(gs_node_or_onnx_node.name) - else: - # Direct from graph - node = gs_node_or_onnx_node - - if node is None: - if gs_graph is not None: - logger.debug( - f"Could not find ONNX node for graphsurgeon node: {gs_node_or_onnx_node.name}" - ) - continue - + for node in sorted_graph.node: # Get input types for this node input_types = [] for inp_name in node.input: @@ -1029,7 +974,7 @@ def str_to_tensor_dtype(dtype_str: str) -> onnx.TensorProto.DataType: for output_idx in range(len(node.output)): # explicit type is set in schema, use it if "tensor" in output_schemas[output_idx]: - found_type = str_to_tensor_dtype(output_schemas[output_idx]) + found_type = onnx_type_str_to_enum(output_schemas[output_idx]) output_types[output_idx] = found_type continue # sometimes output type is set with a placeholder name despite supporting a single type @@ -1038,7 +983,7 @@ def str_to_tensor_dtype(dtype_str: str) -> onnx.TensorProto.DataType: # If output type constraint has only one allowed type, use it directly if constraint.type_param_str == output_schemas[output_idx]: if len(constraint.allowed_type_strs) == 1: - found_type = str_to_tensor_dtype( + found_type = onnx_type_str_to_enum( constraint.allowed_type_strs[0] ) output_types[output_idx] = found_type From c1192bc629f9bf83a4211d2578291132863a137c Mon Sep 17 00:00:00 2001 From: Gal Hubara Agam <96368689+galagam@users.noreply.github.com> Date: Sun, 11 Jan 2026 15:56:54 +0200 Subject: [PATCH 10/11] minor refactoring (cr fixes) Signed-off-by: Gal Hubara Agam <96368689+galagam@users.noreply.github.com> --- modelopt/onnx/utils.py | 12 ++++++------ tests/unit/onnx/autocast/test_precisionconverter.py | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/modelopt/onnx/utils.py b/modelopt/onnx/utils.py index ca14251f7..02306792a 100644 --- a/modelopt/onnx/utils.py +++ b/modelopt/onnx/utils.py @@ -797,13 +797,12 @@ def infer_types_for_graph( tensor_types[init_name] = init.data_type # Helper function to get tensor type - def get_tensor_type(tensor_name: str) -> int | None: + def get_tensor_type_from_name(tensor_name: str) -> int | None: if tensor_name in tensor_types: return tensor_types[tensor_name] if tensor_name in value_info_map: vi = value_info_map[tensor_name] - if vi.type.HasField("tensor_type"): - return vi.type.tensor_type.elem_type + return _get_tensor_type(vi) return None # Process nodes in topological order (single pass) @@ -811,9 +810,10 @@ def get_tensor_type(tensor_name: str) -> int | None: # Get input types for this node input_types = [] for inp_name in node.input: + # an empty tensor name is typically a sign of an optional input, skip it if not inp_name: continue - inp_type = get_tensor_type(inp_name) + inp_type = get_tensor_type_from_name(inp_name) if inp_type is None: raise ValueError(f"Input {inp_name} of node {node.name} has unknown type") input_types.append(inp_type) @@ -844,7 +844,7 @@ def get_tensor_type(tensor_name: str) -> int | None: if output_dtype is not None: output_types = [output_dtype] elif len(node.input) >= 2 and node.input[1]: - scale_type = get_tensor_type(node.input[1]) + scale_type = get_tensor_type_from_name(node.input[1]) if scale_type is not None: output_types = [scale_type] else: @@ -866,7 +866,7 @@ def get_tensor_type(tensor_name: str) -> int | None: if output_dtype is not None: output_types = [output_dtype] * len(node.output) elif len(node.input) >= 3 and node.input[2]: - zero_point_type = get_tensor_type(node.input[2]) + zero_point_type = get_tensor_type_from_name(node.input[2]) if zero_point_type is not None: output_types = [zero_point_type] else: diff --git a/tests/unit/onnx/autocast/test_precisionconverter.py b/tests/unit/onnx/autocast/test_precisionconverter.py index e8777b438..a14991319 100644 --- a/tests/unit/onnx/autocast/test_precisionconverter.py +++ b/tests/unit/onnx/autocast/test_precisionconverter.py @@ -622,7 +622,7 @@ def test_dynamic_model_conversion(model_with_dynamic_shapes, use_standalone_type use_standalone_type_inference=use_standalone_type_inference, ) high_precision_nodes = ["matmul"] - low_precision_nodes = ["transpose", "concat", "size", "div", "concat_dims", "reshape"] + low_precision_nodes = ["transpose", "concat1", "size", "div", "concat2", "reshape"] converted_model = converter2.convert(high_precision_nodes, low_precision_nodes) # Verify model is valid From 9656fcaec5b1f490b487e5dc4d4111d9018cc52a Mon Sep 17 00:00:00 2001 From: Gal Hubara Agam <96368689+galagam@users.noreply.github.com> Date: Wed, 14 Jan 2026 12:49:32 +0200 Subject: [PATCH 11/11] update changelog - push to 0.42 Signed-off-by: Gal Hubara Agam <96368689+galagam@users.noreply.github.com> --- CHANGELOG.rst | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 75840754d..e95055a04 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,6 +1,14 @@ NVIDIA Model Optimizer Changelog (Linux) ======================================== +0.42 (TBD) +^^^^^^^^^^^^^^^^^ + +**Bug Fixes** + +**New Features** +- Add standalone type inference option (``--use_standalone_type_inference``) in ONNX AutoCast as an alternative to ONNX's ``infer_shapes``. This experimental feature performs type-only inference without shape inference, useful as a workaround when shape inference fails or to avoid unnecessary shape inference overhead. + 0.41 (2026-01-19) ^^^^^^^^^^^^^^^^^ @@ -18,7 +26,6 @@ NVIDIA Model Optimizer Changelog (Linux) - Add support for parallel draft heads in Eagle speculative decoding. - Add support to enable custom emulated quantization backend. See :meth:`register_quant_backend `` for more details. See an example in ``tests/unit/torch/quantization/test_custom_backend.py``. - Add ``examples/llm_qad`` for QAD training with Megatron-LM. -- Add standalone type inference option (``--use_standalone_type_inference``) in ONNX AutoCast as an alternative to ONNX's ``infer_shapes``. This experimental feature performs type-only inference without shape inference, useful as a workaround when shape inference fails or to avoid unnecessary shape inference overhead. **Deprecations**