diff --git a/python/tflite_micro/BUILD b/python/tflite_micro/BUILD index b358fd12adc..812cf7092fd 100644 --- a/python/tflite_micro/BUILD +++ b/python/tflite_micro/BUILD @@ -125,7 +125,7 @@ py_test( ":runtime", requirement("numpy"), requirement("tensorflow"), - "//tensorflow/lite/micro/compression", + "//tensorflow/lite/micro/compression:model_editor", ], ) diff --git a/python/tflite_micro/_runtime.cc b/python/tflite_micro/_runtime.cc index 246545fd016..53825f14f0d 100644 --- a/python/tflite_micro/_runtime.cc +++ b/python/tflite_micro/_runtime.cc @@ -33,10 +33,11 @@ PYBIND11_MODULE(_runtime, m) { .def(py::init([](const py::bytes& data, const std::vector& registerers_by_name, size_t arena_size, int num_resource_variables, - tflite::InterpreterConfig config) { - return std::unique_ptr( - new InterpreterWrapper(data.ptr(), registerers_by_name, arena_size, - num_resource_variables, config)); + tflite::InterpreterConfig config, + size_t alt_decompression_memory_size) { + return std::unique_ptr(new InterpreterWrapper( + data.ptr(), registerers_by_name, arena_size, num_resource_variables, + config, alt_decompression_memory_size)); })) .def("PrintAllocations", &InterpreterWrapper::PrintAllocations) .def("Invoke", &InterpreterWrapper::Invoke) diff --git a/python/tflite_micro/interpreter_wrapper.cc b/python/tflite_micro/interpreter_wrapper.cc index 669589890ad..c74ab84736b 100644 --- a/python/tflite_micro/interpreter_wrapper.cc +++ b/python/tflite_micro/interpreter_wrapper.cc @@ -238,7 +238,14 @@ InterpreterWrapper::~InterpreterWrapper() { InterpreterWrapper::InterpreterWrapper( PyObject* model_data, const std::vector& registerers_by_name, - size_t arena_size, int num_resource_variables, InterpreterConfig config) { + size_t arena_size, int num_resource_variables, InterpreterConfig config, + size_t alt_decompression_memory_size) + : memory_arena_(new uint8_t[arena_size]), + alt_decompression_memory_(alt_decompression_memory_size > 0 + ? new uint8_t[alt_decompression_memory_size] + : nullptr), + alt_decompression_region_{alt_decompression_memory_.get(), + alt_decompression_memory_size} { interpreter_ = nullptr; // `model_data` is used as a raw pointer beyond the scope of this @@ -266,7 +273,6 @@ InterpreterWrapper::InterpreterWrapper( "--//:with_compression=true to enable compression support."); } - memory_arena_ = std::unique_ptr(new uint8_t[arena_size]); for (const std::string& registerer : registerers_by_name) { if (!AddCustomOpRegistererByName(registerer.c_str(), &python_ops_resolver_)) { @@ -296,6 +302,14 @@ InterpreterWrapper::InterpreterWrapper( interpreter_ = new MicroInterpreter(model, python_ops_resolver_, allocator_, resource_variables_); + if (alt_decompression_memory_size > 0) { + TfLiteStatus status = + interpreter_->SetDecompressionMemory(&alt_decompression_region_, 1); + if (status != kTfLiteOk) { + ThrowRuntimeError("TFLM failed to set decompression memory"); + } + } + TfLiteStatus status = interpreter_->AllocateTensors(); if (status != kTfLiteOk) { ThrowRuntimeError("TFLM failed to allocate tensors"); diff --git a/python/tflite_micro/interpreter_wrapper.h b/python/tflite_micro/interpreter_wrapper.h index 9bb31b067fe..d3a156b337a 100644 --- a/python/tflite_micro/interpreter_wrapper.h +++ b/python/tflite_micro/interpreter_wrapper.h @@ -19,6 +19,7 @@ limitations under the License. #include "python/tflite_micro/python_ops_resolver.h" #include "tensorflow/lite/micro/micro_allocator.h" +#include "tensorflow/lite/micro/micro_context.h" #include "tensorflow/lite/micro/micro_interpreter.h" #include "tensorflow/lite/micro/recording_micro_allocator.h" @@ -40,7 +41,8 @@ class InterpreterWrapper { InterpreterWrapper( PyObject* model_data, const std::vector& registerers_by_name, size_t arena_size, int num_resource_variables, - InterpreterConfig config = InterpreterConfig::kAllocationRecording); + InterpreterConfig config = InterpreterConfig::kAllocationRecording, + size_t alt_decompression_memory_size = 0); ~InterpreterWrapper(); void PrintAllocations(); @@ -57,6 +59,8 @@ class InterpreterWrapper { tflite::RecordingMicroAllocator* recording_allocator_ = nullptr; const PyObject* model_; std::unique_ptr memory_arena_; + std::unique_ptr alt_decompression_memory_; + tflite::MicroContext::AlternateMemoryRegion alt_decompression_region_; tflite::PythonOpsResolver python_ops_resolver_; tflite::MicroInterpreter* interpreter_; }; diff --git a/python/tflite_micro/runtime.py b/python/tflite_micro/runtime.py index d895f8c4993..7052972b4a6 100644 --- a/python/tflite_micro/runtime.py +++ b/python/tflite_micro/runtime.py @@ -100,6 +100,7 @@ def __init__( custom_op_registerers, arena_size, intrepreter_config=InterpreterConfig.kAllocationRecording, + alt_decompression_memory_size=0, ): if model_data is None: raise ValueError("Model must not be None") @@ -122,6 +123,7 @@ def __init__( arena_size, num_resource_variables, _ENUM_TRANSLATOR[intrepreter_config], + alt_decompression_memory_size, ) @classmethod @@ -131,6 +133,7 @@ def from_file( custom_op_registerers=[], arena_size=None, intrepreter_config=InterpreterConfig.kAllocationRecording, + alt_decompression_memory_size=0, ): """Instantiates a TFLM interpreter from a model .tflite filepath. @@ -140,6 +143,9 @@ def from_file( custom OP registerer arena_size: Tensor arena size in bytes. If unused, tensor arena size will default to 10 times the model size. + alt_decompression_memory_size: Size in bytes of alternate decompression + memory. If non-zero, DECODE operators will use this memory instead of + the main arena for decompressed tensor outputs. Returns: An Interpreter instance @@ -155,6 +161,7 @@ def from_file( custom_op_registerers, arena_size, intrepreter_config, + alt_decompression_memory_size, ) @classmethod @@ -164,6 +171,7 @@ def from_bytes( custom_op_registerers=[], arena_size=None, intrepreter_config=InterpreterConfig.kAllocationRecording, + alt_decompression_memory_size=0, ): """Instantiates a TFLM interpreter from a model in byte array. @@ -173,6 +181,9 @@ def from_bytes( custom OP registerer arena_size: Tensor arena size in bytes. If unused, tensor arena size will default to 10 times the model size. + alt_decompression_memory_size: Size in bytes of alternate decompression + memory. If non-zero, DECODE operators will use this memory instead of + the main arena for decompressed tensor outputs. Returns: An Interpreter instance @@ -183,6 +194,7 @@ def from_bytes( custom_op_registerers, arena_size, intrepreter_config, + alt_decompression_memory_size, ) def print_allocations(self): diff --git a/python/tflite_micro/test_compression_unsupported.py b/python/tflite_micro/test_compression_unsupported.py index 3692dd0a43a..01c598374ce 100644 --- a/python/tflite_micro/test_compression_unsupported.py +++ b/python/tflite_micro/test_compression_unsupported.py @@ -12,84 +12,84 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Test compression metadata detection when compression is disabled.""" +"""Test legacy compression metadata detection when compression is disabled.""" import os import numpy as np import tensorflow as tf from tflite_micro.python.tflite_micro import runtime -from tflite_micro.tensorflow.lite.micro import compression +from tflite_micro.tensorflow.lite.micro.compression import model_editor -class CompressionDetectionTest(tf.test.TestCase): - """Test compression metadata detection when compression is disabled.""" +def _create_test_model(): + """Create a simple quantized model for testing.""" + model = tf.keras.Sequential([ + tf.keras.layers.Dense(10, input_shape=(5, ), activation='relu'), + tf.keras.layers.Dense(5, activation='softmax') + ]) + model.compile(optimizer='adam', loss='sparse_categorical_crossentropy') - def _create_test_model(self): - """Create a simple quantized model for testing.""" - model = tf.keras.Sequential([ - tf.keras.layers.Dense(10, input_shape=(5, ), activation='relu'), - tf.keras.layers.Dense(5, activation='softmax') - ]) - model.compile(optimizer='adam', loss='sparse_categorical_crossentropy') + converter = tf.lite.TFLiteConverter.from_keras_model(model) + converter.optimizations = [tf.lite.Optimize.DEFAULT] - # Convert to quantized TFLite - converter = tf.lite.TFLiteConverter.from_keras_model(model) - converter.optimizations = [tf.lite.Optimize.DEFAULT] + def representative_dataset(): + for _ in range(10): + yield [np.random.randn(1, 5).astype(np.float32)] - def representative_dataset(): - for _ in range(10): - yield [np.random.randn(1, 5).astype(np.float32)] + converter.representative_dataset = representative_dataset + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.uint8 + converter.inference_output_type = tf.uint8 - converter.representative_dataset = representative_dataset - converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] - converter.inference_input_type = tf.uint8 - converter.inference_output_type = tf.uint8 + tflite_model = converter.convert() + return bytes(tflite_model) if isinstance(tflite_model, + bytearray) else tflite_model - tflite_model = converter.convert() - return bytes(tflite_model) if isinstance(tflite_model, - bytearray) else tflite_model + +def _inject_compression_metadata(model_data): + """Inject raw COMPRESSION_METADATA into a model's flatbuffer metadata. + + This simulates a legacy-compressed model (one that uses the + COMPRESSION_METADATA metadata entry and kernel-level decompression) without + going through compress(), which now produces DECODE-based output. + """ + model = model_editor.read(model_data) + model.metadata["COMPRESSION_METADATA"] = b"\x00" + return bytes(model.build()) + + +class LegacyCompressionDetectionTest(tf.test.TestCase): + """Test that legacy COMPRESSION_METADATA is rejected without the flag.""" def test_regular_model_loads_successfully(self): """Non-compressed models should load without issues.""" - model_data = self._create_test_model() + model_data = _create_test_model() interpreter = runtime.Interpreter.from_bytes(model_data) self.assertIsNotNone(interpreter) - def test_compressed_model_raises_runtime_error(self): - """Compressed models should raise RuntimeError when compression is disabled.""" - # Create and compress a model - model_data = self._create_test_model() + def test_legacy_compressed_model_raises_runtime_error(self): + """Models with COMPRESSION_METADATA should raise RuntimeError.""" + model_data = _create_test_model() + legacy_model = _inject_compression_metadata(model_data) - spec = (compression.SpecBuilder().add_tensor( - subgraph=0, tensor=1).with_lut(index_bitwidth=4).build()) - - compressed_model = compression.compress(model_data, spec) - if isinstance(compressed_model, bytearray): - compressed_model = bytes(compressed_model) - - # Should raise RuntimeError with self.assertRaises(RuntimeError): - runtime.Interpreter.from_bytes(compressed_model) - - def test_can_load_regular_after_compressed_failure(self): - """Verify we can still load regular models after compressed model fails.""" - model_data = self._create_test_model() + runtime.Interpreter.from_bytes(legacy_model) - # First try compressed model (should fail) - spec = (compression.SpecBuilder().add_tensor( - subgraph=0, tensor=1).with_lut(index_bitwidth=4).build()) - compressed_model = compression.compress(model_data, spec) + def test_can_load_regular_after_legacy_failure(self): + """Verify regular models still load after a legacy-compressed failure.""" + model_data = _create_test_model() + legacy_model = _inject_compression_metadata(model_data) with self.assertRaises(RuntimeError): - runtime.Interpreter.from_bytes(bytes(compressed_model)) + runtime.Interpreter.from_bytes(legacy_model) - # Then load regular model (should succeed) interpreter = runtime.Interpreter.from_bytes(model_data) self.assertIsNotNone(interpreter) if __name__ == '__main__': - # Set TF environment variables to suppress warnings + # Suppress TF C++ info/debug logs (0=DEBUG, 1=INFO, 2=WARNING, 3=ERROR) os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' + # Disable oneDNN to avoid non-deterministic floating point results os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0' tf.test.main() diff --git a/tensorflow/lite/micro/compression/BUILD b/tensorflow/lite/micro/compression/BUILD index 408fc33912e..ef1880f5abd 100644 --- a/tensorflow/lite/micro/compression/BUILD +++ b/tensorflow/lite/micro/compression/BUILD @@ -123,14 +123,15 @@ py_library( "compress.py", ], deps = [ - ":metadata_py", - ":model_facade", + ":compressor", + ":decode_insert", + ":huffman", + ":lut", + ":model_editor", + ":pruning", ":spec", "//tensorflow/lite/micro/tools:tflite_flatbuffer_align", requirement("absl_py"), - requirement("flatbuffers"), - requirement("bitarray"), - requirement("numpy"), ], ) @@ -159,33 +160,54 @@ py_test( target_compatible_with = INCOMPATIBLE_WITH_WINDOWS, deps = [ ":compress", - ":metadata_py", - ":model_facade", + ":compressor", + ":decode_insert", + ":model_editor", ":spec", - ":test_models", "//tensorflow/lite/python:schema_py", - requirement("bitarray"), requirement("numpy"), ], ) -tflm_py_library( - name = "model_facade", - srcs = ["model_facade.py"], +tflm_py_test( + name = "compression_integration_test", + size = "small", + srcs = ["compression_integration_test.py"], + tags = [ + "noasan", + "nomsan", + "noubsan", + ], + target_compatible_with = INCOMPATIBLE_WITH_WINDOWS, deps = [ + ":compress_lib", + ":decode_insert", + ":model_editor", + ":spec", + "//python/tflite_micro:runtime", "//tensorflow/lite/python:schema_py", - requirement("flatbuffers"), + requirement("numpy"), ], ) -py_test( - name = "model_facade_test", +tflm_py_test( + name = "proprietary_integration_test", size = "small", - srcs = ["model_facade_test.py"], + srcs = ["proprietary_integration_test.py"], + tags = [ + "manual", + "noasan", + "nomsan", + "noubsan", + ], target_compatible_with = INCOMPATIBLE_WITH_WINDOWS, deps = [ - ":model_facade", - ":test_models", + ":compress_lib", + ":model_editor", + ":spec", + "//python/tflite_micro:runtime", + "//tensorflow/lite/python:schema_py", + requirement("numpy"), ], ) @@ -227,8 +249,8 @@ py_test( ) tflm_py_library( - name = "test_models", - srcs = ["test_models.py"], + name = "model_editor", + srcs = ["model_editor.py"], deps = [ "//tensorflow/lite/python:schema_py", requirement("flatbuffers"), @@ -236,14 +258,123 @@ tflm_py_library( ], ) -py_test( - name = "test_models_test", +tflm_py_test( + name = "model_editor_test", size = "small", - srcs = ["test_models_test.py"], - target_compatible_with = INCOMPATIBLE_WITH_WINDOWS, + srcs = ["model_editor_test.py"], + deps = [ + ":model_editor", + "//tensorflow/lite/python:schema_py", + requirement("numpy"), + ], +) + +tflm_py_library( + name = "decode", + srcs = ["decode.py"], +) + +tflm_py_test( + name = "decode_test", + size = "small", + srcs = ["decode_test.py"], + deps = [ + ":decode", + ], +) + +tflm_py_library( + name = "compressor", + srcs = ["compressor.py"], + deps = [ + ":decode", + ":model_editor", + ":spec", + ], +) + +tflm_py_library( + name = "lut", + srcs = ["lut.py"], deps = [ - ":test_models", + ":compressor", + ":decode", + ":model_editor", + ":spec", + requirement("bitarray"), + requirement("numpy"), + ], +) + +tflm_py_test( + name = "lut_test", + size = "small", + srcs = ["lut_test.py"], + tags = [ + "noasan", + "nomsan", + "noubsan", + ], + deps = [ + ":compressor", + ":decode", + ":lut", + ":model_editor", + ":spec", "//tensorflow/lite/python:schema_py", + requirement("numpy"), + ], +) + +tflm_py_library( + name = "huffman", + srcs = ["huffman.py"], + deps = [ + ":compressor", + ":decode", + ":model_editor", + ":spec", + ], +) + +tflm_py_library( + name = "pruning", + srcs = ["pruning.py"], + deps = [ + ":compressor", + ":decode", + ":model_editor", + ":spec", + ], +) + +tflm_py_library( + name = "decode_insert", + srcs = ["decode_insert.py"], + deps = [ + ":compressor", + ":model_editor", + "//tensorflow/lite/python:schema_py", + ], +) + +tflm_py_test( + name = "decode_insert_test", + size = "small", + srcs = ["decode_insert_test.py"], + tags = [ + "noasan", + "nomsan", + "noubsan", + ], + deps = [ + ":compressor", + ":decode", + ":decode_insert", + ":lut", + ":model_editor", + "//tensorflow/lite/python:schema_py", + requirement("numpy"), ], ) diff --git a/tensorflow/lite/micro/compression/compress.py b/tensorflow/lite/micro/compression/compress.py index bd67bf5637b..96b55d94fd7 100644 --- a/tensorflow/lite/micro/compression/compress.py +++ b/tensorflow/lite/micro/compression/compress.py @@ -16,22 +16,22 @@ See USAGE. """ -import bitarray -import bitarray.util -from dataclasses import dataclass, field import os import sys import tempfile -from typing import ByteString, Iterable, Optional +import warnings +from typing import ByteString, Iterable, Type import absl.app import absl.flags -import flatbuffers -import numpy as np -from tflite_micro.tensorflow.lite.micro.compression import model_facade +from tflite_micro.tensorflow.lite.micro.compression import compressor +from tflite_micro.tensorflow.lite.micro.compression import decode_insert +from tflite_micro.tensorflow.lite.micro.compression import huffman +from tflite_micro.tensorflow.lite.micro.compression import lut +from tflite_micro.tensorflow.lite.micro.compression import model_editor +from tflite_micro.tensorflow.lite.micro.compression import pruning from tflite_micro.tensorflow.lite.micro.compression import spec -from tflite_micro.tensorflow.lite.micro.compression import metadata_py_generated as schema from tflite_micro.tensorflow.lite.micro.tools import tflite_flatbuffer_align_wrapper USAGE = f"""\ @@ -49,219 +49,48 @@ {spec.EXAMPLE_YAML_SPEC} --- -The only compression method currently implemented is "lut", i.e., -Look-Up-Table. This method requires the tensor in the input model to have a -small number of unique values, fewer than or equal to 2**index_bitwidth. LUT -compression collects these values into a lookup table, and rewrites the tensor -as bitwidth-wide integer indices into that lookup table. Presumably, the input -model has been trained or preprocessed in a way that the tensor values -are binned into a meaningful, limited set. -""" - -# A compressed model augments the usual .tflite flatbuffer with a flatbuffer of -# its own containing compression metadata, stored at the buffer index stored at -# the following key in the .tflite flatbuffer's metadata map. -TFLITE_METADATA_KEY = "COMPRESSION_METADATA" - - -class CompressionError(Exception): - """Raised when compression fails for the reason documented in the message.""" - - def __init__(self, message, wrapped_exception=None): - super().__init__(f"{message}: {str(wrapped_exception)}") - self.original_exception = wrapped_exception - - -class _MetadataBuilder: - """Builder for the compression metadata flatbuffer.""" - - def __init__(self): - self._metadata = schema.MetadataT() - self._metadata.subgraphs = [] - - def compile(self) -> bytearray: - """Packs the metadata into a binary array and returns it. - """ - builder = flatbuffers.Builder(1 * 2**10) - root = self._metadata.Pack(builder) - builder.Finish(root) - return builder.Output() - - def subgraph(self, index: int): - """Return subgraph at index, adding subgraphs if necessary. - """ - while len(self._metadata.subgraphs) <= index: - self._add_subgraph() - return self._metadata.subgraphs[index] - - def add_lut_tensor(self, subgraph_id: int): - """Add LUT tensor to the given subgraph and return it. - """ - tensor = schema.LutTensorT() - self.subgraph(subgraph_id).lutTensors.append(tensor) - return tensor - - def _add_subgraph(self): - subgraph = schema.SubgraphT() - subgraph.lutTensors = [] - self._metadata.subgraphs.append(subgraph) - return subgraph +Supported compression methods: + lut: Look-Up-Table compression. Requires the tensor to have a small number of + unique values, fewer than or equal to 2**index_bitwidth. LUT compression + collects these values into a lookup table, and rewrites the tensor as + bitwidth-wide integer indices into that lookup table. -@dataclass -class _LutCompressedArray: - compression_axis: Optional[int] = None - lookup_tables: list[np.ndarray] = field(default_factory=list) - indices: np.ndarray = field(default_factory=lambda: np.array([])) + huffman: Huffman compression using Xtensa-format decode tables. (Not yet + implemented.) - @property - def index_bitwidth(self) -> int: - """Returns the number of bits required to encode the indices.""" - if self.indices is None: - raise ValueError - - max_index = int(np.max(self.indices)) - return max_index.bit_length() or 1 - - -def _lut_compress_array(tensor: np.ndarray, - axis: Optional[int]) -> _LutCompressedArray: - """Compresses the given tensor using lookup tables. - - Args: - tensor (np.ndarray): The tensor to be compressed. - - axis (Optional[int]): The axis along which to compress the tensor. If an - axis is given, a lookup table is created for each slice along the - axis. If axis is None, a single lookup table is used for the entire - tensor. - - Compressing a tensor with a lookup table per slice along a - particular axis is analogous to quantizing a tensor with different - quantization parameters per slice along a particular axis (dimension). - - Returns: - _LutCompressedArray: An object containing the compressed tensor data, - including the lookup tables and indices. - """ - compressed = _LutCompressedArray() - compressed.compression_axis = axis - - if axis is None: - # Compute unique values and indices for the entire tensor - values, indices = np.unique(tensor, return_inverse=True) - compressed.lookup_tables.append(values) - compressed.indices = indices.reshape(tensor.shape) - else: - # Iterate over slices along the compression axis - slice_indices = [] - for slice in np.moveaxis(tensor, axis, 0): - values, indices = np.unique(slice, return_inverse=True) - compressed.lookup_tables.append(values) - indices = indices.reshape(slice.shape) - slice_indices.append(indices) - - # Reconstruct a tensor of indices from the slices - stacked = np.stack(slice_indices, axis=0) - compressed.indices = np.moveaxis(stacked, 0, axis) - - return compressed - - -def _check_lut_compression(compression) -> spec.LookUpTableCompression: - if len(compression) != 1: - raise CompressionError("Each tensor must have exactly one compression") - if not isinstance(compression[0], spec.LookUpTableCompression): - raise CompressionError('Only "lut" compression may be specified') - - return compression[0] - - -def _identify_compression_axis(tensor: model_facade._Tensor) -> Optional[int]: - """Determines the axis along which to compress. - - The axis along which to compress is inferred from the tensor's quantization - parameters. - - Returns: - The axis along which to compress, or None to indicate one value table for - the entire tensor. - - Raises: - CompressionError: If the axis cannot be determined. - """ - q = tensor.quantization - if q is not None \ - and q.scale is not None \ - and q.quantizedDimension < len(tensor.shape): - quantization_channels = len(q.scale) - if quantization_channels == 1: - # Use one value table for the entire tensor - return None - - if quantization_channels == tensor.shape[q.quantizedDimension]: - return q.quantizedDimension - - raise CompressionError( - f"Invalid or no quanitzation parameters from which to " - f"infer the axis along which tensor should be compressed.") - - -def _check_bitwidth(compressed: int, specified: int, spec: spec.Tensor): - """Applies business logic regarding specified bitwidth. - - It is an error if the bitwidth required to compress a tensor exceeds the - specified bitwith, and a warning if the tensor can be compressed in less than - the specified bitwidth. The latter is allowed, and is not an error, to permit - testing with larger bitwidths without re-binning a model. - """ - if compressed > specified: - raise CompressionError( - f"index_bitwidth too small: {compressed} bits needed to " - f"enumerate unique values in tensor specified in {spec}") - elif compressed < specified: - print( - f"warning: index_bitwidth too large: only {compressed} " - f"bits needed to enumerate unique values in tensor specified in {spec}", - file=sys.stderr) - - -def _pack_indices(indices: np.ndarray, bitwidth: int) -> bytes: - """Packs indices into a bytearray using bitwidth-sized fields. - """ - endianness = "big" - bits = bitarray.bitarray(endian=endianness) - for i in indices.ravel(): - bits.extend( - bitarray.util.int2ba(int(i), length=bitwidth, endian=endianness)) - return bits.tobytes() + pruning: Pruning (sparsity) compression for sparse tensors. (Not yet + implemented.) +Compressed models use DECODE operators to decompress tensors at runtime. +""" -def _pack_lookup_tables(tables: list[np.ndarray], table_len: int) -> bytearray: - """Packs the value tables of a LutCompressedArray. +# Plugin dispatch table: maps CompressionMethod subclasses to compressor instances +_COMPRESSORS: dict[Type[spec.CompressionMethod], compressor.Compressor] = { + spec.LookUpTableCompression: lut.LutCompressor(), + spec.HuffmanCompression: huffman.HuffmanCompressor(), + spec.PruningCompression: pruning.PruningCompressor(), +} - Pack the value tables of a LutCompressedArray into a bytes object in the - format writable to a value_table buffer in the .tflite flatbuffer. The - tables are concatenated. - """ - buffer = bytearray() - for t in tables: - padding_needed = table_len - len(t) - padded = np.pad(t, (0, padding_needed), mode='constant', constant_values=0) - buffer.extend(padded.tobytes()) - return buffer +def _get_compressor(method: spec.CompressionMethod) -> compressor.Compressor: + """Get the compressor plugin for a given compression method.""" + compressor_instance = _COMPRESSORS.get(type(method)) + if compressor_instance is None: + raise compressor.CompressionError( + f"No compressor registered for {type(method).__name__}") + return compressor_instance def _apply_flatbuffer_alignment(model_bytes: bytearray) -> bytearray: """Applies proper FlatBuffer alignment to a model. - + The Python flatbuffers library doesn't respect `force_align` schema attributes, so we use the C++ wrapper which properly handles alignment requirements. - + Args: model_bytes: The model flatbuffer to align - + Returns: The properly aligned model flatbuffer """ @@ -293,46 +122,66 @@ def _apply_flatbuffer_alignment(model_bytes: bytearray) -> bytearray: def compress(model_in: ByteString, specs: Iterable[spec.Tensor]) -> bytearray: """Compresses a model .tflite flatbuffer. + Compresses tensors according to the given specs and inserts DECODE operators + to decompress them at runtime. + Args: model_in: the original, uncompressed .tflite flatbuffer specs: an iterable of compression specs, see module spec.py Returns: - A compressed flatbuffer. + A compressed flatbuffer with DECODE operators inserted. """ - model = model_facade.read(model_in) - metadata = _MetadataBuilder() + specs = list(specs) + if not specs: + raise compressor.CompressionError( + "Compression spec is empty; no tensors to compress") - for spec in specs: - try: - tensor = model.subgraphs[spec.subgraph].tensors[spec.tensor] - lut_compression = _check_lut_compression(spec.compression) - spec_bitwidth = lut_compression.index_bitwidth - axis = _identify_compression_axis(tensor) - compressed = _lut_compress_array(tensor.array, axis) - _check_bitwidth(compressed.index_bitwidth, spec_bitwidth, spec) - - # overwrite tensor data with indices - tensor.buffer.data = _pack_indices(compressed.indices, spec_bitwidth) - - # write value buffer - value_buffer = model.add_buffer() - value_buffer.data = _pack_lookup_tables(compressed.lookup_tables, - 2**spec_bitwidth) - # add compression metadata for tensor - lut_tensor = metadata.add_lut_tensor(subgraph_id=tensor.subgraph.index) - lut_tensor.tensor = tensor.index - lut_tensor.valueBuffer = value_buffer.index - lut_tensor.indexBitwidth = spec_bitwidth + model = model_editor.read(model_in) + compression_results: dict[tuple[int, int], compressor.CompressionResult] = {} + for tensor_spec in specs: + try: + tensor = model.subgraphs[tensor_spec.subgraph].tensors[ + tensor_spec.tensor] + + # Currently only one compression method per tensor + if len(tensor_spec.compression) != 1: + raise compressor.CompressionError( + "Each tensor must have exactly one compression method") + + method = tensor_spec.compression[0] + plugin = _get_compressor(method) + original_size = len(tensor.buffer.data) if tensor.buffer.data else 0 + result = plugin.compress(tensor, method) + + compressed_size = len(result.encoded_data) + len(result.ancillary_data) + if compressed_size > original_size: + warnings.warn( + f"Compression of tensor {tensor.name!r} (subgraph " + f"{tensor_spec.subgraph}, tensor {tensor_spec.tensor}) resulted " + f"in expansion: {original_size} bytes -> {compressed_size} bytes " + f"(encoded: {len(result.encoded_data)}, " + f"ancillary: {len(result.ancillary_data)})", + stacklevel=2) + + # Replace tensor data with encoded data + tensor.buffer.data = result.encoded_data + + # Store result for DECODE insertion + compression_results[(tensor_spec.subgraph, tensor_spec.tensor)] = result + + except compressor.CompressionError: + raise except Exception as e: - raise CompressionError(f"error compressing {spec}") from e + raise compressor.CompressionError( + f"error compressing {tensor_spec}") from e - # add compression metadata to model - model.add_metadata(TFLITE_METADATA_KEY, metadata.compile()) + # Insert DECODE operators into the graph + decode_insert.insert_decode_operators(model, compression_results) - # Compile the model and apply proper alignment - unaligned_model = model.compile() + # Build the model and apply proper alignment + unaligned_model = model.build() return _apply_flatbuffer_alignment(unaligned_model) diff --git a/tensorflow/lite/micro/compression/compress_test.py b/tensorflow/lite/micro/compression/compress_test.py index 012957acc90..6ee80f200d5 100644 --- a/tensorflow/lite/micro/compression/compress_test.py +++ b/tensorflow/lite/micro/compression/compress_test.py @@ -11,312 +11,109 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""Integration tests for the compression system.""" + +import warnings -import bitarray -import bitarray.util import numpy as np import unittest from tflite_micro.tensorflow.lite.micro.compression import compress -from tflite_micro.tensorflow.lite.micro.compression import metadata_py_generated as schema -from tflite_micro.tensorflow.lite.micro.compression import model_facade +from tflite_micro.tensorflow.lite.micro.compression import compressor +from tflite_micro.tensorflow.lite.micro.compression import decode_insert +from tflite_micro.tensorflow.lite.micro.compression import model_editor from tflite_micro.tensorflow.lite.micro.compression import spec -from tflite_micro.tensorflow.lite.micro.compression import test_models from tflite_micro.tensorflow.lite.python import schema_py_generated as tflite -class TestPackIndices(unittest.TestCase): - - def test_basic_case(self): - indices = np.array([1, 2, 3]) - bitwidth = 4 - result = compress._pack_indices(indices, bitwidth) - expected_bytes = bytes([0b0001_0010, 0b0011_0000]) - self.assertEqual(result, expected_bytes) - - def test_single_element(self): - indices = np.array([10]) - bitwidth = 8 - result = compress._pack_indices(indices, bitwidth) - expected_bytes = bytes([0b0000_1010]) - self.assertEqual(result, expected_bytes) - - def test_different_bitwidth(self): - indices = np.array([1, 2, 3]) - bitwidth = 8 - result = compress._pack_indices(indices, bitwidth) - expected_bytes = bytes([0b0000_0001, 0b0000_0010, 0b0000_0011]) - self.assertEqual(result, expected_bytes) - - def test_large_numbers(self): - indices = np.array([255, 128, 64]) - bitwidth = 8 - result = compress._pack_indices(indices, bitwidth) - expected_bytes = bytes([0b1111_1111, 0b1000_0000, 0b0100_0000]) - self.assertEqual(result, expected_bytes) - - def test_multidimensional_array(self): - indices = np.array([[1, 2], [3, 4]]) - bitwidth = 4 - result = compress._pack_indices(indices, bitwidth) - expected_bytes = bytes([0b0001_0010, 0b0011_0100]) - self.assertEqual(result, expected_bytes) - - def test_zero_bitwidth(self): - indices = np.array([0, 1, 2]) - bitwidth = 0 - with self.assertRaises(ValueError): - compress._pack_indices(indices, bitwidth) - - def test_empty_array(self): - indices = np.array([]) - bitwidth = 4 - result = compress._pack_indices(indices, bitwidth) - expected_bytes = b"" - self.assertEqual(result, expected_bytes) - - def test_bitwidth_1(self): - indices = np.array([1, 0, 1, 1, 0, 1]) - bitwidth = 1 - result = compress._pack_indices(indices, bitwidth) - expected_bytes = bytes([0b101101_00]) - self.assertEqual(result, expected_bytes) - - def test_bitwidth_2(self): - indices = np.array([1, 2, 3, 0]) - bitwidth = 2 - result = compress._pack_indices(indices, bitwidth) - expected_bytes = bytes([0b01_10_11_00]) - self.assertEqual(result, expected_bytes) - - def test_bitwidth_3(self): - indices = np.array([1, 3, 5, 7]) - bitwidth = 3 - result = compress._pack_indices(indices, bitwidth) - expected_bytes = bytes([0b001_011_10, 0b1_111_0000]) - self.assertEqual(result, expected_bytes) - - def test_bitwidth_5(self): - indices = np.array([1, 2, 16, 31]) - bitwidth = 5 - result = compress._pack_indices(indices, bitwidth) - expected_bytes = bytes([0b00001_000, 0b10_10000_1, 0b1111_0000]) - self.assertEqual(result, expected_bytes) - - def test_bitwidth_7(self): - indices = np.array([1, 64, 127, 32]) - bitwidth = 7 - result = compress._pack_indices(indices, bitwidth) - expected_bytes = bytes( - [0b0000001_1, 0b000000_11, 0b11111_010, 0b0000_0000]) - self.assertEqual(result, expected_bytes) - - -class TestPackLookupTables(unittest.TestCase): - - def test_int16_positive(self): - tables = [np.array([0x1234, 0x5678], dtype=' tuple[int, bitarray.bitarray, np.ndarray]: - """Helper: extracts the compressed tensor parts for a given spec. - - Returns: - bitwidth - indices - values - """ - subgraph_obj = self.compressed.subgraphs[subgraph] - tensor_obj = subgraph_obj.tensors[tensor] - lut_tensors = self.metadata.subgraphs[subgraph_obj.index].lutTensors - lut_tensor = next(t for t in lut_tensors if t.tensor == tensor_obj.index) - bitwidth = lut_tensor.indexBitwidth - - indices = bitarray.bitarray(buffer=tensor_obj.buffer.data, endian="big") - n_indices = np.prod(tensor_obj.shape) - indices = indices[:n_indices * bitwidth] # trim possible padding - - value_buffer = self.compressed.buffers[lut_tensor.valueBuffer] - values = np.frombuffer(value_buffer.data, dtype=tensor_obj.dtype) - - return bitwidth, indices, values - - def _make_indices(self, s: str) -> bitarray.bitarray: - """Helper: makes indices from "01" strings for use as expected values.""" - return bitarray.bitarray(s, endian="big") - - def test_compressed_uint8(self): - bitwidth, indices, values = self._get_compressed(subgraph=0, tensor=0) - self.assertEqual(bitwidth, 4) - - # yapf: disable - expected_indices = self._make_indices(""" - 0000 0001 0010 0011 - 0100 0101 0110 0111 - 1000 1001 1010 1011 - 1100 1101 1110 1111 - """) - # yapf: enable - self.assertEqual(indices, expected_indices) - - expected_values = np.array(range(16), dtype=" [FC1 with weights1] -> output1 + input2 -> [FC2 with weights2] -> intermediate -> [FC3 with weights1] -> output2 + + weights1 is shared between FC1 and FC3. weights2 is used only by FC2, which + runs between the two consumers of weights1. + """ + # 4 unique values per tensor for 2-bit LUT compression. Small values avoid + # saturation in chained layers. Different row sums produce varied outputs. + weights1_data = np.array([ + [-1, 0, 0, 1], + [-1, 0, 1, 1], + [-1, 1, 1, 1], + [0, 1, 1, 1], + ], + dtype=np.int8) + weights1 = model_editor.Tensor( + shape=(4, 4), + dtype=tflite.TensorType.INT8, + data=weights1_data, + name="weights1", + quantization=model_editor.Quantization(scales=1.0, zero_points=0), + ) + + weights2_data = np.array([ + [1, 1, 1, 1], + [1, 1, 2, 2], + [1, 2, 2, 3], + [2, 2, 3, 3], + ], + dtype=np.int8) + weights2 = model_editor.Tensor( + shape=(4, 4), + dtype=tflite.TensorType.INT8, + data=weights2_data, + name="weights2", + quantization=model_editor.Quantization(scales=1.0, zero_points=0), + ) + + # All tensors need matching quantization for FULLY_CONNECTED + quant = model_editor.Quantization(scales=1.0, zero_points=0) + + input1 = model_editor.Tensor( + shape=(1, 4), + dtype=tflite.TensorType.INT8, + name="input1", + quantization=quant, + ) + input2 = model_editor.Tensor( + shape=(1, 4), + dtype=tflite.TensorType.INT8, + name="input2", + quantization=quant, + ) + output1 = model_editor.Tensor( + shape=(1, 4), + dtype=tflite.TensorType.INT8, + name="output1", + quantization=quant, + ) + intermediate = model_editor.Tensor( + shape=(1, 4), + dtype=tflite.TensorType.INT8, + name="intermediate", + quantization=quant, + ) + output2 = model_editor.Tensor( + shape=(1, 4), + dtype=tflite.TensorType.INT8, + name="output2", + quantization=quant, + ) + + model = model_editor.Model(subgraphs=[ + model_editor.Subgraph( + tensors=[weights1, weights2], + inputs=[input1, input2], + outputs=[output1, output2], + operators=[ + # FC1: uses weights1 + model_editor.Operator( + opcode=tflite.BuiltinOperator.FULLY_CONNECTED, + inputs=[input1, weights1], + outputs=[output1], + ), + # FC2: uses weights2 (runs between FC1 and FC3) + model_editor.Operator( + opcode=tflite.BuiltinOperator.FULLY_CONNECTED, + inputs=[input2, weights2], + outputs=[intermediate], + ), + # FC3: uses weights1 (second consumer, after DECODE(weights2)) + model_editor.Operator( + opcode=tflite.BuiltinOperator.FULLY_CONNECTED, + inputs=[intermediate, weights1], + outputs=[output2], + ), + ], + ) + ]) + return model.build() + + +class AltDecompressionMemoryTest(unittest.TestCase): + """Tests for alternate decompression memory with shared compressed tensors. + + These tests verify correct behavior when compressed tensors are shared + between multiple operators and alternate decompression memory is enabled. + """ + + def test_shared_compressed_tensor_with_alt_memory(self): + """Verify correct results when a shared compressed tensor is used with alt + decompression memory. + + This test uses a graph where a compressed tensor (weights1) is consumed by + two operators (FC1 and FC3), with an intervening DECODE of a different + compressed tensor (weights2) between them. + + The interpreter's alternate decompression memory has a limitation: each + DECODE's Prepare resets the allocation offset to zero. This means all + DECODE outputs are allocated at the same address, so they overwrite each + other. A DECODE output can only be used until the next DECODE runs. + + To work around this limitation, the DECODE insertion code inserts a + separate DECODE immediately before each consumer of a compressed tensor, + rather than sharing one DECODE output among all consumers. + """ + flatbuffer = _build_shared_weights_model() + + specs = [ + spec.Tensor( + subgraph=0, + tensor=0, # weights1 + compression=[spec.LookUpTableCompression(index_bitwidth=2)], + ), + spec.Tensor( + subgraph=0, + tensor=1, # weights2 + compression=[spec.LookUpTableCompression(index_bitwidth=2)], + ), + ] + + compressed_fb = compress.compress(flatbuffer, specs) + + # Run without alt decompression memory (baseline) + interp_no_alt = runtime.Interpreter.from_bytes(bytes(compressed_fb)) + + # Run with alt decompression memory + interp_with_alt = runtime.Interpreter.from_bytes( + bytes(compressed_fb), + alt_decompression_memory_size=256, + ) + + test_input1 = np.array([[1, 1, 1, 1]], dtype=np.int8) + test_input2 = np.array([[1, 1, 1, 1]], dtype=np.int8) + + interp_no_alt.set_input(test_input1, 0) + interp_no_alt.set_input(test_input2, 1) + interp_no_alt.invoke() + expected1 = interp_no_alt.get_output(0) + expected2 = interp_no_alt.get_output(1) + + interp_with_alt.set_input(test_input1, 0) + interp_with_alt.set_input(test_input2, 1) + interp_with_alt.invoke() + actual1 = interp_with_alt.get_output(0) + actual2 = interp_with_alt.get_output(1) + + np.testing.assert_array_equal( + expected1, actual1, "Output 1 mismatch with alt decompression memory") + np.testing.assert_array_equal( + expected2, actual2, "Output 2 mismatch with alt decompression memory") + + +class HuffmanCompressionTest(unittest.TestCase): + """Integration tests for Huffman compression.""" + + @unittest.skip("Huffman compression not yet implemented") + def test_huffman_compressed_model_matches_uncompressed(self): + """Huffman-compressed model produces same outputs as uncompressed.""" + pass + + @unittest.skip("Huffman compression not yet implemented") + def test_huffman_decode_operators_present(self): + """DECODE operators are inserted for Huffman-compressed tensors.""" + pass + + @unittest.skip("Huffman compression not yet implemented") + def test_huffman_compressed_model_is_smaller(self): + """Huffman-compressed model is smaller than original.""" + pass + + +class PruningCompressionTest(unittest.TestCase): + """Integration tests for pruning compression.""" + + @unittest.skip("Pruning compression not yet implemented") + def test_pruning_compressed_model_matches_uncompressed(self): + """Pruning-compressed model produces same outputs as uncompressed.""" + pass + + @unittest.skip("Pruning compression not yet implemented") + def test_pruning_decode_operators_present(self): + """DECODE operators are inserted for pruning-compressed tensors.""" + pass + + @unittest.skip("Pruning compression not yet implemented") + def test_pruning_compressed_model_is_smaller(self): + """Pruning-compressed model is smaller than original.""" + pass + + +if __name__ == "__main__": + # Suppress TF C++ info/debug logs (0=DEBUG, 1=INFO, 2=WARNING, 3=ERROR) + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" + # Disable oneDNN to avoid non-deterministic floating point results + os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" + unittest.main() diff --git a/tensorflow/lite/micro/compression/compressor.py b/tensorflow/lite/micro/compression/compressor.py new file mode 100644 index 00000000000..3d5a635eb09 --- /dev/null +++ b/tensorflow/lite/micro/compression/compressor.py @@ -0,0 +1,80 @@ +# Copyright 2026 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Compression plugin interface.""" + +from dataclasses import dataclass +from typing import Protocol + +from tflite_micro.tensorflow.lite.micro.compression import decode +from tflite_micro.tensorflow.lite.micro.compression import model_editor +from tflite_micro.tensorflow.lite.micro.compression import spec + + +class CompressionError(Exception): + """Raised when compression fails for the reason documented in the message.""" + + def __init__(self, message, wrapped_exception=None): + if wrapped_exception: + super().__init__(f"{message}: {str(wrapped_exception)}") + else: + super().__init__(message) + self.original_exception = wrapped_exception + + +@dataclass +class CompressionResult: + """Result of compressing a tensor. + + Attributes: + encoded_data: The compressed tensor data (e.g., packed indices for LUT). + ancillary_data: The complete ancillary data tensor bytes (DCM + type-specific + data). This is the full buffer contents for the ancillary + tensor. + """ + encoded_data: bytes + ancillary_data: bytes + + +class Compressor(Protocol): + """Protocol that compression plugins must implement. + + Each compression method (LUT, Huffman, Pruning) provides a class implementing + this protocol. The compress() function uses duck typing to call the plugin. + """ + + @property + def decode_type(self) -> decode.DecodeType: + """The DecodeType constant for this compression method.""" + ... + + def compress( + self, + tensor: model_editor.Tensor, + method: spec.CompressionMethod, + ) -> CompressionResult: + """Compress a tensor according to the specified method. + + Args: + tensor: The tensor to compress. Must have data (tensor.array is not None) + and quantization parameters for axis inference. + method: The compression method spec (e.g., LookUpTableCompression). + + Returns: + CompressionResult with encoded tensor data and ancillary data bytes. + + Raises: + CompressionError: If compression fails (e.g., too many unique values + for specified bitwidth, missing quantization, etc.). + """ + ... diff --git a/tensorflow/lite/micro/compression/decode.py b/tensorflow/lite/micro/compression/decode.py new file mode 100644 index 00000000000..df8428310a3 --- /dev/null +++ b/tensorflow/lite/micro/compression/decode.py @@ -0,0 +1,240 @@ +# Copyright 2026 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""DECODE compression module.""" + +# Implements the DECODE operator compression scheme described in the +# "TFLM DECODE Operator Design" document, revised May 20, 2025. +# +# The DECODE operator transforms an encoded tensor, alongside a paired +# ancillary data tensor, into a tensor ready for use as input to any +# operator. For example, an encoded tensor might contain compressed +# data, while the paired ancillary data tensor holds the information +# necessary for decompression. The DECODE operator's output is a fully +# decompressed tensor. +# +# DECODE operators are inserted into the TfLite model subgraph +# immediately before each operation that uses a decodable tensor as +# input. +# +# Ancillary Data Tensor +# +# The ancillary data tensor contains the information necessary for +# decoding. It begins with a 16-byte DECODE Common Metadata (DCM) +# header, followed by decode-type-specific ancillary data. +# +# DECODE Common Metadata (DCM) +# +# Byte 0: Decode type +# 0-127: TFLM-supported decode operations (see below) +# 128-255: Custom operations requiring application-registered +# handlers +# +# Supported decode types: +# +# 0: LUT decompression +# All TFLM tensor types supported in reference and optimized +# code. +# +# 1: Huffman decompression using Xtensa format decode tables +# INT8 and INT16 tensor types only, in reference and optimized +# code. +# +# 2: Pruning decompression +# All TFLM tensor types supported in reference and optimized +# code. +# +# 3-127: Reserved +# +# 128-255: Custom decode types +# Requires user-supplied encoding module and decoding ancillary +# data. +# +# Byte 1: DCM version (currently 1) +# +# Bytes 2-3: Reserved +# +# Bytes 4-15: User-defined +# Used by TFLM decode types to avoid requiring additional alignment +# of metadata or ancillary data. +# +# The 16-byte DCM size ensures that subsequent metadata and ancillary +# data are 128-bit aligned, which is required for some optimized +# decoding operations such as Xtensa LUT decompression. +# +# For TFLM decode types, ancillary data starts immediately after the +# DCM. For custom decode types, the location is determined by +# user-defined metadata. + +from dataclasses import dataclass +from typing import Protocol + + +class DecodeType: + """Decode operation type (0-255). + + Use predefined constants for built-in types or DecodeType.custom() + for custom types: + DecodeType.LUT # 0 + DecodeType.HUFFMAN # 1 + DecodeType.PRUNING # 2 + DecodeType.custom(200) # Custom type 128-255 + """ + + # Built-in decode types (class variables set after class definition) + LUT: 'DecodeType' + HUFFMAN: 'DecodeType' + PRUNING: 'DecodeType' + + def __init__(self, code: int, name: str = None): + """Initialize DecodeType. + + Args: + code: Integer code 0-255 + name: Optional name for the type. If not provided: + - Codes 0-127: Named "TYPE_{code}" + - Codes 128-255: Named "CUSTOM_{code}" + """ + if not 0 <= code <= 255: + raise ValueError(f"Decode type must be 0-255, got {code}") + self.code = code + + # Auto-generate name if not provided + if name is None: + self.name = f"CUSTOM_{code}" if code >= 128 else f"TYPE_{code}" + else: + self.name = name + + self._is_custom = code >= 128 + + @property + def is_custom(self) -> bool: + """True if this is a custom decode type (128-255).""" + return self._is_custom + + @classmethod + def custom(cls, code: int) -> 'DecodeType': + """Create custom decode type (128-255). + + Args: + code: Integer code 128-255 + + Returns: + DecodeType with name CUSTOM_{code} + """ + if not 128 <= code <= 255: + raise ValueError(f"Custom decode type must be 128-255, got {code}") + return cls(code) + + def __int__(self): + """Convert to integer for serialization.""" + return self.code + + def __eq__(self, other): + if isinstance(other, DecodeType): + return self.code == other.code + return self.code == other + + def __repr__(self): + return f"DecodeType.{self.name}({self.code})" + + +# Define built-in decode type constants +DecodeType.LUT = DecodeType(0, "LUT") +DecodeType.HUFFMAN = DecodeType(1, "HUFFMAN") +DecodeType.PRUNING = DecodeType(2, "PRUNING") + + +@dataclass +class DecodeCommonMetadata: + """16-byte DECODE Common Metadata (DCM) header. + + Attributes: + decode_type: Decode operation type. Use DecodeType constants or + DecodeType.custom(code) for custom types. + version: DCM version (currently 1). + user_data: 12 bytes of user-defined data (bytes 4-15 of DCM). Used by TFLM + decode types to avoid requiring additional alignment of metadata + or ancillary data. + """ + decode_type: DecodeType + version: int = 1 + user_data: bytes = b'\x00' * 12 + + def to_bytes(self) -> bytes: + """Serialize DCM to 16-byte sequence.""" + decode_code = int(self.decode_type) + if not 0 <= self.version <= 255: + raise ValueError(f"version must be 0-255, got {self.version}") + if len(self.user_data) < 12: + # Pad with zeros if user_data is too short + user_data = self.user_data + b'\x00' * (12 - len(self.user_data)) + else: + user_data = self.user_data[:12] + + result = bytearray(16) + result[0] = decode_code + result[1] = self.version + # bytes 2-3 remain zero (reserved) + result[4:16] = user_data + return bytes(result) + + +class AncillaryDataSerializer(Protocol): + """Protocol for objects that can serialize ancillary data.""" + + def to_bytes(self) -> bytes: + ... + + +@dataclass +class AncillaryDataTensor: + """Complete Ancillary Data Tensor (ADT): DCM + decode-type-specific data. + + The ADT is stored as a buffer in the TFLite model. It begins with a 16-byte + DCM header, followed by decode-type-specific ancillary data. + + Attributes: + dcm: The DECODE Common Metadata header. + ancillary_data: The decode-type-specific ancillary data, either as raw bytes + or as an object implementing the AncillaryDataSerializer + protocol. May be None if only the DCM is needed. + """ + dcm: DecodeCommonMetadata + ancillary_data: AncillaryDataSerializer | bytes | None = None + + def with_ancillary_data( + self, data: AncillaryDataSerializer | bytes) -> 'AncillaryDataTensor': + """Create new ADT with ancillary data added. + + Args: + data: Ancillary data to add, either as raw bytes or as an object + implementing AncillaryDataSerializer. + + Returns: + New AncillaryDataTensor with the specified ancillary data. + """ + return AncillaryDataTensor(self.dcm, data) + + def to_bytes(self) -> bytes: + """Serialize entire ADT to bytes. + + Returns: + Byte sequence containing DCM followed by ancillary data (if present). + """ + dcm_bytes = self.dcm.to_bytes() + if self.ancillary_data is None: + return dcm_bytes + if isinstance(self.ancillary_data, bytes): + return dcm_bytes + self.ancillary_data + return dcm_bytes + self.ancillary_data.to_bytes() diff --git a/tensorflow/lite/micro/compression/decode_insert.py b/tensorflow/lite/micro/compression/decode_insert.py new file mode 100644 index 00000000000..fa91896e538 --- /dev/null +++ b/tensorflow/lite/micro/compression/decode_insert.py @@ -0,0 +1,280 @@ +# Copyright 2026 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""DECODE operator insertion into TFLite model graphs. + +This module inserts DECODE operators into a compressed model. DECODE operators +transform encoded tensors (with their paired ancillary data tensors) into +tensors ready for use by downstream operators. + +The DECODE operator is registered as a custom operator named "TFLM_DECODE". +Each DECODE output requires two inputs: the encoded tensor and the ancillary +data tensor (containing the DCM header and decode-type-specific data). +""" + +import warnings +from collections import defaultdict +from dataclasses import dataclass +from typing import Optional + +from tflite_micro.tensorflow.lite.micro.compression import compressor +from tflite_micro.tensorflow.lite.micro.compression import model_editor +from tflite_micro.tensorflow.lite.python import schema_py_generated as tflite + +# Custom operator name for DECODE +DECODE_CUSTOM_OP_NAME = "TFLM_DECODE" + + +@dataclass +class _CompressedTensorInfo: + """Information about a compressed tensor for DECODE insertion.""" + subgraph_idx: int + tensor_idx: int + tensor: model_editor.Tensor + encoded_data: bytes + ancillary_data: bytes + consumers: list[model_editor.Operator] + + +def _find_tensor_consumers( + subgraph: model_editor.Subgraph, + tensor: model_editor.Tensor, +) -> list[model_editor.Operator]: + """Find all operators in subgraph that use tensor as an input.""" + consumers = [] + for op in subgraph.operators: + if tensor in op.inputs: + consumers.append(op) + return consumers + + +def _create_ancillary_tensor( + ancillary_data: bytes, + original_tensor: model_editor.Tensor, +) -> model_editor.Tensor: + """Create an ancillary data tensor for a compressed tensor. + + Args: + ancillary_data: The complete ancillary data (DCM + type-specific data). + original_tensor: The original tensor being decoded, for naming. + + Returns: + A new Tensor containing the ancillary data. + """ + name = None + if original_tensor.name: + name = f"{original_tensor.name}_ancillary" + + return model_editor.Tensor( + shape=(len(ancillary_data), ), + dtype=tflite.TensorType.UINT8, + data=ancillary_data, + name=name, + ) + + +def _create_output_tensor( + original_tensor: model_editor.Tensor, ) -> model_editor.Tensor: + """Create the output tensor for a DECODE operator. + + The output tensor has the same shape, dtype, and quantization as the + original tensor would have when decoded. It has no data---the DECODE + operator produces its values at runtime. + + Args: + original_tensor: The original tensor being decoded. + + Returns: + A new Tensor for the DECODE output. + """ + name = None + if original_tensor.name: + name = f"{original_tensor.name}_decoded" + + return model_editor.Tensor( + shape=original_tensor.shape, + dtype=original_tensor.dtype, + quantization=original_tensor.quantization, + name=name, + ) + + +def _rewire_consumers( + consumers: list[model_editor.Operator], + old_tensor: model_editor.Tensor, + new_tensor: model_editor.Tensor, +) -> None: + """Replace old_tensor with new_tensor in all consumer inputs.""" + for consumer in consumers: + consumer.inputs = [ + new_tensor if t is old_tensor else t for t in consumer.inputs + ] + + +def _rewrite_encoded_tensor( + tensor: model_editor.Tensor, + encoded_data: bytes, +) -> None: + """Rewrite a compressed tensor to hold encoded data. + + The original tensor contained uncompressed values with quantization. After + compression, it holds packed indices (or other encoded form) as raw bytes. + This function updates the tensor in place to reflect its new role. + + Args: + tensor: The tensor to rewrite. + encoded_data: The compressed/encoded data bytes. + """ + tensor.shape = (len(encoded_data), ) + tensor.dtype = tflite.TensorType.UINT8 + tensor.quantization = None + tensor.buffer.data = encoded_data + + +def insert_decode_operators( + model: model_editor.Model, + compression_results: dict[tuple[int, int], compressor.CompressionResult], +) -> None: + """Insert DECODE operators for all compressed tensors. + + This function modifies the model in-place, inserting DECODE operators + before any operator that uses a compressed tensor as input. + + A separate DECODE is inserted before each consumer, rather than sharing one + DECODE output among all consumers. This is required because the interpreter's + alternate decompression memory resets its allocation offset for each DECODE's + Prepare, causing all DECODE outputs to be allocated at the same address. If + two consumers share one DECODE and another DECODE runs between them, the + intervening DECODE overwrites the shared output, corrupting data for the + second consumer. + + For each consumer of a compressed tensor: + 1. Create an ancillary data tensor containing DCM + type-specific data + 2. Create an output tensor with the same shape/dtype as the decoded tensor + 3. Insert a DECODE operator immediately before the consumer + 4. Rewire the consumer to use the DECODE output + + Args: + model: The model to modify in-place. + compression_results: Map from (subgraph_idx, tensor_idx) to the + CompressionResult containing ancillary_data. + """ + # Group compressed tensors by subgraph + by_subgraph: dict[int, list[_CompressedTensorInfo]] = defaultdict(list) + + for (sg_idx, tensor_idx), result in compression_results.items(): + subgraph = model.subgraphs[sg_idx] + tensor = subgraph.tensors[tensor_idx] + consumers = _find_tensor_consumers(subgraph, tensor) + + if not consumers: + # Check if tensor is a subgraph output + is_output = tensor in subgraph.outputs + if is_output: + # TODO: Handle compressed tensors that are subgraph outputs. + # This occurs in multi-subgraph models using IF/WHILE where a + # compressed tensor flows out of a subgraph. + raise NotImplementedError( + f"Compressed tensor {tensor.name!r} (subgraph {sg_idx}, " + f"tensor {tensor_idx}) is a subgraph output with no consumers. " + "Compressed subgraph outputs are not yet supported.") + else: + warnings.warn( + f"Compressed tensor {tensor.name!r} (subgraph {sg_idx}, " + f"tensor {tensor_idx}) has no consumers and is not a subgraph " + "output. No DECODE operator will be inserted.", + stacklevel=2) + continue + + info = _CompressedTensorInfo( + subgraph_idx=sg_idx, + tensor_idx=tensor_idx, + tensor=tensor, + encoded_data=result.encoded_data, + ancillary_data=result.ancillary_data, + consumers=consumers, + ) + by_subgraph[sg_idx].append(info) + + # Process each subgraph + for sg_idx, tensor_infos in by_subgraph.items(): + subgraph = model.subgraphs[sg_idx] + + # Group tensor infos by consumer so multiple compressed inputs to the + # same operator get batched into a single DECODE. + consumer_to_infos: dict[model_editor.Operator, list[_CompressedTensorInfo]] + consumer_to_infos = defaultdict(list) + for info in tensor_infos: + for consumer in info.consumers: + if info not in consumer_to_infos[consumer]: + consumer_to_infos[consumer].append(info) + + # Sort consumers by position in reverse so insertions don't invalidate + # earlier positions. + sorted_consumers = sorted( + consumer_to_infos.keys(), + key=lambda op: subgraph.operators.index(op), + reverse=True, + ) + + # Cache ancillary tensors by original tensor to avoid duplicates. Each + # DECODE needs its own output tensor, but ancillary data is identical for + # all DECODEs of the same compressed tensor. + ancillary_cache: dict[model_editor.Tensor, model_editor.Tensor] = {} + + # Track tensors to rewrite after all output tensors are created, since + # _create_output_tensor reads the original tensor's shape/dtype/quantization. + tensors_to_rewrite: dict[model_editor.Tensor, bytes] = {} + + for consumer in sorted_consumers: + decode_inputs = [] + decode_outputs = [] + + for info in consumer_to_infos[consumer]: + # Reuse or create ancillary data tensor + if info.tensor not in ancillary_cache: + ancillary_tensor = _create_ancillary_tensor( + info.ancillary_data, + info.tensor, + ) + subgraph.tensors.append(ancillary_tensor) + ancillary_cache[info.tensor] = ancillary_tensor + tensors_to_rewrite[info.tensor] = info.encoded_data + else: + ancillary_tensor = ancillary_cache[info.tensor] + + # Create output tensor (one per compressed input) + output_tensor = _create_output_tensor(info.tensor) + subgraph.tensors.append(output_tensor) + + decode_inputs.extend([info.tensor, ancillary_tensor]) + decode_outputs.append(output_tensor) + + # Rewire this consumer to use the decoded output + _rewire_consumers([consumer], info.tensor, output_tensor) + + # Create single DECODE operator for all compressed inputs + decode_op = model_editor.Operator( + opcode=tflite.BuiltinOperator.CUSTOM, + custom_code=DECODE_CUSTOM_OP_NAME, + inputs=decode_inputs, + outputs=decode_outputs, + ) + + # Insert DECODE immediately before this consumer + insert_pos = subgraph.operators.index(consumer) + subgraph.operators.insert(insert_pos, decode_op) + + # Rewrite encoded tensors after all output tensors are created + for tensor, encoded_data in tensors_to_rewrite.items(): + _rewrite_encoded_tensor(tensor, encoded_data) diff --git a/tensorflow/lite/micro/compression/decode_insert_test.py b/tensorflow/lite/micro/compression/decode_insert_test.py new file mode 100644 index 00000000000..60965b46676 --- /dev/null +++ b/tensorflow/lite/micro/compression/decode_insert_test.py @@ -0,0 +1,559 @@ +# Copyright 2026 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit tests for DECODE operator insertion.""" + +import unittest +import warnings + +import numpy as np + +from tflite_micro.tensorflow.lite.micro.compression import compressor +from tflite_micro.tensorflow.lite.micro.compression import decode +from tflite_micro.tensorflow.lite.micro.compression import decode_insert +from tflite_micro.tensorflow.lite.micro.compression import lut +from tflite_micro.tensorflow.lite.micro.compression import model_editor +from tflite_micro.tensorflow.lite.python import schema_py_generated as tflite + + +def _build_simple_fc_model(): + """Build a simple model with one FC operator and compressible weights.""" + # yapf: disable + weights = model_editor.Tensor( + shape=(4, 4), + dtype=tflite.TensorType.INT8, + data=np.array([[1, 2, 1, 2], + [3, 4, 3, 4], + [1, 2, 1, 2], + [3, 4, 3, 4]], dtype=np.int8), + name="weights", + quantization=model_editor.Quantization(scales=0.5, zero_points=0), + ) + # yapf: enable + input_t = model_editor.Tensor( + shape=(1, 4), + dtype=tflite.TensorType.INT8, + name="input", + ) + output_t = model_editor.Tensor( + shape=(1, 4), + dtype=tflite.TensorType.INT8, + name="output", + ) + + model = model_editor.Model(subgraphs=[ + model_editor.Subgraph( + tensors=[weights], + operators=[ + model_editor.Operator( + opcode=tflite.BuiltinOperator.FULLY_CONNECTED, + inputs=[input_t, weights], + outputs=[output_t], + ) + ], + ) + ]) + return model + + +def _build_shared_weights_model(): + """Build model where one tensor is used by multiple operators.""" + weights = model_editor.Tensor( + shape=(4, 4), + dtype=tflite.TensorType.INT8, + data=np.ones((4, 4), dtype=np.int8), + name="shared_weights", + quantization=model_editor.Quantization(scales=0.5, zero_points=0), + ) + input1 = model_editor.Tensor( + shape=(1, 4), + dtype=tflite.TensorType.INT8, + name="input1", + ) + input2 = model_editor.Tensor( + shape=(1, 4), + dtype=tflite.TensorType.INT8, + name="input2", + ) + output1 = model_editor.Tensor( + shape=(1, 4), + dtype=tflite.TensorType.INT8, + name="output1", + ) + output2 = model_editor.Tensor( + shape=(1, 4), + dtype=tflite.TensorType.INT8, + name="output2", + ) + + model = model_editor.Model(subgraphs=[ + model_editor.Subgraph( + tensors=[weights], + operators=[ + model_editor.Operator( + opcode=tflite.BuiltinOperator.FULLY_CONNECTED, + inputs=[input1, weights], + outputs=[output1], + ), + model_editor.Operator( + opcode=tflite.BuiltinOperator.FULLY_CONNECTED, + inputs=[input2, weights], + outputs=[output2], + ), + ], + ) + ]) + return model + + +def _make_dummy_ancillary_data(bitwidth=4) -> bytes: + """Create dummy ancillary data for testing.""" + n_entries = 1 << bitwidth + value_tables = bytes(range(1, n_entries + 1)) + value_tables += b'\x00' * ((-len(value_tables)) % 16) + + lut_data = lut.LutAncillaryData( + bitwidth=bitwidth, + value_table_stride=n_entries, + value_tables=value_tables, + ) + dcm = decode.DecodeCommonMetadata( + decode_type=decode.DecodeType.LUT, + user_data=lut_data.to_user_data(), + ) + return dcm.to_bytes() + lut_data.to_bytes() + + +class TestDecodeInsertion(unittest.TestCase): + """Tests for insert_decode_operators function.""" + + def test_insert_single_decode_operator(self): + """DECODE operator inserted before FC that uses compressed weights.""" + model = _build_simple_fc_model() + weights_tensor = model.subgraphs[0].tensor_by_name("weights") + + # Create compression result + compression_results = { + (0, 0): + compressor.CompressionResult( + encoded_data=b'\x00\x00', + ancillary_data=_make_dummy_ancillary_data(), + ) + } + + # Insert DECODE operators + decode_insert.insert_decode_operators(model, compression_results) + + sg = model.subgraphs[0] + + # Should have 2 operators: DECODE then FC + self.assertEqual(len(sg.operators), 2) + self.assertEqual(sg.operators[0].opcode, tflite.BuiltinOperator.CUSTOM) + self.assertEqual(sg.operators[0].custom_code, + decode_insert.DECODE_CUSTOM_OP_NAME) + self.assertEqual(sg.operators[1].opcode, + tflite.BuiltinOperator.FULLY_CONNECTED) + + def test_decode_inputs_structure(self): + """DECODE operator has correct inputs: encoded tensor + ancillary.""" + model = _build_simple_fc_model() + weights_tensor = model.subgraphs[0].tensor_by_name("weights") + + compression_results = { + (0, 0): + compressor.CompressionResult( + encoded_data=b'\x00\x00', + ancillary_data=_make_dummy_ancillary_data(), + ) + } + + decode_insert.insert_decode_operators(model, compression_results) + + decode_op = model.subgraphs[0].operators[0] + + # DECODE has 2 inputs + self.assertEqual(len(decode_op.inputs), 2) + # First input is the encoded tensor (original weights) + self.assertIs(decode_op.inputs[0], weights_tensor) + # Second input is ancillary tensor + self.assertEqual(decode_op.inputs[1].dtype, tflite.TensorType.UINT8) + + def test_decode_output_structure(self): + """DECODE operator output has correct shape and dtype.""" + model = _build_simple_fc_model() + weights_tensor = model.subgraphs[0].tensor_by_name("weights") + + # Save original properties before rewrite + original_shape = weights_tensor.shape + original_dtype = weights_tensor.dtype + + compression_results = { + (0, 0): + compressor.CompressionResult( + encoded_data=b'\x00\x00', + ancillary_data=_make_dummy_ancillary_data(), + ) + } + + decode_insert.insert_decode_operators(model, compression_results) + + decode_op = model.subgraphs[0].operators[0] + output = decode_op.outputs[0] + + # Output matches original (pre-rewrite) tensor shape and dtype + self.assertEqual(output.shape, original_shape) + self.assertEqual(output.dtype, original_dtype) + + def test_consumer_rewired_to_decode_output(self): + """FC operator input rewired to use DECODE output.""" + model = _build_simple_fc_model() + weights_tensor = model.subgraphs[0].tensor_by_name("weights") + + compression_results = { + (0, 0): + compressor.CompressionResult( + encoded_data=b'\x00\x00', + ancillary_data=_make_dummy_ancillary_data(), + ) + } + + decode_insert.insert_decode_operators(model, compression_results) + + decode_op = model.subgraphs[0].operators[0] + fc_op = model.subgraphs[0].operators[1] + + # FC's second input (weights) should now be DECODE's output + self.assertIs(fc_op.inputs[1], decode_op.outputs[0]) + # Original weights tensor should NOT be in FC inputs + self.assertNotIn(weights_tensor, fc_op.inputs) + + def test_shared_tensor_decode_per_consumer(self): + """Tensor used by multiple ops gets separate DECODE for each consumer.""" + model = _build_shared_weights_model() + weights_tensor = model.subgraphs[0].tensor_by_name("shared_weights") + + compression_results = { + (0, 0): + compressor.CompressionResult( + encoded_data=b'\x00\x00', + ancillary_data=_make_dummy_ancillary_data(), + ) + } + + decode_insert.insert_decode_operators(model, compression_results) + + sg = model.subgraphs[0] + + # Should have 4 operators: 2 DECODEs + 2 FCs (DECODE before each FC) + self.assertEqual(len(sg.operators), 4) + self.assertEqual(sg.operators[0].opcode, tflite.BuiltinOperator.CUSTOM) + self.assertEqual(sg.operators[1].opcode, + tflite.BuiltinOperator.FULLY_CONNECTED) + self.assertEqual(sg.operators[2].opcode, tflite.BuiltinOperator.CUSTOM) + self.assertEqual(sg.operators[3].opcode, + tflite.BuiltinOperator.FULLY_CONNECTED) + + decode_op1 = sg.operators[0] + fc_op1 = sg.operators[1] + decode_op2 = sg.operators[2] + fc_op2 = sg.operators[3] + + # Each FC should use its own DECODE's output + self.assertIs(fc_op1.inputs[1], decode_op1.outputs[0]) + self.assertIs(fc_op2.inputs[1], decode_op2.outputs[0]) + # The two DECODEs should have different outputs + self.assertIsNot(decode_op1.outputs[0], decode_op2.outputs[0]) + # The two DECODEs should share the same ancillary tensor + self.assertIs(decode_op1.inputs[1], decode_op2.inputs[1]) + + def test_ancillary_tensor_contains_dcm(self): + """Ancillary tensor data contains valid DCM header.""" + model = _build_simple_fc_model() + + ancillary_data = _make_dummy_ancillary_data() + compression_results = { + (0, 0): + compressor.CompressionResult( + encoded_data=b'\x00\x00', + ancillary_data=ancillary_data, + ) + } + + decode_insert.insert_decode_operators(model, compression_results) + + decode_op = model.subgraphs[0].operators[0] + ancillary_tensor = decode_op.inputs[1] + + # Ancillary tensor data should match what we provided + self.assertEqual(bytes(ancillary_tensor.array), ancillary_data) + + # Verify DCM header + dcm_bytes = ancillary_tensor.array[:16] + self.assertEqual(dcm_bytes[0], 0) # decode_type = LUT + self.assertEqual(dcm_bytes[1], 1) # DCM version + + def test_no_consumers_no_decode(self): + """Tensor with no consumers gets no DECODE operator and emits warning.""" + # Create model where compressed tensor is not used as input + unused_tensor = model_editor.Tensor( + shape=(4, 4), + dtype=tflite.TensorType.INT8, + data=np.ones((4, 4), dtype=np.int8), + name="unused", + quantization=model_editor.Quantization(scales=0.5, zero_points=0), + ) + input_t = model_editor.Tensor( + shape=(1, 4), + dtype=tflite.TensorType.INT8, + name="input", + ) + output_t = model_editor.Tensor( + shape=(1, 4), + dtype=tflite.TensorType.INT8, + name="output", + ) + other_weights = model_editor.Tensor( + shape=(4, 4), + dtype=tflite.TensorType.INT8, + data=np.ones((4, 4), dtype=np.int8), + name="other_weights", + quantization=model_editor.Quantization(scales=0.5, zero_points=0), + ) + + model = model_editor.Model(subgraphs=[ + model_editor.Subgraph( + tensors=[unused_tensor, other_weights], + operators=[ + model_editor.Operator( + opcode=tflite.BuiltinOperator.FULLY_CONNECTED, + inputs=[input_t, other_weights], + outputs=[output_t], + ) + ], + ) + ]) + + # Compress the unused tensor + compression_results = { + (0, 0): + compressor.CompressionResult( + encoded_data=b'\x00\x00', + ancillary_data=_make_dummy_ancillary_data(), + ) + } + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + decode_insert.insert_decode_operators(model, compression_results) + + # Should emit a warning about no consumers + self.assertEqual(len(w), 1) + self.assertIn("no consumers", str(w[0].message)) + self.assertIn("unused", str(w[0].message)) + + # Should still have just 1 operator (no DECODE inserted) + self.assertEqual(len(model.subgraphs[0].operators), 1) + + def test_tensor_naming(self): + """Output and ancillary tensors get appropriate names.""" + model = _build_simple_fc_model() + + compression_results = { + (0, 0): + compressor.CompressionResult( + encoded_data=b'\x00\x00', + ancillary_data=_make_dummy_ancillary_data(), + ) + } + + decode_insert.insert_decode_operators(model, compression_results) + + decode_op = model.subgraphs[0].operators[0] + ancillary = decode_op.inputs[1] + output = decode_op.outputs[0] + + self.assertEqual(ancillary.name, "weights_ancillary") + self.assertEqual(output.name, "weights_decoded") + + def test_multiple_compressed_inputs_batched(self): + """CONCATENATION with two compressed inputs gets one batched DECODE.""" + weights_a = model_editor.Tensor( + shape=(4, 4), + dtype=tflite.TensorType.INT8, + data=np.ones((4, 4), dtype=np.int8), + name="weights_a", + quantization=model_editor.Quantization(scales=0.5, zero_points=0), + ) + weights_b = model_editor.Tensor( + shape=(4, 4), + dtype=tflite.TensorType.INT8, + data=np.ones((4, 4), dtype=np.int8), + name="weights_b", + quantization=model_editor.Quantization(scales=0.25, zero_points=0), + ) + output_t = model_editor.Tensor( + shape=(4, 8), + dtype=tflite.TensorType.INT8, + name="output", + ) + + concat_op = model_editor.Operator( + opcode=tflite.BuiltinOperator.CONCATENATION, + inputs=[weights_a, weights_b], + outputs=[output_t], + ) + + model = model_editor.Model(subgraphs=[ + model_editor.Subgraph( + tensors=[weights_a, weights_b], + operators=[concat_op], + ) + ]) + + ancillary_a = _make_dummy_ancillary_data(bitwidth=2) + ancillary_b = _make_dummy_ancillary_data(bitwidth=4) + compression_results = { + (0, 0): + compressor.CompressionResult( + encoded_data=b'\x00\x01', + ancillary_data=ancillary_a, + ), + (0, 1): + compressor.CompressionResult( + encoded_data=b'\x02\x03', + ancillary_data=ancillary_b, + ), + } + + decode_insert.insert_decode_operators(model, compression_results) + + sg = model.subgraphs[0] + + # One DECODE + one CONCATENATION + self.assertEqual(len(sg.operators), 2) + decode_op = sg.operators[0] + self.assertEqual(decode_op.opcode, tflite.BuiltinOperator.CUSTOM) + self.assertEqual(decode_op.custom_code, + decode_insert.DECODE_CUSTOM_OP_NAME) + + # DECODE has 4 inputs (enc_a, anc_a, enc_b, anc_b) and 2 outputs + self.assertEqual(len(decode_op.inputs), 4) + self.assertEqual(len(decode_op.outputs), 2) + + # Each ancillary tensor carries its own distinct data + self.assertNotEqual(ancillary_a, ancillary_b) + self.assertEqual(bytes(decode_op.inputs[1].array), ancillary_a) + self.assertEqual(bytes(decode_op.inputs[3].array), ancillary_b) + + # CONCATENATION rewired to DECODE outputs + self.assertIs(sg.operators[1].inputs[0], decode_op.outputs[0]) + self.assertIs(sg.operators[1].inputs[1], decode_op.outputs[1]) + + def test_mixed_compressed_and_uncompressed_inputs(self): + """CONCATENATION with one compressed and one plain input.""" + weights = model_editor.Tensor( + shape=(4, 4), + dtype=tflite.TensorType.INT8, + data=np.ones((4, 4), dtype=np.int8), + name="weights", + quantization=model_editor.Quantization(scales=0.5, zero_points=0), + ) + plain = model_editor.Tensor( + shape=(4, 4), + dtype=tflite.TensorType.INT8, + data=np.zeros((4, 4), dtype=np.int8), + name="plain", + ) + output_t = model_editor.Tensor( + shape=(4, 8), + dtype=tflite.TensorType.INT8, + name="output", + ) + + concat_op = model_editor.Operator( + opcode=tflite.BuiltinOperator.CONCATENATION, + inputs=[weights, plain], + outputs=[output_t], + ) + + model = model_editor.Model(subgraphs=[ + model_editor.Subgraph( + tensors=[weights, plain], + operators=[concat_op], + ) + ]) + + # Only compress weights, not plain + compression_results = { + (0, 0): + compressor.CompressionResult( + encoded_data=b'\x00\x01', + ancillary_data=_make_dummy_ancillary_data(), + ), + } + + decode_insert.insert_decode_operators(model, compression_results) + + sg = model.subgraphs[0] + + # One DECODE + one CONCATENATION + self.assertEqual(len(sg.operators), 2) + decode_op = sg.operators[0] + + # DECODE has 2 inputs and 1 output (only the compressed tensor) + self.assertEqual(len(decode_op.inputs), 2) + self.assertEqual(len(decode_op.outputs), 1) + + # CONCATENATION: first input rewired to DECODE output, second unchanged + self.assertIs(sg.operators[1].inputs[0], decode_op.outputs[0]) + self.assertIs(sg.operators[1].inputs[1], plain) + + def test_encoded_tensor_rewritten(self): + """Compressed tensor is rewritten with encoded data, UINT8 type, no quant.""" + model = _build_simple_fc_model() + weights_tensor = model.subgraphs[0].tensor_by_name("weights") + + encoded_data = b'\xAB\xCD\xEF' + compression_results = { + (0, 0): + compressor.CompressionResult( + encoded_data=encoded_data, + ancillary_data=_make_dummy_ancillary_data(), + ) + } + + decode_insert.insert_decode_operators(model, compression_results) + + # Original tensor should be rewritten + self.assertEqual(weights_tensor.shape, (len(encoded_data), )) + self.assertEqual(weights_tensor.dtype, tflite.TensorType.UINT8) + self.assertIsNone(weights_tensor.quantization) + self.assertEqual(weights_tensor.buffer.data, encoded_data) + + +class TestHelperFunctions(unittest.TestCase): + """Tests for internal helper functions.""" + + def test_find_tensor_consumers(self): + """_find_tensor_consumers finds all ops using a tensor.""" + model = _build_shared_weights_model() + sg = model.subgraphs[0] + weights = sg.tensor_by_name("shared_weights") + + consumers = decode_insert._find_tensor_consumers(sg, weights) + + self.assertEqual(len(consumers), 2) + + +if __name__ == "__main__": + unittest.main() diff --git a/tensorflow/lite/micro/compression/decode_test.py b/tensorflow/lite/micro/compression/decode_test.py new file mode 100644 index 00000000000..eca3b42b2b4 --- /dev/null +++ b/tensorflow/lite/micro/compression/decode_test.py @@ -0,0 +1,155 @@ +# Copyright 2026 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from tflite_micro.tensorflow.lite.micro.compression import decode + + +class TestDecodeCommonMetadata(unittest.TestCase): + + def testBasicSerialization(self): + dcm = decode.DecodeCommonMetadata(decode_type=decode.DecodeType.LUT) + result = dcm.to_bytes() + + # Should be exactly 16 bytes + self.assertEqual(len(result), 16) + + # Byte 0: decode_type + self.assertEqual(result[0], 0) + + # Byte 1: version (default 1) + self.assertEqual(result[1], 1) + + # Bytes 2-3: reserved (should be zero) + self.assertEqual(result[2], 0) + self.assertEqual(result[3], 0) + + # Bytes 4-15: user_data (default all zeros) + self.assertEqual(result[4:16], b'\x00' * 12) + + def testCustomVersion(self): + dcm = decode.DecodeCommonMetadata(decode_type=1, version=2) + result = dcm.to_bytes() + + self.assertEqual(result[0], 1) + self.assertEqual(result[1], 2) + + def testUserData(self): + user_data = b'\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c' + dcm = decode.DecodeCommonMetadata(decode_type=0, user_data=user_data) + result = dcm.to_bytes() + + self.assertEqual(result[4:16], user_data) + + def testUserDataPadding(self): + # User data shorter than 12 bytes should be padded with zeros + user_data = b'\x01\x02\x03' + dcm = decode.DecodeCommonMetadata(decode_type=0, user_data=user_data) + result = dcm.to_bytes() + + expected = b'\x01\x02\x03' + b'\x00' * 9 + self.assertEqual(result[4:16], expected) + + def testUserDataTruncation(self): + # User data longer than 12 bytes should be truncated + user_data = b'\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f' + dcm = decode.DecodeCommonMetadata(decode_type=0, user_data=user_data) + result = dcm.to_bytes() + + self.assertEqual(result[4:16], user_data[:12]) + + def testDecodeTypeRange(self): + # Valid decode types: 0-255 + decode.DecodeCommonMetadata(decode_type=decode.DecodeType.LUT).to_bytes() + decode.DecodeCommonMetadata(decode_type=decode.DecodeType(127)).to_bytes() + decode.DecodeCommonMetadata( + decode_type=decode.DecodeType.custom(255)).to_bytes() + + # Invalid decode types should raise ValueError + with self.assertRaises(ValueError): + decode.DecodeCommonMetadata(decode_type=decode.DecodeType(-1)).to_bytes() + with self.assertRaises(ValueError): + decode.DecodeCommonMetadata( + decode_type=decode.DecodeType(256)).to_bytes() + + def testVersionRange(self): + # Valid versions: 0-255 + decode.DecodeCommonMetadata(decode_type=0, version=0).to_bytes() + decode.DecodeCommonMetadata(decode_type=0, version=255).to_bytes() + + # Invalid versions should raise ValueError + with self.assertRaises(ValueError): + decode.DecodeCommonMetadata(decode_type=0, version=-1).to_bytes() + with self.assertRaises(ValueError): + decode.DecodeCommonMetadata(decode_type=0, version=256).to_bytes() + + +class TestAncillaryDataTensor(unittest.TestCase): + + def testDcmOnly(self): + dcm = decode.DecodeCommonMetadata(decode_type=decode.DecodeType.LUT) + adt = decode.AncillaryDataTensor(dcm) + result = adt.to_bytes() + + # Should be just the 16-byte DCM + self.assertEqual(len(result), 16) + self.assertEqual(result, dcm.to_bytes()) + + def testWithBytesAncillaryData(self): + dcm = decode.DecodeCommonMetadata(decode_type=decode.DecodeType.HUFFMAN) + ancillary = b'\xaa\xbb\xcc\xdd' + adt = decode.AncillaryDataTensor(dcm, ancillary) + result = adt.to_bytes() + + # Should be DCM + ancillary data + self.assertEqual(len(result), 20) + self.assertEqual(result[:16], dcm.to_bytes()) + self.assertEqual(result[16:], ancillary) + + def testWithAncillaryDataMethod(self): + dcm = decode.DecodeCommonMetadata(decode_type=decode.DecodeType.PRUNING) + adt = decode.AncillaryDataTensor(dcm) + + ancillary = b'\x11\x22\x33\x44' + adt_with_data = adt.with_ancillary_data(ancillary) + result = adt_with_data.to_bytes() + + # Original ADT should be unchanged + self.assertEqual(adt.to_bytes(), dcm.to_bytes()) + + # New ADT should have ancillary data + self.assertEqual(len(result), 20) + self.assertEqual(result[:16], dcm.to_bytes()) + self.assertEqual(result[16:], ancillary) + + def testWithSerializerProtocol(self): + # Test with an object that implements AncillaryDataSerializer + class MockSerializer: + + def to_bytes(self): + return b'\xff\xee\xdd\xcc' + + dcm = decode.DecodeCommonMetadata(decode_type=decode.DecodeType(3)) + serializer = MockSerializer() + adt = decode.AncillaryDataTensor(dcm, serializer) + result = adt.to_bytes() + + self.assertEqual(len(result), 20) + self.assertEqual(result[:16], dcm.to_bytes()) + self.assertEqual(result[16:], b'\xff\xee\xdd\xcc') + + +if __name__ == '__main__': + unittest.main() diff --git a/tensorflow/lite/micro/compression/huffman.py b/tensorflow/lite/micro/compression/huffman.py new file mode 100644 index 00000000000..e539827eae4 --- /dev/null +++ b/tensorflow/lite/micro/compression/huffman.py @@ -0,0 +1,60 @@ +# Copyright 2026 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Huffman compression plugin (stub). + +This module provides a placeholder for Huffman compression using Xtensa-format +decode tables. The actual implementation is not yet available. + +Supported tensor types (when implemented): INT8, INT16 +""" + +from tflite_micro.tensorflow.lite.micro.compression import compressor +from tflite_micro.tensorflow.lite.micro.compression import decode +from tflite_micro.tensorflow.lite.micro.compression import model_editor +from tflite_micro.tensorflow.lite.micro.compression import spec + + +class HuffmanCompressor(compressor.Compressor): + """Huffman compression plugin (stub). + + This stub exists to validate the plugin architecture. The actual Huffman + compression algorithm using Xtensa-format decode tables is not yet + implemented. + """ + + @property + def decode_type(self) -> decode.DecodeType: + """Returns DecodeType.HUFFMAN.""" + return decode.DecodeType.HUFFMAN + + def compress( + self, + tensor: model_editor.Tensor, + method: spec.CompressionMethod, + ) -> compressor.CompressionResult: + """Compress a tensor using Huffman encoding. + + Args: + tensor: The tensor to compress. + method: Must be a HuffmanCompression instance. + + Returns: + CompressionResult (not implemented). + + Raises: + CompressionError: Always, since this is a stub. + """ + raise compressor.CompressionError( + "Huffman compression not yet implemented. " + "This stub exists to validate the plugin architecture.") diff --git a/tensorflow/lite/micro/compression/lut.py b/tensorflow/lite/micro/compression/lut.py new file mode 100644 index 00000000000..991288f54cc --- /dev/null +++ b/tensorflow/lite/micro/compression/lut.py @@ -0,0 +1,318 @@ +# Copyright 2026 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""LUT (Look-Up Table) compression plugin.""" + +import sys +from dataclasses import dataclass, field +from typing import Optional + +import bitarray +import bitarray.util +import numpy as np + +from tflite_micro.tensorflow.lite.micro.compression import compressor +from tflite_micro.tensorflow.lite.micro.compression import decode +from tflite_micro.tensorflow.lite.micro.compression import model_editor +from tflite_micro.tensorflow.lite.micro.compression import spec + + +@dataclass +class LutCompressedArray: + """Intermediate representation of LUT-compressed data. + + Attributes: + compression_axis: The axis along which compression was performed, or None + for per-tensor compression. + lookup_tables: List of value lookup tables. One table for per-tensor + compression, or one per channel for per-channel compression. + indices: Array of indices into the lookup tables, same shape as original. + """ + compression_axis: Optional[int] = None + lookup_tables: list[np.ndarray] = field(default_factory=list) + indices: np.ndarray = field(default_factory=lambda: np.array([])) + + @property + def index_bitwidth(self) -> int: + """Returns the number of bits required to encode the indices.""" + if self.indices is None or self.indices.size == 0: + raise ValueError("No indices to compute bitwidth from") + max_index = int(np.max(self.indices)) + return max_index.bit_length() or 1 + + +@dataclass +class LutAncillaryData: + """LUT-specific ancillary data matching C++ decode_state_lut.cc format. + + The LUT ancillary data uses the DCM user_data bytes (4-15) plus value tables: + - Byte 4: LUT version (currently 1) + - Byte 5: Params (lower 3 bits = bitwidth, 1-7) + - Byte 6: Value table channel stride (elements per channel) + - Bytes 7-15: Reserved (zeros) + - Bytes 16+: Value tables (concatenated, stride elements per channel) + + Attributes: + lut_version: LUT format version (currently 1). + bitwidth: Number of bits per index (1-7). + value_table_stride: Number of elements per channel in value tables. + value_tables: Packed value table data following the DCM. + """ + lut_version: int = 1 + bitwidth: int = 4 + value_table_stride: int = 16 + value_tables: bytes = b'' + + def __post_init__(self): + if not 1 <= self.bitwidth <= 7: + raise ValueError(f"bitwidth must be 1-7, got {self.bitwidth}") + if not 0 <= self.value_table_stride <= 128: + raise ValueError( + f"value_table_stride must be 0-128, got {self.value_table_stride}") + + def to_user_data(self) -> bytes: + """Serialize to 12-byte user_data for DCM bytes 4-15.""" + user_data = bytearray(12) + user_data[0] = self.lut_version + user_data[1] = self.bitwidth & 0x07 + user_data[2] = self.value_table_stride + # Bytes 3-11 (DCM bytes 7-15) remain zero (reserved) + return bytes(user_data) + + def to_bytes(self) -> bytes: + """Serialize for use as AncillaryDataTensor.ancillary_data.""" + # This returns the type-specific data that follows the DCM header. + # For LUT, that's just the value tables since user_data is in DCM. + return self.value_tables + + +def compress_array(tensor: np.ndarray, + axis: Optional[int]) -> LutCompressedArray: + """Compresses the given tensor using lookup tables. + + Args: + tensor: The tensor to be compressed. + axis: The axis along which to compress. If an axis is given, a lookup table + is created for each slice along the axis. If axis is None, a single + lookup table is used for the entire tensor. + + Compressing a tensor with a lookup table per slice along a particular + axis is analogous to quantizing a tensor with different quantization + parameters per slice along a particular axis (dimension). + + Returns: + LutCompressedArray containing lookup tables and indices. + """ + compressed = LutCompressedArray() + compressed.compression_axis = axis + + if axis is None: + # Compute unique values and indices for the entire tensor + values, indices = np.unique(tensor, return_inverse=True) + compressed.lookup_tables.append(values) + compressed.indices = indices.reshape(tensor.shape) + else: + # Iterate over slices along the compression axis + slice_indices = [] + for slice in np.moveaxis(tensor, axis, 0): + values, indices = np.unique(slice, return_inverse=True) + compressed.lookup_tables.append(values) + indices = indices.reshape(slice.shape) + slice_indices.append(indices) + + # Reconstruct a tensor of indices from the slices + stacked = np.stack(slice_indices, axis=0) + compressed.indices = np.moveaxis(stacked, 0, axis) + + return compressed + + +def identify_compression_axis(tensor: model_editor.Tensor) -> Optional[int]: + """Determines the axis along which to compress. + + The axis along which to compress is inferred from the tensor's quantization + parameters. Unquantized tensors use per-tensor compression. + + Args: + tensor: The tensor to analyze. + + Returns: + The axis along which to compress, or None to indicate one value table for + the entire tensor. + + Raises: + CompressionError: If the axis cannot be determined from quantization. + """ + q = tensor.quantization + if q is None: + return None + + # model_editor wraps quantization, access scales/axis from wrapper + scales = q.scales if isinstance(q.scales, list) else [q.scales] + quantization_channels = len(scales) + + if quantization_channels == 1: + return None + + if q.axis is not None and q.axis < len(tensor.shape): + if quantization_channels == tensor.shape[q.axis]: + return q.axis + + raise compressor.CompressionError( + "Invalid or no quantization parameters from which to " + "infer the axis along which tensor should be compressed.") + + +def check_bitwidth(compressed: int, specified: int, tensor_spec: spec.Tensor): + """Validates that the specified bitwidth is sufficient. + + It is an error if the bitwidth required to compress a tensor exceeds the + specified bitwith, and a warning if the tensor can be compressed in less than + the specified bitwidth. The latter is allowed, and is not an error, to permit + testing with larger bitwidths without re-binning a model. + + Args: + compressed: The bitwidth required by the compressed data. + specified: The bitwidth specified in the compression spec. + tensor_spec: The tensor spec, for error messages. + + Raises: + CompressionError: If specified bitwidth is too small. + """ + if compressed > specified: + raise compressor.CompressionError( + f"index_bitwidth too small: {compressed} bits needed to " + f"enumerate unique values in tensor specified in {tensor_spec}") + elif compressed < specified: + print( + f"warning: index_bitwidth too large: only {compressed} " + f"bits needed to enumerate unique values in tensor specified in " + f"{tensor_spec}", + file=sys.stderr) + + +def pack_indices(indices: np.ndarray, bitwidth: int) -> bytes: + """Packs indices into a bytearray using bitwidth-sized fields. + + Args: + indices: Array of indices to pack. + bitwidth: Number of bits per index. + + Returns: + Packed bytes with indices in big-endian bit order. + """ + endianness = "big" + bits = bitarray.bitarray(endian=endianness) + for i in indices.ravel(): + bits.extend( + bitarray.util.int2ba(int(i), length=bitwidth, endian=endianness)) + return bits.tobytes() + + +def pack_lookup_tables(tables: list[np.ndarray], table_len: int) -> bytes: + """Packs the value tables of a LutCompressedArray. + + Pack the value tables of a LutCompressedArray into a bytes object in the + format writable to a value_table buffer in the .tflite flatbuffer. The + tables are concatenated. + + Args: + tables: List of numpy arrays containing lookup table values. + table_len: Length to pad each table to (typically 2**bitwidth). + + Returns: + Packed bytes containing all tables concatenated. + """ + buffer = bytearray() + for t in tables: + padding_needed = table_len - len(t) + padded = np.pad(t, (0, padding_needed), mode='constant', constant_values=0) + buffer.extend(padded.tobytes()) + return bytes(buffer) + + +class LutCompressor(compressor.Compressor): + """LUT compression plugin implementing the Compressor protocol.""" + + @property + def decode_type(self) -> decode.DecodeType: + """Returns DecodeType.LUT.""" + return decode.DecodeType.LUT + + def compress( + self, + tensor: model_editor.Tensor, + method: spec.CompressionMethod, + ) -> compressor.CompressionResult: + """Compress a tensor using LUT compression. + + Args: + tensor: The tensor to compress. + method: Must be a LookUpTableCompression instance. + + Returns: + CompressionResult with packed indices and ancillary data. + + Raises: + CompressionError: If compression fails. + """ + if not isinstance(method, spec.LookUpTableCompression): + raise compressor.CompressionError( + f"LutCompressor requires LookUpTableCompression, got {type(method)}") + + if tensor.array is None: + raise compressor.CompressionError("Tensor has no data to compress") + + spec_bitwidth = method.index_bitwidth + axis = identify_compression_axis(tensor) + compressed = compress_array(tensor.array, axis) + # Note: check_bitwidth requires a spec.Tensor but we don't have it here. + # We'll do a simpler check. + actual_bitwidth = compressed.index_bitwidth + if actual_bitwidth > spec_bitwidth: + raise compressor.CompressionError( + f"index_bitwidth too small: {actual_bitwidth} bits needed, " + f"but only {spec_bitwidth} specified") + elif actual_bitwidth < spec_bitwidth: + print( + f"warning: index_bitwidth larger than necessary: only " + f"{actual_bitwidth} bits needed, but {spec_bitwidth} specified", + file=sys.stderr) + + # Pack indices into bytes + encoded_data = pack_indices(compressed.indices, spec_bitwidth) + + # Pack value tables + table_len = max(len(t) for t in compressed.lookup_tables) + value_tables_bytes = pack_lookup_tables(compressed.lookup_tables, + table_len) + + # Build ancillary data + lut_data = LutAncillaryData( + lut_version=1, + bitwidth=spec_bitwidth, + value_table_stride=table_len, + value_tables=value_tables_bytes, + ) + + # Build complete ancillary data tensor bytes: DCM header + value tables + dcm = decode.DecodeCommonMetadata( + decode_type=self.decode_type, + user_data=lut_data.to_user_data(), + ) + ancillary_data = dcm.to_bytes() + lut_data.to_bytes() + + return compressor.CompressionResult( + encoded_data=encoded_data, + ancillary_data=ancillary_data, + ) diff --git a/tensorflow/lite/micro/compression/lut_test.py b/tensorflow/lite/micro/compression/lut_test.py new file mode 100644 index 00000000000..d01dcfd4260 --- /dev/null +++ b/tensorflow/lite/micro/compression/lut_test.py @@ -0,0 +1,405 @@ +# Copyright 2026 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit tests for LUT compression plugin.""" + +import numpy as np +import unittest + +from tflite_micro.tensorflow.lite.micro.compression import compressor +from tflite_micro.tensorflow.lite.micro.compression import decode +from tflite_micro.tensorflow.lite.micro.compression import lut +from tflite_micro.tensorflow.lite.micro.compression import model_editor +from tflite_micro.tensorflow.lite.micro.compression import spec +from tflite_micro.tensorflow.lite.python import schema_py_generated as tflite + + +class TestCompressArray(unittest.TestCase): + """Tests for the compress_array function.""" + + def test_per_tensor_basic(self): + """Per-tensor compression extracts unique values.""" + array = np.array([1, 2, 1, 2, 3, 3], dtype=np.int8) + compressed = lut.compress_array(array, axis=None) + + self.assertIsNone(compressed.compression_axis) + self.assertEqual(len(compressed.lookup_tables), 1) + np.testing.assert_array_equal(compressed.lookup_tables[0], [1, 2, 3]) + # Indices should map back to original values + reconstructed = compressed.lookup_tables[0][compressed.indices] + np.testing.assert_array_equal(reconstructed, array) + + def test_per_tensor_preserves_shape(self): + """Indices array has same shape as input.""" + # yapf: disable + array = np.array([[1, 2], + [3, 1], + [2, 3]], dtype=np.int8) + # yapf: enable + compressed = lut.compress_array(array, axis=None) + + self.assertEqual(compressed.indices.shape, array.shape) + + def test_per_channel_axis0(self): + """Per-channel compression along axis 0.""" + # Each row gets its own value table + # yapf: disable + array = np.array([[1, 1, 1], + [5, 5, 5], + [9, 9, 9]], dtype=np.int8) + # yapf: enable + compressed = lut.compress_array(array, axis=0) + + self.assertEqual(compressed.compression_axis, 0) + self.assertEqual(len(compressed.lookup_tables), 3) + np.testing.assert_array_equal(compressed.lookup_tables[0], [1]) + np.testing.assert_array_equal(compressed.lookup_tables[1], [5]) + np.testing.assert_array_equal(compressed.lookup_tables[2], [9]) + + def test_per_channel_axis1(self): + """Per-channel compression along axis 1.""" + # Each column gets its own value table + # yapf: disable + array = np.array([[1, 5], + [1, 5], + [1, 5]], dtype=np.int8) + # yapf: enable + compressed = lut.compress_array(array, axis=1) + + self.assertEqual(compressed.compression_axis, 1) + self.assertEqual(len(compressed.lookup_tables), 2) + np.testing.assert_array_equal(compressed.lookup_tables[0], [1]) + np.testing.assert_array_equal(compressed.lookup_tables[1], [5]) + + def test_single_value(self): + """Array with single unique value.""" + array = np.array([7, 7, 7, 7], dtype=np.int8) + compressed = lut.compress_array(array, axis=None) + + self.assertEqual(len(compressed.lookup_tables), 1) + np.testing.assert_array_equal(compressed.lookup_tables[0], [7]) + np.testing.assert_array_equal(compressed.indices, [0, 0, 0, 0]) + + def test_bitwidth_calculation(self): + """Index bitwidth is computed correctly.""" + # 3 unique values -> 2 bits needed + array = np.array([0, 1, 2], dtype=np.int8) + compressed = lut.compress_array(array, axis=None) + self.assertEqual(compressed.index_bitwidth, 2) + + # 4 unique values -> 2 bits needed + array = np.array([0, 1, 2, 3], dtype=np.int8) + compressed = lut.compress_array(array, axis=None) + self.assertEqual(compressed.index_bitwidth, 2) + + # 5 unique values -> 3 bits needed + array = np.array([0, 1, 2, 3, 4], dtype=np.int8) + compressed = lut.compress_array(array, axis=None) + self.assertEqual(compressed.index_bitwidth, 3) + + def test_bitwidth_single_value(self): + """Single unique value requires 1 bit.""" + array = np.array([42, 42, 42], dtype=np.int8) + compressed = lut.compress_array(array, axis=None) + self.assertEqual(compressed.index_bitwidth, 1) + + +class TestPackIndices(unittest.TestCase): + """Tests for the pack_indices function.""" + + def test_4bit_packing(self): + """Pack indices into 4-bit fields.""" + indices = np.array([1, 2, 3, 0]) + result = lut.pack_indices(indices, bitwidth=4) + # Big-endian: 0001 0010 | 0011 0000 = 0x12 0x30 + self.assertEqual(result, bytes([0x12, 0x30])) + + def test_2bit_packing(self): + """Pack indices into 2-bit fields.""" + indices = np.array([0, 1, 2, 3]) + result = lut.pack_indices(indices, bitwidth=2) + # Big-endian: 00 01 10 11 = 0x1B + self.assertEqual(result, bytes([0x1B])) + + def test_3bit_packing(self): + """Pack indices into 3-bit fields.""" + indices = np.array([0, 1, 2, 3, 4, 5, 6, 7]) + result = lut.pack_indices(indices, bitwidth=3) + # 000 001 010 011 | 100 101 110 111 + # 00000101 | 00111001 | 01110111 = 0x05 0x39 0x77 + self.assertEqual(result, bytes([0x05, 0x39, 0x77])) + + def test_1bit_packing(self): + """Pack indices into 1-bit fields.""" + indices = np.array([0, 1, 0, 1, 1, 0, 1, 0]) + result = lut.pack_indices(indices, bitwidth=1) + # 0 1 0 1 1 0 1 0 = 0x5A + self.assertEqual(result, bytes([0x5A])) + + def test_multidimensional_flattens(self): + """Multidimensional indices are flattened row-major.""" + # yapf: disable + indices = np.array([[0, 1], + [2, 3]]) + # yapf: enable + result = lut.pack_indices(indices, bitwidth=4) + # 0000 0001 | 0010 0011 = 0x01 0x23 + self.assertEqual(result, bytes([0x01, 0x23])) + + +class TestPackLookupTables(unittest.TestCase): + """Tests for the pack_lookup_tables function.""" + + def test_single_table_int8(self): + """Pack single INT8 lookup table.""" + tables = [np.array([10, 20, 30], dtype=np.int8)] + result = lut.pack_lookup_tables(tables, table_len=4) + # Values: 10, 20, 30, 0 (padding) + self.assertEqual(result, bytes([10, 20, 30, 0])) + + def test_multiple_tables(self): + """Pack multiple lookup tables.""" + tables = [ + np.array([1, 2], dtype=np.int8), + np.array([3, 4], dtype=np.int8), + ] + result = lut.pack_lookup_tables(tables, table_len=4) + # Table 1: 1, 2, 0, 0 | Table 2: 3, 4, 0, 0 + self.assertEqual(result, bytes([1, 2, 0, 0, 3, 4, 0, 0])) + + def test_int16_little_endian(self): + """INT16 values are packed in native byte order.""" + tables = [np.array([0x1234, 0x5678], dtype=' tflite.QuantizationParametersT: + """Convert to TFLite schema object.""" + q = tflite.QuantizationParametersT() + + # Normalize to lists + scales = [self.scales] if isinstance(self.scales, + (int, float)) else self.scales + zeros = [self.zero_points] if isinstance(self.zero_points, + int) else self.zero_points + + q.scale = scales + q.zeroPoint = zeros + if self.axis is not None: + q.quantizedDimension = self.axis + + return q + + +class Tensor: + """Tensor specification wrapping a TensorT flatbuffer object. + + Provides clean APIs for common fields (shape, dtype, name, buffer, + quantization) while preserving all other TensorT fields during + read-modify-write. + + Supports both buffer= and data= parameters for flexibility: + - buffer=: Explicitly provide a Buffer object (can be shared between tensors) + - data=: Convenience parameter that auto-creates a Buffer + + Cannot specify both buffer and data at initialization. + """ + + def __init__(self, + shape=None, + dtype=None, + buffer=None, + data=None, + quantization=None, + name=None, + _fb: tflite.TensorT = None): + """Initialize Tensor. + + Args: + shape: Tensor shape as tuple + dtype: TensorType enum value + buffer: Optional Buffer object (for explicit buffer sharing) + data: Optional numpy array or bytes (convenience, creates Buffer) + quantization: Optional Quantization object + name: Optional tensor name + _fb: Optional TensorT for wrapping existing flatbuffer object + + Raises: + ValueError: If both buffer and data are specified + """ + if data is not None and buffer is not None: + raise ValueError("Cannot specify both data and buffer") + + # Use provided TensorT or create new one + self._fb = _fb if _fb is not None else tflite.TensorT() + self._index = None + + # Buffer object (managed separately; _fb.buffer is just an index) + self.buffer = buffer + + # Quantization object (managed separately; synced to _fb on compile) + self.quantization = quantization + + # Set fields if provided (these override any values in _fb) + if shape is not None: + self.shape = shape + if dtype is not None: + self.dtype = dtype + if name is not None: + self.name = name + + # Convert data to buffer if provided + if data is not None: + buf_data = data if isinstance(data, bytes) else data.tobytes() + self.buffer = Buffer(data=buf_data) + + @property + def shape(self) -> tuple: + """Tensor shape as tuple.""" + return tuple(self._fb.shape) if self._fb.shape is not None else () + + @shape.setter + def shape(self, value): + self._fb.shape = list(value) + + @property + def dtype(self) -> tflite.TensorType: + """Tensor data type.""" + return self._fb.type + + @dtype.setter + def dtype(self, value: tflite.TensorType): + self._fb.type = value + + @property + def name(self) -> Optional[str]: + """Tensor name for debugging.""" + n = self._fb.name + if isinstance(n, bytes): + return n.decode('utf-8') + return n + + @name.setter + def name(self, value: Optional[str]): + self._fb.name = value + + @property + def array(self) -> Optional[np.ndarray]: + """Get tensor data as properly-shaped numpy array. + + Returns: + numpy array with shape matching tensor.shape and dtype matching + tensor.dtype, or None if tensor has no data. + + For low-level byte access, use tensor.buffer.data instead. + """ + if self.buffer is None: + return None + return np.frombuffer(self.buffer.data, + dtype=_dtype_to_numpy(self.dtype)).reshape(self.shape) + + @array.setter + def array(self, value: np.ndarray): + """Set tensor data from numpy array. + + Args: + value: New tensor data as numpy array. Will be converted to bytes + using tobytes() and stored in the buffer. + + Creates a new Buffer if tensor has no buffer, or updates the existing + buffer's data in place. + + For low-level byte access, use tensor.buffer.data instead. + """ + buf_data = value.tobytes() + if self.buffer is None: + self.buffer = Buffer(data=buf_data) + else: + self.buffer.data = buf_data + + @property + def index(self) -> Optional[int]: + """Tensor index in the subgraph's tensor list. + + Returns index after read() or build(). May be None or stale after + modifications. Use with caution. + """ + return self._index + + @property + def numpy_dtype(self) -> np.dtype: + """Get numpy dtype corresponding to tensor's TFLite dtype. + + Returns: + numpy dtype object for use with np.frombuffer, np.array, etc. + """ + return _dtype_to_numpy(self.dtype) + + +class OperatorCode: + """Operator code specification wrapping an OperatorCodeT flatbuffer object. + + Provides clean APIs for common fields (builtin_code, custom_code, version) + while preserving all other OperatorCodeT fields during read-modify-write. + """ + + def __init__(self, + builtin_code: tflite.BuiltinOperator = None, + custom_code: Optional[str] = None, + version: int = 1, + _fb: tflite.OperatorCodeT = None): + """Initialize OperatorCode. + + Args: + builtin_code: BuiltinOperator enum value + custom_code: Custom operator name (for CUSTOM opcode) + version: Operator version + _fb: Optional OperatorCodeT for wrapping existing flatbuffer object + """ + # Use provided OperatorCodeT or create new one + self._fb = _fb if _fb is not None else tflite.OperatorCodeT() + + # Set fields if provided (these override any values in _fb) + if builtin_code is not None: + self.builtin_code = builtin_code + if custom_code is not None: + self.custom_code = custom_code + if version != 1 or _fb is None: + self.version = version + + @property + def builtin_code(self) -> tflite.BuiltinOperator: + """Builtin operator code.""" + return self._fb.builtinCode + + @builtin_code.setter + def builtin_code(self, value: tflite.BuiltinOperator): + self._fb.builtinCode = value + + @property + def custom_code(self) -> Optional[str]: + """Custom operator name (for CUSTOM opcode).""" + c = self._fb.customCode + if isinstance(c, bytes): + return c.decode('utf-8') + return c + + @custom_code.setter + def custom_code(self, value: Optional[str]): + self._fb.customCode = value + + @property + def version(self) -> int: + """Operator version.""" + return self._fb.version if self._fb.version else 1 + + @version.setter + def version(self, value: int): + self._fb.version = value + + +class Operator: + """Operator specification wrapping an OperatorT flatbuffer object. + + Provides clean APIs for common fields (opcode, inputs, outputs, custom_code) + while preserving all other OperatorT fields (builtin_options, custom_options, + intermediates, mutating_variable_inputs, etc.) during read-modify-write. + """ + + def __init__(self, + opcode: Union[tflite.BuiltinOperator, int] = None, + inputs: List[Tensor] = None, + outputs: List[Tensor] = None, + custom_code: Optional[str] = None, + opcode_index: Optional[int] = None, + _fb: tflite.OperatorT = None): + """Initialize Operator. + + Args: + opcode: BuiltinOperator enum value or CUSTOM + inputs: List of input Tensor objects + outputs: List of output Tensor objects + custom_code: Custom operator name (for CUSTOM opcode) + opcode_index: Index into operator_codes (set during read) + _fb: Optional OperatorT for wrapping existing flatbuffer object + """ + # Use provided OperatorT or create new one + self._fb = _fb if _fb is not None else tflite.OperatorT() + self._index = None + + # Tensor lists (managed separately; _fb stores indices, not objects) + self.inputs = inputs if inputs is not None else [] + self.outputs = outputs if outputs is not None else [] + + # These are derived from OperatorCode, not stored in OperatorT directly + self._opcode = opcode + self._custom_code = custom_code + self._opcode_index = opcode_index + + @property + def opcode(self) -> Union[tflite.BuiltinOperator, int]: + """Builtin operator code.""" + return self._opcode + + @opcode.setter + def opcode(self, value: Union[tflite.BuiltinOperator, int]): + self._opcode = value + + @property + def custom_code(self) -> Optional[str]: + """Custom operator name (for CUSTOM opcode).""" + return self._custom_code + + @custom_code.setter + def custom_code(self, value: Optional[str]): + self._custom_code = value + + @property + def opcode_index(self) -> Optional[int]: + """Index into operator_codes array (from read or after build).""" + return self._opcode_index + + @opcode_index.setter + def opcode_index(self, value: Optional[int]): + self._opcode_index = value + + @property + def index(self) -> Optional[int]: + """Operator index in the subgraph's operator list.""" + return self._index + + +class Subgraph: + """Subgraph specification wrapping a SubGraphT flatbuffer object. + + Provides clean APIs for common fields (tensors, operators, inputs, outputs, + name) while preserving all other SubGraphT fields during read-modify-write. + """ + + def __init__(self, + tensors: List[Tensor] = None, + operators: List[Operator] = None, + inputs: List[Tensor] = None, + outputs: List[Tensor] = None, + name: Optional[str] = None, + _fb: tflite.SubGraphT = None): + """Initialize Subgraph. + + Args: + tensors: List of Tensor objects + operators: List of Operator objects + inputs: List of input Tensor objects + outputs: List of output Tensor objects + name: Subgraph name for debugging + _fb: Optional SubGraphT for wrapping existing flatbuffer object + """ + # Use provided SubGraphT or create new one + self._fb = _fb if _fb is not None else tflite.SubGraphT() + self._index = None + + # Lists of objects (managed separately; _fb stores indices/arrays) + self.tensors = tensors if tensors is not None else [] + self.operators = operators if operators is not None else [] + self.inputs = inputs if inputs is not None else [] + self.outputs = outputs if outputs is not None else [] + + # Set name if provided (overrides _fb value) + if name is not None: + self.name = name + + @property + def name(self) -> Optional[str]: + """Subgraph name for debugging.""" + n = self._fb.name + if isinstance(n, bytes): + return n.decode('utf-8') + return n + + @name.setter + def name(self, value: Optional[str]): + self._fb.name = value + + def add_tensor(self, **kwargs) -> Tensor: + """Add tensor imperatively and return it.""" + t = Tensor(**kwargs) + t._index = len(self.tensors) + self.tensors.append(t) + return t + + def add_operator(self, **kwargs) -> Operator: + """Add operator imperatively and return it.""" + op = Operator(**kwargs) + op._index = len(self.operators) + self.operators.append(op) + return op + + def tensor_by_name(self, name: str) -> Tensor: + """Look up a tensor by name. + + Args: + name: The tensor name to find. + + Returns: + The Tensor with the given name. + + Raises: + KeyError: If no tensor with that name exists. + """ + for t in self.tensors: + if t.name == name: + return t + raise KeyError(f"No tensor named {name!r}") + + @property + def index(self) -> Optional[int]: + """Subgraph index in the model's subgraph list. + + Returns index after read() or build(). May be None or stale after + modifications. Use with caution. + """ + return self._index + + +class Model: + """Model specification wrapping a ModelT flatbuffer object. + + Provides clean APIs for common fields (subgraphs, buffers, operator_codes, + metadata, description) while preserving all other ModelT fields during + read-modify-write. + """ + + def __init__(self, + subgraphs: List[Subgraph] = None, + buffers: _BufferList = None, + operator_codes: List[OperatorCode] = None, + metadata: dict = None, + description: Optional[str] = None, + _fb: tflite.ModelT = None): + """Initialize Model. + + Args: + subgraphs: List of Subgraph objects + buffers: BufferList for tensor data + operator_codes: List of OperatorCode objects + metadata: Dict of metadata name -> bytes + description: Model description string + _fb: Optional ModelT for wrapping existing flatbuffer object + """ + # Use provided ModelT or create new one + self._fb = _fb if _fb is not None else tflite.ModelT() + + # Lists of objects (managed separately; _fb stores arrays) + self.subgraphs = subgraphs if subgraphs is not None else [] + self.buffers = buffers if buffers is not None else _BufferList() + self.operator_codes = operator_codes if operator_codes is not None else [] + self.metadata = metadata if metadata is not None else {} + + # Set description if provided (overrides _fb value) + if description is not None: + self.description = description + + @property + def description(self) -> Optional[str]: + """Model description string.""" + d = self._fb.description + if isinstance(d, bytes): + return d.decode('utf-8') + return d + + @description.setter + def description(self, value: Optional[str]): + self._fb.description = value + + def add_subgraph(self, **kwargs) -> Subgraph: + """Add subgraph imperatively and return it.""" + sg = Subgraph(**kwargs) + sg._index = len(self.subgraphs) + self.subgraphs.append(sg) + return sg + + def build(self) -> bytearray: + """Compile to flatbuffer with automatic bookkeeping.""" + compiler = _ModelCompiler(self) + return compiler.compile() + + +def read(buffer: bytes) -> Model: + """Read a TFLite flatbuffer and return a Model object.""" + fb_model = tflite.ModelT.InitFromPackedBuf(buffer, 0) + + # Create Model wrapping the ModelT; all fields preserved in _fb + model = Model(_fb=fb_model) + + # Create all buffers first (so tensors can reference them) + for i, fb_buf in enumerate(fb_model.buffers): + buf_data = bytes(fb_buf.data) if fb_buf.data is not None else b'' + buf = Buffer(data=buf_data, index=i) + model.buffers.append(buf) + + # Read operator codes + for fb_opcode in fb_model.operatorCodes: + # Create OperatorCode wrapping the OperatorCodeT; all fields preserved in _fb + opcode = OperatorCode(_fb=fb_opcode) + model.operator_codes.append(opcode) + + # Read subgraphs + for sg_idx, fb_sg in enumerate(fb_model.subgraphs): + # Create Subgraph wrapping the SubGraphT; all fields preserved in _fb + sg = Subgraph(_fb=fb_sg) + sg._index = sg_idx + + # Read tensors + for tensor_idx, fb_tensor in enumerate(fb_sg.tensors): + # Resolve buffer reference + # Buffer 0 is the empty buffer (TFLite convention), so treat it as None + buf = None if fb_tensor.buffer == 0 else model.buffers[fb_tensor.buffer] + + # Read quantization parameters if present + quant = None + if fb_tensor.quantization: + fb_quant = fb_tensor.quantization + if fb_quant.scale is not None and len(fb_quant.scale) > 0: + scales = list(fb_quant.scale) + # Copy zero_points as-is, don't expand (per review feedback) + zeros = list( + fb_quant.zeroPoint) if fb_quant.zeroPoint is not None else [0] + # Copy axis if: (1) it's non-zero, or (2) there are multiple scales. + # This preserves per-channel quant with 1 channel (axis non-zero, 1 scale) + # while treating default axis=0 with 1 scale as per-tensor (axis=None). + axis = fb_quant.quantizedDimension + if axis == 0 and len(scales) == 1: + axis = None + quant = Quantization(scales=scales, zero_points=zeros, axis=axis) + + # Create Tensor wrapping the TensorT; all fields preserved in _fb + tensor = Tensor(_fb=fb_tensor, buffer=buf, quantization=quant) + tensor._index = tensor_idx + + sg.tensors.append(tensor) + + # Read operators + for fb_op in fb_sg.operators: + # Get operator code info + opcode_obj = model.operator_codes[fb_op.opcodeIndex] + + # Resolve tensor indices to Tensor objects + inputs = [sg.tensors[i] + for i in fb_op.inputs] if fb_op.inputs is not None else [] + outputs = [sg.tensors[i] + for i in fb_op.outputs] if fb_op.outputs is not None else [] + + # Create Operator wrapping the OperatorT; all fields preserved in _fb + op = Operator( + _fb=fb_op, + opcode=opcode_obj.builtin_code, + inputs=inputs, + outputs=outputs, + custom_code=opcode_obj.custom_code, + opcode_index=fb_op.opcodeIndex, + ) + sg.operators.append(op) + + # Read subgraph inputs/outputs + if fb_sg.inputs is not None and len(fb_sg.inputs) > 0: + sg.inputs = [sg.tensors[i] for i in fb_sg.inputs] + if fb_sg.outputs is not None and len(fb_sg.outputs) > 0: + sg.outputs = [sg.tensors[i] for i in fb_sg.outputs] + + model.subgraphs.append(sg) + + # Read metadata + if fb_model.metadata: + for entry in fb_model.metadata: + # Decode metadata name + name = entry.name + if isinstance(name, bytes): + name = name.decode('utf-8') + + # Get metadata value from buffer + buffer = fb_model.buffers[entry.buffer] + value = bytes(buffer.data) if buffer.data is not None else b'' + + model.metadata[name] = value + + return model + + +def _dtype_to_numpy(dtype: tflite.TensorType) -> np.dtype: + """Convert TFLite dtype to numpy dtype.""" + type_map = { + tflite.TensorType.INT8: np.int8, + tflite.TensorType.INT16: np.int16, + tflite.TensorType.INT32: np.int32, + tflite.TensorType.INT64: np.int64, + tflite.TensorType.UINT8: np.uint8, + tflite.TensorType.FLOAT32: np.float32, + } + return type_map.get(dtype, np.uint8) + + +class _ModelCompiler: + """Internal: compiles Model to flatbuffer with automatic bookkeeping.""" + + def __init__(self, model: Model): + self.model = model + self._buffers = [] + self._buffer_map = {} # Map Buffer object id to index + self._operator_codes = {} + + def compile(self) -> bytearray: + """Compile model using backing ModelT, preserving all fields.""" + # Use the backing ModelT directly---this preserves all fields we don't + # explicitly handle (version, signature_defs, etc.) + root = self.model._fb + + # Initialize buffers + # If model.buffers exists (from read()), preserve those buffers + if self.model.buffers: + for buf in self.model.buffers: + fb_buf = tflite.BufferT() + fb_buf.data = list(buf.data) if buf.data else [] + self._buffers.append(fb_buf) + self._buffer_map[id(buf)] = buf.index + else: + # Creating model from scratch: initialize buffer 0 as empty (TFLite convention) + empty_buffer = tflite.BufferT() + empty_buffer.data = [] + self._buffers = [empty_buffer] + # Note: buffer 0 should not be in _buffer_map since tensors without data use it + + # Auto-collect and register operator codes + self._collect_operator_codes() + root.operatorCodes = list(self._operator_codes.values()) + + # Process subgraphs + root.subgraphs = [] + for sg in self.model.subgraphs: + root.subgraphs.append(self._compile_subgraph(sg)) + + # Process buffers + root.buffers = self._buffers + + # Process metadata + root.metadata = self._compile_metadata() + + # Pack and return + builder = flatbuffers.Builder(4 * 2**20) + builder.Finish(root.Pack(builder)) + return builder.Output() + + def _collect_operator_codes(self): + """Scan all operators and build operator code table.""" + # Build lookup from existing OperatorCodes (from read()) to reuse their _fb + existing_opcodes = { + (oc.builtin_code, oc.custom_code): oc + for oc in self.model.operator_codes + } + + for sg in self.model.subgraphs: + for op in sg.operators: + key = (op.opcode, op.custom_code) + if key not in self._operator_codes: + # Reuse existing OperatorCodeT if available (preserves deprecated_builtin_code) + if key in existing_opcodes: + self._operator_codes[key] = existing_opcodes[key]._fb + else: + # Create new OperatorCodeT for newly added operators + opcode = tflite.OperatorCodeT() + opcode.builtinCode = op.opcode + if op.custom_code: + opcode.customCode = op.custom_code + self._operator_codes[key] = opcode + + def _compile_subgraph(self, sg: Subgraph) -> tflite.SubGraphT: + """Compile subgraph using backing SubGraphT, preserving all fields.""" + # Use the backing SubGraphT directly---this preserves all fields we don't + # explicitly handle (debug_metadata_index, etc.) + sg_t = sg._fb + + # Collect all tensors (from tensor list and inline in operators) + all_tensors = list(sg.tensors) + tensor_to_index = {} + for i, t in enumerate(all_tensors): + t._index = i + tensor_to_index[id(t)] = i + + # Extract inline tensors from operators and subgraph inputs/outputs + inline_sources = [op.inputs + op.outputs for op in sg.operators] + inline_sources.append(sg.inputs) + inline_sources.append(sg.outputs) + for source in inline_sources: + for tensor in source: + if id(tensor) not in tensor_to_index: + tensor._index = len(all_tensors) + tensor_to_index[id(tensor)] = tensor._index + all_tensors.append(tensor) + + # Compile all tensors + sg_t.tensors = [] + for tensor in all_tensors: + sg_t.tensors.append(self._compile_tensor(tensor)) + + # Compile operators + sg_t.operators = [] + for op in sg.operators: + sg_t.operators.append(self._compile_operator(op, tensor_to_index)) + + # Set subgraph inputs/outputs + sg_t.inputs = [tensor_to_index[id(t)] for t in sg.inputs] + sg_t.outputs = [tensor_to_index[id(t)] for t in sg.outputs] + + return sg_t + + def _compile_operator(self, op: Operator, + tensor_to_index: dict) -> tflite.OperatorT: + """Compile operator using backing OperatorT, preserving all fields.""" + # Use the backing OperatorT directly---this preserves all fields we don't + # explicitly handle (builtin_options, custom_options, intermediates, etc.) + op_t = op._fb + + # Get opcode index + key = (op.opcode, op.custom_code) + opcode_index = list(self._operator_codes.keys()).index(key) + op_t.opcodeIndex = opcode_index + + # Resolve tensor references to indices + op_t.inputs = [tensor_to_index[id(inp)] for inp in op.inputs] + op_t.outputs = [tensor_to_index[id(outp)] for outp in op.outputs] + + return op_t + + def _compile_tensor(self, tensor: Tensor) -> tflite.TensorT: + """Compile tensor using backing TensorT, preserving all fields.""" + # Use the backing TensorT directly---this preserves all fields we don't + # explicitly handle (is_variable, sparsity, shape_signature, has_rank, etc.) + t = tensor._fb + + # Handle buffer assignment + if tensor.buffer is None: + # No data: use buffer 0 + t.buffer = 0 + else: + # Has buffer: get or create index for it + buf_id = id(tensor.buffer) + if buf_id not in self._buffer_map: + # First time seeing this buffer, add it + fb_buf = tflite.BufferT() + fb_buf.data = list(tensor.buffer.data) + self._buffers.append(fb_buf) + buf_index = len(self._buffers) - 1 + self._buffer_map[buf_id] = buf_index + tensor.buffer.index = buf_index + t.buffer = self._buffer_map[buf_id] + + # Sync quantization: merge our Quantization object into _fb.quantization + if tensor.quantization: + if t.quantization is None: + t.quantization = tflite.QuantizationParametersT() + # Update only the fields we manage; other fields (min, max, details) + # are preserved from the original _fb.quantization + q = tensor.quantization + scales = [q.scales] if isinstance(q.scales, (int, float)) else q.scales + zeros = [q.zero_points] if isinstance(q.zero_points, + int) else q.zero_points + t.quantization.scale = scales + t.quantization.zeroPoint = zeros + if q.axis is not None: + t.quantization.quantizedDimension = q.axis + + return t + + def _compile_metadata(self): + """Compile metadata, creating buffers for metadata values.""" + if not self.model.metadata: + return [] + + metadata_entries = [] + for name, value in self.model.metadata.items(): + # Create buffer for metadata value + buf = tflite.BufferT() + buf.data = list(value) if isinstance(value, bytes) else list(value) + self._buffers.append(buf) + buf_index = len(self._buffers) - 1 + + # Create metadata entry + entry = tflite.MetadataT() + entry.name = name + entry.buffer = buf_index + metadata_entries.append(entry) + + return metadata_entries diff --git a/tensorflow/lite/micro/compression/model_editor_test.py b/tensorflow/lite/micro/compression/model_editor_test.py new file mode 100644 index 00000000000..1f036b1a1e6 --- /dev/null +++ b/tensorflow/lite/micro/compression/model_editor_test.py @@ -0,0 +1,1409 @@ +# Copyright 2026 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for model_editor module. +""" + +import numpy as np +import unittest +from tflite_micro.tensorflow.lite.python import schema_py_generated as tflite +from tflite_micro.tensorflow.lite.micro.compression import model_editor +from tflite_micro.tensorflow.lite.micro.compression.model_editor import ( + Buffer, Model, Operator, OperatorCode, Quantization, Subgraph, Tensor) + + +class TestBasicModel(unittest.TestCase): + """Test basic model with tensors and operators.""" + + @classmethod + def setUpClass(cls): + """Build model once for all tests in this class.""" + cls.input_data = np.array([[1, 2, 3, 4, 5]], dtype=np.int8) + cls.weights_data = np.array([[1], [2], [3], [4], [5]], dtype=np.int8) + + cls.model = Model( + description="Test model", + subgraphs=[ + Subgraph(operators=[ + Operator(opcode=tflite.BuiltinOperator.FULLY_CONNECTED, + inputs=[ + Tensor(shape=(1, 5), + dtype=tflite.TensorType.INT8, + data=cls.input_data, + name="input"), + Tensor(shape=(5, 1), + dtype=tflite.TensorType.INT8, + data=cls.weights_data, + name="weights") + ], + outputs=[ + Tensor(shape=(1, 1), + dtype=tflite.TensorType.INT8, + name="output") + ]) + ]) + ]) + + # Build the model to a flatbuffer byte array. This exercises the + # model_editor's build path, which converts the high-level Model API + # representation into the binary TFLite format. + fb = cls.model.build() + + # Read the flatbuffer back through model_editor.read() to create a + # loopback model. This exercises the read path, which parses the + # flatbuffer and reconstructs a high-level Model representation. The + # loopback model should be semantically equivalent to cls.model, + # demonstrating that build() and read() are inverse operations. + cls.loopback_model = model_editor.read(fb) + + # Parse the same flatbuffer using the low-level TFLite schema interface + # (ModelT from schema_py_generated). This provides direct access to the + # raw flatbuffer structure, allowing us to verify that model_editor + # encodes data correctly at the binary level. We compare fb_model + # (low-level) against loopback_model (high-level) to ensure both + # representations are consistent. + cls.fb_model = tflite.ModelT.InitFromPackedBuf(fb, 0) + + def test_description(self): + """Verify model description is preserved through loopback.""" + self.assertEqual(self.fb_model.description, b"Test model") + self.assertEqual(self.loopback_model.description, "Test model") + + def test_counts(self): + """Verify subgraph, tensor, and operator counts.""" + self.assertEqual(len(self.fb_model.subgraphs), 1) + self.assertEqual(len(self.loopback_model.subgraphs), 1) + + fb_sg = self.fb_model.subgraphs[0] + loopback_sg = self.loopback_model.subgraphs[0] + + self.assertEqual(len(fb_sg.tensors), 3) + self.assertEqual(len(loopback_sg.tensors), 3) + + self.assertEqual(len(fb_sg.operators), 1) + self.assertEqual(len(loopback_sg.operators), 1) + + def test_tensor_names(self): + """Verify tensor names are preserved.""" + fb_sg = self.fb_model.subgraphs[0] + loopback_sg = self.loopback_model.subgraphs[0] + + # Check that all expected tensor names are present + fb_names = {t.name for t in fb_sg.tensors} + self.assertEqual(fb_names, {b"input", b"weights", b"output"}) + + loopback_names = {t.name for t in loopback_sg.tensors} + self.assertEqual(loopback_names, {"input", "weights", "output"}) + + def test_tensor_properties(self): + """Verify tensor shapes and dtypes.""" + fb_sg = self.fb_model.subgraphs[0] + loopback_sg = self.loopback_model.subgraphs[0] + + # Input tensor + input_fb = next(t for t in fb_sg.tensors if t.name == b"input") + input_loopback = next(t for t in loopback_sg.tensors if t.name == "input") + self.assertEqual(list(input_fb.shape), [1, 5]) + self.assertEqual(input_loopback.shape, (1, 5)) + self.assertEqual(input_fb.type, tflite.TensorType.INT8) + self.assertEqual(input_loopback.dtype, tflite.TensorType.INT8) + + # Weights tensor + weights_fb = next(t for t in fb_sg.tensors if t.name == b"weights") + weights_loopback = next(t for t in loopback_sg.tensors + if t.name == "weights") + self.assertEqual(list(weights_fb.shape), [5, 1]) + self.assertEqual(weights_loopback.shape, (5, 1)) + self.assertEqual(weights_fb.type, tflite.TensorType.INT8) + self.assertEqual(weights_loopback.dtype, tflite.TensorType.INT8) + + # Output tensor + output_fb = next(t for t in fb_sg.tensors if t.name == b"output") + output_loopback = next(t for t in loopback_sg.tensors + if t.name == "output") + self.assertEqual(list(output_fb.shape), [1, 1]) + self.assertEqual(output_loopback.shape, (1, 1)) + self.assertEqual(output_fb.type, tflite.TensorType.INT8) + self.assertEqual(output_loopback.dtype, tflite.TensorType.INT8) + + def test_tensor_data(self): + """Verify tensor data and buffer access.""" + fb_sg = self.fb_model.subgraphs[0] + loopback_sg = self.loopback_model.subgraphs[0] + + # Input tensor data + input_buffer = self.fb_model.buffers[fb_sg.tensors[0].buffer] + self.assertIsNotNone(input_buffer.data) + self.assertEqual(bytes(input_buffer.data), self.input_data.tobytes()) + + self.assertIsNotNone(loopback_sg.tensors[0].array) + np.testing.assert_array_equal(loopback_sg.tensors[0].array, + self.input_data) + + # Weights tensor data + weights_buffer = self.fb_model.buffers[fb_sg.tensors[1].buffer] + self.assertIsNotNone(weights_buffer.data) + self.assertEqual(bytes(weights_buffer.data), self.weights_data.tobytes()) + + self.assertIsNotNone(loopback_sg.tensors[1].array) + np.testing.assert_array_equal(loopback_sg.tensors[1].array, + self.weights_data) + + # Output tensor has no data + self.assertEqual(fb_sg.tensors[2].buffer, 0) + self.assertIsNone(loopback_sg.tensors[2].array) + + def test_buffer_allocation(self): + """Verify buffer allocation and zero convention.""" + fb_sg = self.fb_model.subgraphs[0] + loopback_sg = self.loopback_model.subgraphs[0] + + # Exact buffer count: buffer 0 (empty) + input + weights = 3 total + self.assertEqual(len(self.fb_model.buffers), 3) + self.assertEqual(len(self.loopback_model.buffers), 3) + + # Buffer 0 is empty + buffer_zero = self.fb_model.buffers[0] + self.assertTrue(buffer_zero.data is None or len(buffer_zero.data) == 0) + + # Verify each buffer is referenced by exactly the expected tensor + # Buffer 0 -> output tensor (no data) + output_tensor = next(t for t in fb_sg.tensors if t.name == b"output") + self.assertEqual(output_tensor.buffer, 0) + + # Buffer 1 and 2 -> input and weights (order may vary) + input_tensor = next(t for t in fb_sg.tensors if t.name == b"input") + weights_tensor = next(t for t in fb_sg.tensors if t.name == b"weights") + self.assertNotEqual(input_tensor.buffer, 0) + self.assertNotEqual(weights_tensor.buffer, 0) + self.assertIn(input_tensor.buffer, [1, 2]) + self.assertIn(weights_tensor.buffer, [1, 2]) + + # Tensors with data point to non-zero buffers in loopback model + loopback_input_tensor = next(t for t in loopback_sg.tensors + if t.name == "input") + self.assertIsNotNone(loopback_input_tensor.buffer) + self.assertIsNotNone(loopback_input_tensor.buffer.index) + self.assertNotEqual(loopback_input_tensor.buffer.index, 0) + self.assertEqual(len(loopback_input_tensor.buffer.data), 5) + self.assertEqual(bytes(loopback_input_tensor.buffer.data), + self.input_data.tobytes()) + + def test_operator_references(self): + """Verify operators reference correct tensors.""" + fb_sg = self.fb_model.subgraphs[0] + loopback_sg = self.loopback_model.subgraphs[0] + + # Operator input/output references + self.assertEqual(len(fb_sg.operators[0].inputs), 2) + self.assertEqual([t.name for t in loopback_sg.operators[0].inputs], + ["input", "weights"]) + + self.assertEqual(len(fb_sg.operators[0].outputs), 1) + self.assertEqual([t.name for t in loopback_sg.operators[0].outputs], + ["output"]) + + # Operator indices are in bounds + num_tensors = len(fb_sg.tensors) + for idx in list(fb_sg.operators[0].inputs) + list( + fb_sg.operators[0].outputs): + self.assertGreaterEqual(idx, 0) + self.assertLess(idx, num_tensors) + + def test_operator_codes(self): + """Verify operator code table is correctly populated.""" + fb_sg = self.fb_model.subgraphs[0] + loopback_sg = self.loopback_model.subgraphs[0] + + self.assertIsNotNone(self.fb_model.operatorCodes) + self.assertEqual(len(self.fb_model.operatorCodes), 1) + self.assertEqual(self.fb_model.operatorCodes[0].builtinCode, + tflite.BuiltinOperator.FULLY_CONNECTED) + + self.assertEqual(len(self.loopback_model.operator_codes), 1) + self.assertIsNotNone(loopback_sg.operators[0].opcode_index) + loopback_opcode = self.loopback_model.operator_codes[ + loopback_sg.operators[0].opcode_index] + self.assertEqual(loopback_opcode.builtin_code, + tflite.BuiltinOperator.FULLY_CONNECTED) + + +class TestAdvancedModel(unittest.TestCase): + """Test multiple operators, custom ops, shared tensors, and mixed references.""" + + @classmethod + def setUpClass(cls): + """Build model once for all tests in this class.""" + cls.input_data = np.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]], dtype=np.int8) + cls.weights_data = np.array( + [[1], [2], [3], [4], [5], [6], [7], [8], [9], [10]], dtype=np.int8) + cls.bias_data = np.array([10], dtype=np.int8) + # Int16 data to test endianness: values that will show byte order issues + cls.int16_data = np.array([256, 512, 1024], + dtype=np.int16) # 0x0100, 0x0200, 0x0400 + + # Pre-declare shared tensor (output of FC, input to custom op) + cls.hidden = Tensor(shape=(1, 1), + dtype=tflite.TensorType.INT8, + name="hidden") + + # Create explicit shared buffer to test buffer sharing between tensors + cls.shared_buffer_data = np.array([100, 127], dtype=np.int8) + cls.shared_buf = Buffer(data=cls.shared_buffer_data.tobytes()) + + cls.model = Model( + description="Advanced model", + metadata={ + "version": b"1.0.0", + "author": b"test_suite", + "custom_data": bytes([0xDE, 0xAD, 0xBE, 0xEF]) + }, + subgraphs=[ + Subgraph( + tensors=[ + cls.hidden, # Mixed: pre-declared shared tensor + # Int16 tensor to test endianness + Tensor(shape=(3, ), + dtype=tflite.TensorType.INT16, + data=cls.int16_data, + name="int16_tensor"), + # Two tensors sharing same buffer to test buffer deduplication + Tensor(shape=(2, ), + dtype=tflite.TensorType.INT8, + buffer=cls.shared_buf, + name="shared_buf_tensor1"), + Tensor(shape=(2, ), + dtype=tflite.TensorType.INT8, + buffer=cls.shared_buf, + name="shared_buf_tensor2") + ], + operators=[ + # Multiple operators: FULLY_CONNECTED + Operator( + opcode=tflite.BuiltinOperator.FULLY_CONNECTED, + inputs=[ + Tensor(shape=(1, 10), + dtype=tflite.TensorType.INT8, + data=cls.input_data, + name="input"), + Tensor(shape=(10, 1), + dtype=tflite.TensorType.INT8, + data=cls.weights_data, + name="weights") + ], + outputs=[cls.hidden + ] # Shared: reference to pre-declared + ), + # Custom operator + Operator( + opcode=tflite.BuiltinOperator.CUSTOM, + custom_code="MyCustomOp", + inputs=[cls.hidden], # Shared: reuse hidden tensor + outputs=[ + Tensor(shape=(1, 1), + dtype=tflite.TensorType.INT8, + name="processed") + ]), + # Multiple operators: ADD + Operator( + opcode=tflite.BuiltinOperator.ADD, + inputs=[ + Tensor( + shape=(1, 1), + dtype=tflite.TensorType.INT8, + name="processed_ref" # Mixed: inline tensor + ), + Tensor(shape=(1, 1), + dtype=tflite.TensorType.INT8, + data=cls.bias_data, + name="bias") + ], + outputs=[ + Tensor(shape=(1, 1), + dtype=tflite.TensorType.INT8, + name="output") + ]) + ]) + ]) + + fb = cls.model.build() + cls.loopback_model = model_editor.read(fb) + cls.fb_model = tflite.ModelT.InitFromPackedBuf(fb, 0) + + def test_operator_counts(self): + """Verify correct number of operators.""" + fb_sg = self.fb_model.subgraphs[0] + loopback_sg = self.loopback_model.subgraphs[0] + + self.assertEqual(len(fb_sg.operators), 3) + self.assertEqual(len(loopback_sg.operators), 3) + + def test_operator_code_table(self): + """Verify operator code table contains all operator types.""" + self.assertEqual(len(self.fb_model.operatorCodes), 3) + self.assertEqual(len(self.loopback_model.operator_codes), 3) + + opcodes_fb = {op.builtinCode for op in self.fb_model.operatorCodes} + self.assertIn(tflite.BuiltinOperator.FULLY_CONNECTED, opcodes_fb) + self.assertIn(tflite.BuiltinOperator.CUSTOM, opcodes_fb) + self.assertIn(tflite.BuiltinOperator.ADD, opcodes_fb) + + opcodes_loopback = { + op.builtin_code + for op in self.loopback_model.operator_codes + } + self.assertIn(tflite.BuiltinOperator.FULLY_CONNECTED, opcodes_loopback) + self.assertIn(tflite.BuiltinOperator.CUSTOM, opcodes_loopback) + self.assertIn(tflite.BuiltinOperator.ADD, opcodes_loopback) + + def test_custom_operator(self): + """Verify custom operator code preservation.""" + loopback_sg = self.loopback_model.subgraphs[0] + + # Custom code in operator code table + custom_opcode_fb = next(op for op in self.fb_model.operatorCodes + if op.builtinCode == tflite.BuiltinOperator.CUSTOM) + self.assertEqual(custom_opcode_fb.customCode, b"MyCustomOp") + + custom_opcode_loopback = next( + op for op in self.loopback_model.operator_codes + if op.builtin_code == tflite.BuiltinOperator.CUSTOM) + self.assertEqual(custom_opcode_loopback.custom_code, "MyCustomOp") + + # Custom operator references custom code + custom_op_loopback = loopback_sg.operators[1] + self.assertEqual(custom_op_loopback.opcode, tflite.BuiltinOperator.CUSTOM) + self.assertEqual(custom_op_loopback.custom_code, "MyCustomOp") + + def test_shared_tensor_references(self): + """Verify tensors shared between operators.""" + fb_sg = self.fb_model.subgraphs[0] + loopback_sg = self.loopback_model.subgraphs[0] + + # Hidden tensor is at index 0 (pre-declared) + self.assertEqual(fb_sg.tensors[0].name, b"hidden") + self.assertEqual(loopback_sg.tensors[0].name, "hidden") + + # FC operator outputs to hidden + self.assertEqual([t.name for t in loopback_sg.operators[0].outputs], + ["hidden"]) + + # Custom operator inputs from hidden + self.assertEqual([t.name for t in loopback_sg.operators[1].inputs], + ["hidden"]) + + # Same Tensor object is referenced by both operators + fc_output = loopback_sg.operators[0].outputs[0] + custom_input = loopback_sg.operators[1].inputs[0] + self.assertIs(fc_output, custom_input) + + def test_mixed_tensor_references(self): + """Verify mix of pre-declared and inline tensors.""" + fb_sg = self.fb_model.subgraphs[0] + loopback_sg = self.loopback_model.subgraphs[0] + + # Total: hidden, int16_tensor, shared_buf_tensor1, shared_buf_tensor2 (pre-declared) + # + input, weights, processed, processed_ref, bias, output (inline from operators) + self.assertEqual(len(fb_sg.tensors), 10) + self.assertEqual(len(loopback_sg.tensors), 10) + + def test_int16_endianness(self): + """Verify int16 data is stored in little-endian byte order.""" + fb_sg = self.fb_model.subgraphs[0] + loopback_sg = self.loopback_model.subgraphs[0] + + # Find int16 tensor by name + int16_tensor_fb = next(t for t in fb_sg.tensors + if t.name == b"int16_tensor") + int16_tensor_loopback = next(t for t in loopback_sg.tensors + if t.name == "int16_tensor") + + # Verify dtype + self.assertEqual(int16_tensor_fb.type, tflite.TensorType.INT16) + self.assertEqual(int16_tensor_loopback.dtype, tflite.TensorType.INT16) + + # Check flatbuffer buffer has correct little-endian bytes + # For [256, 512, 1024] = [0x0100, 0x0200, 0x0400] + # Little-endian bytes: [0x00, 0x01, 0x00, 0x02, 0x00, 0x04] + int16_buffer_fb = self.fb_model.buffers[int16_tensor_fb.buffer] + self.assertIsNotNone(int16_buffer_fb.data) + expected_bytes = self.int16_data.astype(np.int16).astype('buffer mapping from flatbuffer + metadata_map_fb = {} + for entry in self.fb_model.metadata: + buffer_idx = entry.buffer + self.assertLess(buffer_idx, len(self.fb_model.buffers)) + buffer = self.fb_model.buffers[buffer_idx] + if buffer.data is not None: + metadata_map_fb[entry.name] = bytes(buffer.data) + + # Verify flatbuffer metadata values + self.assertEqual(metadata_map_fb[b"version"], b"1.0.0") + self.assertEqual(metadata_map_fb[b"author"], b"test_suite") + self.assertEqual(metadata_map_fb[b"custom_data"], + bytes([0xDE, 0xAD, 0xBE, 0xEF])) + + # Check loopback model metadata + self.assertIsNotNone(self.loopback_model.metadata) + self.assertEqual(len(self.loopback_model.metadata), 3) + + # Verify loopback metadata values (decoded from bytes) + self.assertEqual(self.loopback_model.metadata["version"], b"1.0.0") + self.assertEqual(self.loopback_model.metadata["author"], b"test_suite") + self.assertEqual(self.loopback_model.metadata["custom_data"], + bytes([0xDE, 0xAD, 0xBE, 0xEF])) + + def test_buffer_allocation(self): + """Verify no orphaned buffers and shared buffer deduplication.""" + fb_sg = self.fb_model.subgraphs[0] + loopback_sg = self.loopback_model.subgraphs[0] + + # Collect all buffer references (from tensors and metadata) + referenced_buffers = {0} # Buffer 0 is special (always referenced) + + # Collect buffer references from tensors + for tensor in fb_sg.tensors: + referenced_buffers.add(tensor.buffer) + + # Collect buffer references from metadata + for entry in self.fb_model.metadata: + referenced_buffers.add(entry.buffer) + + # Verify no orphaned buffers (all buffers are referenced) + for i in range(len(self.fb_model.buffers)): + self.assertIn( + i, referenced_buffers, + f"Buffer {i} is orphaned (not referenced by any tensor or metadata)") + + # Verify shared buffer deduplication: two tensors share one buffer + tensor1_fb = next(t for t in fb_sg.tensors + if t.name == b"shared_buf_tensor1") + tensor2_fb = next(t for t in fb_sg.tensors + if t.name == b"shared_buf_tensor2") + + # Both tensors should point to the same buffer index + self.assertEqual(tensor1_fb.buffer, tensor2_fb.buffer) + self.assertNotEqual(tensor1_fb.buffer, 0) + + # Verify loopback preserves shared buffer (same Buffer object) + tensor1_loopback = next(t for t in loopback_sg.tensors + if t.name == "shared_buf_tensor1") + tensor2_loopback = next(t for t in loopback_sg.tensors + if t.name == "shared_buf_tensor2") + + self.assertIs(tensor1_loopback.buffer, tensor2_loopback.buffer) + self.assertEqual(bytes(tensor1_loopback.buffer.data), + self.shared_buffer_data.tobytes()) + self.assertEqual(bytes(tensor2_loopback.buffer.data), + self.shared_buffer_data.tobytes()) + + +class TestQuantization(unittest.TestCase): + """Test per-tensor and per-channel quantization parameters.""" + + @classmethod + def setUpClass(cls): + """Build model once for all tests in this class.""" + # Per-channel quantization parameters + cls.per_channel_scales = [0.1, 0.2, 0.3, 0.4] + cls.per_channel_zeros = [0, 1, 2, 3] + + cls.model = Model( + description="Quantization test model", + subgraphs=[ + Subgraph(tensors=[ + # Per-tensor quantized tensor (single scale/zero_point) + Tensor(shape=(1, 10), + dtype=tflite.TensorType.INT8, + data=np.ones((1, 10), dtype=np.int8), + name="per_tensor", + quantization=Quantization(scales=0.5, zero_points=10)), + # Per-channel quantized tensor (array of scales/zero_points, axis) + Tensor(shape=(4, 10), + dtype=tflite.TensorType.INT8, + data=np.ones((4, 10), dtype=np.int8), + name="per_channel", + quantization=Quantization( + scales=cls.per_channel_scales, + zero_points=cls.per_channel_zeros, + axis=0)) + ]) + ]) + + fb = cls.model.build() + cls.loopback_model = model_editor.read(fb) + cls.fb_model = tflite.ModelT.InitFromPackedBuf(fb, 0) + + def test_per_tensor_quantization_flatbuffer(self): + """Verify per-tensor quantization in flatbuffer encoding.""" + fb_sg = self.fb_model.subgraphs[0] + + tensor = next(t for t in fb_sg.tensors if t.name == b"per_tensor") + self.assertIsNotNone(tensor.quantization) + + # Scale and zero_point encoded as single-element arrays + self.assertIsNotNone(tensor.quantization.scale) + self.assertEqual(len(tensor.quantization.scale), 1) + self.assertEqual(tensor.quantization.scale[0], 0.5) + + self.assertIsNotNone(tensor.quantization.zeroPoint) + self.assertEqual(len(tensor.quantization.zeroPoint), 1) + self.assertEqual(tensor.quantization.zeroPoint[0], 10) + + def test_per_tensor_quantization_loopback(self): + """Verify per-tensor quantization in loopback model.""" + loopback_sg = self.loopback_model.subgraphs[0] + + tensor = next(t for t in loopback_sg.tensors if t.name == "per_tensor") + self.assertIsNotNone(tensor.quantization) + + # Read back as lists + self.assertEqual(tensor.quantization.scales, [0.5]) + self.assertEqual(tensor.quantization.zero_points, [10]) + self.assertIsNone(tensor.quantization.axis) + + def test_per_channel_quantization_flatbuffer(self): + """Verify per-channel quantization in flatbuffer encoding.""" + fb_sg = self.fb_model.subgraphs[0] + + tensor = next(t for t in fb_sg.tensors if t.name == b"per_channel") + self.assertIsNotNone(tensor.quantization) + + # All scales encoded + self.assertIsNotNone(tensor.quantization.scale) + self.assertEqual(len(tensor.quantization.scale), 4) + self.assertEqual(list(tensor.quantization.scale), self.per_channel_scales) + + # All zero_points encoded + self.assertIsNotNone(tensor.quantization.zeroPoint) + self.assertEqual(len(tensor.quantization.zeroPoint), 4) + self.assertEqual(list(tensor.quantization.zeroPoint), + self.per_channel_zeros) + + # Axis encoded as quantizedDimension + self.assertEqual(tensor.quantization.quantizedDimension, 0) + + def test_per_channel_quantization_loopback(self): + """Verify per-channel quantization in loopback model.""" + loopback_sg = self.loopback_model.subgraphs[0] + + tensor = next(t for t in loopback_sg.tensors if t.name == "per_channel") + self.assertIsNotNone(tensor.quantization) + + # Read back as lists + self.assertEqual(tensor.quantization.scales, self.per_channel_scales) + self.assertEqual(tensor.quantization.zero_points, self.per_channel_zeros) + self.assertEqual(tensor.quantization.axis, 0) + + +class TestReadModifyWrite(unittest.TestCase): + """Test read-modify-write workflows.""" + + @classmethod + def setUpClass(cls): + """Create a simple base model for modification tests.""" + cls.original_data = np.array([[1, 2, 3]], dtype=np.int8) + cls.model = Model( + description="Base model", + metadata={"original": b"metadata"}, + subgraphs=[ + Subgraph(tensors=[ + Tensor(shape=(1, 3), + dtype=tflite.TensorType.INT8, + data=cls.original_data, + name="weights"), + Tensor( + shape=(1, 3), dtype=tflite.TensorType.INT8, name="input"), + Tensor( + shape=(1, 3), dtype=tflite.TensorType.INT8, name="output") + ]) + ]) + + cls.fb = cls.model.build() + + def test_modify_tensor_data(self): + """Read model, modify tensor data, write back, verify.""" + # Read the model + model2 = model_editor.read(self.fb) + + # Modify tensor data using array setter (high-level API) + weights_tensor = next(t for t in model2.subgraphs[0].tensors + if t.name == "weights") + new_data = np.array([[10, 20, 30]], dtype=np.int8) + weights_tensor.array = new_data # Uses array setter + + # Build modified model + fb2 = model2.build() + + # Read back and verify modification + model3 = model_editor.read(fb2) + modified_weights = next(t for t in model3.subgraphs[0].tensors + if t.name == "weights") + np.testing.assert_array_equal(modified_weights.array, new_data) + + # Verify other tensors unchanged + self.assertEqual(len(model3.subgraphs[0].tensors), 3) + + def test_add_tensor_and_operator(self): + """Read model, add new tensor and operator, write back, verify.""" + # Read the model + model2 = model_editor.read(self.fb) + sg = model2.subgraphs[0] + + # Get existing tensors + input_tensor = next(t for t in sg.tensors if t.name == "input") + output_tensor = next(t for t in sg.tensors if t.name == "output") + + # Add new tensor using imperative API + new_weights = np.array([[5, 10, 15]], dtype=np.int8) + new_weights_tensor = sg.add_tensor(shape=(1, 3), + dtype=tflite.TensorType.INT8, + data=new_weights, + name="new_weights") + + # Add new operator using imperative API + sg.add_operator(opcode=tflite.BuiltinOperator.ADD, + inputs=[input_tensor, new_weights_tensor], + outputs=[output_tensor]) + + # Build modified model + fb2 = model2.build() + + # Read back and verify additions + model3 = model_editor.read(fb2) + sg3 = model3.subgraphs[0] + + # Verify tensor was added + self.assertEqual(len(sg3.tensors), 4) + added_tensor = next(t for t in sg3.tensors if t.name == "new_weights") + self.assertIsNotNone(added_tensor) + np.testing.assert_array_equal(added_tensor.array, new_weights) + + # Verify operator was added + self.assertEqual(len(sg3.operators), 1) + added_op = sg3.operators[0] + self.assertEqual([t.name for t in added_op.inputs], + ["input", "new_weights"]) + self.assertEqual([t.name for t in added_op.outputs], ["output"]) + + def test_modify_metadata(self): + """Read model, modify metadata, write back, verify.""" + # Read the model + model2 = model_editor.read(self.fb) + + # Modify existing metadata + model2.metadata["original"] = b"modified_metadata" + + # Add new metadata + model2.metadata["new_key"] = b"new_value" + + # Build modified model + fb2 = model2.build() + + # Read back and verify modifications + model3 = model_editor.read(fb2) + + self.assertEqual(len(model3.metadata), 2) + self.assertEqual(model3.metadata["original"], b"modified_metadata") + self.assertEqual(model3.metadata["new_key"], b"new_value") + + +class TestSubgraphInputsOutputs(unittest.TestCase): + """Test subgraph inputs and outputs are set correctly.""" + + def test_subgraph_inputs_outputs_set(self): + """Verify subgraph inputs/outputs are set in the flatbuffer.""" + input_t = Tensor(shape=(1, 4), dtype=tflite.TensorType.INT8, name="input") + output_t = Tensor(shape=(1, 4), + dtype=tflite.TensorType.INT8, + name="output") + weights = Tensor( + shape=(4, 4), + dtype=tflite.TensorType.INT8, + data=np.array([[1, 2, 3, 4]] * 4, dtype=np.int8), + name="weights", + ) + + model = Model(subgraphs=[ + Subgraph( + tensors=[weights], + inputs=[input_t], + outputs=[output_t], + operators=[ + Operator( + opcode=tflite.BuiltinOperator.FULLY_CONNECTED, + inputs=[input_t, weights], + outputs=[output_t], + ) + ], + ) + ]) + + fb = model.build() + fb_model = tflite.ModelT.InitFromPackedBuf(fb, 0) + fb_sg = fb_model.subgraphs[0] + + # Verify inputs/outputs are set (as tensor indices) + self.assertEqual(len(fb_sg.inputs), 1) + self.assertEqual(len(fb_sg.outputs), 1) + + # Verify indices point to correct tensors + input_idx = fb_sg.inputs[0] + output_idx = fb_sg.outputs[0] + self.assertEqual(fb_sg.tensors[input_idx].name, b"input") + self.assertEqual(fb_sg.tensors[output_idx].name, b"output") + + def test_subgraph_inputs_outputs_loopback(self): + """Verify inputs/outputs survive read/build loopback.""" + input_t = Tensor(shape=(1, 4), dtype=tflite.TensorType.INT8, name="input") + output_t = Tensor(shape=(1, 4), + dtype=tflite.TensorType.INT8, + name="output") + weights = Tensor( + shape=(4, 4), + dtype=tflite.TensorType.INT8, + data=np.array([[1, 2, 3, 4]] * 4, dtype=np.int8), + name="weights", + ) + + model = Model(subgraphs=[ + Subgraph( + tensors=[weights], + inputs=[input_t], + outputs=[output_t], + operators=[ + Operator( + opcode=tflite.BuiltinOperator.FULLY_CONNECTED, + inputs=[input_t, weights], + outputs=[output_t], + ) + ], + ) + ]) + + fb = model.build() + loopback = model_editor.read(fb) + sg = loopback.subgraphs[0] + + # Verify high-level inputs/outputs are populated + self.assertEqual(len(sg.inputs), 1) + self.assertEqual(len(sg.outputs), 1) + self.assertEqual(sg.inputs[0].name, "input") + self.assertEqual(sg.outputs[0].name, "output") + + def test_tensor_by_name_not_found_raises(self): + """tensor_by_name raises KeyError when name not found.""" + model = Model(subgraphs=[ + Subgraph(tensors=[ + Tensor(shape=(4, ), dtype=tflite.TensorType.INT8, name="exists") + ]) + ]) + + with self.assertRaises(KeyError): + model.subgraphs[0].tensor_by_name("nonexistent") + + +class TestReadEdgeCases(unittest.TestCase): + """Test model_editor.read() with edge cases from real-world models. + + These tests construct models using the low-level TFLite schema to create + edge cases that may not be producible via model_editor.build(), but can + appear in models from other sources (e.g., TFLite converter). + """ + + def _build_model_with_schema(self, model_t): + """Build a flatbuffer from a ModelT using the low-level schema.""" + import flatbuffers + builder = flatbuffers.Builder(1024) + builder.Finish(model_t.Pack(builder)) + return bytes(builder.Output()) + + def test_read_scalar_tensor(self): + """Verify read() handles tensors with None shape (scalars). + + Some TFLite models have scalar tensors where shape is None rather than + an empty list. This can occur with constant scalars produced by certain + converters. + """ + # Build a minimal model with a scalar tensor (shape=None) + model_t = tflite.ModelT() + model_t.version = 3 + + # Buffer 0 is always empty, buffer 1 holds scalar data + buf0 = tflite.BufferT() + buf0.data = [] + buf1 = tflite.BufferT() + buf1.data = [42] # Single byte scalar value + + model_t.buffers = [buf0, buf1] + + # Create operator code + opcode = tflite.OperatorCodeT() + opcode.builtinCode = tflite.BuiltinOperator.ADD + model_t.operatorCodes = [opcode] + + # Create subgraph with scalar tensor + sg = tflite.SubGraphT() + + # Tensor with shape=None (scalar) + scalar_tensor = tflite.TensorT() + scalar_tensor.name = b"scalar" + scalar_tensor.type = tflite.TensorType.INT8 + scalar_tensor.buffer = 1 + scalar_tensor.shape = None # This is the edge case + + # Normal tensor for comparison + normal_tensor = tflite.TensorT() + normal_tensor.name = b"normal" + normal_tensor.type = tflite.TensorType.INT8 + normal_tensor.buffer = 0 + normal_tensor.shape = [1, 4] + + sg.tensors = [scalar_tensor, normal_tensor] + sg.inputs = [1] + sg.outputs = [1] + sg.operators = [] + + model_t.subgraphs = [sg] + + # Build and read + fb = self._build_model_with_schema(model_t) + model = model_editor.read(fb) + + # Verify scalar tensor was read with empty shape tuple + self.assertEqual(model.subgraphs[0].tensors[0].shape, ()) + self.assertEqual(model.subgraphs[0].tensors[0].name, "scalar") + + # Verify normal tensor shape is preserved + self.assertEqual(model.subgraphs[0].tensors[1].shape, (1, 4)) + + def test_read_operator_with_empty_inputs(self): + """Verify read() handles operators with None inputs/outputs. + + Some operators (e.g., certain control flow or custom ops) may have + empty input or output lists represented as None in the flatbuffer. + """ + model_t = tflite.ModelT() + model_t.version = 3 + + buf0 = tflite.BufferT() + buf0.data = [] + model_t.buffers = [buf0] + + # Custom op that might have unusual input/output patterns + opcode = tflite.OperatorCodeT() + opcode.builtinCode = tflite.BuiltinOperator.CUSTOM + opcode.customCode = b"NoInputOp" + model_t.operatorCodes = [opcode] + + sg = tflite.SubGraphT() + + # Single output tensor + output_tensor = tflite.TensorT() + output_tensor.name = b"output" + output_tensor.type = tflite.TensorType.INT8 + output_tensor.buffer = 0 + output_tensor.shape = [1] + + sg.tensors = [output_tensor] + sg.inputs = [] + sg.outputs = [0] + + # Operator with None inputs (edge case) + op = tflite.OperatorT() + op.opcodeIndex = 0 + op.inputs = None # This is the edge case + op.outputs = [0] + + sg.operators = [op] + model_t.subgraphs = [sg] + + # Build and read + fb = self._build_model_with_schema(model_t) + model = model_editor.read(fb) + + # Verify operator was read with empty inputs list + self.assertEqual(len(model.subgraphs[0].operators), 1) + self.assertEqual(model.subgraphs[0].operators[0].inputs, []) + self.assertEqual(len(model.subgraphs[0].operators[0].outputs), 1) + + def test_read_operator_with_empty_outputs(self): + """Verify read() handles operators with None outputs. + + Similar to empty inputs, some operators may have None outputs. + """ + model_t = tflite.ModelT() + model_t.version = 3 + + buf0 = tflite.BufferT() + buf0.data = [] + model_t.buffers = [buf0] + + opcode = tflite.OperatorCodeT() + opcode.builtinCode = tflite.BuiltinOperator.CUSTOM + opcode.customCode = b"NoOutputOp" + model_t.operatorCodes = [opcode] + + sg = tflite.SubGraphT() + + input_tensor = tflite.TensorT() + input_tensor.name = b"input" + input_tensor.type = tflite.TensorType.INT8 + input_tensor.buffer = 0 + input_tensor.shape = [1] + + sg.tensors = [input_tensor] + sg.inputs = [0] + sg.outputs = [] + + # Operator with None outputs (edge case) + op = tflite.OperatorT() + op.opcodeIndex = 0 + op.inputs = [0] + op.outputs = None # This is the edge case + + sg.operators = [op] + model_t.subgraphs = [sg] + + fb = self._build_model_with_schema(model_t) + model = model_editor.read(fb) + + self.assertEqual(len(model.subgraphs[0].operators), 1) + self.assertEqual(len(model.subgraphs[0].operators[0].inputs), 1) + self.assertEqual(model.subgraphs[0].operators[0].outputs, []) + + def test_int64_tensor(self): + """Verify INT64 tensors are correctly handled.""" + model_t = tflite.ModelT() + model_t.version = 3 + + buf0 = tflite.BufferT() + buf0.data = [] + buf1 = tflite.BufferT() + # INT64 data: [1, 2, 3, 4] as little-endian 8-byte values + int64_data = np.array([1, 2, 3, 4], dtype=np.int64) + buf1.data = list(int64_data.tobytes()) + + model_t.buffers = [buf0, buf1] + + opcode = tflite.OperatorCodeT() + opcode.builtinCode = tflite.BuiltinOperator.ADD + model_t.operatorCodes = [opcode] + + sg = tflite.SubGraphT() + tensor = tflite.TensorT() + tensor.name = b"int64_tensor" + tensor.type = tflite.TensorType.INT64 + tensor.buffer = 1 + tensor.shape = [4] + + sg.tensors = [tensor] + sg.inputs = [0] + sg.outputs = [0] + sg.operators = [] + model_t.subgraphs = [sg] + + fb = self._build_model_with_schema(model_t) + model = model_editor.read(fb) + + t = model.subgraphs[0].tensors[0] + self.assertEqual(t.dtype, tflite.TensorType.INT64) + np.testing.assert_array_equal(t.array, int64_data) + + +class TestFieldPreservation(unittest.TestCase): + """Test that schema fields are preserved during read-modify-write. + + These tests verify that fields not explicitly handled by model_editor + are still preserved when reading a model, modifying it, and writing + it back. This catches regressions where adding wrapper classes might + accidentally drop fields. + """ + + def _build_model_with_schema(self, model_t): + """Build a flatbuffer from a ModelT using the low-level schema.""" + import flatbuffers + builder = flatbuffers.Builder(1024) + builder.Finish(model_t.Pack(builder)) + return bytes(builder.Output()) + + def _create_base_model(self): + """Create a minimal valid model for testing.""" + model_t = tflite.ModelT() + model_t.version = 3 + model_t.description = b"test" + + buf0 = tflite.BufferT() + buf0.data = [] + buf1 = tflite.BufferT() + buf1.data = [1, 2, 3, 4] + model_t.buffers = [buf0, buf1] + + opcode = tflite.OperatorCodeT() + opcode.builtinCode = tflite.BuiltinOperator.ADD + model_t.operatorCodes = [opcode] + + sg = tflite.SubGraphT() + + t0 = tflite.TensorT() + t0.name = b"input" + t0.type = tflite.TensorType.INT8 + t0.buffer = 1 + t0.shape = [4] + + t1 = tflite.TensorT() + t1.name = b"output" + t1.type = tflite.TensorType.INT8 + t1.buffer = 0 + t1.shape = [4] + + sg.tensors = [t0, t1] + sg.inputs = [0] + sg.outputs = [1] + + op = tflite.OperatorT() + op.opcodeIndex = 0 + op.inputs = [0] + op.outputs = [1] + sg.operators = [op] + + model_t.subgraphs = [sg] + return model_t + + def test_tensor_is_variable_preserved(self): + """Verify Tensor.isVariable is preserved through read-modify-write.""" + model_t = self._create_base_model() + model_t.subgraphs[0].tensors[0].isVariable = True + + fb = self._build_model_with_schema(model_t) + + # Read, modify, write + model = model_editor.read(fb) + model.description = "modified" + fb2 = model.build() + + # Verify field preserved + model_t2 = tflite.ModelT.InitFromPackedBuf(fb2, 0) + self.assertTrue(model_t2.subgraphs[0].tensors[0].isVariable) + + def test_tensor_shape_signature_preserved(self): + """Verify Tensor.shapeSignature is preserved through read-modify-write.""" + model_t = self._create_base_model() + model_t.subgraphs[0].tensors[0].shapeSignature = [-1, 4] + + fb = self._build_model_with_schema(model_t) + + model = model_editor.read(fb) + model.description = "modified" + fb2 = model.build() + + model_t2 = tflite.ModelT.InitFromPackedBuf(fb2, 0) + self.assertEqual(list(model_t2.subgraphs[0].tensors[0].shapeSignature), + [-1, 4]) + + def test_operator_builtin_options_preserved(self): + """Verify Operator.builtinOptions is preserved through read-modify-write.""" + model_t = self._create_base_model() + + # Use ADD operator with AddOptions (must also set builtinOptionsType for union) + add_options = tflite.AddOptionsT() + add_options.fusedActivationFunction = tflite.ActivationFunctionType.RELU + model_t.subgraphs[0].operators[0].builtinOptions = add_options + model_t.subgraphs[0].operators[ + 0].builtinOptionsType = tflite.BuiltinOptions.AddOptions + + fb = self._build_model_with_schema(model_t) + + model = model_editor.read(fb) + model.description = "modified" + fb2 = model.build() + + model_t2 = tflite.ModelT.InitFromPackedBuf(fb2, 0) + self.assertIsNotNone(model_t2.subgraphs[0].operators[0].builtinOptions) + self.assertEqual( + model_t2.subgraphs[0].operators[0].builtinOptions. + fusedActivationFunction, tflite.ActivationFunctionType.RELU) + + def test_operator_custom_options_preserved(self): + """Verify Operator.customOptions is preserved through read-modify-write.""" + model_t = self._create_base_model() + model_t.subgraphs[0].operators[0].customOptions = [0xDE, 0xAD, 0xBE, 0xEF] + + fb = self._build_model_with_schema(model_t) + + model = model_editor.read(fb) + model.description = "modified" + fb2 = model.build() + + model_t2 = tflite.ModelT.InitFromPackedBuf(fb2, 0) + self.assertEqual(list(model_t2.subgraphs[0].operators[0].customOptions), + [0xDE, 0xAD, 0xBE, 0xEF]) + + def test_operator_intermediates_preserved(self): + """Verify Operator.intermediates is preserved through read-modify-write.""" + model_t = self._create_base_model() + model_t.subgraphs[0].operators[0].intermediates = [0, 1] + + fb = self._build_model_with_schema(model_t) + + model = model_editor.read(fb) + model.description = "modified" + fb2 = model.build() + + model_t2 = tflite.ModelT.InitFromPackedBuf(fb2, 0) + self.assertEqual(list(model_t2.subgraphs[0].operators[0].intermediates), + [0, 1]) + + def test_operator_debug_metadata_index_preserved(self): + """Verify Operator.debugMetadataIndex is preserved through read-modify-write.""" + model_t = self._create_base_model() + model_t.subgraphs[0].operators[0].debugMetadataIndex = 7 + + fb = self._build_model_with_schema(model_t) + + model = model_editor.read(fb) + model.description = "modified" + fb2 = model.build() + + model_t2 = tflite.ModelT.InitFromPackedBuf(fb2, 0) + self.assertEqual(model_t2.subgraphs[0].operators[0].debugMetadataIndex, 7) + + def test_operator_code_deprecated_builtin_code_preserved(self): + """Verify OperatorCode.deprecatedBuiltinCode is preserved.""" + model_t = self._create_base_model() + # Set deprecated code to a value different from the new builtin code + model_t.operatorCodes[0].deprecatedBuiltinCode = 42 + + fb = self._build_model_with_schema(model_t) + + model = model_editor.read(fb) + model.description = "modified" + fb2 = model.build() + + model_t2 = tflite.ModelT.InitFromPackedBuf(fb2, 0) + self.assertEqual(model_t2.operatorCodes[0].deprecatedBuiltinCode, 42) + + def test_subgraph_debug_metadata_index_preserved(self): + """Verify SubGraph.debugMetadataIndex is preserved.""" + model_t = self._create_base_model() + model_t.subgraphs[0].debugMetadataIndex = 5 + + fb = self._build_model_with_schema(model_t) + + model = model_editor.read(fb) + model.description = "modified" + fb2 = model.build() + + model_t2 = tflite.ModelT.InitFromPackedBuf(fb2, 0) + self.assertEqual(model_t2.subgraphs[0].debugMetadataIndex, 5) + + def test_model_version_preserved(self): + """Verify Model.version is preserved (not hardcoded to 3).""" + model_t = self._create_base_model() + model_t.version = 42 + + fb = self._build_model_with_schema(model_t) + + model = model_editor.read(fb) + model.description = "modified" + fb2 = model.build() + + model_t2 = tflite.ModelT.InitFromPackedBuf(fb2, 0) + self.assertEqual(model_t2.version, 42) + + def test_model_signature_defs_preserved(self): + """Verify Model.signatureDefs is preserved.""" + model_t = self._create_base_model() + + sig_def = tflite.SignatureDefT() + sig_def.signatureKey = b"serving_default" + sig_def.subgraphIndex = 0 + model_t.signatureDefs = [sig_def] + + fb = self._build_model_with_schema(model_t) + + model = model_editor.read(fb) + model.description = "modified" + fb2 = model.build() + + model_t2 = tflite.ModelT.InitFromPackedBuf(fb2, 0) + self.assertIsNotNone(model_t2.signatureDefs) + self.assertEqual(len(model_t2.signatureDefs), 1) + self.assertEqual(model_t2.signatureDefs[0].signatureKey, + b"serving_default") + + def test_quantization_min_max_preserved(self): + """Verify QuantizationParameters.min/max are preserved.""" + model_t = self._create_base_model() + + quant = tflite.QuantizationParametersT() + quant.scale = [0.5] + quant.zeroPoint = [128] + quant.min = [0.0] + quant.max = [1.0] + model_t.subgraphs[0].tensors[0].quantization = quant + + fb = self._build_model_with_schema(model_t) + + model = model_editor.read(fb) + model.description = "modified" + fb2 = model.build() + + model_t2 = tflite.ModelT.InitFromPackedBuf(fb2, 0) + quant2 = model_t2.subgraphs[0].tensors[0].quantization + self.assertIsNotNone(quant2) + self.assertEqual(list(quant2.min), [0.0]) + self.assertEqual(list(quant2.max), [1.0]) + + def test_tensor_sparsity_preserved(self): + """Verify Tensor.sparsity is preserved through read-modify-write.""" + model_t = self._create_base_model() + + sparsity = tflite.SparsityParametersT() + sparsity.traversalOrder = [0, 1] + sparsity.blockMap = [0] + model_t.subgraphs[0].tensors[0].sparsity = sparsity + + fb = self._build_model_with_schema(model_t) + + model = model_editor.read(fb) + model.description = "modified" + fb2 = model.build() + + model_t2 = tflite.ModelT.InitFromPackedBuf(fb2, 0) + sparsity2 = model_t2.subgraphs[0].tensors[0].sparsity + self.assertIsNotNone(sparsity2) + self.assertEqual(list(sparsity2.traversalOrder), [0, 1]) + self.assertEqual(list(sparsity2.blockMap), [0]) + + def test_tensor_has_rank_preserved(self): + """Verify Tensor.hasRank is preserved through read-modify-write.""" + model_t = self._create_base_model() + model_t.subgraphs[0].tensors[0].hasRank = True + + fb = self._build_model_with_schema(model_t) + + model = model_editor.read(fb) + model.description = "modified" + fb2 = model.build() + + model_t2 = tflite.ModelT.InitFromPackedBuf(fb2, 0) + self.assertTrue(model_t2.subgraphs[0].tensors[0].hasRank) + + def test_operator_builtin_options_2_preserved(self): + """Verify Operator.builtinOptions2 is preserved through read-modify-write.""" + model_t = self._create_base_model() + + options2 = tflite.StablehloConcatenateOptionsT() + options2.dimension = 42 + model_t.subgraphs[0].operators[0].builtinOptions2 = options2 + model_t.subgraphs[0].operators[0].builtinOptions2Type = ( + tflite.BuiltinOptions2.StablehloConcatenateOptions) + + fb = self._build_model_with_schema(model_t) + + model = model_editor.read(fb) + model.description = "modified" + fb2 = model.build() + + model_t2 = tflite.ModelT.InitFromPackedBuf(fb2, 0) + options2_out = model_t2.subgraphs[0].operators[0].builtinOptions2 + self.assertIsNotNone(options2_out) + self.assertEqual(options2_out.dimension, 42) + + def test_quantization_axis_preserved(self): + """Verify QuantizationParameters.quantizedDimension is preserved.""" + model_t = self._create_base_model() + + quant = tflite.QuantizationParametersT() + quant.scale = [0.5, 0.25] + quant.zeroPoint = [0, 0] + quant.quantizedDimension = 1 + model_t.subgraphs[0].tensors[0].quantization = quant + + fb = self._build_model_with_schema(model_t) + + model = model_editor.read(fb) + model.description = "modified" + fb2 = model.build() + + model_t2 = tflite.ModelT.InitFromPackedBuf(fb2, 0) + quant2 = model_t2.subgraphs[0].tensors[0].quantization + self.assertIsNotNone(quant2) + self.assertEqual(quant2.quantizedDimension, 1) + + def test_quantization_zero_point_preserved(self): + """Verify QuantizationParameters.zeroPoint is preserved.""" + model_t = self._create_base_model() + + quant = tflite.QuantizationParametersT() + quant.scale = [0.5] + quant.zeroPoint = [128] + model_t.subgraphs[0].tensors[0].quantization = quant + + fb = self._build_model_with_schema(model_t) + + model = model_editor.read(fb) + model.description = "modified" + fb2 = model.build() + + model_t2 = tflite.ModelT.InitFromPackedBuf(fb2, 0) + quant2 = model_t2.subgraphs[0].tensors[0].quantization + self.assertIsNotNone(quant2) + self.assertEqual(list(quant2.zeroPoint), [128]) + + def test_quantization_zero_point_not_expanded(self): + """Single zeroPoint with multiple scales is preserved as-is. + + TFLite converter optimizes by storing single zeroPoint when all channels + have the same zero point. This must be preserved, not expanded. + """ + model_t = self._create_base_model() + + quant = tflite.QuantizationParametersT() + quant.scale = [0.5, 0.25, 0.125, 0.0625] # 4 scales + quant.zeroPoint = [128] # Single zero point (converter optimization) + quant.quantizedDimension = 0 + model_t.subgraphs[0].tensors[0].quantization = quant + + fb = self._build_model_with_schema(model_t) + + model = model_editor.read(fb) + model.description = "modified" + fb2 = model.build() + + model_t2 = tflite.ModelT.InitFromPackedBuf(fb2, 0) + quant2 = model_t2.subgraphs[0].tensors[0].quantization + self.assertIsNotNone(quant2) + # Should still be single-element, not expanded to 4 + self.assertEqual(len(quant2.zeroPoint), 1) + self.assertEqual(quant2.zeroPoint[0], 128) + + +if __name__ == "__main__": + unittest.main() diff --git a/tensorflow/lite/micro/compression/model_facade.py b/tensorflow/lite/micro/compression/model_facade.py deleted file mode 100644 index 2e58d8080f1..00000000000 --- a/tensorflow/lite/micro/compression/model_facade.py +++ /dev/null @@ -1,276 +0,0 @@ -# Copyright 2024 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -"""A facade for working with tflite.Model. - -This module provides convenient navigation, data type conversions, and -utilities for working with a tflite.Model, which can be tedious and verbose to -work with directly. - -Usage: - model = model_facade.read(flatbuffer) - # manipulate - new_flatbuffer = model.compile() -""" - -from __future__ import annotations - -import flatbuffers -import numpy as np -from numpy.typing import NDArray -from tflite_micro.tensorflow.lite.python import schema_py_generated as tflite -from typing import ByteString, Generic, TypeVar - -_IteratorTo = TypeVar("_IteratorTo") - - -class _Iterator(Generic[_IteratorTo]): - - def __init__(self, sequence, cls, parent): - self._sequence = sequence - self._cls = cls - self._index = 0 - self._parent = parent - - def __getitem__(self, key) -> _IteratorTo: - return self._cls(self._sequence[key], key, self._parent) - - def __len__(self): - return len(self._sequence) - - def __iter__(self): - self._index = 0 - return self - - def __next__(self): - try: - result = self[self._index] - self._index += 1 - return result - except IndexError: - raise StopIteration - - -class _IndirectIterator(Generic[_IteratorTo]): - - def __init__(self, indices, sequence): - self._indices = indices - self._index = 0 - self._sequence = sequence - - def __getitem__(self, key) -> _IteratorTo: - index = self._indices[key] - return self._sequence[index] - - def __len__(self): - return len(self._indices) - - def __iter__(self): - self._index = 0 - return self - - def __next__(self): - try: - result = self[self._index] - self._index += 1 - return result - except IndexError: - raise StopIteration - - -class _Operator: - - def __init__(self, operator, index, subgraph): - self.operator = operator - self.index = index - self.subgraph = subgraph - - @property - def opcode(self) -> tflite.OperatorCodeT: - return self.subgraph.model.operatorCodes[self.operator.opcodeIndex] - - @property - def inputs(self): - return _IndirectIterator(self.operator.inputs, self.subgraph.tensors) - - -_NP_DTYPES = { - tflite.TensorType.FLOAT16: np.dtype(" _Buffer: - return self.subgraph.model.buffers[self._tensor_t.buffer] - - @property - def data(self) -> bytes: - return self.buffer.data - - @property - def dtype(self) -> np.dtype: - return _NP_DTYPES[self._tensor_t.type] - - @property - def array(self) -> np.ndarray: - """Returns an array created from the Tensor's data, type, and shape. - - Note the bytes in the data buffer and the Tensor's type and shape may be - inconsistent, and thus the returned array invalid, if the data buffer has - been altered according to the compression schema, in which the data buffer - is an array of fixed-width, integer fields. - """ - return np.frombuffer(self.data, - dtype=self.dtype).reshape(self._tensor_t.shape) - - @property - def quantization(self) -> tflite.QuantizationParametersT | None: - return self._tensor_t.quantization - - -class _Buffer: - - def __init__(self, buffer_t: tflite.BufferT, index, model): - self._buffer_t = buffer_t - self.index = index - self.model = model - - @property - def data(self) -> bytes: - return bytes(self._buffer_t.data) - - @data.setter - def data(self, value: ByteString): - self._buffer_t.data = list(value) - - def extend(self, values: NDArray): - self._buffer_t.data.extend(values.tobytes()) - - -class _Subgraph: - - def __init__(self, subgraph_t: tflite.SubGraphT, index: int, model: _Model): - self._subgraph_t = subgraph_t - self.index = index - self.model = model - - @property - def operators(self) -> _Iterator[_Operator]: - return _Iterator(self._subgraph_t.operators, _Operator, parent=self) - - @property - def tensors(self) -> _Iterator[_Tensor]: - return _Iterator(self._subgraph_t.tensors, _Tensor, parent=self) - - -class _Model: - """A facade for manipulating tflite.Model. - """ - - def __init__(self, model_t: tflite.ModelT): - self._model_t = model_t - - def compile(self) -> bytearray: - """Returns a tflite.Model flatbuffer. - """ - size_hint = 4 * 2**10 - builder = flatbuffers.Builder(size_hint) - builder.Finish(self._model_t.Pack(builder)) - return builder.Output() - - def add_buffer(self) -> _Buffer: - """Adds a buffer to the model. - """ - buffer = tflite.BufferT() - buffer.data = [] - self._model_t.buffers.append(buffer) - index = len(self._model_t.buffers) - 1 - return _Buffer(buffer, index, self._model_t) - - def add_metadata(self, key, value): - """Adds a key-value pair, writing value to a newly created buffer. - """ - metadata = tflite.MetadataT() - metadata.name = key - buffer = self.add_buffer() - buffer.data = value - metadata.buffer = buffer.index - self._model_t.metadata.append(metadata) - - @property - def metadata(self) -> dict[str, _Buffer]: - """Returns the model's metadata as a dictionary to Buffer objects. - """ - result = {} - for m in self._model_t.metadata: - name = m.name.decode("utf-8") # type: ignore (fb library is wrong) - buffer = _Buffer(self._model_t.buffers[m.buffer], m.buffer, - self._model_t) - result[name] = buffer - - return result - - @property - def operatorCodes(self): - return self._model_t.operatorCodes - - @property - def subgraphs(self) -> _Iterator[_Subgraph]: - return _Iterator(self._model_t.subgraphs, _Subgraph, parent=self) - - @property - def buffers(self) -> _Iterator[_Buffer]: - return _Iterator(self._model_t.buffers, _Buffer, parent=self) - - -def read(buffer: ByteString): - """Reads a tflite.Model and returns a model facade. - """ - schema_model = tflite.ModelT.InitFromPackedBuf(buffer, 0) - return _Model(schema_model) diff --git a/tensorflow/lite/micro/compression/model_facade_test.py b/tensorflow/lite/micro/compression/model_facade_test.py deleted file mode 100644 index 87e71fa968b..00000000000 --- a/tensorflow/lite/micro/compression/model_facade_test.py +++ /dev/null @@ -1,144 +0,0 @@ -# Copyright 2024 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np -import unittest -from tflite_micro.tensorflow.lite.python import schema_py_generated as tflite -from tflite_micro.tensorflow.lite.micro.compression import model_facade -from tflite_micro.tensorflow.lite.micro.compression import test_models - -TEST_MODEL = { - "operator_codes": { - 0: { - "builtin_code": tflite.BuiltinOperator.FULLY_CONNECTED, - }, - 1: { - "builtin_code": tflite.BuiltinOperator.ADD, - }, - }, - "metadata": { - 0: { - "name": "metadata0", - "buffer": 0 - }, - 1: { - "name": "metadata1", - "buffer": 0 - }, - }, - "subgraphs": { - 0: { - "operators": { - 0: { - "opcode_index": 1, # ADD - "inputs": ( - 1, - 2, - ), - "outputs": (3, ), - }, - 1: { - "opcode_index": 0, # FULLY_CONNECTED - "inputs": ( - 3, - 4, - 5, - ), - "outputs": (6, ), - }, - }, - "tensors": { - 0: { - "name": "tensor0", - "shape": (16, 1), - "type": tflite.TensorType.INT8, - "buffer": 1, - }, - 1: { - "name": "tensor1", - "shape": (8, 1), - "type": tflite.TensorType.INT16, - "buffer": 2, - }, - 2: { - "name": "tensor2", - "shape": (4, 1), - "type": tflite.TensorType.INT32, - "buffer": 3, - }, - 3: { - "name": "tensor3", - "shape": (2, 1), - "type": tflite.TensorType.INT64, - "buffer": 4, - }, - }, - }, - }, - "buffers": { - 0: None, - 1: np.array(range(16), dtype=np.dtype(" np.dtype: + """Convert TFLite dtype to numpy dtype.""" + type_map = { + tflite.TensorType.INT8: np.int8, + tflite.TensorType.INT16: np.int16, + tflite.TensorType.INT32: np.int32, + tflite.TensorType.INT64: np.int64, + tflite.TensorType.UINT8: np.uint8, + tflite.TensorType.UINT16: np.uint16, + tflite.TensorType.UINT32: np.uint32, + tflite.TensorType.FLOAT16: np.float16, + tflite.TensorType.FLOAT32: np.float32, + tflite.TensorType.FLOAT64: np.float64, + tflite.TensorType.BOOL: np.bool_, + } + return type_map.get(dtype, np.uint8) + + +class ProprietaryModelTest(unittest.TestCase): + """Integration tests using proprietary models.""" + + # Parsed from command line in main() + models_dir = None + + @classmethod + def setUpClass(cls): + if not cls.models_dir: + raise unittest.SkipTest( + "No models directory provided. " + "Usage: bazel test ... --test_arg=/path/to/models") + + cls.model_paths = sorted( + glob.glob(os.path.join(cls.models_dir, '*.tflite'))) + if not cls.model_paths: + raise unittest.SkipTest(f"No .tflite files found in {cls.models_dir}") + + def test_all_models(self): + """Run compression test on each discovered model.""" + for model_path in self.model_paths: + with self.subTest(model=os.path.basename(model_path)): + self._test_model_compression(model_path) + + def _test_model_compression(self, model_path): + """Test that a compressed model produces same outputs as original.""" + with open(model_path, 'rb') as f: + flatbuffer = f.read() + + # Load compression spec from sidecar file + specs = self._load_compression_spec(model_path) + + # Load tolerance config + rtol, atol = self._load_tolerance(model_path) + + # Compress the model + compressed_fb = compress.compress(flatbuffer, specs) + + # Create interpreters + original_interp = runtime.Interpreter.from_bytes(bytes(flatbuffer)) + compressed_interp = runtime.Interpreter.from_bytes(bytes(compressed_fb)) + + # Generate random inputs and compare outputs + np.random.seed(42) + model = model_editor.read(flatbuffer) + sg = model.subgraphs[0] + + for trial in range(5): + # Set inputs + for i, input_tensor in enumerate(sg.inputs): + test_input = self._generate_input(input_tensor) + original_interp.set_input(test_input, i) + compressed_interp.set_input(test_input, i) + + # Run inference + original_interp.invoke() + compressed_interp.invoke() + + # Compare outputs + for i in range(len(sg.outputs)): + expected = original_interp.get_output(i) + actual = compressed_interp.get_output(i) + self._compare_outputs(expected, actual, rtol, atol, + f"trial {trial}, output {i}") + + def _generate_input(self, tensor): + """Generate random input respecting tensor dtype.""" + shape = tensor.shape + dtype = _dtype_to_numpy(tensor.dtype) + + if np.issubdtype(dtype, np.floating): + return np.random.uniform(-1.0, 1.0, shape).astype(dtype) + elif np.issubdtype(dtype, np.integer): + info = np.iinfo(dtype) + return np.random.randint(info.min, info.max + 1, shape, dtype=dtype) + elif dtype == np.bool_: + return np.random.choice([False, True], shape) + return np.zeros(shape, dtype=dtype) + + def _load_compression_spec(self, model_path): + """Load compression spec from sidecar YAML file. + + Raises: + FileNotFoundError: If no spec file is found. + """ + spec_path = model_path.replace('.tflite', '.spec.yaml') + if os.path.exists(spec_path): + with open(spec_path) as f: + return spec.parse_yaml(f.read()) + + raise FileNotFoundError( + f"No compression spec file found for {model_path}. " + f"Expected: {spec_path}") + + def _load_tolerance(self, model_path): + """Load tolerance from sidecar config if present. + + Returns (0, 0) for exact match if no config file exists. + """ + config_path = model_path.replace('.tflite', '.config.json') + if os.path.exists(config_path): + with open(config_path) as f: + config = json.load(f) + return config.get('rtol', 0), config.get('atol', 0) + return 0, 0 + + def _compare_outputs(self, expected, actual, rtol, atol, context=""): + """Compare outputs with optional tolerance.""" + msg = f"Output mismatch ({context})" if context else "Output mismatch" + if rtol == 0 and atol == 0: + np.testing.assert_array_equal(expected, actual, err_msg=msg) + else: + np.testing.assert_allclose(expected, + actual, + rtol=rtol, + atol=atol, + err_msg=msg) + + +if __name__ == "__main__": + # Suppress TF C++ info/debug logs (0=DEBUG, 1=INFO, 2=WARNING, 3=ERROR) + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" + # Disable oneDNN to avoid non-deterministic floating point results + os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" + + # Parse models directory from args, then strip it so tf.test doesn't see it + for arg in sys.argv[1:]: + if not arg.startswith('-') and os.path.isdir(arg): + ProprietaryModelTest.models_dir = arg + sys.argv.remove(arg) + break + + unittest.main() diff --git a/tensorflow/lite/micro/compression/pruning.py b/tensorflow/lite/micro/compression/pruning.py new file mode 100644 index 00000000000..5c95e3e87e9 --- /dev/null +++ b/tensorflow/lite/micro/compression/pruning.py @@ -0,0 +1,59 @@ +# Copyright 2026 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Pruning compression plugin (stub). + +This module provides a placeholder for pruning (sparsity) compression. +The actual implementation is not yet available. + +Supported tensor types (when implemented): All TFLM tensor types +""" + +from tflite_micro.tensorflow.lite.micro.compression import compressor +from tflite_micro.tensorflow.lite.micro.compression import decode +from tflite_micro.tensorflow.lite.micro.compression import model_editor +from tflite_micro.tensorflow.lite.micro.compression import spec + + +class PruningCompressor(compressor.Compressor): + """Pruning compression plugin (stub). + + This stub exists to validate the plugin architecture. The actual pruning + compression algorithm for sparse tensors is not yet implemented. + """ + + @property + def decode_type(self) -> decode.DecodeType: + """Returns DecodeType.PRUNING.""" + return decode.DecodeType.PRUNING + + def compress( + self, + tensor: model_editor.Tensor, + method: spec.CompressionMethod, + ) -> compressor.CompressionResult: + """Compress a tensor using pruning (sparsity) encoding. + + Args: + tensor: The tensor to compress. + method: Must be a PruningCompression instance. + + Returns: + CompressionResult (not implemented). + + Raises: + CompressionError: Always, since this is a stub. + """ + raise compressor.CompressionError( + "Pruning compression not yet implemented. " + "This stub exists to validate the plugin architecture.") diff --git a/tensorflow/lite/micro/compression/spec.py b/tensorflow/lite/micro/compression/spec.py index 6f782e92d7a..5c0f81885bc 100644 --- a/tensorflow/lite/micro/compression/spec.py +++ b/tensorflow/lite/micro/compression/spec.py @@ -58,10 +58,32 @@ class Tensor: @dataclass class LookUpTableCompression(CompressionMethod): + """LUT compression using lookup tables. + Attributes: + index_bitwidth: Number of bits per index (1-7). + """ index_bitwidth: int +@dataclass +class HuffmanCompression(CompressionMethod): + """Huffman compression using Xtensa-format decode tables. + + Supported tensor types: INT8, INT16 only. + """ + pass + + +@dataclass +class PruningCompression(CompressionMethod): + """Pruning (sparsity) compression. + + Supported tensor types: All TFLM tensor types. + """ + pass + + class ParseError(Exception): "Raised when the spec string cannot be parsed." @@ -70,6 +92,18 @@ def __init__(self, message="error parsing spec", wrapped_exception=None): self.original_exception = wrapped_exception +def _parse_compression_method(comp: dict) -> CompressionMethod: + """Parse a single compression method from YAML dict.""" + if "lut" in comp: + return LookUpTableCompression(index_bitwidth=comp["lut"]["index_bitwidth"]) + elif "huffman" in comp: + return HuffmanCompression() + elif "pruning" in comp: + return PruningCompression() + else: + raise ParseError(f"Unknown compression method: {list(comp.keys())}") + + def parse_yaml(y: str) -> list[Tensor]: "Parses a compression spec in a YAML string into its Python representation." try: @@ -77,14 +111,19 @@ def parse_yaml(y: str) -> list[Tensor]: tensors = [] for item in config["tensors"]: - bitwidth = item["compression"][0]["lut"]["index_bitwidth"] - tensor = Tensor(subgraph=item["subgraph"], - tensor=item["tensor"], - compression=[ - LookUpTableCompression(index_bitwidth=bitwidth), - ]) + methods = [] + for comp in item["compression"]: + methods.append(_parse_compression_method(comp)) + + tensor = Tensor( + subgraph=item["subgraph"], + tensor=item["tensor"], + compression=methods, + ) tensors.append(tensor) + except ParseError: + raise except Exception as e: raise ParseError() from e diff --git a/tensorflow/lite/micro/compression/test_models.py b/tensorflow/lite/micro/compression/test_models.py deleted file mode 100644 index 80286d17359..00000000000 --- a/tensorflow/lite/micro/compression/test_models.py +++ /dev/null @@ -1,190 +0,0 @@ -# Copyright 2024 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -"""Tools for constructing flatbuffers for testing. - -This module provides tools for constructing .tflite flatbuffers from a Python -dictionary representation of a model, a prototype of which can be found in -EXAMPLE_MODEL. - -Example usage: - model_definition = {...} # use EXAMPLE_MODEL as prototype - flatbuffer: bytearray = test_models.build(model_definition) -""" - -# This module must remain low-level and independent from any helpers in this -# project which make constructing model and flatbuffers easier, because this -# module is used to define tests for those helpers. - -import flatbuffers -import numpy as np -from tflite_micro.tensorflow.lite.python import schema_py_generated as tflite - -EXAMPLE_MODEL = { - "operator_codes": { - 0: { - "builtin_code": tflite.BuiltinOperator.FULLY_CONNECTED, - }, - 1: { - "builtin_code": tflite.BuiltinOperator.ADD, - }, - }, - "metadata": { - 0: { - "name": "metadata0", - "buffer": 0 - }, - }, - "subgraphs": { - 0: { - "operators": { - 0: { - "opcode_index": 1, - "inputs": ( - 0, - 1, - ), - "outputs": (3, ), - }, - 1: { - "opcode_index": 0, - "inputs": ( - 3, - 2, - ), - "outputs": (4, ), - }, - }, - "tensors": { - 0: { - "shape": (16, 1), - "type": tflite.TensorType.INT8, - "buffer": 1, - }, - 1: { - "shape": (16, 1), - "type": tflite.TensorType.INT8, - "buffer": 1, - }, - 2: { - "shape": (16, 1), - "type": tflite.TensorType.INT8, - "buffer": 1, - }, - 3: { - "shape": (16, 1), - "type": tflite.TensorType.INT8, - "buffer": 1, - "quantization": { - "quantized_dimension": 0, - }, - }, - }, - }, - }, - "buffers": { - 0: None, - 1: np.array(range(16), dtype=np.dtype(" bytearray: - """Builds a .tflite flatbuffer from a model definition. - - Args: - model_definition: A dictionary representation of the model, a prototype of - which can be found in the EXAMPLE_MODEL attribute of this module. - - Returns: - A tflite flatbuffer. - """ - root = tflite.ModelT() - description = model_definition.get("description") - if description is not None: - root.description = description - - root.operatorCodes = [] - for id, operator_code in model_definition["operator_codes"].items(): - assert id == len(root.operatorCodes) - opcode_t = tflite.OperatorCodeT() - root.operatorCodes.append(opcode_t) - opcode_t.builtinCode = operator_code["builtin_code"] - - root.metadata = [] - if "metadata" in model_definition: - for _, metadata in model_definition["metadata"].items(): - metadata_t = tflite.MetadataT() - metadata_t.name = metadata["name"] - metadata_t.buffer = metadata["buffer"] - root.metadata.append(metadata_t) - - root.subgraphs = [] - for id, subgraph in model_definition["subgraphs"].items(): - assert id == len(root.subgraphs) - subgraph_t = tflite.SubGraphT() - root.subgraphs.append(subgraph_t) - - subgraph_t.operators = [] - for id, operator in subgraph["operators"].items(): - assert id == len(subgraph_t.operators) - operator_t = tflite.OperatorT() - operator_t.opcodeIndex = operator["opcode_index"] - operator_t.inputs = operator["inputs"] - operator_t.outputs = operator["outputs"] - subgraph_t.operators.append(operator_t) - - subgraph_t.tensors = [] - for id, tensor in subgraph["tensors"].items(): - assert id == len(subgraph_t.tensors) - tensor_t = tflite.TensorT() - tensor_t.name = tensor.get("name", None) - tensor_t.shape = tensor["shape"] - tensor_t.type = tensor["type"] - tensor_t.buffer = tensor["buffer"] - - if "quantization" in tensor: - tensor_t.quantization = tflite.QuantizationParametersT() - tensor_t.quantization.quantizedDimension = \ - tensor["quantization"].get("quantized_dimension", None) - tensor_t.quantization.scale = \ - tensor["quantization"].get("scale", None) - tensor_t.quantization.zeroPoint = \ - tensor["quantization"].get("zero_point", None) - - subgraph_t.tensors.append(tensor_t) - - root.buffers = [] - for id, data in model_definition["buffers"].items(): - assert id == len(root.buffers) - buffer_t = tflite.BufferT() - - if data is None: - buffer_t.data = [] - elif isinstance(data, np.ndarray): - array = data.astype(data.dtype.newbyteorder("<")) # ensure little-endian - buffer_t.data = list(array.tobytes()) - else: - raise TypeError(f"buffer_id {id} must be None or an np.ndarray") - - root.buffers.append(buffer_t) - - size_hint = 1 * 2**20 - builder = flatbuffers.Builder(size_hint) - builder.Finish(root.Pack(builder)) - flatbuffer = builder.Output() - return flatbuffer diff --git a/tensorflow/lite/micro/compression/test_models_test.py b/tensorflow/lite/micro/compression/test_models_test.py deleted file mode 100644 index d7e961c2dd9..00000000000 --- a/tensorflow/lite/micro/compression/test_models_test.py +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright 2024 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest -from tflite_micro.tensorflow.lite.micro.compression import test_models -from tflite_micro.tensorflow.lite.python import schema_py_generated as tflite - - -class TestBuild(unittest.TestCase): - - def setUp(self): - self.flatbuffer = test_models.build(test_models.EXAMPLE_MODEL) - - def testNotDegenerate(self): - model = tflite.ModelT.InitFromPackedBuf(self.flatbuffer, 0) - self.assertEqual(model.operatorCodes[0].builtinCode, - tflite.BuiltinOperator.FULLY_CONNECTED) - - -if __name__ == "__main__": - unittest.main()