Skip to content

Commit a5516d3

Browse files
committed
Further polish the base class implementation.
1 parent cabaade commit a5516d3

6 files changed

Lines changed: 120 additions & 148 deletions

File tree

api/common/feeder.py

Lines changed: 92 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,99 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from __future__ import print_function
16-
1715
import collections
1816
import numpy as np
1917

20-
from . import paddle_api_benchmark as paddle_api
21-
from . import tensorflow_api_benchmark as tensorflow_api
18+
19+
def _convert_paddle_dtype(dtype, to_string=True):
20+
import paddle
21+
22+
def _trans(to_string, dtype_str, np_dtype):
23+
dtype = dtype_str if to_string else np.dtype(np_dtype)
24+
return dtype
25+
26+
if not isinstance(dtype, paddle.fluid.core.VarDesc.VarType):
27+
raise TypeError("dtype is not of type fluid.core.VarDesc.VarType")
28+
if dtype == paddle.fluid.core.VarDesc.VarType.FP32:
29+
return _trans(to_string, "float32", np.float32)
30+
elif dtype == paddle.fluid.core.VarDesc.VarType.FP64:
31+
return _trans(to_string, "float64", np.float64)
32+
elif dtype == paddle.fluid.core.VarDesc.VarType.FP16:
33+
return _trans(to_string, "float16", np.float16)
34+
elif dtype == paddle.fluid.core.VarDesc.VarType.INT32:
35+
return _trans(to_string, "int32", np.int32)
36+
elif dtype == paddle.fluid.core.VarDesc.VarType.INT16:
37+
return _trans(to_string, "int16", np.int16)
38+
elif dtype == paddle.fluid.core.VarDesc.VarType.INT64:
39+
return _trans(to_string, "int64", np.int64)
40+
elif dtype == paddle.fluid.core.VarDesc.VarType.BOOL:
41+
return _trans(to_string, "bool", np.bool)
42+
elif dtype == paddle.fluid.core.VarDesc.VarType.INT16:
43+
return _trans(to_string, "uint16", np.uint16)
44+
elif dtype == paddle.fluid.core.VarDesc.VarType.UINT8:
45+
return _trans(to_string, "uint8", np.uint8)
46+
elif dtype == paddle.fluid.core.VarDesc.VarType.INT8:
47+
return _trans(to_string, "int8", np.int8)
48+
else:
49+
raise ValueError("Unsupported dtype %s" % dtype)
50+
51+
52+
def _convert_tensorflow_dtype(dtype, to_string=True):
53+
import tensorflow as tf
54+
55+
def _trans(to_string, dtype_str, np_dtype):
56+
dtype = dtype_str if to_string else np.dtype(np_dtype)
57+
return dtype
58+
59+
if dtype == tf.float16:
60+
# tf.float16: 16-bit half-precision floating-point.
61+
return _trans(to_string, "float16", np.float16)
62+
elif dtype == tf.float32:
63+
# tf.float32: 32-bit single-precision floating-point.
64+
return _trans(to_string, "float32", np.float32)
65+
elif dtype == tf.float64:
66+
# tf.float64: 64-bit double-precision floating-point.
67+
return _trans(to_string, "float64", np.float64)
68+
elif dtype == tf.int8:
69+
# tf.int8: 8-bit signed integer.
70+
return _trans(to_string, "int8", np.int8)
71+
elif dtype == tf.uint8:
72+
# tf.uint8: 8-bit unsigned integer.
73+
return _trans(to_string, "uint8", np.uint8)
74+
elif dtype == tf.uint16:
75+
# tf.uint16: 16-bit unsigned integer.
76+
return _trans(to_string, "uint16", np.uint16)
77+
elif dtype == tf.uint32:
78+
# tf.uint32: 32-bit unsigned integer.
79+
return _trans(to_string, "uint32", np.uint32)
80+
elif dtype == tf.uint64:
81+
# tf.uint64: 64-bit unsigned integer.
82+
return _trans(to_string, "uint64", np.uint64)
83+
elif dtype == tf.int16:
84+
# tf.int16: 16-bit signed integer.
85+
return _trans(to_string, "int16", np.int16)
86+
elif dtype == tf.int32:
87+
# tf.int32: 32-bit signed integer.
88+
return _trans(to_string, "int32", np.int32)
89+
elif dtype == tf.int64:
90+
# tf.int64: 64-bit signed integer.
91+
return _trans(to_string, "int64", np.int64)
92+
elif dtype == tf.bool:
93+
# tf.bool: Boolean.
94+
return _trans(to_string, "bool", np.bool)
95+
else:
96+
# tf.bfloat16: 16-bit truncated floating-point.
97+
# tf.complex64: 64-bit single-precision complex.
98+
# tf.complex128: 128-bit double-precision complex.
99+
# tf.string: String.
100+
# tf.qint8: Quantized 8-bit signed integer.
101+
# tf.quint8: Quantized 8-bit unsigned integer.
102+
# tf.qint16: Quantized 16-bit signed integer.
103+
# tf.quint16: Quantized 16-bit unsigned integer.
104+
# tf.qint32: Quantized 32-bit signed integer.
105+
# tf.resource: Handle to a mutable resource.
106+
# tf.variant: Values of arbitrary types.
107+
raise ValueError("Unsupported dtype %s" % dtype)
22108

23109

24110
def copy_feed_spec(feed_spec):
@@ -132,7 +218,7 @@ def to_paddle(self, feed_vars=None):
132218

133219
# Check shape and dtype
134220
var_shape = var.shape
135-
var_dtype = paddle_api.convert_dtype(
221+
var_dtype = _convert_paddle_dtype(
136222
var.dtype, to_string=True)
137223
value = check_shape_and_dtype(var_shape, var_dtype, value)
138224

@@ -173,7 +259,7 @@ def _to_other(self, target_framework, feed_vars=None):
173259
var = feed_list[i]
174260
var_shape = var.shape
175261
if target_framework == "tensorflow":
176-
var_dtype = tensorflow_api.convert_dtype(
262+
var_dtype = _convert_tensorflow_dtype(
177263
var.dtype, to_string=True)
178264
value = check_shape_and_dtype(var_shape, var_dtype, value)
179265

api/common/paddle_api_benchmark.py

Lines changed: 0 additions & 84 deletions
This file was deleted.

api/common/paddle_op_benchmark.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import abc
16+
import six
1517
import sys
1618
import json
1719
import time
@@ -279,6 +281,10 @@ def generate_gradients(self, targets, inputs):
279281
class PaddleOpBenchmarkBase(BenchmarkBase):
280282
def __init__(self, testing_mode):
281283
super(PaddleOpBenchmarkBase, self).__init__(testing_mode)
284+
if self._testing_mode == "static":
285+
paddle.enable_static()
286+
else:
287+
paddle.disable_static()
282288

283289
def variable(self, name, shape, dtype, value=None, stop_gradient=False):
284290
return self._helper.variable(name, shape, dtype, value, stop_gradient)
@@ -558,6 +564,24 @@ def _get_output_stats(self, use_gpu, config, runtimes, walltimes=None):
558564
return stats
559565

560566

567+
@six.add_metaclass(abc.ABCMeta)
568+
class PaddleAPIBenchmarkBase(PaddleOpBenchmarkBase):
569+
def __init__(self):
570+
super(PaddleAPIBenchmarkBase, self).__init__("static")
571+
self.scope = None
572+
self.feed_vars = None
573+
self.fetch_vars = None
574+
575+
@abc.abstractmethod
576+
def build_program(self, config=None):
577+
pass
578+
579+
def build_graph(self, config=None):
580+
self.build_program(config)
581+
self.feed_list = self.feed_vars
582+
self.fetch_list = self.fetch_vars
583+
584+
561585
class PaddleDynamicAPIBenchmarkBase(PaddleOpBenchmarkBase):
562586
def __init__(self):
563587
super(PaddleDynamicAPIBenchmarkBase, self).__init__("dynamic")

api/common/tensorflow_api_benchmark.py

Lines changed: 0 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -95,62 +95,6 @@ def __exit__(self, exception_type, exception_value, traceback):
9595
return self
9696

9797

98-
def convert_dtype(dtype, to_string=True):
99-
def _trans(to_string, dtype_str, np_dtype):
100-
dtype = dtype_str if to_string else np.dtype(np_dtype)
101-
return dtype
102-
103-
if dtype == tf.float16:
104-
# tf.float16: 16-bit half-precision floating-point.
105-
return _trans(to_string, "float16", np.float16)
106-
elif dtype == tf.float32:
107-
# tf.float32: 32-bit single-precision floating-point.
108-
return _trans(to_string, "float32", np.float32)
109-
elif dtype == tf.float64:
110-
# tf.float64: 64-bit double-precision floating-point.
111-
return _trans(to_string, "float64", np.float64)
112-
elif dtype == tf.int8:
113-
# tf.int8: 8-bit signed integer.
114-
return _trans(to_string, "int8", np.int8)
115-
elif dtype == tf.uint8:
116-
# tf.uint8: 8-bit unsigned integer.
117-
return _trans(to_string, "uint8", np.uint8)
118-
elif dtype == tf.uint16:
119-
# tf.uint16: 16-bit unsigned integer.
120-
return _trans(to_string, "uint16", np.uint16)
121-
elif dtype == tf.uint32:
122-
# tf.uint32: 32-bit unsigned integer.
123-
return _trans(to_string, "uint32", np.uint32)
124-
elif dtype == tf.uint64:
125-
# tf.uint64: 64-bit unsigned integer.
126-
return _trans(to_string, "uint64", np.uint64)
127-
elif dtype == tf.int16:
128-
# tf.int16: 16-bit signed integer.
129-
return _trans(to_string, "int16", np.int16)
130-
elif dtype == tf.int32:
131-
# tf.int32: 32-bit signed integer.
132-
return _trans(to_string, "int32", np.int32)
133-
elif dtype == tf.int64:
134-
# tf.int64: 64-bit signed integer.
135-
return _trans(to_string, "int64", np.int64)
136-
elif dtype == tf.bool:
137-
# tf.bool: Boolean.
138-
return _trans(to_string, "bool", np.bool)
139-
else:
140-
# tf.bfloat16: 16-bit truncated floating-point.
141-
# tf.complex64: 64-bit single-precision complex.
142-
# tf.complex128: 128-bit double-precision complex.
143-
# tf.string: String.
144-
# tf.qint8: Quantized 8-bit signed integer.
145-
# tf.quint8: Quantized 8-bit unsigned integer.
146-
# tf.qint16: Quantized 16-bit signed integer.
147-
# tf.quint16: Quantized 16-bit unsigned integer.
148-
# tf.qint32: Quantized 32-bit signed integer.
149-
# tf.resource: Handle to a mutable resource.
150-
# tf.variant: Values of arbitrary types.
151-
raise ValueError("Unsupported dtype %s" % dtype)
152-
153-
15498
@six.add_metaclass(abc.ABCMeta)
15599
class TensorflowAPIBenchmarkBase(object):
156100
def __init__(self):

api/tests_v2/common_import.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
package_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
3232
sys.path.append(package_path)
3333

34-
from common.paddle_api_benchmark import PaddleAPIBenchmarkBase
34+
from common.paddle_op_benchmark import PaddleAPIBenchmarkBase
3535
from common.tensorflow_api_benchmark import TensorflowAPIBenchmarkBase
3636
from common.api_param import APIConfig
3737
from common.main import test_main, test_main_without_json

ci/scripts/run_test.sh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,9 @@ function run_api(){
9292
fail_name=()
9393
for name in ${API_NAMES[@]}
9494
do
95-
[[ "$name" == "common_import" ]] || continue
95+
if [[ "$name" == "common_import" ]]; then
96+
continue
97+
fi
9698
for device_type in "GPU" "CPU"
9799
do
98100
[ $device_type == "GPU" ] && device_limit="" || device_limit="env CUDA_VISIBLE_DEVICES="

0 commit comments

Comments
 (0)