From b17bb64889feb386d4ef21a36168dbb79f01ef96 Mon Sep 17 00:00:00 2001 From: Ryan Kuester Date: Mon, 24 Nov 2025 22:10:41 -0600 Subject: [PATCH] feat(compression): implement model_editor for TFLite model manipulation Implement unified module for creating, reading, and modifying TFLite models with a clean API. The module eliminates manual index tracking and buffer management through automatic bookkeeping, supporting both declarative and imperative construction styles. Core design uses first-class Buffer objects that can be shared between tensors, with automatic deduplication during build. Tensors reference Buffers directly, matching the TFLite schema structure. The compiler automatically extracts inline tensor declarations, builds operator code tables, and handles index assignment according to TFLite conventions. Supports quantization parameters (per-tensor and per-channel), metadata key-value pairs, and read-modify-write workflows. The read() function preserves the object graph structure, enabling models to be read, modified, and rebuilt. Add comprehensive test coverage for core functionality, advanced features, quantization, and modification workflows. --- tensorflow/lite/micro/compression/BUILD | 22 + .../lite/micro/compression/model_editor.py | 557 +++++++++++++ .../micro/compression/model_editor_test.py | 735 ++++++++++++++++++ 3 files changed, 1314 insertions(+) create mode 100644 tensorflow/lite/micro/compression/model_editor.py create mode 100644 tensorflow/lite/micro/compression/model_editor_test.py diff --git a/tensorflow/lite/micro/compression/BUILD b/tensorflow/lite/micro/compression/BUILD index 69847cba6bd..7e9ca410b5c 100644 --- a/tensorflow/lite/micro/compression/BUILD +++ b/tensorflow/lite/micro/compression/BUILD @@ -237,6 +237,28 @@ py_test( ], ) +py_library( + name = "model_editor", + srcs = ["model_editor.py"], + deps = [ + "//tensorflow/lite/python:schema_py", + requirement("flatbuffers"), + requirement("numpy"), + ], +) + +py_test( + name = "model_editor_test", + size = "small", + srcs = ["model_editor_test.py"], + deps = [ + ":model_editor", + "//tensorflow/lite/python:schema_py", + requirement("numpy"), + requirement("tensorflow-cpu"), + ], +) + py_binary( name = "view", srcs = [ diff --git a/tensorflow/lite/micro/compression/model_editor.py b/tensorflow/lite/micro/compression/model_editor.py new file mode 100644 index 00000000000..541636b7e42 --- /dev/null +++ b/tensorflow/lite/micro/compression/model_editor.py @@ -0,0 +1,557 @@ +# Copyright 2024 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unified TFLite model manipulation module. + +Provides a clean API for creating, reading, and modifying TFLite models. +""" + +from dataclasses import dataclass, field +from typing import Optional, Union, List +import numpy as np +import flatbuffers +from tflite_micro.tensorflow.lite.python import schema_py_generated as tflite + + +class _BufferList(list): + """Custom list that auto-sets buffer.index on append. + + When a buffer is appended, automatically sets buffer.index to its position. + This enables append-only workflows to work seamlessly. + """ + + def append(self, buf): + """Append buffer and auto-set its index.""" + buf.index = len(self) + super().append(buf) + + +@dataclass +class Buffer: + """Buffer holding tensor data. + + The index field indicates the buffer's position in the model's buffer array. + It is automatically populated during: + - read(): Set from flatbuffer + - build(): Set during compilation + - model.buffers.append(): Auto-set to len(model.buffers) - 1 + + The index may become stale after: + - Deleting buffers from model.buffers + - Reordering buffers in model.buffers + + For append-only workflows (the common case), buffer.index can be trusted. + """ + data: bytes + index: Optional[int] = None + + def __len__(self): + return len(self.data) + + def __bytes__(self): + return self.data + + +@dataclass +class Quantization: + """Quantization parameters helper.""" + scales: Union[float, List[float]] + zero_points: Union[int, List[int]] = 0 + axis: Optional[int] = None + + def to_tflite(self) -> tflite.QuantizationParametersT: + """Convert to TFLite schema object.""" + q = tflite.QuantizationParametersT() + + # Normalize to lists + scales = [self.scales] if isinstance(self.scales, + (int, float)) else self.scales + zeros = [self.zero_points] if isinstance(self.zero_points, + int) else self.zero_points + + q.scale = scales + q.zeroPoint = zeros + if self.axis is not None: + q.quantizedDimension = self.axis + + return q + + +@dataclass +class Tensor: + """Declarative tensor specification. + + Supports both buffer= and data= parameters for flexibility: + - buffer=: Explicitly provide a Buffer object (can be shared between tensors) + - data=: Convenience parameter that auto-creates a Buffer + + Cannot specify both buffer and data at initialization. + """ + shape: tuple + dtype: tflite.TensorType + buffer: Optional[Buffer] = None + quantization: Optional[Quantization] = None + name: Optional[str] = None + + # Internal field for data initialization only + _data_init: Optional[Union[bytes, np.ndarray]] = field(default=None, + init=False, + repr=False) + + # Auto-populated during build/read + _index: Optional[int] = field(default=None, init=False, repr=False) + + def __init__(self, + shape, + dtype, + buffer=None, + data=None, + quantization=None, + name=None): + """Initialize Tensor. + + Args: + shape: Tensor shape as tuple + dtype: TensorType enum value + buffer: Optional Buffer object (for explicit buffer sharing) + data: Optional numpy array or bytes (convenience parameter, creates Buffer) + quantization: Optional Quantization object + name: Optional tensor name + + Raises: + ValueError: If both buffer and data are specified + """ + if data is not None and buffer is not None: + raise ValueError("Cannot specify both data and buffer") + + self.shape = shape + self.dtype = dtype + self.buffer = buffer + self.quantization = quantization + self.name = name + self._index = None + + # Convert data to buffer if provided + if data is not None: + buf_data = data if isinstance(data, bytes) else data.tobytes() + self.buffer = Buffer(data=buf_data) + + @property + def array(self) -> Optional[np.ndarray]: + """Get tensor data as properly-shaped numpy array. + + Returns: + numpy array with shape matching tensor.shape and dtype matching + tensor.dtype, or None if tensor has no data. + + For low-level byte access, use tensor.buffer.data instead. + """ + if self.buffer is None: + return None + return np.frombuffer(self.buffer.data, + dtype=_dtype_to_numpy(self.dtype)).reshape(self.shape) + + @array.setter + def array(self, value: np.ndarray): + """Set tensor data from numpy array. + + Args: + value: New tensor data as numpy array. Will be converted to bytes + using tobytes() and stored in the buffer. + + Creates a new Buffer if tensor has no buffer, or updates the existing + buffer's data in place. + + For low-level byte access, use tensor.buffer.data instead. + """ + buf_data = value.tobytes() + if self.buffer is None: + self.buffer = Buffer(data=buf_data) + else: + self.buffer.data = buf_data + + @property + def index(self) -> Optional[int]: + """Tensor index in the subgraph's tensor list. + + Returns index after read() or build(). May be None or stale after + modifications. Use with caution. + """ + return self._index + + @property + def numpy_dtype(self) -> np.dtype: + """Get numpy dtype corresponding to tensor's TFLite dtype. + + Returns: + numpy dtype object for use with np.frombuffer, np.array, etc. + """ + return _dtype_to_numpy(self.dtype) + + +@dataclass +class OperatorCode: + """Operator code specification.""" + builtin_code: tflite.BuiltinOperator + custom_code: Optional[str] = None + version: int = 1 + + +@dataclass +class Operator: + """Declarative operator specification.""" + opcode: Union[tflite.BuiltinOperator, int] + inputs: List[Tensor] + outputs: List[Tensor] + custom_code: Optional[str] = None + + # Set when reading from existing model + opcode_index: Optional[int] = None + + _index: Optional[int] = field(default=None, init=False, repr=False) + + +@dataclass +class Subgraph: + """Declarative subgraph specification with imperative methods.""" + tensors: List[Tensor] = field(default_factory=list) + operators: List[Operator] = field(default_factory=list) + name: Optional[str] = None + + _index: Optional[int] = field(default=None, init=False, repr=False) + + def add_tensor(self, **kwargs) -> Tensor: + """Add tensor imperatively and return it.""" + t = Tensor(**kwargs) + t._index = len(self.tensors) + self.tensors.append(t) + return t + + def add_operator(self, **kwargs) -> Operator: + """Add operator imperatively and return it.""" + op = Operator(**kwargs) + op._index = len(self.operators) + self.operators.append(op) + return op + + @property + def index(self) -> Optional[int]: + """Subgraph index in the model's subgraph list. + + Returns index after read() or build(). May be None or stale after + modifications. Use with caution. + """ + return self._index + + +@dataclass +class Model: + """Top-level model specification.""" + subgraphs: List[Subgraph] = field(default_factory=list) + buffers: _BufferList = field( + default_factory=_BufferList) # Auto-sets buffer.index on append + operator_codes: List[OperatorCode] = field(default_factory=list) + metadata: dict = field(default_factory=dict) + description: Optional[str] = None + + def add_subgraph(self, **kwargs) -> Subgraph: + """Add subgraph imperatively and return it.""" + sg = Subgraph(**kwargs) + sg._index = len(self.subgraphs) + self.subgraphs.append(sg) + return sg + + def build(self) -> bytearray: + """Compile to flatbuffer with automatic bookkeeping.""" + compiler = _ModelCompiler(self) + return compiler.compile() + + +def read(buffer: bytes) -> Model: + """Read a TFLite flatbuffer and return a Model object.""" + fb_model = tflite.ModelT.InitFromPackedBuf(buffer, 0) + + # Create Model with basic fields + # Decode bytes to strings where needed + description = fb_model.description + if isinstance(description, bytes): + description = description.decode('utf-8') + + model = Model(description=description) + + # Create all buffers first (so tensors can reference them) + for i, fb_buf in enumerate(fb_model.buffers): + buf_data = bytes(fb_buf.data) if fb_buf.data is not None else b'' + buf = Buffer(data=buf_data, index=i) + model.buffers.append(buf) + + # Read operator codes + for fb_opcode in fb_model.operatorCodes: + custom_code = fb_opcode.customCode + if isinstance(custom_code, bytes): + custom_code = custom_code.decode('utf-8') + + opcode = OperatorCode( + builtin_code=fb_opcode.builtinCode, + custom_code=custom_code, + version=fb_opcode.version if fb_opcode.version else 1) + model.operator_codes.append(opcode) + + # Read subgraphs + for sg_idx, fb_sg in enumerate(fb_model.subgraphs): + sg = Subgraph() + sg._index = sg_idx + + # Read tensors + for tensor_idx, fb_tensor in enumerate(fb_sg.tensors): + # Decode tensor name + name = fb_tensor.name + if isinstance(name, bytes): + name = name.decode('utf-8') + + # Create tensor referencing the appropriate buffer + # Buffer 0 is the empty buffer (TFLite convention), so treat it as None + buf = None if fb_tensor.buffer == 0 else model.buffers[fb_tensor.buffer] + + # Read quantization parameters if present + quant = None + if fb_tensor.quantization: + fb_quant = fb_tensor.quantization + if fb_quant.scale is not None and len(fb_quant.scale) > 0: + # Quantization parameters present + scales = list(fb_quant.scale) + zeros = list( + fb_quant.zeroPoint + ) if fb_quant.zeroPoint is not None else [0] * len(scales) + + # Handle axis: only set if per-channel (more than one scale) + axis = None + if len(scales) > 1 and fb_quant.quantizedDimension is not None: + axis = fb_quant.quantizedDimension + + quant = Quantization(scales=scales, zero_points=zeros, axis=axis) + + tensor = Tensor(shape=tuple(fb_tensor.shape), + dtype=fb_tensor.type, + buffer=buf, + name=name, + quantization=quant) + tensor._index = tensor_idx + + sg.tensors.append(tensor) + + # Read operators + for fb_op in fb_sg.operators: + # Get operator code info + opcode_obj = model.operator_codes[fb_op.opcodeIndex] + + op = Operator(opcode=opcode_obj.builtin_code, + inputs=[sg.tensors[i] for i in fb_op.inputs], + outputs=[sg.tensors[i] for i in fb_op.outputs], + custom_code=opcode_obj.custom_code, + opcode_index=fb_op.opcodeIndex) + sg.operators.append(op) + + model.subgraphs.append(sg) + + # Read metadata + if fb_model.metadata: + for entry in fb_model.metadata: + # Decode metadata name + name = entry.name + if isinstance(name, bytes): + name = name.decode('utf-8') + + # Get metadata value from buffer + buffer = fb_model.buffers[entry.buffer] + value = bytes(buffer.data) if buffer.data is not None else b'' + + model.metadata[name] = value + + return model + + +def _dtype_to_numpy(dtype: tflite.TensorType) -> np.dtype: + """Convert TFLite dtype to numpy dtype.""" + type_map = { + tflite.TensorType.INT8: np.int8, + tflite.TensorType.INT16: np.int16, + tflite.TensorType.INT32: np.int32, + tflite.TensorType.UINT8: np.uint8, + tflite.TensorType.FLOAT32: np.float32, + } + return type_map.get(dtype, np.uint8) + + +class _ModelCompiler: + """Internal: compiles Model to flatbuffer with automatic bookkeeping.""" + + def __init__(self, model: Model): + self.model = model + self._buffers = [] + self._buffer_map = {} # Map Buffer object id to index + self._operator_codes = {} + + def compile(self) -> bytearray: + """Compile model to flatbuffer.""" + root = tflite.ModelT() + root.version = 3 + + # Set description + root.description = self.model.description + + # Initialize buffers + # If model.buffers exists (from read()), preserve those buffers + if self.model.buffers: + for buf in self.model.buffers: + fb_buf = tflite.BufferT() + fb_buf.data = list(buf.data) if buf.data else [] + self._buffers.append(fb_buf) + self._buffer_map[id(buf)] = buf.index + else: + # Creating model from scratch: initialize buffer 0 as empty (TFLite convention) + empty_buffer = tflite.BufferT() + empty_buffer.data = [] + self._buffers = [empty_buffer] + # Note: buffer 0 should not be in _buffer_map since tensors without data use it + + # Auto-collect and register operator codes + self._collect_operator_codes() + root.operatorCodes = list(self._operator_codes.values()) + + # Process subgraphs + root.subgraphs = [] + for sg in self.model.subgraphs: + root.subgraphs.append(self._compile_subgraph(sg)) + + # Process buffers + root.buffers = self._buffers + + # Process metadata + root.metadata = self._compile_metadata() + + # Pack and return + builder = flatbuffers.Builder(4 * 2**20) + builder.Finish(root.Pack(builder)) + return builder.Output() + + def _collect_operator_codes(self): + """Scan all operators and build operator code table.""" + for sg in self.model.subgraphs: + for op in sg.operators: + key = (op.opcode, op.custom_code) + if key not in self._operator_codes: + opcode = tflite.OperatorCodeT() + opcode.builtinCode = op.opcode + if op.custom_code: + opcode.customCode = op.custom_code + self._operator_codes[key] = opcode + + def _compile_subgraph(self, sg: Subgraph) -> tflite.SubGraphT: + """Compile subgraph, extracting inline tensors from operators.""" + sg_t = tflite.SubGraphT() + sg_t.name = sg.name + + # Collect all tensors (from tensor list and inline in operators) + all_tensors = list(sg.tensors) + tensor_to_index = {} + for i, t in enumerate(all_tensors): + t._index = i + tensor_to_index[id(t)] = i + + # Extract inline tensors from operators + for op in sg.operators: + for tensor in op.inputs + op.outputs: + if id(tensor) not in tensor_to_index: + tensor._index = len(all_tensors) + tensor_to_index[id(tensor)] = tensor._index + all_tensors.append(tensor) + + # Compile all tensors + sg_t.tensors = [] + for tensor in all_tensors: + sg_t.tensors.append(self._compile_tensor(tensor)) + + # Compile operators + sg_t.operators = [] + for op in sg.operators: + sg_t.operators.append(self._compile_operator(op, tensor_to_index)) + + return sg_t + + def _compile_operator(self, op: Operator, + tensor_to_index: dict) -> tflite.OperatorT: + """Compile operator, resolving tensor references and opcodes.""" + op_t = tflite.OperatorT() + + # Get opcode index + key = (op.opcode, op.custom_code) + opcode_index = list(self._operator_codes.keys()).index(key) + op_t.opcodeIndex = opcode_index + + # Resolve tensor references to indices + op_t.inputs = [tensor_to_index[id(inp)] for inp in op.inputs] + op_t.outputs = [tensor_to_index[id(outp)] for outp in op.outputs] + + return op_t + + def _compile_tensor(self, tensor: Tensor) -> tflite.TensorT: + """Compile tensor, reusing or creating buffer as needed.""" + t = tflite.TensorT() + t.shape = list(tensor.shape) + t.type = tensor.dtype + t.name = tensor.name + + # Handle buffer assignment + if tensor.buffer is None: + # No data: use buffer 0 + t.buffer = 0 + else: + # Has buffer: get or create index for it + buf_id = id(tensor.buffer) + if buf_id not in self._buffer_map: + # First time seeing this buffer, add it + fb_buf = tflite.BufferT() + fb_buf.data = list(tensor.buffer.data) + self._buffers.append(fb_buf) + buf_index = len(self._buffers) - 1 + self._buffer_map[buf_id] = buf_index + tensor.buffer.index = buf_index + t.buffer = self._buffer_map[buf_id] + + # Handle quantization + if tensor.quantization: + t.quantization = tensor.quantization.to_tflite() + + return t + + def _compile_metadata(self): + """Compile metadata, creating buffers for metadata values.""" + if not self.model.metadata: + return [] + + metadata_entries = [] + for name, value in self.model.metadata.items(): + # Create buffer for metadata value + buf = tflite.BufferT() + buf.data = list(value) if isinstance(value, bytes) else list(value) + self._buffers.append(buf) + buf_index = len(self._buffers) - 1 + + # Create metadata entry + entry = tflite.MetadataT() + entry.name = name + entry.buffer = buf_index + metadata_entries.append(entry) + + return metadata_entries diff --git a/tensorflow/lite/micro/compression/model_editor_test.py b/tensorflow/lite/micro/compression/model_editor_test.py new file mode 100644 index 00000000000..a6c5de56629 --- /dev/null +++ b/tensorflow/lite/micro/compression/model_editor_test.py @@ -0,0 +1,735 @@ +# Copyright 2025 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for model_editor module. +""" + +import numpy as np +import tensorflow as tf +from tflite_micro.tensorflow.lite.python import schema_py_generated as tflite +from tflite_micro.tensorflow.lite.micro.compression import model_editor +from tflite_micro.tensorflow.lite.micro.compression.model_editor import ( + Buffer, Model, Operator, OperatorCode, Quantization, Subgraph, Tensor) + + +class TestBasicModel(tf.test.TestCase): + """Test basic model with tensors and operators.""" + + @classmethod + def setUpClass(cls): + """Build model once for all tests in this class.""" + cls.input_data = np.array([[1, 2, 3, 4, 5]], dtype=np.int8) + cls.weights_data = np.array([[1], [2], [3], [4], [5]], dtype=np.int8) + + cls.model = Model( + description="Test model", + subgraphs=[ + Subgraph(operators=[ + Operator(opcode=tflite.BuiltinOperator.FULLY_CONNECTED, + inputs=[ + Tensor(shape=(1, 5), + dtype=tflite.TensorType.INT8, + data=cls.input_data, + name="input"), + Tensor(shape=(5, 1), + dtype=tflite.TensorType.INT8, + data=cls.weights_data, + name="weights") + ], + outputs=[ + Tensor(shape=(1, 1), + dtype=tflite.TensorType.INT8, + name="output") + ]) + ]) + ]) + + # Build the model to a flatbuffer byte array. This exercises the + # model_editor's build path, which converts the high-level Model API + # representation into the binary TFLite format. + fb = cls.model.build() + + # Read the flatbuffer back through model_editor.read() to create a + # loopback model. This exercises the read path, which parses the + # flatbuffer and reconstructs a high-level Model representation. The + # loopback model should be semantically equivalent to cls.model, + # demonstrating that build() and read() are inverse operations. + cls.loopback_model = model_editor.read(fb) + + # Parse the same flatbuffer using the low-level TFLite schema interface + # (ModelT from schema_py_generated). This provides direct access to the + # raw flatbuffer structure, allowing us to verify that model_editor + # encodes data correctly at the binary level. We compare fb_model + # (low-level) against loopback_model (high-level) to ensure both + # representations are consistent. + cls.fb_model = tflite.ModelT.InitFromPackedBuf(fb, 0) + + def test_description(self): + """Verify model description is preserved through loopback.""" + self.assertEqual(self.fb_model.description, b"Test model") + self.assertEqual(self.loopback_model.description, "Test model") + + def test_counts(self): + """Verify subgraph, tensor, and operator counts.""" + self.assertEqual(len(self.fb_model.subgraphs), 1) + self.assertEqual(len(self.loopback_model.subgraphs), 1) + + fb_sg = self.fb_model.subgraphs[0] + loopback_sg = self.loopback_model.subgraphs[0] + + self.assertEqual(len(fb_sg.tensors), 3) + self.assertEqual(len(loopback_sg.tensors), 3) + + self.assertEqual(len(fb_sg.operators), 1) + self.assertEqual(len(loopback_sg.operators), 1) + + def test_tensor_names(self): + """Verify tensor names are preserved.""" + fb_sg = self.fb_model.subgraphs[0] + loopback_sg = self.loopback_model.subgraphs[0] + + # Check that all expected tensor names are present + fb_names = {t.name for t in fb_sg.tensors} + self.assertEqual(fb_names, {b"input", b"weights", b"output"}) + + loopback_names = {t.name for t in loopback_sg.tensors} + self.assertEqual(loopback_names, {"input", "weights", "output"}) + + def test_tensor_properties(self): + """Verify tensor shapes and dtypes.""" + fb_sg = self.fb_model.subgraphs[0] + loopback_sg = self.loopback_model.subgraphs[0] + + # Input tensor + input_fb = next(t for t in fb_sg.tensors if t.name == b"input") + input_loopback = next(t for t in loopback_sg.tensors if t.name == "input") + self.assertEqual(list(input_fb.shape), [1, 5]) + self.assertEqual(input_loopback.shape, (1, 5)) + self.assertEqual(input_fb.type, tflite.TensorType.INT8) + self.assertEqual(input_loopback.dtype, tflite.TensorType.INT8) + + # Weights tensor + weights_fb = next(t for t in fb_sg.tensors if t.name == b"weights") + weights_loopback = next(t for t in loopback_sg.tensors + if t.name == "weights") + self.assertEqual(list(weights_fb.shape), [5, 1]) + self.assertEqual(weights_loopback.shape, (5, 1)) + self.assertEqual(weights_fb.type, tflite.TensorType.INT8) + self.assertEqual(weights_loopback.dtype, tflite.TensorType.INT8) + + # Output tensor + output_fb = next(t for t in fb_sg.tensors if t.name == b"output") + output_loopback = next(t for t in loopback_sg.tensors + if t.name == "output") + self.assertEqual(list(output_fb.shape), [1, 1]) + self.assertEqual(output_loopback.shape, (1, 1)) + self.assertEqual(output_fb.type, tflite.TensorType.INT8) + self.assertEqual(output_loopback.dtype, tflite.TensorType.INT8) + + def test_tensor_data(self): + """Verify tensor data and buffer access.""" + fb_sg = self.fb_model.subgraphs[0] + loopback_sg = self.loopback_model.subgraphs[0] + + # Input tensor data + input_buffer = self.fb_model.buffers[fb_sg.tensors[0].buffer] + self.assertIsNotNone(input_buffer.data) + self.assertEqual(bytes(input_buffer.data), self.input_data.tobytes()) + + self.assertIsNotNone(loopback_sg.tensors[0].array) + self.assertAllEqual(loopback_sg.tensors[0].array, self.input_data) + + # Weights tensor data + weights_buffer = self.fb_model.buffers[fb_sg.tensors[1].buffer] + self.assertIsNotNone(weights_buffer.data) + self.assertEqual(bytes(weights_buffer.data), self.weights_data.tobytes()) + + self.assertIsNotNone(loopback_sg.tensors[1].array) + self.assertAllEqual(loopback_sg.tensors[1].array, self.weights_data) + + # Output tensor has no data + self.assertEqual(fb_sg.tensors[2].buffer, 0) + self.assertIsNone(loopback_sg.tensors[2].array) + + def test_buffer_allocation(self): + """Verify buffer allocation and zero convention.""" + fb_sg = self.fb_model.subgraphs[0] + loopback_sg = self.loopback_model.subgraphs[0] + + # Exact buffer count: buffer 0 (empty) + input + weights = 3 total + self.assertEqual(len(self.fb_model.buffers), 3) + self.assertEqual(len(self.loopback_model.buffers), 3) + + # Buffer 0 is empty + buffer_zero = self.fb_model.buffers[0] + self.assertTrue(buffer_zero.data is None or len(buffer_zero.data) == 0) + + # Verify each buffer is referenced by exactly the expected tensor + # Buffer 0 -> output tensor (no data) + output_tensor = next(t for t in fb_sg.tensors if t.name == b"output") + self.assertEqual(output_tensor.buffer, 0) + + # Buffer 1 and 2 -> input and weights (order may vary) + input_tensor = next(t for t in fb_sg.tensors if t.name == b"input") + weights_tensor = next(t for t in fb_sg.tensors if t.name == b"weights") + self.assertNotEqual(input_tensor.buffer, 0) + self.assertNotEqual(weights_tensor.buffer, 0) + self.assertIn(input_tensor.buffer, [1, 2]) + self.assertIn(weights_tensor.buffer, [1, 2]) + + # Tensors with data point to non-zero buffers in loopback model + loopback_input_tensor = next(t for t in loopback_sg.tensors + if t.name == "input") + self.assertIsNotNone(loopback_input_tensor.buffer) + self.assertIsNotNone(loopback_input_tensor.buffer.index) + self.assertNotEqual(loopback_input_tensor.buffer.index, 0) + self.assertEqual(len(loopback_input_tensor.buffer.data), 5) + self.assertEqual(bytes(loopback_input_tensor.buffer.data), + self.input_data.tobytes()) + + def test_operator_references(self): + """Verify operators reference correct tensors.""" + fb_sg = self.fb_model.subgraphs[0] + loopback_sg = self.loopback_model.subgraphs[0] + + # Operator input/output references + self.assertEqual(len(fb_sg.operators[0].inputs), 2) + self.assertEqual([t.name for t in loopback_sg.operators[0].inputs], + ["input", "weights"]) + + self.assertEqual(len(fb_sg.operators[0].outputs), 1) + self.assertEqual([t.name for t in loopback_sg.operators[0].outputs], + ["output"]) + + # Operator indices are in bounds + num_tensors = len(fb_sg.tensors) + for idx in list(fb_sg.operators[0].inputs) + list( + fb_sg.operators[0].outputs): + self.assertGreaterEqual(idx, 0) + self.assertLess(idx, num_tensors) + + def test_operator_codes(self): + """Verify operator code table is correctly populated.""" + fb_sg = self.fb_model.subgraphs[0] + loopback_sg = self.loopback_model.subgraphs[0] + + self.assertIsNotNone(self.fb_model.operatorCodes) + self.assertEqual(len(self.fb_model.operatorCodes), 1) + self.assertEqual(self.fb_model.operatorCodes[0].builtinCode, + tflite.BuiltinOperator.FULLY_CONNECTED) + + self.assertEqual(len(self.loopback_model.operator_codes), 1) + self.assertIsNotNone(loopback_sg.operators[0].opcode_index) + loopback_opcode = self.loopback_model.operator_codes[ + loopback_sg.operators[0].opcode_index] + self.assertEqual(loopback_opcode.builtin_code, + tflite.BuiltinOperator.FULLY_CONNECTED) + + +class TestAdvancedModel(tf.test.TestCase): + """Test multiple operators, custom ops, shared tensors, and mixed references.""" + + @classmethod + def setUpClass(cls): + """Build model once for all tests in this class.""" + cls.input_data = np.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]], dtype=np.int8) + cls.weights_data = np.array( + [[1], [2], [3], [4], [5], [6], [7], [8], [9], [10]], dtype=np.int8) + cls.bias_data = np.array([10], dtype=np.int8) + # Int16 data to test endianness: values that will show byte order issues + cls.int16_data = np.array([256, 512, 1024], + dtype=np.int16) # 0x0100, 0x0200, 0x0400 + + # Pre-declare shared tensor (output of FC, input to custom op) + cls.hidden = Tensor(shape=(1, 1), + dtype=tflite.TensorType.INT8, + name="hidden") + + # Create explicit shared buffer to test buffer sharing between tensors + cls.shared_buffer_data = np.array([100, 127], dtype=np.int8) + cls.shared_buf = Buffer(data=cls.shared_buffer_data.tobytes()) + + cls.model = Model( + description="Advanced model", + metadata={ + "version": b"1.0.0", + "author": b"test_suite", + "custom_data": bytes([0xDE, 0xAD, 0xBE, 0xEF]) + }, + subgraphs=[ + Subgraph( + tensors=[ + cls.hidden, # Mixed: pre-declared shared tensor + # Int16 tensor to test endianness + Tensor(shape=(3, ), + dtype=tflite.TensorType.INT16, + data=cls.int16_data, + name="int16_tensor"), + # Two tensors sharing same buffer to test buffer deduplication + Tensor(shape=(2, ), + dtype=tflite.TensorType.INT8, + buffer=cls.shared_buf, + name="shared_buf_tensor1"), + Tensor(shape=(2, ), + dtype=tflite.TensorType.INT8, + buffer=cls.shared_buf, + name="shared_buf_tensor2") + ], + operators=[ + # Multiple operators: FULLY_CONNECTED + Operator( + opcode=tflite.BuiltinOperator.FULLY_CONNECTED, + inputs=[ + Tensor(shape=(1, 10), + dtype=tflite.TensorType.INT8, + data=cls.input_data, + name="input"), + Tensor(shape=(10, 1), + dtype=tflite.TensorType.INT8, + data=cls.weights_data, + name="weights") + ], + outputs=[cls.hidden + ] # Shared: reference to pre-declared + ), + # Custom operator + Operator( + opcode=tflite.BuiltinOperator.CUSTOM, + custom_code="MyCustomOp", + inputs=[cls.hidden], # Shared: reuse hidden tensor + outputs=[ + Tensor(shape=(1, 1), + dtype=tflite.TensorType.INT8, + name="processed") + ]), + # Multiple operators: ADD + Operator( + opcode=tflite.BuiltinOperator.ADD, + inputs=[ + Tensor( + shape=(1, 1), + dtype=tflite.TensorType.INT8, + name="processed_ref" # Mixed: inline tensor + ), + Tensor(shape=(1, 1), + dtype=tflite.TensorType.INT8, + data=cls.bias_data, + name="bias") + ], + outputs=[ + Tensor(shape=(1, 1), + dtype=tflite.TensorType.INT8, + name="output") + ]) + ]) + ]) + + fb = cls.model.build() + cls.loopback_model = model_editor.read(fb) + cls.fb_model = tflite.ModelT.InitFromPackedBuf(fb, 0) + + def test_operator_counts(self): + """Verify correct number of operators.""" + fb_sg = self.fb_model.subgraphs[0] + loopback_sg = self.loopback_model.subgraphs[0] + + self.assertEqual(len(fb_sg.operators), 3) + self.assertEqual(len(loopback_sg.operators), 3) + + def test_operator_code_table(self): + """Verify operator code table contains all operator types.""" + self.assertEqual(len(self.fb_model.operatorCodes), 3) + self.assertEqual(len(self.loopback_model.operator_codes), 3) + + opcodes_fb = {op.builtinCode for op in self.fb_model.operatorCodes} + self.assertIn(tflite.BuiltinOperator.FULLY_CONNECTED, opcodes_fb) + self.assertIn(tflite.BuiltinOperator.CUSTOM, opcodes_fb) + self.assertIn(tflite.BuiltinOperator.ADD, opcodes_fb) + + opcodes_loopback = { + op.builtin_code + for op in self.loopback_model.operator_codes + } + self.assertIn(tflite.BuiltinOperator.FULLY_CONNECTED, opcodes_loopback) + self.assertIn(tflite.BuiltinOperator.CUSTOM, opcodes_loopback) + self.assertIn(tflite.BuiltinOperator.ADD, opcodes_loopback) + + def test_custom_operator(self): + """Verify custom operator code preservation.""" + loopback_sg = self.loopback_model.subgraphs[0] + + # Custom code in operator code table + custom_opcode_fb = next(op for op in self.fb_model.operatorCodes + if op.builtinCode == tflite.BuiltinOperator.CUSTOM) + self.assertEqual(custom_opcode_fb.customCode, b"MyCustomOp") + + custom_opcode_loopback = next( + op for op in self.loopback_model.operator_codes + if op.builtin_code == tflite.BuiltinOperator.CUSTOM) + self.assertEqual(custom_opcode_loopback.custom_code, "MyCustomOp") + + # Custom operator references custom code + custom_op_loopback = loopback_sg.operators[1] + self.assertEqual(custom_op_loopback.opcode, tflite.BuiltinOperator.CUSTOM) + self.assertEqual(custom_op_loopback.custom_code, "MyCustomOp") + + def test_shared_tensor_references(self): + """Verify tensors shared between operators.""" + fb_sg = self.fb_model.subgraphs[0] + loopback_sg = self.loopback_model.subgraphs[0] + + # Hidden tensor is at index 0 (pre-declared) + self.assertEqual(fb_sg.tensors[0].name, b"hidden") + self.assertEqual(loopback_sg.tensors[0].name, "hidden") + + # FC operator outputs to hidden + self.assertEqual([t.name for t in loopback_sg.operators[0].outputs], + ["hidden"]) + + # Custom operator inputs from hidden + self.assertEqual([t.name for t in loopback_sg.operators[1].inputs], + ["hidden"]) + + # Same Tensor object is referenced by both operators + fc_output = loopback_sg.operators[0].outputs[0] + custom_input = loopback_sg.operators[1].inputs[0] + self.assertIs(fc_output, custom_input) + + def test_mixed_tensor_references(self): + """Verify mix of pre-declared and inline tensors.""" + fb_sg = self.fb_model.subgraphs[0] + loopback_sg = self.loopback_model.subgraphs[0] + + # Total: hidden, int16_tensor, shared_buf_tensor1, shared_buf_tensor2 (pre-declared) + # + input, weights, processed, processed_ref, bias, output (inline from operators) + self.assertEqual(len(fb_sg.tensors), 10) + self.assertEqual(len(loopback_sg.tensors), 10) + + def test_int16_endianness(self): + """Verify int16 data is stored in little-endian byte order.""" + fb_sg = self.fb_model.subgraphs[0] + loopback_sg = self.loopback_model.subgraphs[0] + + # Find int16 tensor by name + int16_tensor_fb = next(t for t in fb_sg.tensors + if t.name == b"int16_tensor") + int16_tensor_loopback = next(t for t in loopback_sg.tensors + if t.name == "int16_tensor") + + # Verify dtype + self.assertEqual(int16_tensor_fb.type, tflite.TensorType.INT16) + self.assertEqual(int16_tensor_loopback.dtype, tflite.TensorType.INT16) + + # Check flatbuffer buffer has correct little-endian bytes + # For [256, 512, 1024] = [0x0100, 0x0200, 0x0400] + # Little-endian bytes: [0x00, 0x01, 0x00, 0x02, 0x00, 0x04] + int16_buffer_fb = self.fb_model.buffers[int16_tensor_fb.buffer] + self.assertIsNotNone(int16_buffer_fb.data) + expected_bytes = self.int16_data.astype(np.int16).astype('buffer mapping from flatbuffer + metadata_map_fb = {} + for entry in self.fb_model.metadata: + buffer_idx = entry.buffer + self.assertLess(buffer_idx, len(self.fb_model.buffers)) + buffer = self.fb_model.buffers[buffer_idx] + if buffer.data is not None: + metadata_map_fb[entry.name] = bytes(buffer.data) + + # Verify flatbuffer metadata values + self.assertEqual(metadata_map_fb[b"version"], b"1.0.0") + self.assertEqual(metadata_map_fb[b"author"], b"test_suite") + self.assertEqual(metadata_map_fb[b"custom_data"], + bytes([0xDE, 0xAD, 0xBE, 0xEF])) + + # Check loopback model metadata + self.assertIsNotNone(self.loopback_model.metadata) + self.assertEqual(len(self.loopback_model.metadata), 3) + + # Verify loopback metadata values (decoded from bytes) + self.assertEqual(self.loopback_model.metadata["version"], b"1.0.0") + self.assertEqual(self.loopback_model.metadata["author"], b"test_suite") + self.assertEqual(self.loopback_model.metadata["custom_data"], + bytes([0xDE, 0xAD, 0xBE, 0xEF])) + + def test_buffer_allocation(self): + """Verify no orphaned buffers and shared buffer deduplication.""" + fb_sg = self.fb_model.subgraphs[0] + loopback_sg = self.loopback_model.subgraphs[0] + + # Collect all buffer references (from tensors and metadata) + referenced_buffers = {0} # Buffer 0 is special (always referenced) + + # Collect buffer references from tensors + for tensor in fb_sg.tensors: + referenced_buffers.add(tensor.buffer) + + # Collect buffer references from metadata + for entry in self.fb_model.metadata: + referenced_buffers.add(entry.buffer) + + # Verify no orphaned buffers (all buffers are referenced) + for i in range(len(self.fb_model.buffers)): + self.assertIn( + i, referenced_buffers, + f"Buffer {i} is orphaned (not referenced by any tensor or metadata)") + + # Verify shared buffer deduplication: two tensors share one buffer + tensor1_fb = next(t for t in fb_sg.tensors + if t.name == b"shared_buf_tensor1") + tensor2_fb = next(t for t in fb_sg.tensors + if t.name == b"shared_buf_tensor2") + + # Both tensors should point to the same buffer index + self.assertEqual(tensor1_fb.buffer, tensor2_fb.buffer) + self.assertNotEqual(tensor1_fb.buffer, 0) + + # Verify loopback preserves shared buffer (same Buffer object) + tensor1_loopback = next(t for t in loopback_sg.tensors + if t.name == "shared_buf_tensor1") + tensor2_loopback = next(t for t in loopback_sg.tensors + if t.name == "shared_buf_tensor2") + + self.assertIs(tensor1_loopback.buffer, tensor2_loopback.buffer) + self.assertEqual(bytes(tensor1_loopback.buffer.data), + self.shared_buffer_data.tobytes()) + self.assertEqual(bytes(tensor2_loopback.buffer.data), + self.shared_buffer_data.tobytes()) + + +class TestQuantization(tf.test.TestCase): + """Test per-tensor and per-channel quantization parameters.""" + + @classmethod + def setUpClass(cls): + """Build model once for all tests in this class.""" + # Per-channel quantization parameters + cls.per_channel_scales = [0.1, 0.2, 0.3, 0.4] + cls.per_channel_zeros = [0, 1, 2, 3] + + cls.model = Model( + description="Quantization test model", + subgraphs=[ + Subgraph(tensors=[ + # Per-tensor quantized tensor (single scale/zero_point) + Tensor(shape=(1, 10), + dtype=tflite.TensorType.INT8, + data=np.ones((1, 10), dtype=np.int8), + name="per_tensor", + quantization=Quantization(scales=0.5, zero_points=10)), + # Per-channel quantized tensor (array of scales/zero_points, axis) + Tensor(shape=(4, 10), + dtype=tflite.TensorType.INT8, + data=np.ones((4, 10), dtype=np.int8), + name="per_channel", + quantization=Quantization( + scales=cls.per_channel_scales, + zero_points=cls.per_channel_zeros, + axis=0)) + ]) + ]) + + fb = cls.model.build() + cls.loopback_model = model_editor.read(fb) + cls.fb_model = tflite.ModelT.InitFromPackedBuf(fb, 0) + + def test_per_tensor_quantization_flatbuffer(self): + """Verify per-tensor quantization in flatbuffer encoding.""" + fb_sg = self.fb_model.subgraphs[0] + + tensor = next(t for t in fb_sg.tensors if t.name == b"per_tensor") + self.assertIsNotNone(tensor.quantization) + + # Scale and zero_point encoded as single-element arrays + self.assertIsNotNone(tensor.quantization.scale) + self.assertEqual(len(tensor.quantization.scale), 1) + self.assertEqual(tensor.quantization.scale[0], 0.5) + + self.assertIsNotNone(tensor.quantization.zeroPoint) + self.assertEqual(len(tensor.quantization.zeroPoint), 1) + self.assertEqual(tensor.quantization.zeroPoint[0], 10) + + def test_per_tensor_quantization_loopback(self): + """Verify per-tensor quantization in loopback model.""" + loopback_sg = self.loopback_model.subgraphs[0] + + tensor = next(t for t in loopback_sg.tensors if t.name == "per_tensor") + self.assertIsNotNone(tensor.quantization) + + # Read back as lists + self.assertEqual(tensor.quantization.scales, [0.5]) + self.assertEqual(tensor.quantization.zero_points, [10]) + self.assertIsNone(tensor.quantization.axis) + + def test_per_channel_quantization_flatbuffer(self): + """Verify per-channel quantization in flatbuffer encoding.""" + fb_sg = self.fb_model.subgraphs[0] + + tensor = next(t for t in fb_sg.tensors if t.name == b"per_channel") + self.assertIsNotNone(tensor.quantization) + + # All scales encoded + self.assertIsNotNone(tensor.quantization.scale) + self.assertEqual(len(tensor.quantization.scale), 4) + self.assertEqual(list(tensor.quantization.scale), self.per_channel_scales) + + # All zero_points encoded + self.assertIsNotNone(tensor.quantization.zeroPoint) + self.assertEqual(len(tensor.quantization.zeroPoint), 4) + self.assertEqual(list(tensor.quantization.zeroPoint), + self.per_channel_zeros) + + # Axis encoded as quantizedDimension + self.assertEqual(tensor.quantization.quantizedDimension, 0) + + def test_per_channel_quantization_loopback(self): + """Verify per-channel quantization in loopback model.""" + loopback_sg = self.loopback_model.subgraphs[0] + + tensor = next(t for t in loopback_sg.tensors if t.name == "per_channel") + self.assertIsNotNone(tensor.quantization) + + # Read back as lists + self.assertEqual(tensor.quantization.scales, self.per_channel_scales) + self.assertEqual(tensor.quantization.zero_points, self.per_channel_zeros) + self.assertEqual(tensor.quantization.axis, 0) + + +class TestReadModifyWrite(tf.test.TestCase): + """Test read-modify-write workflows.""" + + @classmethod + def setUpClass(cls): + """Create a simple base model for modification tests.""" + cls.original_data = np.array([[1, 2, 3]], dtype=np.int8) + cls.model = Model( + description="Base model", + metadata={"original": b"metadata"}, + subgraphs=[ + Subgraph(tensors=[ + Tensor(shape=(1, 3), + dtype=tflite.TensorType.INT8, + data=cls.original_data, + name="weights"), + Tensor( + shape=(1, 3), dtype=tflite.TensorType.INT8, name="input"), + Tensor( + shape=(1, 3), dtype=tflite.TensorType.INT8, name="output") + ]) + ]) + + cls.fb = cls.model.build() + + def test_modify_tensor_data(self): + """Read model, modify tensor data, write back, verify.""" + # Read the model + model2 = model_editor.read(self.fb) + + # Modify tensor data using array setter (high-level API) + weights_tensor = next(t for t in model2.subgraphs[0].tensors + if t.name == "weights") + new_data = np.array([[10, 20, 30]], dtype=np.int8) + weights_tensor.array = new_data # Uses array setter + + # Build modified model + fb2 = model2.build() + + # Read back and verify modification + model3 = model_editor.read(fb2) + modified_weights = next(t for t in model3.subgraphs[0].tensors + if t.name == "weights") + self.assertAllEqual(modified_weights.array, new_data) + + # Verify other tensors unchanged + self.assertEqual(len(model3.subgraphs[0].tensors), 3) + + def test_add_tensor_and_operator(self): + """Read model, add new tensor and operator, write back, verify.""" + # Read the model + model2 = model_editor.read(self.fb) + sg = model2.subgraphs[0] + + # Get existing tensors + input_tensor = next(t for t in sg.tensors if t.name == "input") + output_tensor = next(t for t in sg.tensors if t.name == "output") + + # Add new tensor using imperative API + new_weights = np.array([[5, 10, 15]], dtype=np.int8) + new_weights_tensor = sg.add_tensor(shape=(1, 3), + dtype=tflite.TensorType.INT8, + data=new_weights, + name="new_weights") + + # Add new operator using imperative API + sg.add_operator(opcode=tflite.BuiltinOperator.ADD, + inputs=[input_tensor, new_weights_tensor], + outputs=[output_tensor]) + + # Build modified model + fb2 = model2.build() + + # Read back and verify additions + model3 = model_editor.read(fb2) + sg3 = model3.subgraphs[0] + + # Verify tensor was added + self.assertEqual(len(sg3.tensors), 4) + added_tensor = next(t for t in sg3.tensors if t.name == "new_weights") + self.assertIsNotNone(added_tensor) + self.assertAllEqual(added_tensor.array, new_weights) + + # Verify operator was added + self.assertEqual(len(sg3.operators), 1) + added_op = sg3.operators[0] + self.assertEqual([t.name for t in added_op.inputs], + ["input", "new_weights"]) + self.assertEqual([t.name for t in added_op.outputs], ["output"]) + + def test_modify_metadata(self): + """Read model, modify metadata, write back, verify.""" + # Read the model + model2 = model_editor.read(self.fb) + + # Modify existing metadata + model2.metadata["original"] = b"modified_metadata" + + # Add new metadata + model2.metadata["new_key"] = b"new_value" + + # Build modified model + fb2 = model2.build() + + # Read back and verify modifications + model3 = model_editor.read(fb2) + + self.assertEqual(len(model3.metadata), 2) + self.assertEqual(model3.metadata["original"], b"modified_metadata") + self.assertEqual(model3.metadata["new_key"], b"new_value") + + +if __name__ == "__main__": + tf.test.main()