Skip to content

Commit ef02a98

Browse files
committed
add code rendering
1 parent 7e19bf5 commit ef02a98

File tree

4 files changed

+61
-89
lines changed

4 files changed

+61
-89
lines changed

_unittests/ut_translate_api/test_translate_classic.py

Lines changed: 44 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,13 @@
22
import os
33
from textwrap import dedent
44
import numpy as np
5+
import onnx
6+
import onnx.helper as oh
7+
import onnx.numpy_helper as onh
58
from onnx import ModelProto, TensorProto, load
69
from onnx.defs import onnx_opset_version
710
from onnx.reference import ReferenceEvaluator
811
from onnx.reference.op_run import OpRun
9-
from onnx.helper import (
10-
make_tensor_value_info,
11-
make_node,
12-
make_graph,
13-
make_model,
14-
make_opsetid,
15-
)
1612
from onnx.checker import check_model
1713
from onnx_array_api.ext_test_case import ExtTestCase
1814
from onnx_array_api.light_api import start
@@ -23,27 +19,25 @@
2319

2420
class TestTranslateClassic(ExtTestCase):
2521
def test_check_code(self):
26-
opset_imports = [
27-
make_opsetid("", 19),
28-
]
22+
opset_imports = [oh.make_opsetid("", 19)]
2923
inputs = []
3024
outputs = []
3125
nodes = []
3226
initializers = []
3327
sparse_initializers = []
3428
functions = []
35-
inputs.append(make_tensor_value_info("X", TensorProto.FLOAT, shape=[]))
36-
nodes.append(make_node("Exp", ["X"], ["Y"]))
37-
outputs.append(make_tensor_value_info("Y", TensorProto.FLOAT, shape=[]))
38-
graph = make_graph(
29+
inputs.append(oh.make_tensor_value_info("X", TensorProto.FLOAT, shape=[]))
30+
nodes.append(oh.make_node("Exp", ["X"], ["Y"]))
31+
outputs.append(oh.make_tensor_value_info("Y", TensorProto.FLOAT, shape=[]))
32+
graph = oh.make_graph(
3933
nodes,
4034
"onename",
4135
inputs,
4236
outputs,
4337
initializers,
4438
sparse_initializer=sparse_initializers,
4539
)
46-
model = make_model(graph, functions=functions, opset_imports=opset_imports)
40+
model = oh.make_model(graph, functions=functions, opset_imports=opset_imports)
4741
check_model(model)
4842

4943
def test_exp(self):
@@ -60,32 +54,32 @@ def test_exp(self):
6054
expected = dedent(
6155
"""
6256
opset_imports = [
63-
make_opsetid('', 19),
57+
oh.make_opsetid('', 19),
6458
]
6559
inputs = []
6660
outputs = []
6761
nodes = []
6862
initializers = []
6963
sparse_initializers = []
7064
functions = []
71-
inputs.append(make_tensor_value_info('X', TensorProto.FLOAT, shape=[]))
65+
inputs.append(oh.make_tensor_value_info('X', TensorProto.FLOAT, shape=[]))
7266
nodes.append(
7367
make_node_extended(
7468
'Exp',
7569
['X'],
7670
['Y']
7771
)
7872
)
79-
outputs.append(make_tensor_value_info('Y', TensorProto.FLOAT, shape=[]))
80-
graph = make_graph(
73+
outputs.append(oh.make_tensor_value_info('Y', TensorProto.FLOAT, shape=[]))
74+
graph = oh.make_graph(
8175
nodes,
8276
'light_api',
8377
inputs,
8478
outputs,
8579
initializers,
8680
sparse_initializer=sparse_initializers,
8781
)
88-
model = make_model(
82+
model = oh.make_model(
8983
graph,
9084
functions=functions,
9185
opset_imports=opset_imports
@@ -130,7 +124,7 @@ def test_transpose(self):
130124
expected = dedent(
131125
"""
132126
opset_imports = [
133-
make_opsetid('', 19),
127+
oh.make_opsetid('', 19),
134128
]
135129
inputs = []
136130
outputs = []
@@ -139,12 +133,12 @@ def test_transpose(self):
139133
sparse_initializers = []
140134
functions = []
141135
initializers.append(
142-
from_array(
136+
onh.from_array(
143137
np.array([-1, 1], dtype=np.int64),
144138
name='r'
145139
)
146140
)
147-
inputs.append(make_tensor_value_info('X', TensorProto.FLOAT, shape=[]))
141+
inputs.append(oh.make_tensor_value_info('X', TensorProto.FLOAT, shape=[]))
148142
nodes.append(
149143
make_node_extended(
150144
'Reshape',
@@ -160,16 +154,16 @@ def test_transpose(self):
160154
perm=[1, 0]
161155
)
162156
)
163-
outputs.append(make_tensor_value_info('Y', TensorProto.FLOAT, shape=[]))
164-
graph = make_graph(
157+
outputs.append(oh.make_tensor_value_info('Y', TensorProto.FLOAT, shape=[]))
158+
graph = oh.make_graph(
165159
nodes,
166160
'light_api',
167161
inputs,
168162
outputs,
169163
initializers,
170164
sparse_initializer=sparse_initializers,
171165
)
172-
model = make_model(
166+
model = oh.make_model(
173167
graph,
174168
functions=functions,
175169
opset_imports=opset_imports
@@ -199,7 +193,7 @@ def test_transpose_short(self):
199193
expected = dedent(
200194
"""
201195
opset_imports = [
202-
make_opsetid('', 19),
196+
oh.make_opsetid('', 19),
203197
]
204198
inputs = []
205199
outputs = []
@@ -208,12 +202,12 @@ def test_transpose_short(self):
208202
sparse_initializers = []
209203
functions = []
210204
initializers.append(
211-
from_array(
205+
onh.from_array(
212206
np.array([-1, 1], dtype=np.int64),
213207
name='r'
214208
)
215209
)
216-
inputs.append(make_tensor_value_info('X', TensorProto.FLOAT, shape=[]))
210+
inputs.append(oh.make_tensor_value_info('X', TensorProto.FLOAT, shape=[]))
217211
nodes.append(
218212
make_node_extended(
219213
'Reshape',
@@ -229,16 +223,16 @@ def test_transpose_short(self):
229223
perm=[1, 0]
230224
)
231225
)
232-
outputs.append(make_tensor_value_info('Y', TensorProto.FLOAT, shape=[]))
233-
graph = make_graph(
226+
outputs.append(oh.make_tensor_value_info('Y', TensorProto.FLOAT, shape=[]))
227+
graph = oh.make_graph(
234228
nodes,
235229
'light_api',
236230
inputs,
237231
outputs,
238232
initializers,
239233
sparse_initializer=sparse_initializers,
240234
)
241-
model = make_model(
235+
model = oh.make_model(
242236
graph,
243237
functions=functions,
244238
opset_imports=opset_imports
@@ -270,16 +264,16 @@ def test_topk_reverse(self):
270264
expected = dedent(
271265
"""
272266
opset_imports = [
273-
make_opsetid('', 19),
267+
oh.make_opsetid('', 19),
274268
]
275269
inputs = []
276270
outputs = []
277271
nodes = []
278272
initializers = []
279273
sparse_initializers = []
280274
functions = []
281-
inputs.append(make_tensor_value_info('X', TensorProto.FLOAT, shape=[]))
282-
inputs.append(make_tensor_value_info('K', TensorProto.INT64, shape=[]))
275+
inputs.append(oh.make_tensor_value_info('X', TensorProto.FLOAT, shape=[]))
276+
inputs.append(oh.make_tensor_value_info('K', TensorProto.INT64, shape=[]))
283277
nodes.append(
284278
make_node_extended(
285279
'TopK',
@@ -290,17 +284,17 @@ def test_topk_reverse(self):
290284
sorted=1
291285
)
292286
)
293-
outputs.append(make_tensor_value_info('Values', TensorProto.FLOAT, shape=[]))
294-
outputs.append(make_tensor_value_info('Indices', TensorProto.FLOAT, shape=[]))
295-
graph = make_graph(
287+
outputs.append(oh.make_tensor_value_info('Values', TensorProto.FLOAT, shape=[]))
288+
outputs.append(oh.make_tensor_value_info('Indices', TensorProto.FLOAT, shape=[]))
289+
graph = oh.make_graph(
296290
nodes,
297291
'light_api',
298292
inputs,
299293
outputs,
300294
initializers,
301295
sparse_initializer=sparse_initializers,
302296
)
303-
model = make_model(
297+
model = oh.make_model(
304298
graph,
305299
functions=functions,
306300
opset_imports=opset_imports
@@ -338,8 +332,8 @@ def test_aionnxml(self):
338332
expected = dedent(
339333
"""
340334
opset_imports = [
341-
make_opsetid('', 19),
342-
make_opsetid('ai.onnx.ml', 3),
335+
oh.make_opsetid('', 19),
336+
oh.make_opsetid('ai.onnx.ml', 3),
343337
]
344338
inputs = []
345339
outputs = []
@@ -348,12 +342,12 @@ def test_aionnxml(self):
348342
sparse_initializers = []
349343
functions = []
350344
initializers.append(
351-
from_array(
345+
onh.from_array(
352346
np.array([-1, 1], dtype=np.int64),
353347
name='r'
354348
)
355349
)
356-
inputs.append(make_tensor_value_info('X', TensorProto.FLOAT, shape=[]))
350+
inputs.append(oh.make_tensor_value_info('X', TensorProto.FLOAT, shape=[]))
357351
nodes.append(
358352
make_node_extended(
359353
'Reshape',
@@ -370,16 +364,16 @@ def test_aionnxml(self):
370364
norm='MAX'
371365
)
372366
)
373-
outputs.append(make_tensor_value_info('Y', TensorProto.FLOAT, shape=[]))
374-
graph = make_graph(
367+
outputs.append(oh.make_tensor_value_info('Y', TensorProto.FLOAT, shape=[]))
368+
graph = oh.make_graph(
375369
nodes,
376370
'light_api',
377371
inputs,
378372
outputs,
379373
initializers,
380374
sparse_initializer=sparse_initializers,
381375
)
382-
model = make_model(
376+
model = oh.make_model(
383377
graph,
384378
functions=functions,
385379
opset_imports=opset_imports
@@ -402,9 +396,6 @@ def _run(cls, code):
402396
f"Compilation failed due to {e}\n---\n{cls._code_line(code)}\n---\n{e}"
403397
) from e
404398

405-
import onnx
406-
import onnx.helper
407-
import onnx.numpy_helper
408399
import onnx_array_api.translate_api.make_helper
409400
import ml_dtypes
410401

@@ -430,10 +421,11 @@ def from_array_extended(tensor, name=None):
430421
return t
431422

432423
globs = onnx.__dict__.copy()
433-
globs.update(onnx.helper.__dict__)
434-
globs.update(onnx.numpy_helper.__dict__)
435424
globs.update(onnx_array_api.translate_api.make_helper.__dict__)
436-
globs.update(ml_dtypes.__dict__)
425+
globs["np"] = np
426+
globs["oh"] = oh
427+
globs["onh"] = onh
428+
globs["ml_dtypes"] = ml_dtypes
437429
globs["from_array_extended"] = from_array_extended
438430
locs = {}
439431
try:

onnx_array_api/tools/replace_constants.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@ def replace_initializer_by_constant_of_shape(
5252
)
5353
)
5454
dtype = cst.dtype
55-
assert op_type != "Constant"
5655
new_nodes.append(
5756
make_node(
5857
op_type,

onnx_array_api/translate_api/base_emitter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def render_attribute_value(self, value: Any) -> Tuple[List[str], str]:
131131
sdtype = repl.get(str(v.dtype), str(str(v.dtype)))
132132
package = "np" if hasattr(np, sdtype) else "ml_dtypes"
133133
return [], (
134-
f"from_array(np.array({v.tolist()}, dtype={package}.{sdtype}), "
134+
f"onh.from_array(np.array({v.tolist()}, dtype={package}.{sdtype}), "
135135
f"name={value[0].name!r})"
136136
)
137137
if isinstance(v, (int, float, list)):

0 commit comments

Comments
 (0)