diff --git a/onnxscript/version_converter/__init__.py b/onnxscript/version_converter/__init__.py index bd020c83a6..c29fb4c989 100644 --- a/onnxscript/version_converter/__init__.py +++ b/onnxscript/version_converter/__init__.py @@ -37,18 +37,26 @@ def __init__(self, target_version: int, fallback: bool = False) -> None: super().__init__() self.target_version = target_version self.fallback = fallback - self.convert_pass = ir.passes.Sequential( - _ConvertVersionPass( - target_version=target_version, - fallback=fallback, - ), + self._convert_pass = _ConvertVersionPass( + target_version=target_version, + fallback=fallback, + ) + self._cleanup_passes = ir.passes.Sequential( common_passes.RemoveUnusedNodesPass(), common_passes.RemoveUnusedFunctionsPass(), common_passes.RemoveUnusedOpsetsPass(), ) def call(self, model: ir.Model) -> ir.passes.PassResult: - return self.convert_pass(model) + # Run the conversion pass outside of Sequential so that errors + # (e.g. VersionConverterError) propagate directly without being + # wrapped in PassError. + result = self._convert_pass(model) + cleanup_result = self._cleanup_passes(result) + return ir.passes.PassResult( + cleanup_result.model, + result.modified or cleanup_result.modified, + ) class _ConvertVersionPass(ir.passes.InPlacePass): diff --git a/onnxscript/version_converter/_version_converter.py b/onnxscript/version_converter/_version_converter.py index 7281318619..0cb1179ef1 100644 --- a/onnxscript/version_converter/_version_converter.py +++ b/onnxscript/version_converter/_version_converter.py @@ -310,6 +310,11 @@ def visit_graph_or_function(self, graph_or_function: ir.Graph | ir.Function) -> node_version = node.version or self._default_onnx_opset if node_version is None: raise VersionConverterError(f"Node {node} has no version.") + # RefAttr is not supported by adapters for now. + if any(attr.is_ref() for attr in node.attributes.values()): + raise VersionConverterError( + f"Node '{node!r}' has ref attribute, which is not supported by version converter." + ) # Iterate each node from current node version -> target version # and updating node based on the correct adapter # Up-conversion [ver->ver+1] or down-conversion [ver->ver-1] diff --git a/onnxscript/version_converter/_version_converter_test.py b/onnxscript/version_converter/_version_converter_test.py index bf481313f4..c920746d7b 100644 --- a/onnxscript/version_converter/_version_converter_test.py +++ b/onnxscript/version_converter/_version_converter_test.py @@ -455,6 +455,63 @@ def test_metadata_is_copied_to_multiple_replacement_nodes(self): f"Node {i} ({node.op_type}) should have metadata copied", ) + def test_version_convert_raises_on_function_node_with_ref_attribute(self): + """Test that version conversion raises when a function contains a node with a ref attribute.""" + # Build a function with a LeakyRelu node that uses a RefAttr for 'alpha' + func_input = ir.Value(name="x") + ref_attr = ir.RefAttr("alpha", "alpha", ir.AttributeType.FLOAT) + func_output = ir.Value(name="result") + leaky_relu_node = ir.Node( + domain="", + op_type="LeakyRelu", + inputs=[func_input], + outputs=[func_output], + attributes=[ref_attr], + version=18, + ) + func_graph = ir.Graph( + inputs=[func_input], + outputs=[func_output], + nodes=[leaky_relu_node], + opset_imports={"": 18}, + ) + func_attr_param = ir.Attr("alpha", ir.AttributeType.FLOAT, 0.01) + function = ir.Function( + domain="pkg.custom", + name="leaky_relu_func", + graph=func_graph, + attributes=[func_attr_param], + ) + + # Build a main graph that calls the function + main_input = ir.Value(name="input_x") + main_output = ir.Value(name="output") + call_node = ir.Node( + domain="pkg.custom", + op_type="leaky_relu_func", + inputs=[main_input], + outputs=[main_output], + version=18, + ) + main_graph = ir.Graph( + inputs=[main_input], + outputs=[main_output], + nodes=[call_node], + opset_imports={"": 18, "pkg.custom": 1}, + ) + model = ir.Model( + main_graph, + ir_version=8, + functions=[function], + ) + + target_version = 20 + with self.assertRaisesRegex( + version_converter._version_converter.VersionConverterError, # pylint: disable=protected-access + "has ref attribute, which is not supported by version converter", + ): + version_converter.convert_version(model, target_version=target_version) + class VersionConverter25to26Test(unittest.TestCase): @pytest.mark.xfail(strict=True, reason="Version upgrade beyond 25 not yet supported.")