diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 978ac209d..e95055a04 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,6 +1,14 @@ NVIDIA Model Optimizer Changelog (Linux) ======================================== +0.42 (TBD) +^^^^^^^^^^^^^^^^^ + +**Bug Fixes** + +**New Features** +- Add standalone type inference option (``--use_standalone_type_inference``) in ONNX AutoCast as an alternative to ONNX's ``infer_shapes``. This experimental feature performs type-only inference without shape inference, useful as a workaround when shape inference fails or to avoid unnecessary shape inference overhead. + 0.41 (2026-01-19) ^^^^^^^^^^^^^^^^^ diff --git a/docs/source/guides/8_autocast.rst b/docs/source/guides/8_autocast.rst index 4ad39e969..0701f2f1f 100644 --- a/docs/source/guides/8_autocast.rst +++ b/docs/source/guides/8_autocast.rst @@ -42,6 +42,7 @@ AutoCast can also be used programmatically through its Python API: trt_plugins=[], # list of TensorRT plugin library paths in .so format max_depth_of_reduction=None, # maximum depth of reduction allowed in low precision opset=None, # optional target ONNX opset version (default: 13 for fp16, 22 for bf16) + use_standalone_type_inference=False, # use standalone type inference instead of ONNX's infer_shapes (WAR) ) # Save the converted model @@ -82,6 +83,9 @@ AutoCast follows these steps to convert a model: - Converts eligible nodes to lower precision - Automatically inserts necessary cast operations - Automatically replaces initializers with lower precision values + - Performs type inference to propagate types through the graph + - By default, uses ONNX's ``infer_shapes`` which performs both shape and type inference using the ONNX infer_shapes API. + - Use ``use_standalone_type_inference=True`` to use a standalone type-only inference implementation (experimental). #. **Validation and Export**: @@ -145,6 +149,14 @@ Best Practices - A warning will be issued if you specify an opset lower than the original model's opset, as downgrading opset versions may cause compatibility issues. - The opset may be automatically increased beyond your specified value if certain operations require it (e.g., quantization nodes require opset >= 19). +#. **Type Inference Control** + + - By default, AutoCast uses ONNX's ``infer_shapes`` which performs both shape and type inference. + - Use ``--use_standalone_type_inference`` to enable a standalone type-only inference implementation. + - This is a workaround for cases where shape inference fails for any reason, which allows us to bypass the dependency in ONNX's shape inference logic. + - The standalone implementation uses graphsurgeon for topological sorting and handles special operators like Cast, QuantizeLinear, DequantizeLinear, Constant and ConstantOfShape. + - Note: The standalone type inference may be less robust than ONNX's implementation for edge cases, but avoids unnecessary shape inference overhead and possible failures. + Limitations and Restrictions ---------------------------- - AutoCast does not yet support quantized models. @@ -198,3 +210,9 @@ Convert to BF16 with a specific opset: .. code-block:: bash python -m modelopt.onnx.autocast --onnx_path model.onnx --low_precision_type bf16 --opset 22 + +Use standalone type inference instead of ONNX's infer_shapes: + +.. code-block:: bash + + python -m modelopt.onnx.autocast --onnx_path model.onnx --use_standalone_type_inference diff --git a/modelopt/onnx/autocast/__main__.py b/modelopt/onnx/autocast/__main__.py index cabeff733..da521d524 100644 --- a/modelopt/onnx/autocast/__main__.py +++ b/modelopt/onnx/autocast/__main__.py @@ -185,6 +185,16 @@ def get_parser() -> argparse.ArgumentParser: "higher version." ), ) + parser.add_argument( + "--use_standalone_type_inference", + action="store_true", + default=False, + help=( + "Use local type inference implementation instead of ONNX's infer_shapes (experimental)." + "This is a workaround for cases where shape inference fails for any reason." + "Default: False (uses ONNX's infer_shapes which does both shape and type inference)." + ), + ) return parser @@ -218,6 +228,7 @@ def main(argv=None): trt_plugins_precision=args.trt_plugins_precision, max_depth_of_reduction=args.max_depth_of_reduction, opset=args.opset, + use_standalone_type_inference=args.use_standalone_type_inference, ) output_path = args.output_path diff --git a/modelopt/onnx/autocast/convert.py b/modelopt/onnx/autocast/convert.py index 4328c9fc2..73d2bea4d 100644 --- a/modelopt/onnx/autocast/convert.py +++ b/modelopt/onnx/autocast/convert.py @@ -61,6 +61,7 @@ def convert_to_mixed_precision( trt_plugins_precision: list[str] = [], max_depth_of_reduction: int | None = None, opset: int | None = None, + use_standalone_type_inference: bool = False, ) -> onnx.ModelProto: """Convert model to mixed precision. @@ -85,6 +86,9 @@ def convert_to_mixed_precision( opset: Target ONNX opset version. If None, uses default minimum opset based on low_precision_type (22 for bf16, 13 for fp16). The opset may be automatically increased if certain operations require a higher version. + use_standalone_type_inference: If True, use standalone type inference implementation instead of ONNX's + infer_shapes. This is a workaround (WAR) when only type inference is + needed without shape inference. Default: False. Returns: onnx.ModelProto: The converted mixed precision model. @@ -132,7 +136,7 @@ def convert_to_mixed_precision( model = graph_sanitizer.model # Setup internal mappings - model = onnx_utils.infer_shapes(model) + model = onnx_utils.infer_types(model, use_standalone_type_inference) value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) # Automatically add 'trt' to list of providers if custom ops are detected @@ -164,6 +168,7 @@ def convert_to_mixed_precision( low_precision_type=low_precision_type, init_conversion_max_bytes=init_conversion_max_bytes, custom_ops=graph_sanitizer.custom_ops, + use_standalone_type_inference=use_standalone_type_inference, ) # Obtain reference data @@ -196,6 +201,7 @@ def convert_to_f16( op_block_list: list[str] = [], tensor_block_dict: dict[str, dict[str, list[int]]] = {}, trt_plugins: list[str] | None = [], + use_standalone_type_inference: bool = False, ) -> onnx.ModelProto: """Convert model to mixed precision, using PrecisionConverter. @@ -208,6 +214,9 @@ def convert_to_f16( op_block_list: List of operation types that should remain in FP32. tensor_block_dict: Dictionary of tensors (operation type and I/O indices) that should remain in FP32. trt_plugins: List of TensorRT plugin library paths in .so format (compiled shared library). + use_standalone_type_inference: If True, use standalone type inference implementation instead of ONNX's + infer_shapes. This is a workaround (WAR) when only type inference is + needed without shape inference. Default: False. """ assert low_precision_type in ["fp16", "bf16"], "low_precision_type must be either fp16 or bf16" @@ -225,7 +234,7 @@ def convert_to_f16( model = sanitizer.model # Setup internal mappings - model = onnx_utils.infer_shapes(model) + model = onnx_utils.infer_types(model, use_standalone_type_inference) value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) precision_converter = PrecisionConverter( @@ -237,6 +246,7 @@ def convert_to_f16( low_precision_type=low_precision_type, custom_ops=sanitizer.custom_ops, tensor_block_dict=tensor_block_dict, + use_standalone_type_inference=use_standalone_type_inference, ) high_precision_nodes = [node.name for node in model.graph.node if node.op_type in op_block_list] low_precision_nodes = [ diff --git a/modelopt/onnx/autocast/precisionconverter.py b/modelopt/onnx/autocast/precisionconverter.py index eae589d8a..3a97874a2 100644 --- a/modelopt/onnx/autocast/precisionconverter.py +++ b/modelopt/onnx/autocast/precisionconverter.py @@ -97,6 +97,7 @@ def __init__( max_ir_version: int | None = None, trt_plugins: list[str] | None = [], tensor_block_dict: dict[str, dict[str, list[int]]] = {}, + use_standalone_type_inference: bool = False, ) -> None: """Initialize PrecisionConverter. @@ -114,6 +115,7 @@ def __init__( max_ir_version: Max IR version for conversion. trt_plugins: List of custom TensorRT plugin library paths in .so format (compiled shared library). tensor_block_dict: Dictionary of tensors (operation type and I/O indices) that should remain in FP32. + use_standalone_type_inference: Use standalone type inference instead of ONNX's infer_shapes. """ self.model = deepcopy(model) self.value_info_map = value_info_map @@ -140,6 +142,7 @@ def __init__( self.min_opset = min_opset self.max_ir_version = max_ir_version self.trt_plugins = trt_plugins + self.use_standalone_type_inference = use_standalone_type_inference # Detect additional ops not supported in low precision according to the model's opset version self.op_types_not_supported_in_low_precision = OP_TYPES_NOT_SUPPORTED_IN_LOW_PRECISION + ( @@ -254,10 +257,14 @@ def convert( # Clear type/shape information for intermediates and outputs (including subgraphs) self._clear_types_and_shapes_recursive(self.model.graph) # Populate type information with inferred types - self.model = onnx_utils.infer_shapes(self.model, strict_mode=True, check_type=False) + self.model = onnx_utils.infer_types( + self.model, self.use_standalone_type_inference, strict_mode=True, check_type=False + ) self._ensure_types_are_defined() # Sanity check: Verify type correctness - self.model = onnx_utils.infer_shapes(self.model, strict_mode=True, check_type=True) + self.model = onnx_utils.infer_types( + self.model, self.use_standalone_type_inference, strict_mode=True, check_type=True + ) # Update value_info_map and initializer_map with casts we added self.value_info_map, self.initializer_map, self.node_to_init_map = utils.setup_mappings( @@ -282,9 +289,9 @@ def _clear_types_and_shapes_recursive( ) -> None: """Recursively clear type/shape information for a graph and all its subgraphs. - This is necessary for control flow operators (Scan, If, Loop) which have subgraphs. - For subgraphs, preserve value_info for outer scope variables (not produced by nodes in subgraph). - For main graph, clear all value_info. + If use_standalone_type_inference is True, we clear only types, not shapes. + For subgraphs, input types/shapes are cleared, so that the input types/shapes are propagated + from the main graph. Args: graph: The ONNX graph to clear types and shapes for. @@ -301,9 +308,10 @@ def _clear_callback(g: onnx.GraphProto, parent: onnx.NodeProto, is_sub: bool) -> for inp in g.input: if inp.type.HasField("tensor_type"): inp.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED - for idx, d in enumerate(inp.type.tensor_type.shape.dim): - if d.dim_value: - inp.type.tensor_type.shape.dim[idx].dim_param = "unk" + if not self.use_standalone_type_inference: + for idx, d in enumerate(inp.type.tensor_type.shape.dim): + if d.dim_value: + inp.type.tensor_type.shape.dim[idx].dim_param = "unk" if is_sub: # Identify which tensors are produced by nodes in this subgraph @@ -315,9 +323,10 @@ def _clear_callback(g: onnx.GraphProto, parent: onnx.NodeProto, is_sub: bool) -> for vi in g.value_info: if vi.name in subgraph_outputs: vi.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED - for idx, d in enumerate(vi.type.tensor_type.shape.dim): - if d.dim_value: - vi.type.tensor_type.shape.dim[idx].dim_param = "unk" + if not self.use_standalone_type_inference: + for idx, d in enumerate(vi.type.tensor_type.shape.dim): + if d.dim_value: + vi.type.tensor_type.shape.dim[idx].dim_param = "unk" else: for vi in g.value_info: vi.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED @@ -328,9 +337,10 @@ def _clear_callback(g: onnx.GraphProto, parent: onnx.NodeProto, is_sub: bool) -> # Clear outputs for both main graph and subgraphs for out in g.output: out.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED - for idx, d in enumerate(out.type.tensor_type.shape.dim): - if d.dim_value: - out.type.tensor_type.shape.dim[idx].dim_param = "unk" + if not self.use_standalone_type_inference: + for idx, d in enumerate(out.type.tensor_type.shape.dim): + if d.dim_value: + out.type.tensor_type.shape.dim[idx].dim_param = "unk" utils.walk_subgraphs_recursive(graph, _clear_callback, is_subgraph=is_subgraph) @@ -1175,8 +1185,16 @@ def _remove_redundant_casts(self): if self.custom_ops: self.model = self._propagate_types_shapes_custom_ops(self.model) else: - self.model = onnx_utils.infer_shapes(self.model, strict_mode=True) - self.model = onnx_utils.infer_shapes(self.model, strict_mode=True, check_type=True) + self.model = onnx_utils.infer_types( + self.model, self.use_standalone_type_inference, strict_mode=True + ) + if not self.use_standalone_type_inference: + self.model = onnx_utils.infer_types( + self.model, + self.use_standalone_type_inference, + strict_mode=True, + check_type=True, + ) nodes_to_remove = [] for node in self.model.graph.node: @@ -1261,7 +1279,12 @@ def _fix_network_output_names(self): if self.custom_ops: self.model = self._propagate_types_shapes_custom_ops(self.model) else: - self.model = onnx_utils.infer_shapes(self.model, strict_mode=True, check_type=True) + self.model = onnx_utils.infer_types( + self.model, + self.use_standalone_type_inference, + strict_mode=True, + check_type=True, + ) self.value_info_map, self.initializer_map, self.node_to_init_map = utils.setup_mappings( self.model ) diff --git a/modelopt/onnx/utils.py b/modelopt/onnx/utils.py index a6b37758e..02306792a 100644 --- a/modelopt/onnx/utils.py +++ b/modelopt/onnx/utils.py @@ -728,6 +728,366 @@ def get_attribute(node: onnx.NodeProto, attr_name: str) -> Any: raise ValueError(f"Attribute {attr_name} not found in node {node.name}") +def _infer_types_only(model: onnx.ModelProto) -> onnx.ModelProto: + """Infers types (but not shapes) of the onnx graph using local implementation. + + This is an internal function. Use infer_types() as the public API. + + This is a workaround for cases when ONNX's shape inference fails. + ONNX's infer_shapes performs both shape and type inference together, but for AutoCast, we only + need type inference. + + Args: + model: ONNX model to infer types for. + + Returns: + onnx.ModelProto: Model with inferred types updated in value_info and outputs. + """ + from modelopt.onnx.autocast import utils as autocast_utils + + # Get opset version + opset = get_opset_version(model) + + # Process each graph (main graph and all subgraphs) recursively + def infer_types_for_graph( + graph: onnx.GraphProto, parent_node: onnx.NodeProto = None, is_subgraph: bool = False + ) -> None: + """Infer types for a single graph (main or subgraph). + + Args: + graph: The graph to infer types for. + parent_node: The parent node containing this subgraph (None for main graph). + is_subgraph: Whether this is a subgraph (True) or the main graph (False). + """ + # Use graphsurgeon to topologically sort nodes for efficient single-pass traversal + # Create a temporary model with just this graph for graphsurgeon + temp_model = onnx.ModelProto() + temp_model.graph.CopyFrom(graph) + temp_model.opset_import.add().version = opset + temp_model.ir_version = model.ir_version + + try: + gs_graph = gs.import_onnx(temp_model) + gs_graph.toposort() + # Convert back to ONNX to get topologically sorted nodes + sorted_model = gs.export_onnx(gs_graph) + sorted_graph = sorted_model.graph + except Exception as e: + logger.debug( + f"Graphsurgeon toposort failed for {'subgraph' if is_subgraph else 'main graph'}," + f"using original order: {e!s}" + ) + # Fallback: process nodes in original order + sorted_graph = graph + + # Create mappings for quick lookup for this graph + initializer_map = {init.name: init for init in graph.initializer} + value_info_map = {vi.name: vi for vi in graph.value_info} + output_names = {out.name for out in graph.output} + + # Map tensor names to their inferred types (scoped to this graph) + tensor_types = {} + + # Initialize types from inputs and initializers + for inp in graph.input: + if inp.type.HasField("tensor_type"): + tensor_types[inp.name] = inp.type.tensor_type.elem_type + + for init_name, init in initializer_map.items(): + tensor_types[init_name] = init.data_type + + # Helper function to get tensor type + def get_tensor_type_from_name(tensor_name: str) -> int | None: + if tensor_name in tensor_types: + return tensor_types[tensor_name] + if tensor_name in value_info_map: + vi = value_info_map[tensor_name] + return _get_tensor_type(vi) + return None + + # Process nodes in topological order (single pass) + for node in sorted_graph.node: + # Get input types for this node + input_types = [] + for inp_name in node.input: + # an empty tensor name is typically a sign of an optional input, skip it + if not inp_name: + continue + inp_type = get_tensor_type_from_name(inp_name) + if inp_type is None: + raise ValueError(f"Input {inp_name} of node {node.name} has unknown type") + input_types.append(inp_type) + + # Infer output types for this node + output_types = [] + + if node.op_type == "Cast": + # Cast node: output type is the 'to' attribute + cast_to_type = None + for attr in node.attribute: + if attr.name == "to": + cast_to_type = attr.i + break + if cast_to_type is None: + raise ValueError(f"Cast node {node.name} has unknown target type") + output_types = [cast_to_type] + elif node.op_type == "DequantizeLinear": + # DequantizeLinear: output type is determined by output_dtype attribute if present, + # otherwise use the scale type (input[1]) + # inputs: [data, scale, zero_point (optional)] + output_dtype = None + for attr in node.attribute: + if attr.name == "output_dtype": + output_dtype = attr.i + break + + if output_dtype is not None: + output_types = [output_dtype] + elif len(node.input) >= 2 and node.input[1]: + scale_type = get_tensor_type_from_name(node.input[1]) + if scale_type is not None: + output_types = [scale_type] + else: + # Fallback: use first input type or FLOAT + output_types = [input_types[0] if input_types else onnx.TensorProto.FLOAT] + else: + # Fallback: use first input type or FLOAT + output_types = [input_types[0] if input_types else onnx.TensorProto.FLOAT] + elif node.op_type == "QuantizeLinear": + # QuantizeLinear: output type is determined by output_dtype attribute if present, + # otherwise use the zero_point type (input[2]) + # inputs: [data, scale, zero_point] + output_dtype = None + for attr in node.attribute: + if attr.name == "output_dtype": + output_dtype = attr.i + break + + if output_dtype is not None: + output_types = [output_dtype] * len(node.output) + elif len(node.input) >= 3 and node.input[2]: + zero_point_type = get_tensor_type_from_name(node.input[2]) + if zero_point_type is not None: + output_types = [zero_point_type] + else: + # Fallback: use INT8 as fallback, since TRT doesn't support UINT8 + output_types = [onnx.TensorProto.INT8] + else: + # Fallback: use INT8 as fallback, since TRT doesn't support UINT8 + output_types = [onnx.TensorProto.INT8] + elif node.op_type == "Constant": + # Constant: output type is from the value attribute's tensor data_type + const_type = None + for attr in node.attribute: + if attr.name == "value" and attr.type == onnx.AttributeProto.TENSOR: + if attr.t.HasField("data_type"): + const_type = attr.t.data_type + break + assert const_type is not None + output_types = [const_type] + elif node.op_type == "ConstantOfShape": + # ConstantOfShape: output type is from the value attribute's tensor data_type + # If no value attribute, defaults to FLOAT + # Note: Schema allows multiple types, so we need to check the value attribute + const_type = None + for attr in node.attribute: + if attr.name == "value" and attr.type == onnx.AttributeProto.TENSOR: + if attr.t.HasField("data_type"): + const_type = attr.t.data_type + break + assert const_type is not None + output_types = [const_type] + elif node.op_type == "Split": + # Split schema allows multiple outputs, but the schema only specifies one output type + output_types = [input_types[0]] * len(node.output) + else: + # Check if this node has subgraphs (GRAPH or GRAPHS attributes) + # Common nodes with subgraphs: If, Loop, Scan + subgraphs = [] + for attr in node.attribute: + if attr.type == onnx.AttributeProto.GRAPH: + subgraphs.append(attr.g) + elif attr.type == onnx.AttributeProto.GRAPHS: + subgraphs.extend(attr.graphs) + + # If node has subgraphs, infer types for them first + if subgraphs: + for subgraph in subgraphs: + infer_types_for_graph(subgraph, parent_node=node, is_subgraph=True) + + # For nodes with subgraphs, try to infer output types from subgraph outputs + # This avoids incorrectly matching to control inputs (e.g., condition for If, trip_count for Loop) + output_types = [] + if len(node.output) > 0: + # Use the first subgraph as reference (works for If, Loop, Scan) + first_subgraph = subgraphs[0] + for out_idx, out_name in enumerate(node.output): + if out_idx < len(first_subgraph.output): + subgraph_out = first_subgraph.output[out_idx] + # Typically we only have one subgraph, but If nodes have two subgraphs + # (then_branch and else_branch). In any case, the output types of the + # subgraphs must be identical, so we check just the first one + if ( + subgraph_out.type.HasField("tensor_type") + and subgraph_out.type.tensor_type.elem_type + != onnx.TensorProto.UNDEFINED + ): + output_types.append(subgraph_out.type.tensor_type.elem_type) + else: + output_types.append(onnx.TensorProto.FLOAT) + else: + # Fallback if we can't infer from subgraphs + output_types = None + + # If we couldn't infer from subgraphs, fall through to schema-based inference + if output_types is None or len(output_types) != len(node.output): + output_types = None + else: + # No subgraphs, proceed with normal inference + output_types = None + + # If output_types not set yet, use schema-based inference + if output_types is None: + default_type = input_types[0] if input_types else onnx.TensorProto.FLOAT + # Use ONNX operator schema to determine output types + try: + schema = onnx.defs.get_schema(node.op_type, opset, domain=node.domain or "") + assert schema.outputs and len(schema.outputs) >= len(node.output) + except Exception as e: + # Fallback: if schema lookup fails, propagate first input type + logger.debug( + f"Node {node.name}: Failed to get schema for {node.op_type}: {e}, " + "propagate first input type" + ) + default_type = input_types[0] if input_types else onnx.TensorProto.FLOAT + output_types = [default_type] * len(node.output) + else: + # Try to infer from schema + input_schemas = [ + schema.inputs[i].type_str for i in range(len(schema.inputs)) + ] + output_schemas = [ + schema.outputs[i].type_str for i in range(len(schema.outputs)) + ] + output_types = [None] * len(node.output) + + for output_idx in range(len(node.output)): + # explicit type is set in schema, use it + if "tensor" in output_schemas[output_idx]: + found_type = onnx_type_str_to_enum(output_schemas[output_idx]) + output_types[output_idx] = found_type + continue + # sometimes output type is set with a placeholder name despite supporting a single type + # e.g. Shape operator is constrained to int64, but the type_str is "T1" + for constraint in schema.type_constraints: + # If output type constraint has only one allowed type, use it directly + if constraint.type_param_str == output_schemas[output_idx]: + if len(constraint.allowed_type_strs) == 1: + found_type = onnx_type_str_to_enum( + constraint.allowed_type_strs[0] + ) + output_types[output_idx] = found_type + break + else: + # We have a placeholder name "T", "T1", "T2", etc that should + # match one of the input types + try: + input_match_idx = input_schemas.index( + output_schemas[output_idx] + ) + except ValueError: + input_match_idx = None + if input_match_idx is not None: + found_type = input_types[input_match_idx] + else: + found_type = default_type + logger.debug( + f"Node {node.name}: Failed to infer type for output " + f"#{output_idx}, propagate first input type" + ) + output_types[output_idx] = found_type + + # Update output tensor types + for out_idx, out_name in enumerate(node.output): + if not out_name or out_idx >= len(output_types): + continue + + output_type = output_types[out_idx] + tensor_types[out_name] = output_type + + # Update value_info if it exists + if out_name in value_info_map: + value_info_map[out_name].type.tensor_type.elem_type = output_type + elif out_name not in output_names: + # Create new value_info for intermediate tensor + new_vi = graph.value_info.add() + new_vi.name = out_name + new_vi.type.tensor_type.elem_type = output_type + value_info_map[out_name] = new_vi + + # Update output types for this graph + for out in graph.output: + if out.name in tensor_types: + out.type.tensor_type.elem_type = tensor_types[out.name] + + # Process main graph and all subgraphs recursively + autocast_utils.walk_subgraphs_recursive(model.graph, infer_types_for_graph, is_subgraph=False) + infer_types_verification(model) + return model + + +def infer_types_verification(model: onnx.ModelProto) -> onnx.ModelProto: + """Verify that all reachable tensors have a defined type. + + This is necessary because some nodes may be removed during the inference process, + leaving unreachable value_info entries. + """ + reachable_tensors = set() + + # Add graph inputs as reachable + for inp in model.graph.input: + reachable_tensors.add(inp.name) + + # Add initializers as reachable + for init in model.graph.initializer: + reachable_tensors.add(init.name) + + # Traverse nodes to find all reachable tensor outputs + for node in model.graph.node: + # A node is reachable if any of its inputs are reachable + # (or if it has no inputs - rare but possible) + node_is_reachable = not node.input or any( + inp in reachable_tensors for inp in node.input if inp + ) + + if node_is_reachable: + # All outputs of a reachable node are reachable + for out in node.output: + if out: # Skip empty output names + reachable_tensors.add(out) + + is_undefined = False + # Check value_info for reachable tensors + for vi in model.graph.value_info: + if vi.name in reachable_tensors: + if vi.type.tensor_type.elem_type == onnx.TensorProto.UNDEFINED: + logger.error( + f"Infer types verification failed. Value info {vi.name} has undefined type" + ) + is_undefined = True + + # Graph outputs should always be reachable + for out in model.graph.output: + if out.type.tensor_type.elem_type == onnx.TensorProto.UNDEFINED: + logger.error(f"Infer types verification failed. Output {out.name} has undefined type") + is_undefined = True + if is_undefined: + raise ValueError( + "Infer types verification failed. Undefined types found in the model - see logs for details." + ) + return model + + def infer_shapes(model: onnx.ModelProto, **kwargs): """Infers shapes of the onnx graph, handles large models.""" if model.ByteSize() > (2 * (1024**3)): # 2GB limit @@ -744,6 +1104,29 @@ def infer_shapes(model: onnx.ModelProto, **kwargs): return onnx.shape_inference.infer_shapes(model, **kwargs) +def infer_types( + model: onnx.ModelProto, use_standalone_type_inference: bool = False, **kwargs +) -> onnx.ModelProto: + """Infers types (and optionally shapes) based on the use_standalone_type_inference flag. + + When use_standalone_type_inference is True, uses a standalone type inference implementation + that only infers types. Otherwise, uses ONNX's infer_shapes which infers both types and shapes. + + Args: + model: ONNX model to infer types/shapes for. + use_standalone_type_inference: If True, use standalone type inference (_infer_types_only). + If False, use ONNX's shape inference (infer_shapes). + **kwargs: Additional arguments passed to infer_shapes when not using standalone type inference. + + Returns: + onnx.ModelProto: Model with inferred types (and shapes if not using standalone type inference). + """ + if use_standalone_type_inference: + return _infer_types_only(model) + else: + return infer_shapes(model, **kwargs) + + def onnx_type_str_to_enum(dtype: str) -> int: """Converts ONNX type in string format to onnx.TensorProto format. diff --git a/tests/unit/onnx/autocast/test_precisionconverter.py b/tests/unit/onnx/autocast/test_precisionconverter.py index 4fb02a230..a14991319 100644 --- a/tests/unit/onnx/autocast/test_precisionconverter.py +++ b/tests/unit/onnx/autocast/test_precisionconverter.py @@ -32,6 +32,15 @@ def low_precision_onnx_type(low_precision_type_str): return TensorProto.FLOAT16 if low_precision_type_str == "fp16" else TensorProto.BFLOAT16 +def setup_mappings( + model: onnx.ModelProto, use_standalone_type_inference: bool = False +) -> tuple[onnx.ModelProto, dict, dict, dict]: + # Setup internal mappings + model = onnx_utils.infer_types(model, use_standalone_type_inference) + value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) + return model, value_info_map, initializer_map, node_to_init_map + + #################################################################################################### # Testing with a basic GEMM->Add->Relu graph #################################################################################################### @@ -56,16 +65,21 @@ def simple_model(): model.ir_version = 10 onnx.checker.check_model(model) - model = onnx_utils.infer_shapes(model) - value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) + model, value_info_map, initializer_map, node_to_init_map = setup_mappings(model) return model, value_info_map, initializer_map, node_to_init_map -def test_graph_converter_init(simple_model): +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) +def test_graph_converter_init(simple_model, use_standalone_type_inference): model, value_info_map, initializer_map, node_to_init_map = simple_model converter = PrecisionConverter( - model, value_info_map, initializer_map, node_to_init_map, keep_io_types=True + model, + value_info_map, + initializer_map, + node_to_init_map, + keep_io_types=True, + use_standalone_type_inference=use_standalone_type_inference, ) assert converter.model == model assert converter.value_info_map == value_info_map @@ -75,7 +89,10 @@ def test_graph_converter_init(simple_model): @pytest.mark.parametrize("keep_io_types", [True, False]) @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) -def test_simple_convert(simple_model, keep_io_types, low_precision_type): +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) +def test_simple_convert( + simple_model, keep_io_types, low_precision_type, use_standalone_type_inference +): model, value_info_map, initializer_map, node_to_init_map = simple_model converter = PrecisionConverter( model, @@ -84,6 +101,7 @@ def test_simple_convert(simple_model, keep_io_types, low_precision_type): node_to_init_map, keep_io_types=keep_io_types, low_precision_type=low_precision_type, + use_standalone_type_inference=use_standalone_type_inference, ) # Convert add node to fp16, keep mul in fp32 @@ -133,7 +151,10 @@ def test_unsupported_precision_type(simple_model, low_precision_type): @pytest.mark.parametrize("keep_io_types", [True, False]) @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) -def test_convert_no_disabled_nodes(simple_model, keep_io_types, low_precision_type): +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) +def test_convert_no_disabled_nodes( + simple_model, keep_io_types, low_precision_type, use_standalone_type_inference +): model, value_info_map, initializer_map, node_to_init_map = simple_model converter = PrecisionConverter( model, @@ -142,6 +163,7 @@ def test_convert_no_disabled_nodes(simple_model, keep_io_types, low_precision_ty node_to_init_map, keep_io_types=keep_io_types, low_precision_type=low_precision_type, + use_standalone_type_inference=use_standalone_type_inference, ) # Convert all nodes to fp16 @@ -167,7 +189,10 @@ def test_convert_no_disabled_nodes(simple_model, keep_io_types, low_precision_ty @pytest.mark.parametrize("keep_io_types", [True, False]) @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) -def test_get_tensors_to_cast(simple_model, keep_io_types, low_precision_type): +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) +def test_get_tensors_to_cast( + simple_model, keep_io_types, low_precision_type, use_standalone_type_inference +): model, value_info_map, initializer_map, node_to_init_map = simple_model converter = PrecisionConverter( model, @@ -176,6 +201,7 @@ def test_get_tensors_to_cast(simple_model, keep_io_types, low_precision_type): node_to_init_map, keep_io_types=keep_io_types, low_precision_type=low_precision_type, + use_standalone_type_inference=use_standalone_type_inference, ) # Test when relu node is in low precision @@ -196,7 +222,10 @@ def test_get_tensors_to_cast(simple_model, keep_io_types, low_precision_type): @pytest.mark.parametrize("keep_io_types", [True, False]) @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) -def test_keep_io_names(simple_model, keep_io_types, low_precision_type): +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) +def test_keep_io_names( + simple_model, keep_io_types, low_precision_type, use_standalone_type_inference +): model, value_info_map, initializer_map, node_to_init_map = simple_model converter = PrecisionConverter( model, @@ -205,6 +234,7 @@ def test_keep_io_names(simple_model, keep_io_types, low_precision_type): node_to_init_map, keep_io_types=keep_io_types, low_precision_type=low_precision_type, + use_standalone_type_inference=use_standalone_type_inference, ) # Convert all nodes to low precision @@ -258,16 +288,16 @@ def model_with_multiple_consumers(): model.ir_version = 10 onnx.checker.check_model(model) - model = onnx_utils.infer_shapes(model) - value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) + model, value_info_map, initializer_map, node_to_init_map = setup_mappings(model) return model, value_info_map, initializer_map, node_to_init_map @pytest.mark.parametrize("keep_io_types", [True, False]) @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) def test_convert_with_multiple_consumers( - model_with_multiple_consumers, keep_io_types, low_precision_type + model_with_multiple_consumers, keep_io_types, low_precision_type, use_standalone_type_inference ): model, value_info_map, initializer_map, node_to_init_map = model_with_multiple_consumers converter = PrecisionConverter( @@ -277,6 +307,7 @@ def test_convert_with_multiple_consumers( node_to_init_map, keep_io_types=keep_io_types, low_precision_type=low_precision_type, + use_standalone_type_inference=use_standalone_type_inference, ) # Only gemm1 and add1 are converted to fp32, gemm2 and add2 are fp16 @@ -300,8 +331,9 @@ def test_convert_with_multiple_consumers( @pytest.mark.parametrize("keep_io_types", [True, False]) @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) def test_get_tensors_to_cast_multiple_consumers( - model_with_multiple_consumers, keep_io_types, low_precision_type + model_with_multiple_consumers, keep_io_types, low_precision_type, use_standalone_type_inference ): model, value_info_map, initializer_map, node_to_init_map = model_with_multiple_consumers converter = PrecisionConverter( @@ -311,6 +343,7 @@ def test_get_tensors_to_cast_multiple_consumers( node_to_init_map, keep_io_types=keep_io_types, low_precision_type=low_precision_type, + use_standalone_type_inference=use_standalone_type_inference, ) # Test when gemm2 and add1 nodes are in low precision @@ -327,7 +360,10 @@ def test_get_tensors_to_cast_multiple_consumers( @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) -def test_convert_initializers(model_with_multiple_consumers, low_precision_type): +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) +def test_convert_initializers( + model_with_multiple_consumers, low_precision_type, use_standalone_type_inference +): model, value_info_map, initializer_map, node_to_init_map = model_with_multiple_consumers converter = PrecisionConverter( model, @@ -335,6 +371,7 @@ def test_convert_initializers(model_with_multiple_consumers, low_precision_type) initializer_map, node_to_init_map, low_precision_type=low_precision_type, + use_standalone_type_inference=use_standalone_type_inference, ) # Test successful cast, add1 and add2 share add_init and operate in different precisions @@ -361,6 +398,7 @@ def test_convert_initializers(model_with_multiple_consumers, low_precision_type) initializer_map, node_to_init_map, low_precision_type=low_precision_type, + use_standalone_type_inference=use_standalone_type_inference, ) add1_node = next(n for n in converter2.model.graph.node if n.name == "add1") add2_node = next(n for n in converter2.model.graph.node if n.name == "add2") @@ -384,6 +422,7 @@ def test_convert_initializers(model_with_multiple_consumers, low_precision_type) initializer_map, node_to_init_map, low_precision_type=low_precision_type, + use_standalone_type_inference=use_standalone_type_inference, ) add1_node = next(n for n in converter3.model.graph.node if n.name == "add1") add2_node = next(n for n in converter3.model.graph.node if n.name == "add2") @@ -404,7 +443,10 @@ def test_convert_initializers(model_with_multiple_consumers, low_precision_type) assert f"add_init_{low_precision_type}" in init_names -def test_clamping_fp16_initializers_out_of_range(model_with_multiple_consumers): +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) +def test_clamping_fp16_initializers_out_of_range( + model_with_multiple_consumers, use_standalone_type_inference +): model, value_info_map, initializer_map, node_to_init_map = model_with_multiple_consumers # Initializer is out of FP16 range, node is converted to FP16 @@ -412,7 +454,13 @@ def test_clamping_fp16_initializers_out_of_range(model_with_multiple_consumers): add_init = numpy_helper.from_array(add_init_out_of_range, name="add_init") model.graph.initializer[1].CopyFrom(add_init) - converter = PrecisionConverter(model, value_info_map, initializer_map, node_to_init_map) + converter = PrecisionConverter( + model, + value_info_map, + initializer_map, + node_to_init_map, + use_standalone_type_inference=use_standalone_type_inference, + ) converter._convert_initializers(low_precision_nodes=["add1", "add2"], high_precision_nodes=[]) # Verify initializer is clamped @@ -427,7 +475,13 @@ def test_clamping_fp16_initializers_out_of_range(model_with_multiple_consumers): assert add_init_converted_array[0, 1] == np.finfo(np.float16).max # Initializer is out of FP16 range, node is kept in FP32 - converter2 = PrecisionConverter(model, value_info_map, initializer_map, node_to_init_map) + converter2 = PrecisionConverter( + model, + value_info_map, + initializer_map, + node_to_init_map, + use_standalone_type_inference=use_standalone_type_inference, + ) converter2._convert_initializers(low_precision_nodes=[], high_precision_nodes=["add1", "add2"]) # Verify initializer is not clamped @@ -441,7 +495,13 @@ def test_clamping_fp16_initializers_out_of_range(model_with_multiple_consumers): assert np.all(add_init_converted_array == add_init_out_of_range) # Initializer is out of FP16 range, one consumer is converted to FP16, the other is kept in FP32 - converter3 = PrecisionConverter(model, value_info_map, initializer_map, node_to_init_map) + converter3 = PrecisionConverter( + model, + value_info_map, + initializer_map, + node_to_init_map, + use_standalone_type_inference=use_standalone_type_inference, + ) converter3._convert_initializers(low_precision_nodes=["add1"], high_precision_nodes=["add2"]) # Verify initializer is duplicated, and the FP16 copy is clamped @@ -462,7 +522,10 @@ def test_clamping_fp16_initializers_out_of_range(model_with_multiple_consumers): assert np.all(add_init_fp32_array == add_init_out_of_range) -def test_bf16_no_clamping_initializers_out_of_range(model_with_multiple_consumers): +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) +def test_bf16_no_clamping_initializers_out_of_range( + model_with_multiple_consumers, use_standalone_type_inference +): model, value_info_map, initializer_map, node_to_init_map = model_with_multiple_consumers # Initializer is out of FP16 range, but that does not affect BF16 conversion @@ -476,6 +539,7 @@ def test_bf16_no_clamping_initializers_out_of_range(model_with_multiple_consumer initializer_map, node_to_init_map, low_precision_type="bf16", + use_standalone_type_inference=use_standalone_type_inference, ) converter._convert_initializers(low_precision_nodes=["add1", "add2"], high_precision_nodes=[]) @@ -511,13 +575,13 @@ def model_with_dynamic_shapes(): matmul_node = helper.make_node("MatMul", ["X", "weight"], ["matmul_out"], name="matmul") transpose_node = helper.make_node("Transpose", ["Y"], ["transpose_out"], name="transpose") concat_node = helper.make_node( - "Concat", ["matmul_out", "transpose_out"], ["concat_out"], name="concat", axis=0 + "Concat", ["matmul_out", "transpose_out"], ["concat_out"], name="concat1", axis=0 ) size_y = helper.make_node("Size", ["concat_out"], ["total_size"], name="size") const_4 = numpy_helper.from_array(np.array([4], dtype=np.int64), name="const_4") first_dim = helper.make_node("Div", ["total_size", "const_4"], ["first_dim"], name="div") concat_dims_node = helper.make_node( - "Concat", ["first_dim", "const_4"], ["final_shape"], name="concat", axis=0 + "Concat", ["first_dim", "const_4"], ["final_shape"], name="concat2", axis=0 ) reshape_node = helper.make_node("Reshape", ["concat_out", "final_shape"], ["Z"], name="reshape") @@ -540,20 +604,25 @@ def model_with_dynamic_shapes(): model = helper.make_model(graph, producer_name="model_dynamic") model.opset_import[0].version = 20 model.ir_version = 10 - model = onnx_utils.infer_shapes(model) - - value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) + model, value_info_map, initializer_map, node_to_init_map = setup_mappings(model) return model, value_info_map, initializer_map, node_to_init_map -def test_dynamic_model_conversion(model_with_dynamic_shapes): +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) +def test_dynamic_model_conversion(model_with_dynamic_shapes, use_standalone_type_inference): model, value_info_map, initializer_map, node_to_init_map = model_with_dynamic_shapes # Test mixed precision conversion - converter2 = PrecisionConverter(model, value_info_map, initializer_map, node_to_init_map) + converter2 = PrecisionConverter( + model, + value_info_map, + initializer_map, + node_to_init_map, + use_standalone_type_inference=use_standalone_type_inference, + ) high_precision_nodes = ["matmul"] - low_precision_nodes = ["transpose", "concat", "size", "div", "concat_dims", "reshape"] + low_precision_nodes = ["transpose", "concat1", "size", "div", "concat2", "reshape"] converted_model = converter2.convert(high_precision_nodes, low_precision_nodes) # Verify model is valid @@ -563,7 +632,8 @@ def test_dynamic_model_conversion(model_with_dynamic_shapes): #################################################################################################### # Cast cleanup logic #################################################################################################### -def test_cast_output_pattern(): +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) +def test_cast_output_pattern(use_standalone_type_inference): x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [3, 4]) y1 = helper.make_tensor_value_info("Y1", TensorProto.FLOAT, [3, 4]) y2 = helper.make_tensor_value_info("Y2", TensorProto.FLOAT, [3, 4]) @@ -583,10 +653,14 @@ def test_cast_output_pattern(): model = helper.make_model(graph, producer_name="model_double_cast") model.opset_import[0].version = 20 model.ir_version = 10 - model = onnx_utils.infer_shapes(model) - - value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) - converter = PrecisionConverter(model, value_info_map, initializer_map, node_to_init_map) + model, value_info_map, initializer_map, node_to_init_map = setup_mappings(model) + converter = PrecisionConverter( + model, + value_info_map, + initializer_map, + node_to_init_map, + use_standalone_type_inference=use_standalone_type_inference, + ) # Setting all nodes to FP16 means that the final graph should have no cast nodes converted_model = converter.convert( @@ -602,7 +676,8 @@ def test_cast_output_pattern(): assert converted_model.graph.output[i].name == model.graph.output[i].name -def test_cast_output_pattern_mixed_precision(): +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) +def test_cast_output_pattern_mixed_precision(use_standalone_type_inference): x1 = helper.make_tensor_value_info("X1", TensorProto.FLOAT, [3, 4]) x2 = helper.make_tensor_value_info("X2", TensorProto.FLOAT, [3, 4]) y0 = helper.make_tensor_value_info("Y0", TensorProto.FLOAT, [3, 4]) @@ -625,10 +700,14 @@ def test_cast_output_pattern_mixed_precision(): model = helper.make_model(graph, producer_name="model_double_cast") model.opset_import[0].version = 20 model.ir_version = 10 - model = onnx_utils.infer_shapes(model) - - value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) - converter = PrecisionConverter(model, value_info_map, initializer_map, node_to_init_map) + model, value_info_map, initializer_map, node_to_init_map = setup_mappings(model) + converter = PrecisionConverter( + model, + value_info_map, + initializer_map, + node_to_init_map, + use_standalone_type_inference=use_standalone_type_inference, + ) # Network output Y0 has two consumers, one is FP16 and the other is FP32 converted_model = converter.convert( @@ -641,7 +720,8 @@ def test_cast_output_pattern_mixed_precision(): @pytest.mark.parametrize("keep_io_types", [True, False]) -def test_chain_of_casts_pattern(keep_io_types): +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) +def test_chain_of_casts_pattern(keep_io_types, use_standalone_type_inference): x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [3, 4]) y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [3, 4]) @@ -690,11 +770,14 @@ def test_chain_of_casts_pattern(keep_io_types): model = helper.make_model(graph, producer_name="model_cast_chain") model.opset_import[0].version = 20 model.ir_version = 10 - model = onnx_utils.infer_shapes(model) - - value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) + model, value_info_map, initializer_map, node_to_init_map = setup_mappings(model) converter = PrecisionConverter( - model, value_info_map, initializer_map, node_to_init_map, keep_io_types=keep_io_types + model, + value_info_map, + initializer_map, + node_to_init_map, + keep_io_types=keep_io_types, + use_standalone_type_inference=use_standalone_type_inference, ) converter.convert(high_precision_nodes=["add"], low_precision_nodes=[]) @@ -705,7 +788,8 @@ def test_chain_of_casts_pattern(keep_io_types): @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) -def test_existing_low_precision_output(low_precision_type): +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) +def test_existing_low_precision_output(low_precision_type, use_standalone_type_inference): # Create a simple model with FP16 output x = helper.make_tensor_value_info("X", low_precision_onnx_type(low_precision_type), [3, 4]) y = helper.make_tensor_value_info("Y", low_precision_onnx_type(low_precision_type), [3, 4]) @@ -715,8 +799,7 @@ def test_existing_low_precision_output(low_precision_type): model.opset_import[0].version = 20 model.ir_version = 10 - model = onnx_utils.infer_shapes(model) - value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) + model, value_info_map, initializer_map, node_to_init_map = setup_mappings(model) converter = PrecisionConverter( model, @@ -725,6 +808,7 @@ def test_existing_low_precision_output(low_precision_type): node_to_init_map, keep_io_types=True, low_precision_type=low_precision_type, + use_standalone_type_inference=use_standalone_type_inference, ) converter.convert(high_precision_nodes=["add"], low_precision_nodes=[]) @@ -743,7 +827,8 @@ def test_existing_low_precision_output(low_precision_type): @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) -def test_output_cast_output_pattern(low_precision_type): +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) +def test_output_cast_output_pattern(low_precision_type, use_standalone_type_inference): x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [3, 4]) y1 = helper.make_tensor_value_info("Y1", TensorProto.FLOAT, [3, 4]) y2 = helper.make_tensor_value_info("Y2", low_precision_onnx_type(low_precision_type), [3, 4]) @@ -764,9 +849,8 @@ def test_output_cast_output_pattern(low_precision_type): model = helper.make_model(graph, producer_name="model_output_cast_output") model.opset_import[0].version = 20 model.ir_version = 10 - model = onnx_utils.infer_shapes(model) + model, value_info_map, initializer_map, node_to_init_map = setup_mappings(model) - value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) converter = PrecisionConverter( model, value_info_map, @@ -774,6 +858,7 @@ def test_output_cast_output_pattern(low_precision_type): node_to_init_map, keep_io_types=True, low_precision_type=low_precision_type, + use_standalone_type_inference=use_standalone_type_inference, ) # Setting nodes precision to match I/O type means that the final graph should have no cast nodes @@ -790,7 +875,8 @@ def test_output_cast_output_pattern(low_precision_type): @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) -def test_cast_output_keep_io_types_pattern(low_precision_type): +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) +def test_cast_output_keep_io_types_pattern(low_precision_type, use_standalone_type_inference): x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [3, 4]) y1 = helper.make_tensor_value_info("Y1", TensorProto.FLOAT, [3, 4]) y2 = helper.make_tensor_value_info("Y2", TensorProto.FLOAT, [3, 4]) @@ -809,9 +895,7 @@ def test_cast_output_keep_io_types_pattern(low_precision_type): model = helper.make_model(graph, producer_name="model_cast_output_keep_io_types") model.opset_import[0].version = 20 model.ir_version = 10 - model = onnx_utils.infer_shapes(model) - - value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) + model, value_info_map, initializer_map, node_to_init_map = setup_mappings(model) converter = PrecisionConverter( model, value_info_map, @@ -819,6 +903,7 @@ def test_cast_output_keep_io_types_pattern(low_precision_type): node_to_init_map, keep_io_types=True, low_precision_type=low_precision_type, + use_standalone_type_inference=use_standalone_type_inference, ) converter.convert(high_precision_nodes=[], low_precision_nodes=["add1", "add2"]) @@ -827,7 +912,8 @@ def test_cast_output_keep_io_types_pattern(low_precision_type): assert converter.model.graph.output[1].type.tensor_type.elem_type == TensorProto.FLOAT -def test_unsupported_op_types_model(): +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) +def test_unsupported_op_types_model(use_standalone_type_inference): x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [3, 4]) roi = helper.make_tensor_value_info("roi", TensorProto.FLOAT, [3, 4]) scales = helper.make_tensor_value_info("scales", TensorProto.FLOAT, [4]) @@ -848,17 +934,24 @@ def test_unsupported_op_types_model(): [], ) model = helper.make_model(graph, producer_name="model_celu") - model = onnx.shape_inference.infer_shapes(model) - - value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) - converter = PrecisionConverter(model, value_info_map, initializer_map, node_to_init_map) + model, value_info_map, initializer_map, node_to_init_map = setup_mappings(model) + converter = PrecisionConverter( + model, + value_info_map, + initializer_map, + node_to_init_map, + use_standalone_type_inference=use_standalone_type_inference, + ) converter.convert(high_precision_nodes=[], low_precision_nodes=["celu", "resize", "nms"]) onnx.checker.check_model(converter.model) @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) @pytest.mark.parametrize("empty_tensor_target", ["low_precision", "high_precision"]) -def test_empty_tensor_handling(low_precision_type, empty_tensor_target): +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) +def test_empty_tensor_handling( + low_precision_type, empty_tensor_target, use_standalone_type_inference +): """Test empty tensor handling for both low and high precision node targets.""" # Create model with empty float tensor from Constant layer x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2]) @@ -888,8 +981,7 @@ def test_empty_tensor_handling(low_precision_type, empty_tensor_target): model.ir_version = 10 onnx.checker.check_model(model) - model = onnx_utils.infer_shapes(model) - value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) + model, value_info_map, initializer_map, node_to_init_map = setup_mappings(model) converter = PrecisionConverter( model, value_info_map, @@ -897,6 +989,7 @@ def test_empty_tensor_handling(low_precision_type, empty_tensor_target): node_to_init_map, keep_io_types=True, low_precision_type=low_precision_type, + use_standalone_type_inference=use_standalone_type_inference, ) # Test empty tensor detection @@ -979,14 +1072,16 @@ def model_with_constant_cast_patterns(): model.ir_version = 10 onnx.checker.check_model(model) - model = onnx_utils.infer_shapes(model) - value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) + model, value_info_map, initializer_map, node_to_init_map = setup_mappings(model) return model, value_info_map, initializer_map, node_to_init_map @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) -def test_constant_cast_folding(model_with_constant_cast_patterns, low_precision_type): +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) +def test_constant_cast_folding( + model_with_constant_cast_patterns, low_precision_type, use_standalone_type_inference +): """Test constant->cast folding as part of the full conversion process.""" model, value_info_map, initializer_map, node_to_init_map = model_with_constant_cast_patterns @@ -997,6 +1092,7 @@ def test_constant_cast_folding(model_with_constant_cast_patterns, low_precision_ node_to_init_map, keep_io_types=True, low_precision_type=low_precision_type, + use_standalone_type_inference=use_standalone_type_inference, ) # Convert with some nodes in low precision to trigger cast insertion @@ -1077,15 +1173,17 @@ def model_with_multiple_output_node_casted_to_output(): model.ir_version = 10 onnx.checker.check_model(model) - model = onnx_utils.infer_shapes(model) - value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) + model, value_info_map, initializer_map, node_to_init_map = setup_mappings(model) return model, value_info_map, initializer_map, node_to_init_map @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) def test_multiple_output_node_casted_to_output( - model_with_multiple_output_node_casted_to_output, low_precision_type + model_with_multiple_output_node_casted_to_output, + low_precision_type, + use_standalone_type_inference, ): model, value_info_map, initializer_map, node_to_init_map = ( model_with_multiple_output_node_casted_to_output @@ -1098,6 +1196,7 @@ def test_multiple_output_node_casted_to_output( node_to_init_map, keep_io_types=True, low_precision_type=low_precision_type, + use_standalone_type_inference=use_standalone_type_inference, ) converted_model = converter.convert( high_precision_nodes=[], low_precision_nodes=["concat_1", "concat_2"] @@ -1145,16 +1244,19 @@ def model_with_casted_input_to_output(): model.ir_version = 10 onnx.checker.check_model(model) - model = onnx_utils.infer_shapes(model) - value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) + model, value_info_map, initializer_map, node_to_init_map = setup_mappings(model) return model, value_info_map, initializer_map, node_to_init_map @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) @pytest.mark.parametrize("keep_io_types", [True, False]) +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) def test_casted_input_to_output_model( - model_with_casted_input_to_output, low_precision_type, keep_io_types + model_with_casted_input_to_output, + low_precision_type, + keep_io_types, + use_standalone_type_inference, ): model, value_info_map, initializer_map, node_to_init_map = model_with_casted_input_to_output @@ -1168,6 +1270,7 @@ def test_casted_input_to_output_model( min_opset=22 if low_precision_type == "bf16" else 13, max_ir_version=LATEST_IR_VERSION_SUPPORTED_BY_ORT, trt_plugins=[], + use_standalone_type_inference=use_standalone_type_inference, ) converted_model = converter.convert( high_precision_nodes=["cast_input"], low_precision_nodes=["add1", "add2"] @@ -1218,8 +1321,7 @@ def create_model_with_resize_op(): model.ir_version = 10 onnx.checker.check_model(model) - model = onnx_utils.infer_shapes(model) - value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) + model, value_info_map, initializer_map, node_to_init_map = setup_mappings(model) return model, value_info_map, initializer_map, node_to_init_map @@ -1276,16 +1378,16 @@ def create_model_with_resize_op_tensor_scales(): model.ir_version = 10 onnx.checker.check_model(model) - model = onnx_utils.infer_shapes(model) - value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) + model, value_info_map, initializer_map, node_to_init_map = setup_mappings(model) return model, value_info_map, initializer_map, node_to_init_map @pytest.mark.parametrize("keep_io_types", [True, False]) @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) def test_resize_op_initializer_conversion( - create_model_with_resize_op, keep_io_types, low_precision_type + create_model_with_resize_op, keep_io_types, low_precision_type, use_standalone_type_inference ): model, value_info_map, initializer_map, node_to_init_map = create_model_with_resize_op @@ -1296,6 +1398,7 @@ def test_resize_op_initializer_conversion( node_to_init_map, keep_io_types=keep_io_types, low_precision_type=low_precision_type, + use_standalone_type_inference=use_standalone_type_inference, ) converted_model = converter.convert( high_precision_nodes=[], low_precision_nodes=[node.name for node in model.graph.node] @@ -1305,8 +1408,12 @@ def test_resize_op_initializer_conversion( @pytest.mark.parametrize("keep_io_types", [True, False]) @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) def test_resize_op_tensor_scales_conversion( - create_model_with_resize_op_tensor_scales, keep_io_types, low_precision_type + create_model_with_resize_op_tensor_scales, + keep_io_types, + low_precision_type, + use_standalone_type_inference, ): model, value_info_map, initializer_map, node_to_init_map = ( create_model_with_resize_op_tensor_scales @@ -1319,6 +1426,7 @@ def test_resize_op_tensor_scales_conversion( node_to_init_map, keep_io_types=keep_io_types, low_precision_type=low_precision_type, + use_standalone_type_inference=use_standalone_type_inference, ) converted_model = converter.convert( high_precision_nodes=[], low_precision_nodes=[node.name for node in model.graph.node] @@ -1409,15 +1517,15 @@ def model_with_if_subgraph(): model.ir_version = 10 onnx.checker.check_model(model) - model = onnx_utils.infer_shapes(model) - value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) + model, value_info_map, initializer_map, node_to_init_map = setup_mappings(model) return model, value_info_map, initializer_map, node_to_init_map @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) @pytest.mark.parametrize("if_precision", ["low", "high"]) +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) def test_if_subgraph_initializer_conversion( - model_with_if_subgraph, low_precision_type, if_precision + model_with_if_subgraph, low_precision_type, if_precision, use_standalone_type_inference ): """Test that initializers in If subgraphs are converted based on parent node precision.""" model, value_info_map, initializer_map, node_to_init_map = model_with_if_subgraph @@ -1429,6 +1537,7 @@ def test_if_subgraph_initializer_conversion( node_to_init_map, keep_io_types=True, low_precision_type=low_precision_type, + use_standalone_type_inference=use_standalone_type_inference, ) # Classify the If node based on test parameter @@ -1482,7 +1591,10 @@ def test_if_subgraph_initializer_conversion( @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) -def test_if_subgraph_mixed_precision_boundary(model_with_if_subgraph, low_precision_type): +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) +def test_if_subgraph_mixed_precision_boundary( + model_with_if_subgraph, low_precision_type, use_standalone_type_inference +): """Test that types are correctly handled at If subgraph boundaries in mixed precision.""" model, value_info_map, initializer_map, node_to_init_map = model_with_if_subgraph @@ -1498,7 +1610,7 @@ def test_if_subgraph_mixed_precision_boundary(model_with_if_subgraph, low_precis model.graph.output.append(output_tensor) # Refresh mappings - value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) + model, value_info_map, initializer_map, node_to_init_map = setup_mappings(model) converter = PrecisionConverter( model, @@ -1507,6 +1619,7 @@ def test_if_subgraph_mixed_precision_boundary(model_with_if_subgraph, low_precis node_to_init_map, keep_io_types=True, low_precision_type=low_precision_type, + use_standalone_type_inference=use_standalone_type_inference, ) # If in low precision, Add in high precision