diff --git a/onnxscript/version_converter/__init__.py b/onnxscript/version_converter/__init__.py index b0831a00f9..bd020c83a6 100644 --- a/onnxscript/version_converter/__init__.py +++ b/onnxscript/version_converter/__init__.py @@ -38,8 +38,7 @@ def __init__(self, target_version: int, fallback: bool = False) -> None: self.target_version = target_version self.fallback = fallback self.convert_pass = ir.passes.Sequential( - common_passes.InlinePass(), - _ConvertVersionPassRequiresInline( + _ConvertVersionPass( target_version=target_version, fallback=fallback, ), @@ -52,7 +51,7 @@ def call(self, model: ir.Model) -> ir.passes.PassResult: return self.convert_pass(model) -class _ConvertVersionPassRequiresInline(ir.passes.InPlacePass): +class _ConvertVersionPass(ir.passes.InPlacePass): """Convert the model to the specified ONNX opset version. This pass leverages the onnxscript version converter to convert the model. If @@ -73,12 +72,6 @@ def __init__(self, target_version: int, fallback: bool) -> None: self.fallback = fallback def call(self, model: ir.Model) -> ir.passes.PassResult: - if model.functions: - raise ValueError( - "The model contains functions. The version conversion pass does not support " - "functions. Please use `common_passes.InlinePass` to inline the " - f"functions before applying this pass ({self.__class__.__name__})." - ) if "" in model.graph.opset_imports: onnx_opset_version = model.graph.opset_imports[""] if onnx_opset_version == self.target_version: diff --git a/onnxscript/version_converter/_version_converter.py b/onnxscript/version_converter/_version_converter.py index 7eb4425941..7470df0ead 100644 --- a/onnxscript/version_converter/_version_converter.py +++ b/onnxscript/version_converter/_version_converter.py @@ -274,10 +274,10 @@ def visit_attribute(self, attr: ir.Attr) -> None: if attr.is_ref(): return if attr.type == ir.AttributeType.GRAPH: - self.visit_graph(attr.as_graph()) + self.visit_graph_or_function(attr.as_graph()) elif attr.type == ir.AttributeType.GRAPHS: for graph in attr.as_graphs(): - self.visit_graph(graph) + self.visit_graph_or_function(graph) def visit_node( self, @@ -303,8 +303,8 @@ def visit_node( self._default_metadata_merger.copy_merged_metadata([node], replacement.new_nodes) self.replace_node(node, replacement, root) - def visit_graph(self, graph: ir.Graph) -> None: - for node in graph: + def visit_graph_or_function(self, graph_or_function: ir.Graph | ir.Function) -> None: + for node in graph_or_function: if node.domain != "": continue node_version = node.version or self._default_onnx_opset @@ -321,7 +321,7 @@ def visit_graph(self, graph: ir.Graph) -> None: ) for from_version in range(node_version, self._target_version): try: - self.visit_node(node, graph, from_version, up_conversion=True) + self.visit_node(node, graph_or_function, from_version, up_conversion=True) except VersionConverterError as e: logger.warning( "Skipping version conversion for node %s due to exception: %s", @@ -331,7 +331,9 @@ def visit_graph(self, graph: ir.Graph) -> None: def visit_model(self, model: ir.Model) -> None: self._default_onnx_opset = _get_onnx_opset_version(model) - self.visit_graph(model.graph) + self.visit_graph_or_function(model.graph) + for function in model.functions.values(): + self.visit_graph_or_function(function) _set_onnx_opset_version(model, self._target_version) diff --git a/onnxscript/version_converter/_version_converter_test.py b/onnxscript/version_converter/_version_converter_test.py index cb5a449dbd..2b615a8f7f 100644 --- a/onnxscript/version_converter/_version_converter_test.py +++ b/onnxscript/version_converter/_version_converter_test.py @@ -208,40 +208,114 @@ def test_version_convert_gridsample_cubic(self): self.assertEqual(model.graph.node(4).version, 20) self.assertEqual(model.graph.node(4).attributes["mode"].value, "cubic") - def test_version_convert_inline(self): + def test_version_convert_function_nodes(self): + """Test that version converter processes nodes inside model functions.""" model = ir.from_onnx_text( """ - - agraph (float[4, 512, 512] input_x, float[4, 1024, 1024] input_y) => (float[4, 257, 64, 2] output) + + agraph (float[4, 512, 512] input_x) => (float[4, 257, 64, 2] output) { - shape_a = Constant() - reshape_x = Reshape (input_x, shape_a) - shape_b = Constant() - reshape_y = Reshape (input_x, shape_b) - gridsample = GridSample (reshape_x, reshape_y) - output = foo(gridsample) + output = pkg.custom.dft_func (input_x) } - - foo (x) => (dft) { - dft = DFT (x) + + dft_func (x) => (result) { + shape_a = Constant() + reshape_x = Reshape (x, shape_a) + dft = DFT (reshape_x) + shape_c = Constant() + result = Reshape (dft, shape_c) } """ ) + # Verify the function exists with correct initial state + self.assertEqual(len(model.functions), 1) + func = model.functions[("pkg.custom", "dft_func", "")] + self.assertEqual(len(func), 5) # 5 nodes in the function + target_version = 20 version_converter.convert_version(model, target_version=target_version) self.assertEqual(model.opset_imports[""], target_version) - self.assertEqual(model.graph.node(0).op_type, "Constant") - self.assertEqual(model.graph.node(0).version, 20) - self.assertEqual(model.graph.node(1).op_type, "Reshape") - self.assertEqual(model.graph.node(1).version, 20) - self.assertEqual(model.graph.node(4).op_type, "GridSample") - self.assertEqual(model.graph.node(4).version, 20) - self.assertEqual(model.graph.node(4).attributes["mode"].value, "linear") - self.assertEqual(model.graph.node(6).op_type, "DFT") - self.assertEqual(model.graph.node(6).version, 20) - self.assertEqual(len(model.graph.node(6).inputs), 3) + # Verify that nodes inside the function were version-converted + func = model.functions[("pkg.custom", "dft_func", "")] + self.assertEqual(func[0].op_type, "Constant") + self.assertEqual(func[0].version, 20) + self.assertEqual(func[1].op_type, "Reshape") + self.assertEqual(func[1].version, 20) + # After DFT adapter, a new Constant node is inserted for dft_length + self.assertEqual(func[2].op_type, "Constant") + self.assertEqual(func[2].version, 20) + self.assertEqual(func[3].op_type, "DFT") + self.assertEqual(func[3].version, 20) + self.assertEqual(len(func[3].inputs), 3) # DFT 19->20 adds dft_length input + + def test_version_convert_function_with_control_flow_subgraph(self): + """Test that version converter processes subgraphs inside control flow nodes in functions.""" + model = ir.from_onnx_text( + """ + + agraph (float[4, 512, 512] input_x, bool cond) => (float[4, 257, 64, 2] output) + { + output = pkg.custom.conditional_dft (input_x, cond) + } + + + conditional_dft (x, cond) => (result) { + result = If (cond) (out) { + shape_a = Constant() + reshape_x = Reshape (x, shape_a) + dft = DFT (reshape_x) + shape_c = Constant() + out = Reshape (dft, shape_c) + }, else_branch: graph = else_graph () => (out) { + shape_c = Constant() + out = Reshape (x, shape_c) + }> + } + """ + ) + # Verify the function exists with correct initial state + self.assertEqual(len(model.functions), 1) + func = model.functions[("pkg.custom", "conditional_dft", "")] + self.assertEqual(len(func), 1) # 1 node (If) in the function + + # Verify the If node has subgraphs + if_node = func[0] + self.assertEqual(if_node.op_type, "If") + then_branch = if_node.attributes["then_branch"].as_graph() + else_branch = if_node.attributes["else_branch"].as_graph() + self.assertEqual(len(then_branch), 5) # 5 nodes in then_branch + self.assertEqual(len(else_branch), 2) # 2 nodes in else_branch + + target_version = 20 + # Use internal API to test function version conversion without inlining + version_converter.convert_version(model, target_version=target_version) + self.assertEqual(model.opset_imports[""], target_version) + + # Verify nodes inside the function's If node subgraphs were version-converted + func = model.functions[("pkg.custom", "conditional_dft", "")] + if_node = func[0] + self.assertEqual(if_node.op_type, "If") + self.assertEqual(if_node.version, 20) + + # Check then_branch subgraph nodes + then_branch = if_node.attributes["then_branch"].as_graph() + # After DFT adapter, a new Constant node is inserted for dft_length + self.assertEqual(len(then_branch), 6) # 5 + 1 new Constant for DFT + dft_node = None + for node in then_branch: + self.assertEqual(node.version, 20) + if node.op_type == "DFT": + dft_node = node + self.assertIsNotNone(dft_node) + self.assertEqual(len(dft_node.inputs), 3) # DFT 19->20 adds dft_length input + + # Check else_branch subgraph nodes + else_branch = if_node.attributes["else_branch"].as_graph() + self.assertEqual(len(else_branch), 2) + for node in else_branch: + self.assertEqual(node.version, 20) class VersionConverter20to21Test(unittest.TestCase):