Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions onnxscript/version_converter/_version_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,11 @@ def visit_graph_or_function(self, graph_or_function: ir.Graph | ir.Function) ->
node_version = node.version or self._default_onnx_opset
if node_version is None:
raise VersionConverterError(f"Node {node} has no version.")
# RefAttr is not supported by adapters for now.
if any(attr.is_ref() for attr in node.attributes.values()):
raise VersionConverterError(
f"Node '{node!r}' has ref attribute, which is not supported by version converter."
)
# Iterate each node from current node version -> target version
# and updating node based on the correct adapter
# Up-conversion [ver->ver+1] or down-conversion [ver->ver-1]
Expand Down
67 changes: 67 additions & 0 deletions onnxscript/version_converter/_version_converter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,73 @@ def test_metadata_is_copied_to_multiple_replacement_nodes(self):
f"Node {i} ({node.op_type}) should have metadata copied",
)

def test_version_convert_raises_on_function_node_with_ref_attribute(self):
"""Test that version conversion raises when a function contains a node with a ref attribute."""
# Build a function with a LeakyRelu node that uses a RefAttr for 'alpha'
func_input = ir.Value(name="x")
ref_attr = ir.RefAttr("alpha", "alpha", ir.AttributeType.FLOAT)
func_output = ir.Value(name="result")
leaky_relu_node = ir.Node(
domain="",
op_type="LeakyRelu",
inputs=[func_input],
outputs=[func_output],
attributes=[ref_attr],
version=18,
)
func_graph = ir.Graph(
inputs=[func_input],
outputs=[func_output],
nodes=[leaky_relu_node],
opset_imports={"": 18},
)
func_attr_param = ir.Attr("alpha", ir.AttributeType.FLOAT, 0.01)
function = ir.Function(
domain="pkg.custom",
name="leaky_relu_func",
graph=func_graph,
attributes=[func_attr_param],
)

# Build a main graph that calls the function
main_input = ir.Value(name="input_x")
main_output = ir.Value(name="output")
call_node = ir.Node(
domain="pkg.custom",
op_type="leaky_relu_func",
inputs=[main_input],
outputs=[main_output],
version=18,
)
main_graph = ir.Graph(
inputs=[main_input],
outputs=[main_output],
nodes=[call_node],
opset_imports={"": 18, "pkg.custom": 1},
)
model = ir.Model(
main_graph,
ir_version=8,
functions=[function],
)

target_version = 20
with self.assertRaises(
(
version_converter._version_converter.VersionConverterError, # pylint: disable=protected-access
ir.passes.PassError,
)
) as ctx:
version_converter.convert_version(model, target_version=target_version)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is an issue that PassError is raised. I realized this in the shape inferencer as well. I think we should find a way to move the main logic out of the pass, and make the pass a simple wrapper around the logic so VersionConverterError can be properly raised by convert_version

# Check the error message, unwrapping PassError if needed
error = ctx.exception
if isinstance(error, ir.passes.PassError) and error.__cause__ is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should use assertRegex, after the error is properly raised

error = error.__cause__
self.assertIn(
"has ref attribute, which is not supported by version converter",
str(error),
)


class VersionConverter25to26Test(unittest.TestCase):
@pytest.mark.xfail(strict=True, reason="Version upgrade beyond 25 not yet supported.")
Expand Down
Loading