Skip to content

Commit bc37c7b

Browse files
committed
fix
1 parent c03d391 commit bc37c7b

File tree

8 files changed

+116
-72
lines changed

8 files changed

+116
-72
lines changed

_unittests/ut_light_api/test_backend_export.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,15 @@
44
from difflib import unified_diff
55
import packaging.version as pv
66
import numpy
7+
import ml_dtypes
78
from numpy.testing import assert_allclose
89
from onnx.defs import onnx_opset_version
910
import onnx.backend.base
1011
import onnx.backend.test
1112
import onnx.shape_inference
1213
import onnx.version_converter
14+
import onnx.helper as oh
15+
import onnx.numpy_helper as onh
1316
from onnx import ModelProto, TensorProto, __version__ as onnx_version
1417
from onnx.helper import (
1518
make_function,
@@ -94,6 +97,10 @@ def run(
9497

9598
locs = {
9699
"np": numpy,
100+
"ml_dtypes": ml_dtypes,
101+
"onnx": onnx,
102+
"oh": oh,
103+
"onh": onh,
97104
"to_array": to_array,
98105
"to_array_extended": to_array_extended,
99106
"from_array": from_array,

_unittests/ut_translate_api/test_translate.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import unittest
22
from textwrap import dedent
33
import numpy as np
4-
from onnx import ModelProto, TensorProto
4+
import onnx
55
from onnx.defs import onnx_opset_version
66
from onnx.reference import ReferenceEvaluator
77
from onnx_array_api.ext_test_case import ExtTestCase
@@ -20,7 +20,7 @@ def test_event_type(self):
2020

2121
def test_exp(self):
2222
onx = start(opset=19).vin("X").Exp().rename("Y").vout().to_onnx()
23-
self.assertIsInstance(onx, ModelProto)
23+
self.assertIsInstance(onx, onnx.ModelProto)
2424
self.assertIn("Exp", str(onx))
2525
ref = ReferenceEvaluator(onx)
2626
a = np.arange(10).astype(np.float32)
@@ -32,25 +32,25 @@ def test_exp(self):
3232
"""
3333
(
3434
start(opset=19)
35-
.vin('X', elem_type=TensorProto.FLOAT)
35+
.vin('X', elem_type=onnx.TensorProto.FLOAT)
3636
.bring('X')
3737
.Exp()
3838
.rename('Y')
3939
.bring('Y')
40-
.vout(elem_type=TensorProto.FLOAT)
40+
.vout(elem_type=onnx.TensorProto.FLOAT)
4141
.to_onnx()
4242
)"""
4343
).strip("\n")
4444
self.assertEqual(expected, code)
4545

4646
onx2 = (
4747
start(opset=19)
48-
.vin("X", elem_type=TensorProto.FLOAT)
48+
.vin("X", elem_type=onnx.TensorProto.FLOAT)
4949
.bring("X")
5050
.Exp()
5151
.rename("Y")
5252
.bring("Y")
53-
.vout(elem_type=TensorProto.FLOAT)
53+
.vout(elem_type=onnx.TensorProto.FLOAT)
5454
.to_onnx()
5555
)
5656
ref = ReferenceEvaluator(onx2)
@@ -68,7 +68,7 @@ def test_transpose(self):
6868
.vout()
6969
.to_onnx()
7070
)
71-
self.assertIsInstance(onx, ModelProto)
71+
self.assertIsInstance(onx, onnx.ModelProto)
7272
self.assertIn("Transpose", str(onx))
7373
ref = ReferenceEvaluator(onx)
7474
a = np.arange(10).astype(np.float32)
@@ -82,15 +82,15 @@ def test_transpose(self):
8282
start(opset=19)
8383
.cst(np.array([-1, 1], dtype=np.int64))
8484
.rename('r')
85-
.vin('X', elem_type=TensorProto.FLOAT)
85+
.vin('X', elem_type=onnx.TensorProto.FLOAT)
8686
.bring('X', 'r')
8787
.Reshape()
8888
.rename('r0_0')
8989
.bring('r0_0')
9090
.Transpose(perm=[1, 0])
9191
.rename('Y')
9292
.bring('Y')
93-
.vout(elem_type=TensorProto.FLOAT)
93+
.vout(elem_type=onnx.TensorProto.FLOAT)
9494
.to_onnx()
9595
)"""
9696
).strip("\n")
@@ -107,7 +107,7 @@ def test_topk_reverse(self):
107107
.vout()
108108
.to_onnx()
109109
)
110-
self.assertIsInstance(onx, ModelProto)
110+
self.assertIsInstance(onx, onnx.ModelProto)
111111
ref = ReferenceEvaluator(onx)
112112
x = np.array([[0, 1, 2, 3], [9, 8, 7, 6]], dtype=np.float32)
113113
k = np.array([2], dtype=np.int64)
@@ -120,15 +120,15 @@ def test_topk_reverse(self):
120120
"""
121121
(
122122
start(opset=19)
123-
.vin('X', elem_type=TensorProto.FLOAT)
124-
.vin('K', elem_type=TensorProto.INT64)
123+
.vin('X', elem_type=onnx.TensorProto.FLOAT)
124+
.vin('K', elem_type=onnx.TensorProto.INT64)
125125
.bring('X', 'K')
126126
.TopK(axis=-1, largest=0, sorted=1)
127127
.rename('Values', 'Indices')
128128
.bring('Values')
129-
.vout(elem_type=TensorProto.FLOAT)
129+
.vout(elem_type=onnx.TensorProto.FLOAT)
130130
.bring('Indices')
131-
.vout(elem_type=TensorProto.FLOAT)
131+
.vout(elem_type=onnx.TensorProto.FLOAT)
132132
.to_onnx()
133133
)"""
134134
).strip("\n")
@@ -152,7 +152,7 @@ def test_export_if(self):
152152
.to_onnx()
153153
)
154154

155-
self.assertIsInstance(onx, ModelProto)
155+
self.assertIsInstance(onx, onnx.ModelProto)
156156
ref = ReferenceEvaluator(onx)
157157
x = np.array([[0, 1, 2, 3], [9, 8, 7, 6]], dtype=np.float32)
158158
k = np.array([2], dtype=np.int64)
@@ -162,19 +162,19 @@ def test_export_if(self):
162162
code = translate(onx)
163163
selse = (
164164
"g().cst(np.array([0], dtype=np.int64)).rename('Z')."
165-
"bring('Z').vout(elem_type=TensorProto.FLOAT)"
165+
"bring('Z').vout(elem_type=onnx.TensorProto.FLOAT)"
166166
)
167167
sthen = (
168168
"g().cst(np.array([1], dtype=np.int64)).rename('Z')."
169-
"bring('Z').vout(elem_type=TensorProto.FLOAT)"
169+
"bring('Z').vout(elem_type=onnx.TensorProto.FLOAT)"
170170
)
171171
expected = dedent(
172172
f"""
173173
(
174174
start(opset=19)
175175
.cst(np.array([0.0], dtype=np.float32))
176176
.rename('r')
177-
.vin('X', elem_type=TensorProto.FLOAT)
177+
.vin('X', elem_type=onnx.TensorProto.FLOAT)
178178
.bring('X')
179179
.ReduceSum(keepdims=1, noop_with_empty_axes=0)
180180
.rename('Xs')
@@ -185,7 +185,7 @@ def test_export_if(self):
185185
.If(else_branch={selse}, then_branch={sthen})
186186
.rename('W')
187187
.bring('W')
188-
.vout(elem_type=TensorProto.FLOAT)
188+
.vout(elem_type=onnx.TensorProto.FLOAT)
189189
.to_onnx()
190190
)"""
191191
).strip("\n")
@@ -210,15 +210,15 @@ def test_aionnxml(self):
210210
start(opset=19, opsets={'ai.onnx.ml': 3})
211211
.cst(np.array([-1, 1], dtype=np.int64))
212212
.rename('r')
213-
.vin('X', elem_type=TensorProto.FLOAT)
213+
.vin('X', elem_type=onnx.TensorProto.FLOAT)
214214
.bring('X', 'r')
215215
.Reshape()
216216
.rename('USE')
217217
.bring('USE')
218218
.ai.onnx.ml.Normalizer(norm='MAX')
219219
.rename('Y')
220220
.bring('Y')
221-
.vout(elem_type=TensorProto.FLOAT)
221+
.vout(elem_type=onnx.TensorProto.FLOAT)
222222
.to_onnx()
223223
)"""
224224
).strip("\n")

_unittests/ut_translate_api/test_translate_builder.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from textwrap import dedent
33
import numpy as np
44
import onnx.helper as oh
5-
from onnx import ModelProto, TensorProto
5+
import onnx
66
from onnx.checker import check_model
77
from onnx.defs import onnx_opset_version
88
from onnx.reference import ReferenceEvaluator
@@ -22,7 +22,7 @@ def setUp(self):
2222

2323
def test_exp(self):
2424
onx = start(opset=19, ir_version=10).vin("X").Exp().rename("Y").vout().to_onnx()
25-
self.assertIsInstance(onx, ModelProto)
25+
self.assertIsInstance(onx, onnx.ModelProto)
2626
self.assertIn("Exp", str(onx))
2727
ref = ReferenceEvaluator(onx)
2828
a = np.arange(10).astype(np.float32)
@@ -42,9 +42,9 @@ def light_api(
4242
return Y
4343
4444
g = GraphBuilder({'': 19}, ir_version=10)
45-
g.make_tensor_input("X", TensorProto.FLOAT, ())
45+
g.make_tensor_input("X", onnx.TensorProto.FLOAT, ())
4646
light_api(g.op, "X")
47-
g.make_tensor_output("Y", TensorProto.FLOAT, ()__SUFFIX__)
47+
g.make_tensor_output("Y", onnx.TensorProto.FLOAT, ()__SUFFIX__)
4848
model = g.to_onnx()
4949
"""
5050
)
@@ -62,10 +62,10 @@ def light_api(
6262
return Y
6363

6464
g2 = GraphBuilder({"": 19})
65-
g2.make_tensor_input("X", TensorProto.FLOAT, ("A",))
65+
g2.make_tensor_input("X", onnx.TensorProto.FLOAT, ("A",))
6666
light_api(g2.op, "X")
6767
g2.make_tensor_output(
68-
"Y", TensorProto.FLOAT, ("A",), is_dimension=False, indexed=False
68+
"Y", onnx.TensorProto.FLOAT, ("A",), is_dimension=False, indexed=False
6969
)
7070
onx2 = g2.to_onnx()
7171

@@ -99,9 +99,9 @@ def light_api(
9999
return Y
100100
101101
g = GraphBuilder({'': 19}, ir_version=10)
102-
g.make_tensor_input("X", TensorProto.FLOAT, ())
102+
g.make_tensor_input("X", onnx.TensorProto.FLOAT, ())
103103
light_api(g.op, "X")
104-
g.make_tensor_output("Y", TensorProto.FLOAT, ()__SUFFIX__)
104+
g.make_tensor_output("Y", onnx.TensorProto.FLOAT, ()__SUFFIX__)
105105
model = g.to_onnx()
106106
"""
107107
)
@@ -122,16 +122,16 @@ def light_api(
122122
return Y
123123

124124
g = GraphBuilder({"": 21})
125-
X = g.make_tensor_input("X", TensorProto.FLOAT, ())
125+
X = g.make_tensor_input("X", onnx.TensorProto.FLOAT, ())
126126
light_api(g.op, X)
127-
g.make_tensor_output("Y", TensorProto.FLOAT, ())
127+
g.make_tensor_output("Y", onnx.TensorProto.FLOAT, ())
128128
model = g.to_onnx()
129129
self.assertNotEmpty(model)
130130
check_model(model)
131131

132132
def test_exp_f(self):
133133
onx = start(opset=19, ir_version=10).vin("X").Exp().rename("Y").vout().to_onnx()
134-
self.assertIsInstance(onx, ModelProto)
134+
self.assertIsInstance(onx, onnx.ModelProto)
135135
self.assertIn("Exp", str(onx))
136136
ref = ReferenceEvaluator(onx)
137137
a = np.arange(10).astype(np.float32)
@@ -155,9 +155,9 @@ def light_api(
155155
156156
def mm() -> "ModelProto":
157157
g = GraphBuilder({'': 19}, ir_version=10)
158-
g.make_tensor_input("X", TensorProto.FLOAT, ())
158+
g.make_tensor_input("X", onnx.TensorProto.FLOAT, ())
159159
light_api(g.op, "X")
160-
g.make_tensor_output("Y", TensorProto.FLOAT, ()__SUFFIX__)
160+
g.make_tensor_output("Y", onnx.TensorProto.FLOAT, ()__SUFFIX__)
161161
model = g.to_onnx()
162162
return model
163163
@@ -179,10 +179,10 @@ def light_api(
179179
return Y
180180

181181
g2 = GraphBuilder({"": 19})
182-
g2.make_tensor_input("X", TensorProto.FLOAT, ("A",))
182+
g2.make_tensor_input("X", onnx.TensorProto.FLOAT, ("A",))
183183
light_api(g2.op, "X")
184184
g2.make_tensor_output(
185-
"Y", TensorProto.FLOAT, ("A",), is_dimension=False, indexed=False
185+
"Y", onnx.TensorProto.FLOAT, ("A",), is_dimension=False, indexed=False
186186
)
187187
onx2 = g2.to_onnx()
188188

@@ -216,11 +216,11 @@ def test_local_function(self):
216216
],
217217
"example",
218218
[
219-
oh.make_tensor_value_info("X", TensorProto.FLOAT, [None, None]),
220-
oh.make_tensor_value_info("A", TensorProto.FLOAT, [None, None]),
221-
oh.make_tensor_value_info("B", TensorProto.FLOAT, [None, None]),
219+
oh.make_tensor_value_info("X", onnx.TensorProto.FLOAT, [None, None]),
220+
oh.make_tensor_value_info("A", onnx.TensorProto.FLOAT, [None, None]),
221+
oh.make_tensor_value_info("B", onnx.TensorProto.FLOAT, [None, None]),
222222
],
223-
[oh.make_tensor_value_info("Y", TensorProto.FLOAT, None)],
223+
[oh.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, None)],
224224
)
225225

226226
onnx_model = oh.make_model(
@@ -262,11 +262,11 @@ def make_custom_LinearRegression(g: "GraphBuilder"):
262262
263263
def mm() -> "ModelProto":
264264
g = GraphBuilder({'': 14, 'custom': 1}, ir_version=10)
265-
g.make_tensor_input("X", TensorProto.FLOAT, ('', ''))
266-
g.make_tensor_input("A", TensorProto.FLOAT, ('', ''))
267-
g.make_tensor_input("B", TensorProto.FLOAT, ('', ''))
265+
g.make_tensor_input("X", onnx.TensorProto.FLOAT, ('', ''))
266+
g.make_tensor_input("A", onnx.TensorProto.FLOAT, ('', ''))
267+
g.make_tensor_input("B", onnx.TensorProto.FLOAT, ('', ''))
268268
example(g.op, "X", "A", "B")
269-
g.make_tensor_output("Y", TensorProto.FLOAT, ()__SUFFIX__)
269+
g.make_tensor_output("Y", onnx.TensorProto.FLOAT, ()__SUFFIX__)
270270
make_custom_LinearRegression(g)
271271
model = g.to_onnx()
272272
return model

0 commit comments

Comments
 (0)