77
88import numpy as np
99import onnx
10- import onnx .helper # noqa: TID251
11- from onnx .defs import OpSchema
1210
1311from onnxscript import ir , tensor
1412
2018# python values into ONNX TensorProto, while the runtime converts python values into
2119# ONNXScript runtime's value-representation (based on Tensor).
2220
23-
24- # Utilities to convert a python value to TensorProto (for use by the script converter)
25-
26-
27- def pyvalue_to_onnx_tensor (tensor_name : str , pyvalue ):
28- return ir .serde .serialize_tensor (ir .tensor (pyvalue , name = tensor_name ))
29-
30-
3121_REPEATED_ATTRIBUTE_TYPES = frozenset (
3222 {
33- onnx . AttributeProto .FLOATS ,
34- onnx . AttributeProto .INTS ,
35- onnx . AttributeProto .STRINGS ,
36- onnx . AttributeProto .TENSORS ,
37- onnx . AttributeProto .GRAPHS ,
38- onnx . AttributeProto .SPARSE_TENSORS ,
39- onnx . AttributeProto .TYPE_PROTOS ,
23+ ir . AttributeType .FLOATS ,
24+ ir . AttributeType .INTS ,
25+ ir . AttributeType .STRINGS ,
26+ ir . AttributeType .TENSORS ,
27+ ir . AttributeType .GRAPHS ,
28+ ir . AttributeType .SPARSE_TENSORS ,
29+ ir . AttributeType .TYPE_PROTOS ,
4030 }
4131)
4232
@@ -45,33 +35,28 @@ def pyvalue_to_onnx_attribute(
4535 key : str ,
4636 value : Any ,
4737 name_generator : Callable [[], str ],
48- attr_type : onnx . AttributeProto .AttributeType | None = None ,
49- ) -> onnx . AttributeProto :
38+ attr_type : ir .AttributeType | None = None ,
39+ ) -> ir . Attr :
5040 """Helper function to create an ONNX AttributeProto.
5141
52- This is a refinement of onnx.helper.make_attribute that works with ONNX Script
53- conventions for allowed types for attribute-values. In particular, it allows
54- * Empty lists as attribute values, provided the attribute type is specified
42+ * Empty lists can be attribute values, provided the attribute type is specified
5543 and is a list type.
5644 * Scalar-values like 1.0 as well as lists like [1, -1] to be specified
5745 when the attribute type is TensorProto by automatically converting the value
5846 into a 0-D or 1-D tensor respectively.
5947 """
48+ # TODO(justinchuby): Remove this function and use onnx-ir directly.
6049 if isinstance (value , list ) and not value :
6150 # Empty list value:
6251 if attr_type is None :
6352 raise ValueError ("Attribute type must be specified for empty list value." )
6453 if attr_type not in _REPEATED_ATTRIBUTE_TYPES :
6554 raise ValueError ("Empty list value is only allowed for repeated attribute types." )
66- return onnx .AttributeProto (name = key , type = attr_type )
67- elif attr_type == onnx .AttributeProto .TENSOR and not isinstance (value , onnx .TensorProto ):
68- return onnx .AttributeProto (
69- name = key , type = attr_type , t = pyvalue_to_onnx_tensor (name_generator (), value )
70- )
55+ return ir .Attr (name = key , type = attr_type , value = [])
56+ elif attr_type == ir .AttributeType .TENSOR and not isinstance (value , onnx .TensorProto ):
57+ return ir .AttrTensor (name = key , value = ir .tensor (value , name = name_generator ()))
7158 else :
72- # When the value is a subgraph, ONNX IR will complain that some values are
73- # not found from the scope.
74- return onnx .helper .make_attribute (key , value ) # noqa: TID251
59+ return ir .convenience .convert_attribute (key , value , attr_type = attr_type )
7560
7661
7762# Utilities to convert python values into onnxscript tensors.
@@ -126,7 +111,7 @@ def cast_pyvalue_to_os_tensor(pyvalue, dtype=None):
126111def cast_inputs (
127112 get_type_info : Callable [[Any ], Any ],
128113 cast : Callable [[Any , Any ], Any ],
129- op_schema : OpSchema | None ,
114+ op_signature : ir . schemas . OpSignature | None ,
130115 args ,
131116) -> tuple [Any , ...]:
132117 """Uses schema specification to support a limited form of auto-casting.
@@ -140,12 +125,13 @@ def cast_inputs(
140125 This is used by the converter in a static-mode, as well as by the eager-mode
141126 execution in a dynamic-mode.
142127 """
143- if op_schema is None :
128+ if op_signature is None :
144129 # Either an error or a custom op.
145130 # No checks/casts in this case.
146131 return tuple (cast (x , None ) for x in args )
147132
148- expected_inputs = op_schema .inputs
133+ # Filter to get only input parameters (not AttributeParameters)
134+ expected_inputs = op_signature .inputs
149135 # We make two passes. In the first pass, we identify known type-bindings for
150136 # type-variables: eg., {'T1' : np.float32, 'T2' : np.int32}.
151137 # In the second pass, we use these bindings to cast scalar-values to
@@ -156,17 +142,17 @@ def cast_inputs(
156142 for i , x in enumerate (args ):
157143 if i < len (expected_inputs ):
158144 expected = expected_inputs [i ]
159- elif expected_inputs [- 1 ].option == OpSchema . FormalParameterOption . Variadic :
145+ elif expected_inputs [- 1 ].variadic :
160146 expected = expected_inputs [- 1 ]
161- if not expected .is_homogeneous :
147+ if not expected .homogeneous :
162148 args_typevars .append ((x , None ))
163149 continue
164150 else :
165151 raise ValueError (
166152 f"Number of actual parameters { len (args )} "
167153 f"exceeds number of formal parameters { len (expected_inputs )} ."
168154 )
169- typevar = expected .type_str
155+ typevar = expected .type_constraint . name
170156 if "(" not in typevar :
171157 # typevar is an identifier, like "T"
172158 typeinfo = get_type_info (x )
@@ -177,18 +163,18 @@ def cast_inputs(
177163 return tuple (cast_args )
178164
179165
180- def dynamic_cast_inputs (op_schema : OpSchema , args ):
166+ def dynamic_cast_inputs (op_signature : ir . schemas . OpSignature , args ):
181167 """Used for autocast during eager-mode execution."""
182168
183169 def get_type_info (x ):
184170 return x .dtype if isinstance (x , tensor .Tensor ) else None
185171
186- return cast_inputs (get_type_info , cast_pyvalue_to_os_tensor , op_schema , args )
172+ return cast_inputs (get_type_info , cast_pyvalue_to_os_tensor , op_signature , args )
187173
188174
189175def static_cast_inputs (
190176 converter_ : converter .Converter ,
191- op_schema : Optional [OpSchema ],
177+ op_signature : Optional [ir . schemas . OpSignature ],
192178 args : Sequence [Optional [ir .Value ]],
193179) -> tuple [str , ...]:
194180 """Used for autocast during script-translation.
@@ -212,4 +198,4 @@ def cast_like(x: Optional[ir.Value], y: Optional[ir.Value]) -> Optional[str]:
212198 return converter_ .emit1 ([x_cast ], "CastLike" , [x , y ])
213199 return x
214200
215- return cast_inputs (get_type_info , cast_like , op_schema , args )
201+ return cast_inputs (get_type_info , cast_like , op_signature , args )
0 commit comments