Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 2 additions & 9 deletions onnxscript/version_converter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ def __init__(self, target_version: int, fallback: bool = False) -> None:
self.target_version = target_version
self.fallback = fallback
self.convert_pass = ir.passes.Sequential(
common_passes.InlinePass(),
_ConvertVersionPassRequiresInline(
_ConvertVersionPass(
target_version=target_version,
fallback=fallback,
),
Expand All @@ -52,7 +51,7 @@ def call(self, model: ir.Model) -> ir.passes.PassResult:
return self.convert_pass(model)


class _ConvertVersionPassRequiresInline(ir.passes.InPlacePass):
class _ConvertVersionPass(ir.passes.InPlacePass):
"""Convert the model to the specified ONNX opset version.

This pass leverages the onnxscript version converter to convert the model. If
Expand All @@ -73,12 +72,6 @@ def __init__(self, target_version: int, fallback: bool) -> None:
self.fallback = fallback

def call(self, model: ir.Model) -> ir.passes.PassResult:
if model.functions:
raise ValueError(
"The model contains functions. The version conversion pass does not support "
"functions. Please use `common_passes.InlinePass` to inline the "
f"functions before applying this pass ({self.__class__.__name__})."
)
if "" in model.graph.opset_imports:
onnx_opset_version = model.graph.opset_imports[""]
if onnx_opset_version == self.target_version:
Expand Down
14 changes: 8 additions & 6 deletions onnxscript/version_converter/_version_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,10 +274,10 @@ def visit_attribute(self, attr: ir.Attr) -> None:
if attr.is_ref():
return
if attr.type == ir.AttributeType.GRAPH:
self.visit_graph(attr.as_graph())
self.visit_graph_or_function(attr.as_graph())
elif attr.type == ir.AttributeType.GRAPHS:
for graph in attr.as_graphs():
self.visit_graph(graph)
self.visit_graph_or_function(graph)

def visit_node(
self,
Expand All @@ -303,8 +303,8 @@ def visit_node(
self._default_metadata_merger.copy_merged_metadata([node], replacement.new_nodes)
self.replace_node(node, replacement, root)

def visit_graph(self, graph: ir.Graph) -> None:
for node in graph:
def visit_graph_or_function(self, graph_or_function: ir.Graph | ir.Function) -> None:
for node in graph_or_function:
if node.domain != "":
continue
node_version = node.version or self._default_onnx_opset
Expand All @@ -321,7 +321,7 @@ def visit_graph(self, graph: ir.Graph) -> None:
)
for from_version in range(node_version, self._target_version):
try:
self.visit_node(node, graph, from_version, up_conversion=True)
self.visit_node(node, graph_or_function, from_version, up_conversion=True)
except VersionConverterError as e:
logger.warning(
"Skipping version conversion for node %s due to exception: %s",
Expand All @@ -331,7 +331,9 @@ def visit_graph(self, graph: ir.Graph) -> None:

def visit_model(self, model: ir.Model) -> None:
self._default_onnx_opset = _get_onnx_opset_version(model)
self.visit_graph(model.graph)
self.visit_graph_or_function(model.graph)
for function in model.functions.values():
self.visit_graph_or_function(function)
_set_onnx_opset_version(model, self._target_version)


Expand Down
118 changes: 96 additions & 22 deletions onnxscript/version_converter/_version_converter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,40 +208,114 @@ def test_version_convert_gridsample_cubic(self):
self.assertEqual(model.graph.node(4).version, 20)
self.assertEqual(model.graph.node(4).attributes["mode"].value, "cubic")

def test_version_convert_inline(self):
def test_version_convert_function_nodes(self):
"""Test that version converter processes nodes inside model functions."""
model = ir.from_onnx_text(
"""
<ir_version: 8, opset_import: [ "" : 18]>
agraph (float[4, 512, 512] input_x, float[4, 1024, 1024] input_y) => (float[4, 257, 64, 2] output)
<ir_version: 8, opset_import: [ "" : 18, "pkg.custom": 1]>
agraph (float[4, 512, 512] input_x) => (float[4, 257, 64, 2] output)
{
shape_a = Constant<value: tensor = int64[5] {1, 4, 512, 512}>()
reshape_x = Reshape (input_x, shape_a)
shape_b = Constant<value: tensor = int64[5] {1, 4, 1024, 1024}>()
reshape_y = Reshape (input_x, shape_b)
gridsample = GridSample <mode = "bilinear"> (reshape_x, reshape_y)
output = foo(gridsample)
output = pkg.custom.dft_func (input_x)
}

<opset_import: [ "" : 18]>
foo (x) => (dft) {
dft = DFT <axis = 2, onesided = 1> (x)
<domain: "pkg.custom", opset_import: [ "" : 18]>
dft_func (x) => (result) {
shape_a = Constant<value: tensor = int64[5] {1, 4, 512, 512, 1}>()
reshape_x = Reshape (x, shape_a)
dft = DFT <axis = 2, onesided = 1> (reshape_x)
shape_c = Constant<value: tensor = int64[4] {4, 257, 64, 2}>()
result = Reshape (dft, shape_c)
}
"""
)
# Verify the function exists with correct initial state
self.assertEqual(len(model.functions), 1)
func = model.functions[("pkg.custom", "dft_func", "")]
self.assertEqual(len(func), 5) # 5 nodes in the function

target_version = 20
version_converter.convert_version(model, target_version=target_version)
self.assertEqual(model.opset_imports[""], target_version)

self.assertEqual(model.graph.node(0).op_type, "Constant")
self.assertEqual(model.graph.node(0).version, 20)
self.assertEqual(model.graph.node(1).op_type, "Reshape")
self.assertEqual(model.graph.node(1).version, 20)
self.assertEqual(model.graph.node(4).op_type, "GridSample")
self.assertEqual(model.graph.node(4).version, 20)
self.assertEqual(model.graph.node(4).attributes["mode"].value, "linear")
self.assertEqual(model.graph.node(6).op_type, "DFT")
self.assertEqual(model.graph.node(6).version, 20)
self.assertEqual(len(model.graph.node(6).inputs), 3)
# Verify that nodes inside the function were version-converted
func = model.functions[("pkg.custom", "dft_func", "")]
self.assertEqual(func[0].op_type, "Constant")
self.assertEqual(func[0].version, 20)
self.assertEqual(func[1].op_type, "Reshape")
self.assertEqual(func[1].version, 20)
# After DFT adapter, a new Constant node is inserted for dft_length
self.assertEqual(func[2].op_type, "Constant")
self.assertEqual(func[2].version, 20)
self.assertEqual(func[3].op_type, "DFT")
self.assertEqual(func[3].version, 20)
self.assertEqual(len(func[3].inputs), 3) # DFT 19->20 adds dft_length input

def test_version_convert_function_with_control_flow_subgraph(self):
"""Test that version converter processes subgraphs inside control flow nodes in functions."""
model = ir.from_onnx_text(
"""
<ir_version: 8, opset_import: [ "" : 18, "pkg.custom": 1]>
agraph (float[4, 512, 512] input_x, bool cond) => (float[4, 257, 64, 2] output)
{
output = pkg.custom.conditional_dft (input_x, cond)
}

<domain: "pkg.custom", opset_import: [ "" : 18]>
conditional_dft (x, cond) => (result) {
result = If (cond) <then_branch: graph = then_graph () => (out) {
shape_a = Constant<value: tensor = int64[5] {1, 4, 512, 512, 1}>()
reshape_x = Reshape (x, shape_a)
dft = DFT <axis = 2, onesided = 1> (reshape_x)
shape_c = Constant<value: tensor = int64[4] {4, 257, 64, 2}>()
out = Reshape (dft, shape_c)
}, else_branch: graph = else_graph () => (out) {
shape_c = Constant<value: tensor = int64[4] {4, 257, 64, 2}>()
out = Reshape (x, shape_c)
}>
}
"""
)
# Verify the function exists with correct initial state
self.assertEqual(len(model.functions), 1)
func = model.functions[("pkg.custom", "conditional_dft", "")]
self.assertEqual(len(func), 1) # 1 node (If) in the function

# Verify the If node has subgraphs
if_node = func[0]
self.assertEqual(if_node.op_type, "If")
then_branch = if_node.attributes["then_branch"].as_graph()
else_branch = if_node.attributes["else_branch"].as_graph()
self.assertEqual(len(then_branch), 5) # 5 nodes in then_branch
self.assertEqual(len(else_branch), 2) # 2 nodes in else_branch

target_version = 20
# Use internal API to test function version conversion without inlining
version_converter.convert_version(model, target_version=target_version)
self.assertEqual(model.opset_imports[""], target_version)

# Verify nodes inside the function's If node subgraphs were version-converted
func = model.functions[("pkg.custom", "conditional_dft", "")]
if_node = func[0]
self.assertEqual(if_node.op_type, "If")
self.assertEqual(if_node.version, 20)

# Check then_branch subgraph nodes
then_branch = if_node.attributes["then_branch"].as_graph()
# After DFT adapter, a new Constant node is inserted for dft_length
self.assertEqual(len(then_branch), 6) # 5 + 1 new Constant for DFT
dft_node = None
for node in then_branch:
self.assertEqual(node.version, 20)
if node.op_type == "DFT":
dft_node = node
self.assertIsNotNone(dft_node)
self.assertEqual(len(dft_node.inputs), 3) # DFT 19->20 adds dft_length input

# Check else_branch subgraph nodes
else_branch = if_node.attributes["else_branch"].as_graph()
self.assertEqual(len(else_branch), 2)
for node in else_branch:
self.assertEqual(node.version, 20)


class VersionConverter20to21Test(unittest.TestCase):
Expand Down
Loading