Skip to content

Commit d32e7ae

Browse files
committed
Submit add_n categorical test script and json
1 parent e4b5294 commit d32e7ae

4 files changed

Lines changed: 154 additions & 12 deletions

File tree

api/tests/add_n.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,22 +34,16 @@ def build_graph(self, config):
3434
@benchmark_registry.register("add_n")
3535
class TorchAddN(PytorchOpBenchmarkBase):
3636
def build_graph(self, config):
37-
input_list = []
38-
input_0 = self.variable(
39-
name='input_' + str(0),
40-
shape=config.inputs_shape[0],
41-
dtype=config.inputs_dtype[0])
42-
result = input_0
43-
input_list.append(input_0)
44-
for i in range(1, len(config.inputs_shape)):
37+
inputs = []
38+
for i in range(len(config.inputs_shape)):
4539
input_i = self.variable(
4640
name='input_' + str(i),
4741
shape=config.inputs_shape[i],
4842
dtype=config.inputs_dtype[i])
49-
result = torch.add(result, input_i)
50-
input_list.append(input_i)
51-
52-
self.feed_list = input_list
43+
inputs.append(input_i)
44+
inputs = torch.stack(inputs, dim=0)
45+
result = torch.sum(inputs, axis=0)
46+
self.feed_list = inputs
5347
self.fetch_list = [result]
5448

5549

api/tests/categorical.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from common_import import *
2+
3+
4+
@benchmark_registry.register("categorical")
5+
class CategoricalConfig(APIConfig):
6+
def __init__(self):
7+
super(CategoricalConfig, self).__init__("categorical")
8+
self.feed_spec = {"range": [-5, -0.1]}
9+
10+
11+
@benchmark_registry.register("categorical")
12+
class PaddleCategorical(PaddleOpBenchmarkBase):
13+
def build_graph(self, config):
14+
logits = self.variable(
15+
name="logits",
16+
shape=config.logits_shape,
17+
dtype=config.logits_dtype)
18+
result = paddle.distribution.Categorical(logits)
19+
counts = result.sample([100])
20+
self.feed_list = [logits]
21+
self.fetch_list = [counts]
22+
23+
24+
@benchmark_registry.register("categorical")
25+
class TorchCategorical(PytorchOpBenchmarkBase):
26+
def build_graph(self, config):
27+
logits = self.variable(
28+
name="logits",
29+
shape=config.logits_shape,
30+
dtype=config.logits_dtype)
31+
result = torch.distributions.categorical.Categorical(
32+
logits=torch.tensor(logits))
33+
counts = result.sample([100])
34+
self.feed_list = [logits]
35+
self.fetch_list = [counts]
36+
37+
38+
@benchmark_registry.register("categorical")
39+
class TFCategoricall(TensorflowOpBenchmarkBase):
40+
def build_graph(self, config):
41+
logits = self.variable(
42+
name='logits',
43+
shape=config.logits_shape,
44+
dtype=config.logits_dtype)
45+
counts = tf.random.categorical(logits, 100)
46+
self.feed_list = [logits]
47+
self.fetch_list = [counts]

api/tests_v2/configs/add_n.json

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,4 +89,95 @@
8989
"type": "list<Variable>"
9090
}
9191
}
92+
}, {
93+
"op": "add_n",
94+
"param_info": {
95+
"inputs": {
96+
"inputs0": {
97+
"dtype": "float16",
98+
"shape": "[1L]",
99+
"type": "Variable"
100+
},
101+
"inputs1": {
102+
"dtype": "float16",
103+
"shape": "[1L]",
104+
"type": "Variable"
105+
},
106+
"type": "list<Variable>"
107+
}
108+
}
109+
}, {
110+
"op": "add_n",
111+
"param_info": {
112+
"inputs": {
113+
"inputs0": {
114+
"dtype": "float16",
115+
"shape": "[1L]",
116+
"type": "Variable"
117+
},
118+
"inputs1": {
119+
"dtype": "float16",
120+
"shape": "[1L]",
121+
"type": "Variable"
122+
},
123+
"inputs2": {
124+
"dtype": "float16",
125+
"shape": "[1L]",
126+
"type": "Variable"
127+
},
128+
"inputs3": {
129+
"dtype": "float16",
130+
"shape": "[1L]",
131+
"type": "Variable"
132+
},
133+
"inputs4": {
134+
"dtype": "float16",
135+
"shape": "[1L]",
136+
"type": "Variable"
137+
},
138+
"inputs5": {
139+
"dtype": "float16",
140+
"shape": "[1L]",
141+
"type": "Variable"
142+
},
143+
"inputs6": {
144+
"dtype": "float16",
145+
"shape": "[1L]",
146+
"type": "Variable"
147+
},
148+
"inputs7": {
149+
"dtype": "float16",
150+
"shape": "[1L]",
151+
"type": "Variable"
152+
},
153+
"type": "list<Variable>"
154+
}
155+
}
156+
}, {
157+
"op": "add_n",
158+
"param_info": {
159+
"inputs": {
160+
"inputs0": {
161+
"dtype": "float16",
162+
"shape": "[-1L, 256L]",
163+
"type": "Variable"
164+
},
165+
"inputs1": {
166+
"dtype": "float16",
167+
"shape": "[-1L, 256L]",
168+
"type": "Variable"
169+
},
170+
"inputs2": {
171+
"dtype": "float16",
172+
"shape": "[-1L, 256L]",
173+
"type": "Variable"
174+
},
175+
"inputs3": {
176+
"dtype": "float16",
177+
"shape": "[-1L, 256L]",
178+
"type": "Variable"
179+
},
180+
"type": "list<Variable>"
181+
}
182+
}
92183
}]
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
[{
2+
"op": "categorical",
3+
"param_info": {
4+
"logits": {
5+
"dtype": "float32",
6+
"shape": "[524288L, 23L]",
7+
"type": "Variable"
8+
}
9+
}
10+
}]

0 commit comments

Comments
 (0)