Skip to content

Commit 7971fff

Browse files
authored
[5694695][AutoCast] Preserve outer scope variable types in subgraphs (#717)
## What does this PR do? **Type of change:** Bug fix **Overview:** When clearing type information for shape inference, preserve value_info for outer scope variables in subgraphs. Previously, all value_info entries were cleared indiscriminately, causing shape inference failures when subgraph nodes referenced outer scope variables. ## Testing pytest tests/unit/onnx/autocast/test_precisionconverter.py::test_if_subgraph_outer_scope_type_preservation ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes - **Did you write any new necessary tests?**: Yes - **Did you add or update any necessary documentation?**: N/A - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: No Signed-off-by: Gal Hubara Agam <96368689+galagam@users.noreply.github.com>
1 parent ecda7b0 commit 7971fff

2 files changed

Lines changed: 98 additions & 6 deletions

File tree

modelopt/onnx/autocast/precisionconverter.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,8 @@ def _clear_types_and_shapes_recursive(
283283
"""Recursively clear type/shape information for a graph and all its subgraphs.
284284
285285
This is necessary for control flow operators (Scan, If, Loop) which have subgraphs.
286+
For subgraphs, preserve value_info for outer scope variables (not produced by nodes in subgraph).
287+
For main graph, clear all value_info.
286288
287289
Args:
288290
graph: The ONNX graph to clear types and shapes for.
@@ -303,13 +305,27 @@ def _clear_callback(g: onnx.GraphProto, parent: onnx.NodeProto, is_sub: bool) ->
303305
if d.dim_value:
304306
inp.type.tensor_type.shape.dim[idx].dim_param = "unk"
305307

306-
# Clear type/shape information for intermediates and outputs
307-
for vi in g.value_info:
308-
vi.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED
309-
for idx, d in enumerate(vi.type.tensor_type.shape.dim):
310-
if d.dim_value:
311-
vi.type.tensor_type.shape.dim[idx].dim_param = "unk"
308+
if is_sub:
309+
# Identify which tensors are produced by nodes in this subgraph
310+
subgraph_outputs = set()
311+
for node in g.node:
312+
subgraph_outputs.update(node.output)
313+
314+
# Clear value_info only for intermediates produced by nodes in this subgraph
315+
for vi in g.value_info:
316+
if vi.name in subgraph_outputs:
317+
vi.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED
318+
for idx, d in enumerate(vi.type.tensor_type.shape.dim):
319+
if d.dim_value:
320+
vi.type.tensor_type.shape.dim[idx].dim_param = "unk"
321+
else:
322+
for vi in g.value_info:
323+
vi.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED
324+
for idx, d in enumerate(vi.type.tensor_type.shape.dim):
325+
if d.dim_value:
326+
vi.type.tensor_type.shape.dim[idx].dim_param = "unk"
312327

328+
# Clear outputs for both main graph and subgraphs
313329
for out in g.output:
314330
out.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED
315331
for idx, d in enumerate(out.type.tensor_type.shape.dim):

tests/unit/onnx/autocast/test_precisionconverter.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1520,3 +1520,79 @@ def test_if_subgraph_mixed_precision_boundary(model_with_if_subgraph, low_precis
15201520
# Verify a cast was inserted between If output and Add input
15211521
cast_nodes = [n for n in converted_model.graph.node if n.op_type == "Cast"]
15221522
assert len(cast_nodes) > 0, "Should have cast nodes for mixed precision"
1523+
1524+
1525+
@pytest.fixture
1526+
def model_with_if_outer_scope_reference():
1527+
"""Create a minimal model where If subgraphs reference outer scope variables.
1528+
1529+
This tests that subgraph value_info for outer scope variables is preserved during type clearing.
1530+
"""
1531+
# Main graph inputs/outputs
1532+
x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 4])
1533+
condition = helper.make_tensor_value_info("condition", TensorProto.BOOL, [])
1534+
y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 4])
1535+
1536+
# Create "then" branch: Identity on X from outer scope
1537+
then_y = helper.make_tensor_value_info("then_y", TensorProto.FLOAT, [2, 4])
1538+
then_identity = helper.make_node("Identity", ["X"], ["then_y"], name="then_identity")
1539+
then_graph = helper.make_graph([then_identity], "then_branch", [], [then_y])
1540+
# Add X to value_info - this is what needs to be preserved
1541+
then_graph.value_info.extend([helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 4])])
1542+
1543+
# Create "else" branch: Identity on X from outer scope
1544+
else_y = helper.make_tensor_value_info("else_y", TensorProto.FLOAT, [2, 4])
1545+
else_identity = helper.make_node("Identity", ["X"], ["else_y"], name="else_identity")
1546+
else_graph = helper.make_graph([else_identity], "else_branch", [], [else_y])
1547+
else_graph.value_info.extend([helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 4])])
1548+
1549+
# Create If node and main graph
1550+
if_node = helper.make_node(
1551+
"If", ["condition"], ["Y"], name="if_node", then_branch=then_graph, else_branch=else_graph
1552+
)
1553+
main_graph = helper.make_graph([if_node], "model_with_outer_scope", [x, condition], [y])
1554+
1555+
model = helper.make_model(main_graph, producer_name="model_with_outer_scope")
1556+
model.opset_import[0].version = 20
1557+
onnx.checker.check_model(model)
1558+
1559+
model = onnx_utils.infer_shapes(model)
1560+
value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model)
1561+
return model, value_info_map, initializer_map, node_to_init_map
1562+
1563+
1564+
@pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"])
1565+
def test_if_subgraph_outer_scope_type_preservation(
1566+
model_with_if_outer_scope_reference, low_precision_type
1567+
):
1568+
"""Test that outer scope variable types are preserved in If subgraphs during conversion.
1569+
1570+
Without preserving X's value_info in subgraphs, shape inference fails with
1571+
"Element type of input 0 unknown".
1572+
"""
1573+
model, value_info_map, initializer_map, node_to_init_map = model_with_if_outer_scope_reference
1574+
1575+
converter = PrecisionConverter(
1576+
model,
1577+
value_info_map,
1578+
initializer_map,
1579+
node_to_init_map,
1580+
keep_io_types=True,
1581+
low_precision_type=low_precision_type,
1582+
)
1583+
1584+
converted_model = converter.convert(high_precision_nodes=["if_node"], low_precision_nodes=[])
1585+
onnx.checker.check_model(converted_model)
1586+
1587+
# Verify X's value_info is preserved in both subgraphs
1588+
if_node = next(n for n in converted_model.graph.node if n.op_type == "If")
1589+
then_branch = next(attr.g for attr in if_node.attribute if attr.name == "then_branch")
1590+
else_branch = next(attr.g for attr in if_node.attribute if attr.name == "else_branch")
1591+
1592+
then_x_info = [vi for vi in then_branch.value_info if vi.name == "X"]
1593+
else_x_info = [vi for vi in else_branch.value_info if vi.name == "X"]
1594+
1595+
assert len(then_x_info) > 0, "X value_info should be preserved in then branch"
1596+
assert len(else_x_info) > 0, "X value_info should be preserved in else branch"
1597+
assert then_x_info[0].type.tensor_type.elem_type != onnx.TensorProto.UNDEFINED
1598+
assert else_x_info[0].type.tensor_type.elem_type != onnx.TensorProto.UNDEFINED

0 commit comments

Comments
 (0)