Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 22 additions & 6 deletions modelopt/onnx/autocast/precisionconverter.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,8 @@ def _clear_types_and_shapes_recursive(
"""Recursively clear type/shape information for a graph and all its subgraphs.

This is necessary for control flow operators (Scan, If, Loop) which have subgraphs.
For subgraphs, preserve value_info for outer scope variables (not produced by nodes in subgraph).
For main graph, clear all value_info.

Args:
graph: The ONNX graph to clear types and shapes for.
Expand All @@ -303,13 +305,27 @@ def _clear_callback(g: onnx.GraphProto, parent: onnx.NodeProto, is_sub: bool) ->
if d.dim_value:
inp.type.tensor_type.shape.dim[idx].dim_param = "unk"

# Clear type/shape information for intermediates and outputs
for vi in g.value_info:
vi.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED
for idx, d in enumerate(vi.type.tensor_type.shape.dim):
if d.dim_value:
vi.type.tensor_type.shape.dim[idx].dim_param = "unk"
if is_sub:
# Identify which tensors are produced by nodes in this subgraph
subgraph_outputs = set()
for node in g.node:
subgraph_outputs.update(node.output)

# Clear value_info only for intermediates produced by nodes in this subgraph
for vi in g.value_info:
if vi.name in subgraph_outputs:
vi.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jricker2 do you mean that replacing lines 317~320 with vi.type.ClearField("tensor_type") solves the ORT with CUDA EP issue that you're observing?

Copy link

@jricker2 jricker2 Jan 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes exactly, just commenting those lines out actually resolves the issue (lines 321-326, non-subgraph block)

I also found that if I cleared the value_info from the original fp32 graph then ran shape inference to re-populate it, then fed this graph into the model optimizer I had no issues. It seems to be an issue with how the graph was produced (exported by torch I believe). I don't own the creation/export of the original torch model, and the workaround is straightforward so I decided to not look much further into it (also because I found debugging value_info related issue to be very time consuming, tools like DL designer are not very helpful for this).

Anyhow, from what I can tell having a clear value_info before shape inference is the best way to go as opposed to pre-filling with generic shape/type.

edit: as to not take anything away from this PR - there are no subgraphs in the model I had this issue with, I just saw that this touched similar parts of code as I was looking into so figured I ask. Don't want to hold this up.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the confirmation. Please let us know when you have an IP-free repro so we can check if the suggested WAR is enough to fix this issue. Thanks.

for idx, d in enumerate(vi.type.tensor_type.shape.dim):
if d.dim_value:
vi.type.tensor_type.shape.dim[idx].dim_param = "unk"
else:
for vi in g.value_info:
vi.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED
for idx, d in enumerate(vi.type.tensor_type.shape.dim):
if d.dim_value:
vi.type.tensor_type.shape.dim[idx].dim_param = "unk"

# Clear outputs for both main graph and subgraphs
for out in g.output:
out.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED
for idx, d in enumerate(out.type.tensor_type.shape.dim):
Expand Down
76 changes: 76 additions & 0 deletions tests/unit/onnx/autocast/test_precisionconverter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1520,3 +1520,79 @@ def test_if_subgraph_mixed_precision_boundary(model_with_if_subgraph, low_precis
# Verify a cast was inserted between If output and Add input
cast_nodes = [n for n in converted_model.graph.node if n.op_type == "Cast"]
assert len(cast_nodes) > 0, "Should have cast nodes for mixed precision"


@pytest.fixture
def model_with_if_outer_scope_reference():
"""Create a minimal model where If subgraphs reference outer scope variables.

This tests that subgraph value_info for outer scope variables is preserved during type clearing.
"""
# Main graph inputs/outputs
x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 4])
condition = helper.make_tensor_value_info("condition", TensorProto.BOOL, [])
y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 4])

# Create "then" branch: Identity on X from outer scope
then_y = helper.make_tensor_value_info("then_y", TensorProto.FLOAT, [2, 4])
then_identity = helper.make_node("Identity", ["X"], ["then_y"], name="then_identity")
then_graph = helper.make_graph([then_identity], "then_branch", [], [then_y])
# Add X to value_info - this is what needs to be preserved
then_graph.value_info.extend([helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 4])])

# Create "else" branch: Identity on X from outer scope
else_y = helper.make_tensor_value_info("else_y", TensorProto.FLOAT, [2, 4])
else_identity = helper.make_node("Identity", ["X"], ["else_y"], name="else_identity")
else_graph = helper.make_graph([else_identity], "else_branch", [], [else_y])
else_graph.value_info.extend([helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 4])])

# Create If node and main graph
if_node = helper.make_node(
"If", ["condition"], ["Y"], name="if_node", then_branch=then_graph, else_branch=else_graph
)
main_graph = helper.make_graph([if_node], "model_with_outer_scope", [x, condition], [y])

model = helper.make_model(main_graph, producer_name="model_with_outer_scope")
model.opset_import[0].version = 20
onnx.checker.check_model(model)

model = onnx_utils.infer_shapes(model)
value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model)
return model, value_info_map, initializer_map, node_to_init_map


@pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"])
def test_if_subgraph_outer_scope_type_preservation(
model_with_if_outer_scope_reference, low_precision_type
):
"""Test that outer scope variable types are preserved in If subgraphs during conversion.

Without preserving X's value_info in subgraphs, shape inference fails with
"Element type of input 0 unknown".
"""
model, value_info_map, initializer_map, node_to_init_map = model_with_if_outer_scope_reference

converter = PrecisionConverter(
model,
value_info_map,
initializer_map,
node_to_init_map,
keep_io_types=True,
low_precision_type=low_precision_type,
)

converted_model = converter.convert(high_precision_nodes=["if_node"], low_precision_nodes=[])
onnx.checker.check_model(converted_model)

# Verify X's value_info is preserved in both subgraphs
if_node = next(n for n in converted_model.graph.node if n.op_type == "If")
then_branch = next(attr.g for attr in if_node.attribute if attr.name == "then_branch")
else_branch = next(attr.g for attr in if_node.attribute if attr.name == "else_branch")

then_x_info = [vi for vi in then_branch.value_info if vi.name == "X"]
else_x_info = [vi for vi in else_branch.value_info if vi.name == "X"]

assert len(then_x_info) > 0, "X value_info should be preserved in then branch"
assert len(else_x_info) > 0, "X value_info should be preserved in else branch"
assert then_x_info[0].type.tensor_type.elem_type != onnx.TensorProto.UNDEFINED
assert else_x_info[0].type.tensor_type.elem_type != onnx.TensorProto.UNDEFINED