Skip to content
8 changes: 8 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
NVIDIA Model Optimizer Changelog (Linux)
========================================

0.42 (TBD)
^^^^^^^^^^^^^^^^^

**Bug Fixes**

**New Features**
- Add standalone type inference option (``--use_standalone_type_inference``) in ONNX AutoCast as an alternative to ONNX's ``infer_shapes``. This experimental feature performs type-only inference without shape inference, useful as a workaround when shape inference fails or to avoid unnecessary shape inference overhead.

0.41 (2026-01-19)
^^^^^^^^^^^^^^^^^

Expand Down
18 changes: 18 additions & 0 deletions docs/source/guides/8_autocast.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ AutoCast can also be used programmatically through its Python API:
trt_plugins=[], # list of TensorRT plugin library paths in .so format
max_depth_of_reduction=None, # maximum depth of reduction allowed in low precision
opset=None, # optional target ONNX opset version (default: 13 for fp16, 22 for bf16)
use_standalone_type_inference=False, # use standalone type inference instead of ONNX's infer_shapes (WAR)
)

# Save the converted model
Expand Down Expand Up @@ -82,6 +83,9 @@ AutoCast follows these steps to convert a model:
- Converts eligible nodes to lower precision
- Automatically inserts necessary cast operations
- Automatically replaces initializers with lower precision values
- Performs type inference to propagate types through the graph
- By default, uses ONNX's ``infer_shapes`` which performs both shape and type inference using the ONNX infer_shapes API.
- Use ``use_standalone_type_inference=True`` to use a standalone type-only inference implementation (experimental).

#. **Validation and Export**:

Expand Down Expand Up @@ -145,6 +149,14 @@ Best Practices
- A warning will be issued if you specify an opset lower than the original model's opset, as downgrading opset versions may cause compatibility issues.
- The opset may be automatically increased beyond your specified value if certain operations require it (e.g., quantization nodes require opset >= 19).

#. **Type Inference Control**

- By default, AutoCast uses ONNX's ``infer_shapes`` which performs both shape and type inference.
- Use ``--use_standalone_type_inference`` to enable a standalone type-only inference implementation.
- This is a workaround for cases where shape inference fails for any reason, which allows us to bypass the dependency in ONNX's shape inference logic.
- The standalone implementation uses graphsurgeon for topological sorting and handles special operators like Cast, QuantizeLinear, DequantizeLinear, Constant and ConstantOfShape.
- Note: The standalone type inference may be less robust than ONNX's implementation for edge cases, but avoids unnecessary shape inference overhead and possible failures.

Limitations and Restrictions
----------------------------
- AutoCast does not yet support quantized models.
Expand Down Expand Up @@ -198,3 +210,9 @@ Convert to BF16 with a specific opset:
.. code-block:: bash

python -m modelopt.onnx.autocast --onnx_path model.onnx --low_precision_type bf16 --opset 22

Use standalone type inference instead of ONNX's infer_shapes:

.. code-block:: bash

python -m modelopt.onnx.autocast --onnx_path model.onnx --use_standalone_type_inference
11 changes: 11 additions & 0 deletions modelopt/onnx/autocast/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,16 @@ def get_parser() -> argparse.ArgumentParser:
"higher version."
),
)
parser.add_argument(
"--use_standalone_type_inference",
action="store_true",
default=False,
help=(
"Use local type inference implementation instead of ONNX's infer_shapes (experimental)."
"This is a workaround for cases where shape inference fails for any reason."
"Default: False (uses ONNX's infer_shapes which does both shape and type inference)."
),
)

return parser

Expand Down Expand Up @@ -218,6 +228,7 @@ def main(argv=None):
trt_plugins_precision=args.trt_plugins_precision,
max_depth_of_reduction=args.max_depth_of_reduction,
opset=args.opset,
use_standalone_type_inference=args.use_standalone_type_inference,
)

output_path = args.output_path
Expand Down
14 changes: 12 additions & 2 deletions modelopt/onnx/autocast/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def convert_to_mixed_precision(
trt_plugins_precision: list[str] = [],
max_depth_of_reduction: int | None = None,
opset: int | None = None,
use_standalone_type_inference: bool = False,
) -> onnx.ModelProto:
"""Convert model to mixed precision.

Expand All @@ -85,6 +86,9 @@ def convert_to_mixed_precision(
opset: Target ONNX opset version. If None, uses default minimum opset based on low_precision_type
(22 for bf16, 13 for fp16). The opset may be automatically increased if certain operations
require a higher version.
use_standalone_type_inference: If True, use standalone type inference implementation instead of ONNX's
infer_shapes. This is a workaround (WAR) when only type inference is
needed without shape inference. Default: False.

Returns:
onnx.ModelProto: The converted mixed precision model.
Expand Down Expand Up @@ -132,7 +136,7 @@ def convert_to_mixed_precision(
model = graph_sanitizer.model

# Setup internal mappings
model = onnx_utils.infer_shapes(model)
model = onnx_utils.infer_types(model, use_standalone_type_inference)
value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model)

# Automatically add 'trt' to list of providers if custom ops are detected
Expand Down Expand Up @@ -164,6 +168,7 @@ def convert_to_mixed_precision(
low_precision_type=low_precision_type,
init_conversion_max_bytes=init_conversion_max_bytes,
custom_ops=graph_sanitizer.custom_ops,
use_standalone_type_inference=use_standalone_type_inference,
)

# Obtain reference data
Expand Down Expand Up @@ -196,6 +201,7 @@ def convert_to_f16(
op_block_list: list[str] = [],
tensor_block_dict: dict[str, dict[str, list[int]]] = {},
trt_plugins: list[str] | None = [],
use_standalone_type_inference: bool = False,
) -> onnx.ModelProto:
"""Convert model to mixed precision, using PrecisionConverter.

Expand All @@ -208,6 +214,9 @@ def convert_to_f16(
op_block_list: List of operation types that should remain in FP32.
tensor_block_dict: Dictionary of tensors (operation type and I/O indices) that should remain in FP32.
trt_plugins: List of TensorRT plugin library paths in .so format (compiled shared library).
use_standalone_type_inference: If True, use standalone type inference implementation instead of ONNX's
infer_shapes. This is a workaround (WAR) when only type inference is
needed without shape inference. Default: False.
"""
assert low_precision_type in ["fp16", "bf16"], "low_precision_type must be either fp16 or bf16"

Expand All @@ -225,7 +234,7 @@ def convert_to_f16(
model = sanitizer.model

# Setup internal mappings
model = onnx_utils.infer_shapes(model)
model = onnx_utils.infer_types(model, use_standalone_type_inference)
value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model)

precision_converter = PrecisionConverter(
Expand All @@ -237,6 +246,7 @@ def convert_to_f16(
low_precision_type=low_precision_type,
custom_ops=sanitizer.custom_ops,
tensor_block_dict=tensor_block_dict,
use_standalone_type_inference=use_standalone_type_inference,
)
high_precision_nodes = [node.name for node in model.graph.node if node.op_type in op_block_list]
low_precision_nodes = [
Expand Down
57 changes: 40 additions & 17 deletions modelopt/onnx/autocast/precisionconverter.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def __init__(
max_ir_version: int | None = None,
trt_plugins: list[str] | None = [],
tensor_block_dict: dict[str, dict[str, list[int]]] = {},
use_standalone_type_inference: bool = False,
) -> None:
"""Initialize PrecisionConverter.

Expand All @@ -114,6 +115,7 @@ def __init__(
max_ir_version: Max IR version for conversion.
trt_plugins: List of custom TensorRT plugin library paths in .so format (compiled shared library).
tensor_block_dict: Dictionary of tensors (operation type and I/O indices) that should remain in FP32.
use_standalone_type_inference: Use standalone type inference instead of ONNX's infer_shapes.
"""
self.model = deepcopy(model)
self.value_info_map = value_info_map
Expand All @@ -140,6 +142,7 @@ def __init__(
self.min_opset = min_opset
self.max_ir_version = max_ir_version
self.trt_plugins = trt_plugins
self.use_standalone_type_inference = use_standalone_type_inference

# Detect additional ops not supported in low precision according to the model's opset version
self.op_types_not_supported_in_low_precision = OP_TYPES_NOT_SUPPORTED_IN_LOW_PRECISION + (
Expand Down Expand Up @@ -254,10 +257,14 @@ def convert(
# Clear type/shape information for intermediates and outputs (including subgraphs)
self._clear_types_and_shapes_recursive(self.model.graph)
# Populate type information with inferred types
self.model = onnx_utils.infer_shapes(self.model, strict_mode=True, check_type=False)
self.model = onnx_utils.infer_types(
self.model, self.use_standalone_type_inference, strict_mode=True, check_type=False
)
self._ensure_types_are_defined()
# Sanity check: Verify type correctness
self.model = onnx_utils.infer_shapes(self.model, strict_mode=True, check_type=True)
self.model = onnx_utils.infer_types(
self.model, self.use_standalone_type_inference, strict_mode=True, check_type=True
)

# Update value_info_map and initializer_map with casts we added
self.value_info_map, self.initializer_map, self.node_to_init_map = utils.setup_mappings(
Expand All @@ -282,9 +289,9 @@ def _clear_types_and_shapes_recursive(
) -> None:
"""Recursively clear type/shape information for a graph and all its subgraphs.

This is necessary for control flow operators (Scan, If, Loop) which have subgraphs.
For subgraphs, preserve value_info for outer scope variables (not produced by nodes in subgraph).
For main graph, clear all value_info.
If use_standalone_type_inference is True, we clear only types, not shapes.
For subgraphs, input types/shapes are cleared, so that the input types/shapes are propagated
from the main graph.

Args:
graph: The ONNX graph to clear types and shapes for.
Expand All @@ -301,9 +308,10 @@ def _clear_callback(g: onnx.GraphProto, parent: onnx.NodeProto, is_sub: bool) ->
for inp in g.input:
if inp.type.HasField("tensor_type"):
inp.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED
for idx, d in enumerate(inp.type.tensor_type.shape.dim):
if d.dim_value:
inp.type.tensor_type.shape.dim[idx].dim_param = "unk"
if not self.use_standalone_type_inference:
for idx, d in enumerate(inp.type.tensor_type.shape.dim):
if d.dim_value:
inp.type.tensor_type.shape.dim[idx].dim_param = "unk"
Comment on lines +311 to +314
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly for this, can we create a util function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leaving it to next refactor, see - #717 (comment)


if is_sub:
# Identify which tensors are produced by nodes in this subgraph
Expand All @@ -315,9 +323,10 @@ def _clear_callback(g: onnx.GraphProto, parent: onnx.NodeProto, is_sub: bool) ->
for vi in g.value_info:
if vi.name in subgraph_outputs:
vi.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED
for idx, d in enumerate(vi.type.tensor_type.shape.dim):
if d.dim_value:
vi.type.tensor_type.shape.dim[idx].dim_param = "unk"
if not self.use_standalone_type_inference:
for idx, d in enumerate(vi.type.tensor_type.shape.dim):
if d.dim_value:
vi.type.tensor_type.shape.dim[idx].dim_param = "unk"
else:
for vi in g.value_info:
vi.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED
Expand All @@ -328,9 +337,10 @@ def _clear_callback(g: onnx.GraphProto, parent: onnx.NodeProto, is_sub: bool) ->
# Clear outputs for both main graph and subgraphs
for out in g.output:
out.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED
for idx, d in enumerate(out.type.tensor_type.shape.dim):
if d.dim_value:
out.type.tensor_type.shape.dim[idx].dim_param = "unk"
if not self.use_standalone_type_inference:
for idx, d in enumerate(out.type.tensor_type.shape.dim):
if d.dim_value:
out.type.tensor_type.shape.dim[idx].dim_param = "unk"

utils.walk_subgraphs_recursive(graph, _clear_callback, is_subgraph=is_subgraph)

Expand Down Expand Up @@ -1175,8 +1185,16 @@ def _remove_redundant_casts(self):
if self.custom_ops:
self.model = self._propagate_types_shapes_custom_ops(self.model)
else:
self.model = onnx_utils.infer_shapes(self.model, strict_mode=True)
self.model = onnx_utils.infer_shapes(self.model, strict_mode=True, check_type=True)
self.model = onnx_utils.infer_types(
self.model, self.use_standalone_type_inference, strict_mode=True
)
if not self.use_standalone_type_inference:
self.model = onnx_utils.infer_types(
self.model,
self.use_standalone_type_inference,
strict_mode=True,
check_type=True,
)

nodes_to_remove = []
for node in self.model.graph.node:
Expand Down Expand Up @@ -1261,7 +1279,12 @@ def _fix_network_output_names(self):
if self.custom_ops:
self.model = self._propagate_types_shapes_custom_ops(self.model)
else:
self.model = onnx_utils.infer_shapes(self.model, strict_mode=True, check_type=True)
self.model = onnx_utils.infer_types(
self.model,
self.use_standalone_type_inference,
strict_mode=True,
check_type=True,
)
self.value_info_map, self.initializer_map, self.node_to_init_map = utils.setup_mappings(
self.model
)
Expand Down
Loading