From 5a24d70bc1243327bc698b2304f8e59aa2045c43 Mon Sep 17 00:00:00 2001 From: Ryan Kuester Date: Wed, 27 May 2026 16:28:13 -0500 Subject: [PATCH] refactor(compression): hoist numpy dtype map into tensor_type Add a tensor_type module that holds the single mapping from a TFLite TensorType to a numpy dtype, and convert view.py to use it. The mapping was inlined in view.py; centralizing it gives the compression tooling one place to maintain as more callers need to read tensor buffers as numpy arrays. tensor_type.to_numpy() raises ValueError for types with no clean numpy equivalent (STRING, RESOURCE, VARIANT, BFLOAT16, and the sub-byte integer types) instead of silently returning a wrong dtype. Only types with an unambiguous little-endian numpy representation are mapped. BUG=part of #3256 --- tensorflow/lite/micro/compression/BUILD | 21 +++++++ .../lite/micro/compression/tensor_type.py | 61 +++++++++++++++++++ .../micro/compression/tensor_type_test.py | 45 ++++++++++++++ tensorflow/lite/micro/compression/view.py | 17 +----- 4 files changed, 129 insertions(+), 15 deletions(-) create mode 100644 tensorflow/lite/micro/compression/tensor_type.py create mode 100644 tensorflow/lite/micro/compression/tensor_type_test.py diff --git a/tensorflow/lite/micro/compression/BUILD b/tensorflow/lite/micro/compression/BUILD index 408fc33912e..39e35fa7797 100644 --- a/tensorflow/lite/micro/compression/BUILD +++ b/tensorflow/lite/micro/compression/BUILD @@ -247,6 +247,26 @@ py_test( ], ) +tflm_py_library( + name = "tensor_type", + srcs = ["tensor_type.py"], + deps = [ + "//tensorflow/lite/python:schema_py", + requirement("numpy"), + ], +) + +tflm_py_test( + name = "tensor_type_test", + size = "small", + srcs = ["tensor_type_test.py"], + deps = [ + ":tensor_type", + "//tensorflow/lite/python:schema_py", + requirement("numpy"), + ], +) + tflm_py_binary( name = "view", srcs = [ @@ -255,6 +275,7 @@ tflm_py_binary( target_compatible_with = INCOMPATIBLE_WITH_WINDOWS, deps = [ ":metadata_py", + ":tensor_type", "//tensorflow/lite/python:schema_py", requirement("absl_py"), requirement("bitarray"), diff --git a/tensorflow/lite/micro/compression/tensor_type.py b/tensorflow/lite/micro/compression/tensor_type.py new file mode 100644 index 00000000000..90df0a356fc --- /dev/null +++ b/tensorflow/lite/micro/compression/tensor_type.py @@ -0,0 +1,61 @@ +# Copyright 2026 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Single source of truth for mapping a TFLite TensorType to a numpy dtype. + +Compression tooling reads tensor buffer bytes as numpy arrays, so it needs to +know the element type. Only the TensorTypes with a clean numpy equivalent are +mapped; anything else raises rather than silently guessing a type. +""" + +import numpy as np + +from tflite_micro.tensorflow.lite.python import schema_py_generated as tflite + +# TFLite buffers are little-endian, so the dtypes are pinned to little-endian +# byte order to keep np.frombuffer correct on any host. +_TO_NUMPY = { + tflite.TensorType.FLOAT16: np.dtype(" name, for readable error messages. +_NAMES = { + value: name + for name, value in vars(tflite.TensorType).items() + if not name.startswith("_") +} + + +def to_numpy(tensor_type: int) -> np.dtype: + """Return the little-endian numpy dtype for a TFLite TensorType. + + Raises: + ValueError: if the type has no clean numpy equivalent (e.g. STRING, + RESOURCE, VARIANT, BFLOAT16, or the sub-byte INT4/UINT4/INT2 types). + """ + try: + return _TO_NUMPY[tensor_type] + except KeyError: + name = _NAMES.get(tensor_type, "?") + raise ValueError( + f"no numpy dtype for TFLite TensorType {name} ({tensor_type})") diff --git a/tensorflow/lite/micro/compression/tensor_type_test.py b/tensorflow/lite/micro/compression/tensor_type_test.py new file mode 100644 index 00000000000..58c3446eb0a --- /dev/null +++ b/tensorflow/lite/micro/compression/tensor_type_test.py @@ -0,0 +1,45 @@ +# Copyright 2026 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +from tflite_micro.tensorflow.lite.python import schema_py_generated as tflite +from tflite_micro.tensorflow.lite.micro.compression import tensor_type + + +class TensorTypeTest(unittest.TestCase): + + def test_maps_known_types_to_little_endian_dtypes(self): + self.assertEqual(tensor_type.to_numpy(tflite.TensorType.INT8), + np.dtype(" np.ndarray: model_tensor = model_subgraph.tensors[coordinates.tensor_index] value_buffer = self.model.buffers[metadata.valueBuffer] values = np.frombuffer(bytes(value_buffer.data), - dtype=_NP_DTYPES[model_tensor.type]) + dtype=tensor_type.to_numpy(model_tensor.type)) values_per_table = 2**metadata.indexBitwidth tables = len(values) // values_per_table values = values.reshape((tables, values_per_table))