|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | | -from __future__ import print_function |
16 | | - |
17 | 15 | import collections |
18 | 16 | import numpy as np |
19 | 17 |
|
20 | | -from . import paddle_api_benchmark as paddle_api |
21 | | -from . import tensorflow_api_benchmark as tensorflow_api |
| 18 | + |
| 19 | +def _convert_paddle_dtype(dtype, to_string=True): |
| 20 | + import paddle |
| 21 | + |
| 22 | + def _trans(to_string, dtype_str, np_dtype): |
| 23 | + dtype = dtype_str if to_string else np.dtype(np_dtype) |
| 24 | + return dtype |
| 25 | + |
| 26 | + if not isinstance(dtype, paddle.fluid.core.VarDesc.VarType): |
| 27 | + raise TypeError("dtype is not of type fluid.core.VarDesc.VarType") |
| 28 | + if dtype == paddle.fluid.core.VarDesc.VarType.FP32: |
| 29 | + return _trans(to_string, "float32", np.float32) |
| 30 | + elif dtype == paddle.fluid.core.VarDesc.VarType.FP64: |
| 31 | + return _trans(to_string, "float64", np.float64) |
| 32 | + elif dtype == paddle.fluid.core.VarDesc.VarType.FP16: |
| 33 | + return _trans(to_string, "float16", np.float16) |
| 34 | + elif dtype == paddle.fluid.core.VarDesc.VarType.INT32: |
| 35 | + return _trans(to_string, "int32", np.int32) |
| 36 | + elif dtype == paddle.fluid.core.VarDesc.VarType.INT16: |
| 37 | + return _trans(to_string, "int16", np.int16) |
| 38 | + elif dtype == paddle.fluid.core.VarDesc.VarType.INT64: |
| 39 | + return _trans(to_string, "int64", np.int64) |
| 40 | + elif dtype == paddle.fluid.core.VarDesc.VarType.BOOL: |
| 41 | + return _trans(to_string, "bool", np.bool) |
| 42 | + elif dtype == paddle.fluid.core.VarDesc.VarType.INT16: |
| 43 | + return _trans(to_string, "uint16", np.uint16) |
| 44 | + elif dtype == paddle.fluid.core.VarDesc.VarType.UINT8: |
| 45 | + return _trans(to_string, "uint8", np.uint8) |
| 46 | + elif dtype == paddle.fluid.core.VarDesc.VarType.INT8: |
| 47 | + return _trans(to_string, "int8", np.int8) |
| 48 | + else: |
| 49 | + raise ValueError("Unsupported dtype %s" % dtype) |
| 50 | + |
| 51 | + |
| 52 | +def _convert_tensorflow_dtype(dtype, to_string=True): |
| 53 | + import tensorflow as tf |
| 54 | + |
| 55 | + def _trans(to_string, dtype_str, np_dtype): |
| 56 | + dtype = dtype_str if to_string else np.dtype(np_dtype) |
| 57 | + return dtype |
| 58 | + |
| 59 | + if dtype == tf.float16: |
| 60 | + # tf.float16: 16-bit half-precision floating-point. |
| 61 | + return _trans(to_string, "float16", np.float16) |
| 62 | + elif dtype == tf.float32: |
| 63 | + # tf.float32: 32-bit single-precision floating-point. |
| 64 | + return _trans(to_string, "float32", np.float32) |
| 65 | + elif dtype == tf.float64: |
| 66 | + # tf.float64: 64-bit double-precision floating-point. |
| 67 | + return _trans(to_string, "float64", np.float64) |
| 68 | + elif dtype == tf.int8: |
| 69 | + # tf.int8: 8-bit signed integer. |
| 70 | + return _trans(to_string, "int8", np.int8) |
| 71 | + elif dtype == tf.uint8: |
| 72 | + # tf.uint8: 8-bit unsigned integer. |
| 73 | + return _trans(to_string, "uint8", np.uint8) |
| 74 | + elif dtype == tf.uint16: |
| 75 | + # tf.uint16: 16-bit unsigned integer. |
| 76 | + return _trans(to_string, "uint16", np.uint16) |
| 77 | + elif dtype == tf.uint32: |
| 78 | + # tf.uint32: 32-bit unsigned integer. |
| 79 | + return _trans(to_string, "uint32", np.uint32) |
| 80 | + elif dtype == tf.uint64: |
| 81 | + # tf.uint64: 64-bit unsigned integer. |
| 82 | + return _trans(to_string, "uint64", np.uint64) |
| 83 | + elif dtype == tf.int16: |
| 84 | + # tf.int16: 16-bit signed integer. |
| 85 | + return _trans(to_string, "int16", np.int16) |
| 86 | + elif dtype == tf.int32: |
| 87 | + # tf.int32: 32-bit signed integer. |
| 88 | + return _trans(to_string, "int32", np.int32) |
| 89 | + elif dtype == tf.int64: |
| 90 | + # tf.int64: 64-bit signed integer. |
| 91 | + return _trans(to_string, "int64", np.int64) |
| 92 | + elif dtype == tf.bool: |
| 93 | + # tf.bool: Boolean. |
| 94 | + return _trans(to_string, "bool", np.bool) |
| 95 | + else: |
| 96 | + # tf.bfloat16: 16-bit truncated floating-point. |
| 97 | + # tf.complex64: 64-bit single-precision complex. |
| 98 | + # tf.complex128: 128-bit double-precision complex. |
| 99 | + # tf.string: String. |
| 100 | + # tf.qint8: Quantized 8-bit signed integer. |
| 101 | + # tf.quint8: Quantized 8-bit unsigned integer. |
| 102 | + # tf.qint16: Quantized 16-bit signed integer. |
| 103 | + # tf.quint16: Quantized 16-bit unsigned integer. |
| 104 | + # tf.qint32: Quantized 32-bit signed integer. |
| 105 | + # tf.resource: Handle to a mutable resource. |
| 106 | + # tf.variant: Values of arbitrary types. |
| 107 | + raise ValueError("Unsupported dtype %s" % dtype) |
22 | 108 |
|
23 | 109 |
|
24 | 110 | def copy_feed_spec(feed_spec): |
@@ -132,7 +218,7 @@ def to_paddle(self, feed_vars=None): |
132 | 218 |
|
133 | 219 | # Check shape and dtype |
134 | 220 | var_shape = var.shape |
135 | | - var_dtype = paddle_api.convert_dtype( |
| 221 | + var_dtype = _convert_paddle_dtype( |
136 | 222 | var.dtype, to_string=True) |
137 | 223 | value = check_shape_and_dtype(var_shape, var_dtype, value) |
138 | 224 |
|
@@ -173,7 +259,7 @@ def _to_other(self, target_framework, feed_vars=None): |
173 | 259 | var = feed_list[i] |
174 | 260 | var_shape = var.shape |
175 | 261 | if target_framework == "tensorflow": |
176 | | - var_dtype = tensorflow_api.convert_dtype( |
| 262 | + var_dtype = _convert_tensorflow_dtype( |
177 | 263 | var.dtype, to_string=True) |
178 | 264 | value = check_shape_and_dtype(var_shape, var_dtype, value) |
179 | 265 |
|
|
0 commit comments