From f5371676c430ec5714309bbd1252af7a08b606dc Mon Sep 17 00:00:00 2001 From: Ryan Kuester Date: Sun, 24 May 2026 22:55:34 -0500 Subject: [PATCH 01/21] feat(compression): implement model_editor for TFLite model manipulation Implement unified module for creating, reading, and modifying TFLite models with a clean API. The module eliminates manual index tracking and buffer management through automatic bookkeeping, supporting both declarative and imperative construction styles. Wrapper classes (Tensor, Operator, Subgraph, Model) hold the underlying flatbuffer T objects as backing storage rather than copying fields into dataclasses. This ensures all schema fields are preserved during read-modify-write cycles, even fields not explicitly handled by model_editor. Future schema additions will be preserved automatically. Add comprehensive test coverage including field preservation tests that verify unhandled schema fields survive read-modify-write. BUG=part of #3256 --- tensorflow/lite/micro/compression/BUILD | 21 + .../lite/micro/compression/model_editor.py | 826 ++++++++++ .../micro/compression/model_editor_test.py | 1409 +++++++++++++++++ 3 files changed, 2256 insertions(+) create mode 100644 tensorflow/lite/micro/compression/model_editor.py create mode 100644 tensorflow/lite/micro/compression/model_editor_test.py diff --git a/tensorflow/lite/micro/compression/BUILD b/tensorflow/lite/micro/compression/BUILD index 408fc33912e..23e4fe93ca5 100644 --- a/tensorflow/lite/micro/compression/BUILD +++ b/tensorflow/lite/micro/compression/BUILD @@ -247,6 +247,27 @@ py_test( ], ) +tflm_py_library( + name = "model_editor", + srcs = ["model_editor.py"], + deps = [ + "//tensorflow/lite/python:schema_py", + requirement("flatbuffers"), + requirement("numpy"), + ], +) + +tflm_py_test( + name = "model_editor_test", + size = "small", + srcs = ["model_editor_test.py"], + deps = [ + ":model_editor", + "//tensorflow/lite/python:schema_py", + requirement("numpy"), + ], +) + tflm_py_binary( name = "view", srcs = [ diff --git a/tensorflow/lite/micro/compression/model_editor.py b/tensorflow/lite/micro/compression/model_editor.py new file mode 100644 index 00000000000..b42b66f81cc --- /dev/null +++ b/tensorflow/lite/micro/compression/model_editor.py @@ -0,0 +1,826 @@ +# Copyright 2026 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unified TFLite model manipulation module. + +Provides a clean API for creating, reading, and modifying TFLite models. +""" + +from dataclasses import dataclass, field +from typing import Optional, Union, List +import numpy as np +import flatbuffers +from tflite_micro.tensorflow.lite.python import schema_py_generated as tflite + + +class _BufferList(list): + """Custom list that auto-sets buffer.index on append. + + When a buffer is appended, automatically sets buffer.index to its position. + This enables append-only workflows to work seamlessly. + """ + + def append(self, buf): + """Append buffer and auto-set its index.""" + buf.index = len(self) + super().append(buf) + + +@dataclass +class Buffer: + """Buffer holding tensor data. + + The index field indicates the buffer's position in the model's buffer array. + It is automatically populated during: + - read(): Set from flatbuffer + - build(): Set during compilation + - model.buffers.append(): Auto-set to len(model.buffers) - 1 + + The index may become stale after: + - Deleting buffers from model.buffers + - Reordering buffers in model.buffers + + For append-only workflows (the common case), buffer.index can be trusted. + """ + data: bytes + index: Optional[int] = None + + def __len__(self): + return len(self.data) + + def __bytes__(self): + return self.data + + +@dataclass +class Quantization: + """Quantization parameters helper.""" + scales: Union[float, List[float]] + zero_points: Union[int, List[int]] = 0 + axis: Optional[int] = None + + def to_tflite(self) -> tflite.QuantizationParametersT: + """Convert to TFLite schema object.""" + q = tflite.QuantizationParametersT() + + # Normalize to lists + scales = [self.scales] if isinstance(self.scales, + (int, float)) else self.scales + zeros = [self.zero_points] if isinstance(self.zero_points, + int) else self.zero_points + + q.scale = scales + q.zeroPoint = zeros + if self.axis is not None: + q.quantizedDimension = self.axis + + return q + + +class Tensor: + """Tensor specification wrapping a TensorT flatbuffer object. + + Provides clean APIs for common fields (shape, dtype, name, buffer, + quantization) while preserving all other TensorT fields during + read-modify-write. + + Supports both buffer= and data= parameters for flexibility: + - buffer=: Explicitly provide a Buffer object (can be shared between tensors) + - data=: Convenience parameter that auto-creates a Buffer + + Cannot specify both buffer and data at initialization. + """ + + def __init__(self, + shape=None, + dtype=None, + buffer=None, + data=None, + quantization=None, + name=None, + _fb: tflite.TensorT = None): + """Initialize Tensor. + + Args: + shape: Tensor shape as tuple + dtype: TensorType enum value + buffer: Optional Buffer object (for explicit buffer sharing) + data: Optional numpy array or bytes (convenience, creates Buffer) + quantization: Optional Quantization object + name: Optional tensor name + _fb: Optional TensorT for wrapping existing flatbuffer object + + Raises: + ValueError: If both buffer and data are specified + """ + if data is not None and buffer is not None: + raise ValueError("Cannot specify both data and buffer") + + # Use provided TensorT or create new one + self._fb = _fb if _fb is not None else tflite.TensorT() + self._index = None + + # Buffer object (managed separately; _fb.buffer is just an index) + self.buffer = buffer + + # Quantization object (managed separately; synced to _fb on compile) + self.quantization = quantization + + # Set fields if provided (these override any values in _fb) + if shape is not None: + self.shape = shape + if dtype is not None: + self.dtype = dtype + if name is not None: + self.name = name + + # Convert data to buffer if provided + if data is not None: + buf_data = data if isinstance(data, bytes) else data.tobytes() + self.buffer = Buffer(data=buf_data) + + @property + def shape(self) -> tuple: + """Tensor shape as tuple.""" + return tuple(self._fb.shape) if self._fb.shape is not None else () + + @shape.setter + def shape(self, value): + self._fb.shape = list(value) + + @property + def dtype(self) -> tflite.TensorType: + """Tensor data type.""" + return self._fb.type + + @dtype.setter + def dtype(self, value: tflite.TensorType): + self._fb.type = value + + @property + def name(self) -> Optional[str]: + """Tensor name for debugging.""" + n = self._fb.name + if isinstance(n, bytes): + return n.decode('utf-8') + return n + + @name.setter + def name(self, value: Optional[str]): + self._fb.name = value + + @property + def array(self) -> Optional[np.ndarray]: + """Get tensor data as properly-shaped numpy array. + + Returns: + numpy array with shape matching tensor.shape and dtype matching + tensor.dtype, or None if tensor has no data. + + For low-level byte access, use tensor.buffer.data instead. + """ + if self.buffer is None: + return None + return np.frombuffer(self.buffer.data, + dtype=_dtype_to_numpy(self.dtype)).reshape(self.shape) + + @array.setter + def array(self, value: np.ndarray): + """Set tensor data from numpy array. + + Args: + value: New tensor data as numpy array. Will be converted to bytes + using tobytes() and stored in the buffer. + + Creates a new Buffer if tensor has no buffer, or updates the existing + buffer's data in place. + + For low-level byte access, use tensor.buffer.data instead. + """ + buf_data = value.tobytes() + if self.buffer is None: + self.buffer = Buffer(data=buf_data) + else: + self.buffer.data = buf_data + + @property + def index(self) -> Optional[int]: + """Tensor index in the subgraph's tensor list. + + Returns index after read() or build(). May be None or stale after + modifications. Use with caution. + """ + return self._index + + @property + def numpy_dtype(self) -> np.dtype: + """Get numpy dtype corresponding to tensor's TFLite dtype. + + Returns: + numpy dtype object for use with np.frombuffer, np.array, etc. + """ + return _dtype_to_numpy(self.dtype) + + +class OperatorCode: + """Operator code specification wrapping an OperatorCodeT flatbuffer object. + + Provides clean APIs for common fields (builtin_code, custom_code, version) + while preserving all other OperatorCodeT fields during read-modify-write. + """ + + def __init__(self, + builtin_code: tflite.BuiltinOperator = None, + custom_code: Optional[str] = None, + version: int = 1, + _fb: tflite.OperatorCodeT = None): + """Initialize OperatorCode. + + Args: + builtin_code: BuiltinOperator enum value + custom_code: Custom operator name (for CUSTOM opcode) + version: Operator version + _fb: Optional OperatorCodeT for wrapping existing flatbuffer object + """ + # Use provided OperatorCodeT or create new one + self._fb = _fb if _fb is not None else tflite.OperatorCodeT() + + # Set fields if provided (these override any values in _fb) + if builtin_code is not None: + self.builtin_code = builtin_code + if custom_code is not None: + self.custom_code = custom_code + if version != 1 or _fb is None: + self.version = version + + @property + def builtin_code(self) -> tflite.BuiltinOperator: + """Builtin operator code.""" + return self._fb.builtinCode + + @builtin_code.setter + def builtin_code(self, value: tflite.BuiltinOperator): + self._fb.builtinCode = value + + @property + def custom_code(self) -> Optional[str]: + """Custom operator name (for CUSTOM opcode).""" + c = self._fb.customCode + if isinstance(c, bytes): + return c.decode('utf-8') + return c + + @custom_code.setter + def custom_code(self, value: Optional[str]): + self._fb.customCode = value + + @property + def version(self) -> int: + """Operator version.""" + return self._fb.version if self._fb.version else 1 + + @version.setter + def version(self, value: int): + self._fb.version = value + + +class Operator: + """Operator specification wrapping an OperatorT flatbuffer object. + + Provides clean APIs for common fields (opcode, inputs, outputs, custom_code) + while preserving all other OperatorT fields (builtin_options, custom_options, + intermediates, mutating_variable_inputs, etc.) during read-modify-write. + """ + + def __init__(self, + opcode: Union[tflite.BuiltinOperator, int] = None, + inputs: List[Tensor] = None, + outputs: List[Tensor] = None, + custom_code: Optional[str] = None, + opcode_index: Optional[int] = None, + _fb: tflite.OperatorT = None): + """Initialize Operator. + + Args: + opcode: BuiltinOperator enum value or CUSTOM + inputs: List of input Tensor objects + outputs: List of output Tensor objects + custom_code: Custom operator name (for CUSTOM opcode) + opcode_index: Index into operator_codes (set during read) + _fb: Optional OperatorT for wrapping existing flatbuffer object + """ + # Use provided OperatorT or create new one + self._fb = _fb if _fb is not None else tflite.OperatorT() + self._index = None + + # Tensor lists (managed separately; _fb stores indices, not objects) + self.inputs = inputs if inputs is not None else [] + self.outputs = outputs if outputs is not None else [] + + # These are derived from OperatorCode, not stored in OperatorT directly + self._opcode = opcode + self._custom_code = custom_code + self._opcode_index = opcode_index + + @property + def opcode(self) -> Union[tflite.BuiltinOperator, int]: + """Builtin operator code.""" + return self._opcode + + @opcode.setter + def opcode(self, value: Union[tflite.BuiltinOperator, int]): + self._opcode = value + + @property + def custom_code(self) -> Optional[str]: + """Custom operator name (for CUSTOM opcode).""" + return self._custom_code + + @custom_code.setter + def custom_code(self, value: Optional[str]): + self._custom_code = value + + @property + def opcode_index(self) -> Optional[int]: + """Index into operator_codes array (from read or after build).""" + return self._opcode_index + + @opcode_index.setter + def opcode_index(self, value: Optional[int]): + self._opcode_index = value + + @property + def index(self) -> Optional[int]: + """Operator index in the subgraph's operator list.""" + return self._index + + +class Subgraph: + """Subgraph specification wrapping a SubGraphT flatbuffer object. + + Provides clean APIs for common fields (tensors, operators, inputs, outputs, + name) while preserving all other SubGraphT fields during read-modify-write. + """ + + def __init__(self, + tensors: List[Tensor] = None, + operators: List[Operator] = None, + inputs: List[Tensor] = None, + outputs: List[Tensor] = None, + name: Optional[str] = None, + _fb: tflite.SubGraphT = None): + """Initialize Subgraph. + + Args: + tensors: List of Tensor objects + operators: List of Operator objects + inputs: List of input Tensor objects + outputs: List of output Tensor objects + name: Subgraph name for debugging + _fb: Optional SubGraphT for wrapping existing flatbuffer object + """ + # Use provided SubGraphT or create new one + self._fb = _fb if _fb is not None else tflite.SubGraphT() + self._index = None + + # Lists of objects (managed separately; _fb stores indices/arrays) + self.tensors = tensors if tensors is not None else [] + self.operators = operators if operators is not None else [] + self.inputs = inputs if inputs is not None else [] + self.outputs = outputs if outputs is not None else [] + + # Set name if provided (overrides _fb value) + if name is not None: + self.name = name + + @property + def name(self) -> Optional[str]: + """Subgraph name for debugging.""" + n = self._fb.name + if isinstance(n, bytes): + return n.decode('utf-8') + return n + + @name.setter + def name(self, value: Optional[str]): + self._fb.name = value + + def add_tensor(self, **kwargs) -> Tensor: + """Add tensor imperatively and return it.""" + t = Tensor(**kwargs) + t._index = len(self.tensors) + self.tensors.append(t) + return t + + def add_operator(self, **kwargs) -> Operator: + """Add operator imperatively and return it.""" + op = Operator(**kwargs) + op._index = len(self.operators) + self.operators.append(op) + return op + + def tensor_by_name(self, name: str) -> Tensor: + """Look up a tensor by name. + + Args: + name: The tensor name to find. + + Returns: + The Tensor with the given name. + + Raises: + KeyError: If no tensor with that name exists. + """ + for t in self.tensors: + if t.name == name: + return t + raise KeyError(f"No tensor named {name!r}") + + @property + def index(self) -> Optional[int]: + """Subgraph index in the model's subgraph list. + + Returns index after read() or build(). May be None or stale after + modifications. Use with caution. + """ + return self._index + + +class Model: + """Model specification wrapping a ModelT flatbuffer object. + + Provides clean APIs for common fields (subgraphs, buffers, operator_codes, + metadata, description) while preserving all other ModelT fields during + read-modify-write. + """ + + def __init__(self, + subgraphs: List[Subgraph] = None, + buffers: _BufferList = None, + operator_codes: List[OperatorCode] = None, + metadata: dict = None, + description: Optional[str] = None, + _fb: tflite.ModelT = None): + """Initialize Model. + + Args: + subgraphs: List of Subgraph objects + buffers: BufferList for tensor data + operator_codes: List of OperatorCode objects + metadata: Dict of metadata name -> bytes + description: Model description string + _fb: Optional ModelT for wrapping existing flatbuffer object + """ + # Use provided ModelT or create new one + self._fb = _fb if _fb is not None else tflite.ModelT() + + # Lists of objects (managed separately; _fb stores arrays) + self.subgraphs = subgraphs if subgraphs is not None else [] + self.buffers = buffers if buffers is not None else _BufferList() + self.operator_codes = operator_codes if operator_codes is not None else [] + self.metadata = metadata if metadata is not None else {} + + # Set description if provided (overrides _fb value) + if description is not None: + self.description = description + + @property + def description(self) -> Optional[str]: + """Model description string.""" + d = self._fb.description + if isinstance(d, bytes): + return d.decode('utf-8') + return d + + @description.setter + def description(self, value: Optional[str]): + self._fb.description = value + + def add_subgraph(self, **kwargs) -> Subgraph: + """Add subgraph imperatively and return it.""" + sg = Subgraph(**kwargs) + sg._index = len(self.subgraphs) + self.subgraphs.append(sg) + return sg + + def build(self) -> bytearray: + """Compile to flatbuffer with automatic bookkeeping.""" + compiler = _ModelCompiler(self) + return compiler.compile() + + +def read(buffer: bytes) -> Model: + """Read a TFLite flatbuffer and return a Model object.""" + fb_model = tflite.ModelT.InitFromPackedBuf(buffer, 0) + + # Create Model wrapping the ModelT; all fields preserved in _fb + model = Model(_fb=fb_model) + + # Create all buffers first (so tensors can reference them) + for i, fb_buf in enumerate(fb_model.buffers): + buf_data = bytes(fb_buf.data) if fb_buf.data is not None else b'' + buf = Buffer(data=buf_data, index=i) + model.buffers.append(buf) + + # Read operator codes + for fb_opcode in fb_model.operatorCodes: + # Create OperatorCode wrapping the OperatorCodeT; all fields preserved in _fb + opcode = OperatorCode(_fb=fb_opcode) + model.operator_codes.append(opcode) + + # Read subgraphs + for sg_idx, fb_sg in enumerate(fb_model.subgraphs): + # Create Subgraph wrapping the SubGraphT; all fields preserved in _fb + sg = Subgraph(_fb=fb_sg) + sg._index = sg_idx + + # Read tensors + for tensor_idx, fb_tensor in enumerate(fb_sg.tensors): + # Resolve buffer reference + # Buffer 0 is the empty buffer (TFLite convention), so treat it as None + buf = None if fb_tensor.buffer == 0 else model.buffers[fb_tensor.buffer] + + # Read quantization parameters if present + quant = None + if fb_tensor.quantization: + fb_quant = fb_tensor.quantization + if fb_quant.scale is not None and len(fb_quant.scale) > 0: + scales = list(fb_quant.scale) + # Copy zero_points as-is, don't expand (per review feedback) + zeros = list( + fb_quant.zeroPoint) if fb_quant.zeroPoint is not None else [0] + # Copy axis if: (1) it's non-zero, or (2) there are multiple scales. + # This preserves per-channel quant with 1 channel (axis non-zero, 1 scale) + # while treating default axis=0 with 1 scale as per-tensor (axis=None). + axis = fb_quant.quantizedDimension + if axis == 0 and len(scales) == 1: + axis = None + quant = Quantization(scales=scales, zero_points=zeros, axis=axis) + + # Create Tensor wrapping the TensorT; all fields preserved in _fb + tensor = Tensor(_fb=fb_tensor, buffer=buf, quantization=quant) + tensor._index = tensor_idx + + sg.tensors.append(tensor) + + # Read operators + for fb_op in fb_sg.operators: + # Get operator code info + opcode_obj = model.operator_codes[fb_op.opcodeIndex] + + # Resolve tensor indices to Tensor objects + inputs = [sg.tensors[i] + for i in fb_op.inputs] if fb_op.inputs is not None else [] + outputs = [sg.tensors[i] + for i in fb_op.outputs] if fb_op.outputs is not None else [] + + # Create Operator wrapping the OperatorT; all fields preserved in _fb + op = Operator( + _fb=fb_op, + opcode=opcode_obj.builtin_code, + inputs=inputs, + outputs=outputs, + custom_code=opcode_obj.custom_code, + opcode_index=fb_op.opcodeIndex, + ) + sg.operators.append(op) + + # Read subgraph inputs/outputs + if fb_sg.inputs is not None and len(fb_sg.inputs) > 0: + sg.inputs = [sg.tensors[i] for i in fb_sg.inputs] + if fb_sg.outputs is not None and len(fb_sg.outputs) > 0: + sg.outputs = [sg.tensors[i] for i in fb_sg.outputs] + + model.subgraphs.append(sg) + + # Read metadata + if fb_model.metadata: + for entry in fb_model.metadata: + # Decode metadata name + name = entry.name + if isinstance(name, bytes): + name = name.decode('utf-8') + + # Get metadata value from buffer + buffer = fb_model.buffers[entry.buffer] + value = bytes(buffer.data) if buffer.data is not None else b'' + + model.metadata[name] = value + + return model + + +def _dtype_to_numpy(dtype: tflite.TensorType) -> np.dtype: + """Convert TFLite dtype to numpy dtype.""" + type_map = { + tflite.TensorType.INT8: np.int8, + tflite.TensorType.INT16: np.int16, + tflite.TensorType.INT32: np.int32, + tflite.TensorType.INT64: np.int64, + tflite.TensorType.UINT8: np.uint8, + tflite.TensorType.FLOAT32: np.float32, + } + return type_map.get(dtype, np.uint8) + + +class _ModelCompiler: + """Internal: compiles Model to flatbuffer with automatic bookkeeping.""" + + def __init__(self, model: Model): + self.model = model + self._buffers = [] + self._buffer_map = {} # Map Buffer object id to index + self._operator_codes = {} + + def compile(self) -> bytearray: + """Compile model using backing ModelT, preserving all fields.""" + # Use the backing ModelT directly---this preserves all fields we don't + # explicitly handle (version, signature_defs, etc.) + root = self.model._fb + + # Initialize buffers + # If model.buffers exists (from read()), preserve those buffers + if self.model.buffers: + for buf in self.model.buffers: + fb_buf = tflite.BufferT() + fb_buf.data = list(buf.data) if buf.data else [] + self._buffers.append(fb_buf) + self._buffer_map[id(buf)] = buf.index + else: + # Creating model from scratch: initialize buffer 0 as empty (TFLite convention) + empty_buffer = tflite.BufferT() + empty_buffer.data = [] + self._buffers = [empty_buffer] + # Note: buffer 0 should not be in _buffer_map since tensors without data use it + + # Auto-collect and register operator codes + self._collect_operator_codes() + root.operatorCodes = list(self._operator_codes.values()) + + # Process subgraphs + root.subgraphs = [] + for sg in self.model.subgraphs: + root.subgraphs.append(self._compile_subgraph(sg)) + + # Process buffers + root.buffers = self._buffers + + # Process metadata + root.metadata = self._compile_metadata() + + # Pack and return + builder = flatbuffers.Builder(4 * 2**20) + builder.Finish(root.Pack(builder)) + return builder.Output() + + def _collect_operator_codes(self): + """Scan all operators and build operator code table.""" + # Build lookup from existing OperatorCodes (from read()) to reuse their _fb + existing_opcodes = { + (oc.builtin_code, oc.custom_code): oc + for oc in self.model.operator_codes + } + + for sg in self.model.subgraphs: + for op in sg.operators: + key = (op.opcode, op.custom_code) + if key not in self._operator_codes: + # Reuse existing OperatorCodeT if available (preserves deprecated_builtin_code) + if key in existing_opcodes: + self._operator_codes[key] = existing_opcodes[key]._fb + else: + # Create new OperatorCodeT for newly added operators + opcode = tflite.OperatorCodeT() + opcode.builtinCode = op.opcode + if op.custom_code: + opcode.customCode = op.custom_code + self._operator_codes[key] = opcode + + def _compile_subgraph(self, sg: Subgraph) -> tflite.SubGraphT: + """Compile subgraph using backing SubGraphT, preserving all fields.""" + # Use the backing SubGraphT directly---this preserves all fields we don't + # explicitly handle (debug_metadata_index, etc.) + sg_t = sg._fb + + # Collect all tensors (from tensor list and inline in operators) + all_tensors = list(sg.tensors) + tensor_to_index = {} + for i, t in enumerate(all_tensors): + t._index = i + tensor_to_index[id(t)] = i + + # Extract inline tensors from operators and subgraph inputs/outputs + inline_sources = [op.inputs + op.outputs for op in sg.operators] + inline_sources.append(sg.inputs) + inline_sources.append(sg.outputs) + for source in inline_sources: + for tensor in source: + if id(tensor) not in tensor_to_index: + tensor._index = len(all_tensors) + tensor_to_index[id(tensor)] = tensor._index + all_tensors.append(tensor) + + # Compile all tensors + sg_t.tensors = [] + for tensor in all_tensors: + sg_t.tensors.append(self._compile_tensor(tensor)) + + # Compile operators + sg_t.operators = [] + for op in sg.operators: + sg_t.operators.append(self._compile_operator(op, tensor_to_index)) + + # Set subgraph inputs/outputs + sg_t.inputs = [tensor_to_index[id(t)] for t in sg.inputs] + sg_t.outputs = [tensor_to_index[id(t)] for t in sg.outputs] + + return sg_t + + def _compile_operator(self, op: Operator, + tensor_to_index: dict) -> tflite.OperatorT: + """Compile operator using backing OperatorT, preserving all fields.""" + # Use the backing OperatorT directly---this preserves all fields we don't + # explicitly handle (builtin_options, custom_options, intermediates, etc.) + op_t = op._fb + + # Get opcode index + key = (op.opcode, op.custom_code) + opcode_index = list(self._operator_codes.keys()).index(key) + op_t.opcodeIndex = opcode_index + + # Resolve tensor references to indices + op_t.inputs = [tensor_to_index[id(inp)] for inp in op.inputs] + op_t.outputs = [tensor_to_index[id(outp)] for outp in op.outputs] + + return op_t + + def _compile_tensor(self, tensor: Tensor) -> tflite.TensorT: + """Compile tensor using backing TensorT, preserving all fields.""" + # Use the backing TensorT directly---this preserves all fields we don't + # explicitly handle (is_variable, sparsity, shape_signature, has_rank, etc.) + t = tensor._fb + + # Handle buffer assignment + if tensor.buffer is None: + # No data: use buffer 0 + t.buffer = 0 + else: + # Has buffer: get or create index for it + buf_id = id(tensor.buffer) + if buf_id not in self._buffer_map: + # First time seeing this buffer, add it + fb_buf = tflite.BufferT() + fb_buf.data = list(tensor.buffer.data) + self._buffers.append(fb_buf) + buf_index = len(self._buffers) - 1 + self._buffer_map[buf_id] = buf_index + tensor.buffer.index = buf_index + t.buffer = self._buffer_map[buf_id] + + # Sync quantization: merge our Quantization object into _fb.quantization + if tensor.quantization: + if t.quantization is None: + t.quantization = tflite.QuantizationParametersT() + # Update only the fields we manage; other fields (min, max, details) + # are preserved from the original _fb.quantization + q = tensor.quantization + scales = [q.scales] if isinstance(q.scales, (int, float)) else q.scales + zeros = [q.zero_points] if isinstance(q.zero_points, + int) else q.zero_points + t.quantization.scale = scales + t.quantization.zeroPoint = zeros + if q.axis is not None: + t.quantization.quantizedDimension = q.axis + + return t + + def _compile_metadata(self): + """Compile metadata, creating buffers for metadata values.""" + if not self.model.metadata: + return [] + + metadata_entries = [] + for name, value in self.model.metadata.items(): + # Create buffer for metadata value + buf = tflite.BufferT() + buf.data = list(value) if isinstance(value, bytes) else list(value) + self._buffers.append(buf) + buf_index = len(self._buffers) - 1 + + # Create metadata entry + entry = tflite.MetadataT() + entry.name = name + entry.buffer = buf_index + metadata_entries.append(entry) + + return metadata_entries diff --git a/tensorflow/lite/micro/compression/model_editor_test.py b/tensorflow/lite/micro/compression/model_editor_test.py new file mode 100644 index 00000000000..1f036b1a1e6 --- /dev/null +++ b/tensorflow/lite/micro/compression/model_editor_test.py @@ -0,0 +1,1409 @@ +# Copyright 2026 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for model_editor module. +""" + +import numpy as np +import unittest +from tflite_micro.tensorflow.lite.python import schema_py_generated as tflite +from tflite_micro.tensorflow.lite.micro.compression import model_editor +from tflite_micro.tensorflow.lite.micro.compression.model_editor import ( + Buffer, Model, Operator, OperatorCode, Quantization, Subgraph, Tensor) + + +class TestBasicModel(unittest.TestCase): + """Test basic model with tensors and operators.""" + + @classmethod + def setUpClass(cls): + """Build model once for all tests in this class.""" + cls.input_data = np.array([[1, 2, 3, 4, 5]], dtype=np.int8) + cls.weights_data = np.array([[1], [2], [3], [4], [5]], dtype=np.int8) + + cls.model = Model( + description="Test model", + subgraphs=[ + Subgraph(operators=[ + Operator(opcode=tflite.BuiltinOperator.FULLY_CONNECTED, + inputs=[ + Tensor(shape=(1, 5), + dtype=tflite.TensorType.INT8, + data=cls.input_data, + name="input"), + Tensor(shape=(5, 1), + dtype=tflite.TensorType.INT8, + data=cls.weights_data, + name="weights") + ], + outputs=[ + Tensor(shape=(1, 1), + dtype=tflite.TensorType.INT8, + name="output") + ]) + ]) + ]) + + # Build the model to a flatbuffer byte array. This exercises the + # model_editor's build path, which converts the high-level Model API + # representation into the binary TFLite format. + fb = cls.model.build() + + # Read the flatbuffer back through model_editor.read() to create a + # loopback model. This exercises the read path, which parses the + # flatbuffer and reconstructs a high-level Model representation. The + # loopback model should be semantically equivalent to cls.model, + # demonstrating that build() and read() are inverse operations. + cls.loopback_model = model_editor.read(fb) + + # Parse the same flatbuffer using the low-level TFLite schema interface + # (ModelT from schema_py_generated). This provides direct access to the + # raw flatbuffer structure, allowing us to verify that model_editor + # encodes data correctly at the binary level. We compare fb_model + # (low-level) against loopback_model (high-level) to ensure both + # representations are consistent. + cls.fb_model = tflite.ModelT.InitFromPackedBuf(fb, 0) + + def test_description(self): + """Verify model description is preserved through loopback.""" + self.assertEqual(self.fb_model.description, b"Test model") + self.assertEqual(self.loopback_model.description, "Test model") + + def test_counts(self): + """Verify subgraph, tensor, and operator counts.""" + self.assertEqual(len(self.fb_model.subgraphs), 1) + self.assertEqual(len(self.loopback_model.subgraphs), 1) + + fb_sg = self.fb_model.subgraphs[0] + loopback_sg = self.loopback_model.subgraphs[0] + + self.assertEqual(len(fb_sg.tensors), 3) + self.assertEqual(len(loopback_sg.tensors), 3) + + self.assertEqual(len(fb_sg.operators), 1) + self.assertEqual(len(loopback_sg.operators), 1) + + def test_tensor_names(self): + """Verify tensor names are preserved.""" + fb_sg = self.fb_model.subgraphs[0] + loopback_sg = self.loopback_model.subgraphs[0] + + # Check that all expected tensor names are present + fb_names = {t.name for t in fb_sg.tensors} + self.assertEqual(fb_names, {b"input", b"weights", b"output"}) + + loopback_names = {t.name for t in loopback_sg.tensors} + self.assertEqual(loopback_names, {"input", "weights", "output"}) + + def test_tensor_properties(self): + """Verify tensor shapes and dtypes.""" + fb_sg = self.fb_model.subgraphs[0] + loopback_sg = self.loopback_model.subgraphs[0] + + # Input tensor + input_fb = next(t for t in fb_sg.tensors if t.name == b"input") + input_loopback = next(t for t in loopback_sg.tensors if t.name == "input") + self.assertEqual(list(input_fb.shape), [1, 5]) + self.assertEqual(input_loopback.shape, (1, 5)) + self.assertEqual(input_fb.type, tflite.TensorType.INT8) + self.assertEqual(input_loopback.dtype, tflite.TensorType.INT8) + + # Weights tensor + weights_fb = next(t for t in fb_sg.tensors if t.name == b"weights") + weights_loopback = next(t for t in loopback_sg.tensors + if t.name == "weights") + self.assertEqual(list(weights_fb.shape), [5, 1]) + self.assertEqual(weights_loopback.shape, (5, 1)) + self.assertEqual(weights_fb.type, tflite.TensorType.INT8) + self.assertEqual(weights_loopback.dtype, tflite.TensorType.INT8) + + # Output tensor + output_fb = next(t for t in fb_sg.tensors if t.name == b"output") + output_loopback = next(t for t in loopback_sg.tensors + if t.name == "output") + self.assertEqual(list(output_fb.shape), [1, 1]) + self.assertEqual(output_loopback.shape, (1, 1)) + self.assertEqual(output_fb.type, tflite.TensorType.INT8) + self.assertEqual(output_loopback.dtype, tflite.TensorType.INT8) + + def test_tensor_data(self): + """Verify tensor data and buffer access.""" + fb_sg = self.fb_model.subgraphs[0] + loopback_sg = self.loopback_model.subgraphs[0] + + # Input tensor data + input_buffer = self.fb_model.buffers[fb_sg.tensors[0].buffer] + self.assertIsNotNone(input_buffer.data) + self.assertEqual(bytes(input_buffer.data), self.input_data.tobytes()) + + self.assertIsNotNone(loopback_sg.tensors[0].array) + np.testing.assert_array_equal(loopback_sg.tensors[0].array, + self.input_data) + + # Weights tensor data + weights_buffer = self.fb_model.buffers[fb_sg.tensors[1].buffer] + self.assertIsNotNone(weights_buffer.data) + self.assertEqual(bytes(weights_buffer.data), self.weights_data.tobytes()) + + self.assertIsNotNone(loopback_sg.tensors[1].array) + np.testing.assert_array_equal(loopback_sg.tensors[1].array, + self.weights_data) + + # Output tensor has no data + self.assertEqual(fb_sg.tensors[2].buffer, 0) + self.assertIsNone(loopback_sg.tensors[2].array) + + def test_buffer_allocation(self): + """Verify buffer allocation and zero convention.""" + fb_sg = self.fb_model.subgraphs[0] + loopback_sg = self.loopback_model.subgraphs[0] + + # Exact buffer count: buffer 0 (empty) + input + weights = 3 total + self.assertEqual(len(self.fb_model.buffers), 3) + self.assertEqual(len(self.loopback_model.buffers), 3) + + # Buffer 0 is empty + buffer_zero = self.fb_model.buffers[0] + self.assertTrue(buffer_zero.data is None or len(buffer_zero.data) == 0) + + # Verify each buffer is referenced by exactly the expected tensor + # Buffer 0 -> output tensor (no data) + output_tensor = next(t for t in fb_sg.tensors if t.name == b"output") + self.assertEqual(output_tensor.buffer, 0) + + # Buffer 1 and 2 -> input and weights (order may vary) + input_tensor = next(t for t in fb_sg.tensors if t.name == b"input") + weights_tensor = next(t for t in fb_sg.tensors if t.name == b"weights") + self.assertNotEqual(input_tensor.buffer, 0) + self.assertNotEqual(weights_tensor.buffer, 0) + self.assertIn(input_tensor.buffer, [1, 2]) + self.assertIn(weights_tensor.buffer, [1, 2]) + + # Tensors with data point to non-zero buffers in loopback model + loopback_input_tensor = next(t for t in loopback_sg.tensors + if t.name == "input") + self.assertIsNotNone(loopback_input_tensor.buffer) + self.assertIsNotNone(loopback_input_tensor.buffer.index) + self.assertNotEqual(loopback_input_tensor.buffer.index, 0) + self.assertEqual(len(loopback_input_tensor.buffer.data), 5) + self.assertEqual(bytes(loopback_input_tensor.buffer.data), + self.input_data.tobytes()) + + def test_operator_references(self): + """Verify operators reference correct tensors.""" + fb_sg = self.fb_model.subgraphs[0] + loopback_sg = self.loopback_model.subgraphs[0] + + # Operator input/output references + self.assertEqual(len(fb_sg.operators[0].inputs), 2) + self.assertEqual([t.name for t in loopback_sg.operators[0].inputs], + ["input", "weights"]) + + self.assertEqual(len(fb_sg.operators[0].outputs), 1) + self.assertEqual([t.name for t in loopback_sg.operators[0].outputs], + ["output"]) + + # Operator indices are in bounds + num_tensors = len(fb_sg.tensors) + for idx in list(fb_sg.operators[0].inputs) + list( + fb_sg.operators[0].outputs): + self.assertGreaterEqual(idx, 0) + self.assertLess(idx, num_tensors) + + def test_operator_codes(self): + """Verify operator code table is correctly populated.""" + fb_sg = self.fb_model.subgraphs[0] + loopback_sg = self.loopback_model.subgraphs[0] + + self.assertIsNotNone(self.fb_model.operatorCodes) + self.assertEqual(len(self.fb_model.operatorCodes), 1) + self.assertEqual(self.fb_model.operatorCodes[0].builtinCode, + tflite.BuiltinOperator.FULLY_CONNECTED) + + self.assertEqual(len(self.loopback_model.operator_codes), 1) + self.assertIsNotNone(loopback_sg.operators[0].opcode_index) + loopback_opcode = self.loopback_model.operator_codes[ + loopback_sg.operators[0].opcode_index] + self.assertEqual(loopback_opcode.builtin_code, + tflite.BuiltinOperator.FULLY_CONNECTED) + + +class TestAdvancedModel(unittest.TestCase): + """Test multiple operators, custom ops, shared tensors, and mixed references.""" + + @classmethod + def setUpClass(cls): + """Build model once for all tests in this class.""" + cls.input_data = np.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]], dtype=np.int8) + cls.weights_data = np.array( + [[1], [2], [3], [4], [5], [6], [7], [8], [9], [10]], dtype=np.int8) + cls.bias_data = np.array([10], dtype=np.int8) + # Int16 data to test endianness: values that will show byte order issues + cls.int16_data = np.array([256, 512, 1024], + dtype=np.int16) # 0x0100, 0x0200, 0x0400 + + # Pre-declare shared tensor (output of FC, input to custom op) + cls.hidden = Tensor(shape=(1, 1), + dtype=tflite.TensorType.INT8, + name="hidden") + + # Create explicit shared buffer to test buffer sharing between tensors + cls.shared_buffer_data = np.array([100, 127], dtype=np.int8) + cls.shared_buf = Buffer(data=cls.shared_buffer_data.tobytes()) + + cls.model = Model( + description="Advanced model", + metadata={ + "version": b"1.0.0", + "author": b"test_suite", + "custom_data": bytes([0xDE, 0xAD, 0xBE, 0xEF]) + }, + subgraphs=[ + Subgraph( + tensors=[ + cls.hidden, # Mixed: pre-declared shared tensor + # Int16 tensor to test endianness + Tensor(shape=(3, ), + dtype=tflite.TensorType.INT16, + data=cls.int16_data, + name="int16_tensor"), + # Two tensors sharing same buffer to test buffer deduplication + Tensor(shape=(2, ), + dtype=tflite.TensorType.INT8, + buffer=cls.shared_buf, + name="shared_buf_tensor1"), + Tensor(shape=(2, ), + dtype=tflite.TensorType.INT8, + buffer=cls.shared_buf, + name="shared_buf_tensor2") + ], + operators=[ + # Multiple operators: FULLY_CONNECTED + Operator( + opcode=tflite.BuiltinOperator.FULLY_CONNECTED, + inputs=[ + Tensor(shape=(1, 10), + dtype=tflite.TensorType.INT8, + data=cls.input_data, + name="input"), + Tensor(shape=(10, 1), + dtype=tflite.TensorType.INT8, + data=cls.weights_data, + name="weights") + ], + outputs=[cls.hidden + ] # Shared: reference to pre-declared + ), + # Custom operator + Operator( + opcode=tflite.BuiltinOperator.CUSTOM, + custom_code="MyCustomOp", + inputs=[cls.hidden], # Shared: reuse hidden tensor + outputs=[ + Tensor(shape=(1, 1), + dtype=tflite.TensorType.INT8, + name="processed") + ]), + # Multiple operators: ADD + Operator( + opcode=tflite.BuiltinOperator.ADD, + inputs=[ + Tensor( + shape=(1, 1), + dtype=tflite.TensorType.INT8, + name="processed_ref" # Mixed: inline tensor + ), + Tensor(shape=(1, 1), + dtype=tflite.TensorType.INT8, + data=cls.bias_data, + name="bias") + ], + outputs=[ + Tensor(shape=(1, 1), + dtype=tflite.TensorType.INT8, + name="output") + ]) + ]) + ]) + + fb = cls.model.build() + cls.loopback_model = model_editor.read(fb) + cls.fb_model = tflite.ModelT.InitFromPackedBuf(fb, 0) + + def test_operator_counts(self): + """Verify correct number of operators.""" + fb_sg = self.fb_model.subgraphs[0] + loopback_sg = self.loopback_model.subgraphs[0] + + self.assertEqual(len(fb_sg.operators), 3) + self.assertEqual(len(loopback_sg.operators), 3) + + def test_operator_code_table(self): + """Verify operator code table contains all operator types.""" + self.assertEqual(len(self.fb_model.operatorCodes), 3) + self.assertEqual(len(self.loopback_model.operator_codes), 3) + + opcodes_fb = {op.builtinCode for op in self.fb_model.operatorCodes} + self.assertIn(tflite.BuiltinOperator.FULLY_CONNECTED, opcodes_fb) + self.assertIn(tflite.BuiltinOperator.CUSTOM, opcodes_fb) + self.assertIn(tflite.BuiltinOperator.ADD, opcodes_fb) + + opcodes_loopback = { + op.builtin_code + for op in self.loopback_model.operator_codes + } + self.assertIn(tflite.BuiltinOperator.FULLY_CONNECTED, opcodes_loopback) + self.assertIn(tflite.BuiltinOperator.CUSTOM, opcodes_loopback) + self.assertIn(tflite.BuiltinOperator.ADD, opcodes_loopback) + + def test_custom_operator(self): + """Verify custom operator code preservation.""" + loopback_sg = self.loopback_model.subgraphs[0] + + # Custom code in operator code table + custom_opcode_fb = next(op for op in self.fb_model.operatorCodes + if op.builtinCode == tflite.BuiltinOperator.CUSTOM) + self.assertEqual(custom_opcode_fb.customCode, b"MyCustomOp") + + custom_opcode_loopback = next( + op for op in self.loopback_model.operator_codes + if op.builtin_code == tflite.BuiltinOperator.CUSTOM) + self.assertEqual(custom_opcode_loopback.custom_code, "MyCustomOp") + + # Custom operator references custom code + custom_op_loopback = loopback_sg.operators[1] + self.assertEqual(custom_op_loopback.opcode, tflite.BuiltinOperator.CUSTOM) + self.assertEqual(custom_op_loopback.custom_code, "MyCustomOp") + + def test_shared_tensor_references(self): + """Verify tensors shared between operators.""" + fb_sg = self.fb_model.subgraphs[0] + loopback_sg = self.loopback_model.subgraphs[0] + + # Hidden tensor is at index 0 (pre-declared) + self.assertEqual(fb_sg.tensors[0].name, b"hidden") + self.assertEqual(loopback_sg.tensors[0].name, "hidden") + + # FC operator outputs to hidden + self.assertEqual([t.name for t in loopback_sg.operators[0].outputs], + ["hidden"]) + + # Custom operator inputs from hidden + self.assertEqual([t.name for t in loopback_sg.operators[1].inputs], + ["hidden"]) + + # Same Tensor object is referenced by both operators + fc_output = loopback_sg.operators[0].outputs[0] + custom_input = loopback_sg.operators[1].inputs[0] + self.assertIs(fc_output, custom_input) + + def test_mixed_tensor_references(self): + """Verify mix of pre-declared and inline tensors.""" + fb_sg = self.fb_model.subgraphs[0] + loopback_sg = self.loopback_model.subgraphs[0] + + # Total: hidden, int16_tensor, shared_buf_tensor1, shared_buf_tensor2 (pre-declared) + # + input, weights, processed, processed_ref, bias, output (inline from operators) + self.assertEqual(len(fb_sg.tensors), 10) + self.assertEqual(len(loopback_sg.tensors), 10) + + def test_int16_endianness(self): + """Verify int16 data is stored in little-endian byte order.""" + fb_sg = self.fb_model.subgraphs[0] + loopback_sg = self.loopback_model.subgraphs[0] + + # Find int16 tensor by name + int16_tensor_fb = next(t for t in fb_sg.tensors + if t.name == b"int16_tensor") + int16_tensor_loopback = next(t for t in loopback_sg.tensors + if t.name == "int16_tensor") + + # Verify dtype + self.assertEqual(int16_tensor_fb.type, tflite.TensorType.INT16) + self.assertEqual(int16_tensor_loopback.dtype, tflite.TensorType.INT16) + + # Check flatbuffer buffer has correct little-endian bytes + # For [256, 512, 1024] = [0x0100, 0x0200, 0x0400] + # Little-endian bytes: [0x00, 0x01, 0x00, 0x02, 0x00, 0x04] + int16_buffer_fb = self.fb_model.buffers[int16_tensor_fb.buffer] + self.assertIsNotNone(int16_buffer_fb.data) + expected_bytes = self.int16_data.astype(np.int16).astype('buffer mapping from flatbuffer + metadata_map_fb = {} + for entry in self.fb_model.metadata: + buffer_idx = entry.buffer + self.assertLess(buffer_idx, len(self.fb_model.buffers)) + buffer = self.fb_model.buffers[buffer_idx] + if buffer.data is not None: + metadata_map_fb[entry.name] = bytes(buffer.data) + + # Verify flatbuffer metadata values + self.assertEqual(metadata_map_fb[b"version"], b"1.0.0") + self.assertEqual(metadata_map_fb[b"author"], b"test_suite") + self.assertEqual(metadata_map_fb[b"custom_data"], + bytes([0xDE, 0xAD, 0xBE, 0xEF])) + + # Check loopback model metadata + self.assertIsNotNone(self.loopback_model.metadata) + self.assertEqual(len(self.loopback_model.metadata), 3) + + # Verify loopback metadata values (decoded from bytes) + self.assertEqual(self.loopback_model.metadata["version"], b"1.0.0") + self.assertEqual(self.loopback_model.metadata["author"], b"test_suite") + self.assertEqual(self.loopback_model.metadata["custom_data"], + bytes([0xDE, 0xAD, 0xBE, 0xEF])) + + def test_buffer_allocation(self): + """Verify no orphaned buffers and shared buffer deduplication.""" + fb_sg = self.fb_model.subgraphs[0] + loopback_sg = self.loopback_model.subgraphs[0] + + # Collect all buffer references (from tensors and metadata) + referenced_buffers = {0} # Buffer 0 is special (always referenced) + + # Collect buffer references from tensors + for tensor in fb_sg.tensors: + referenced_buffers.add(tensor.buffer) + + # Collect buffer references from metadata + for entry in self.fb_model.metadata: + referenced_buffers.add(entry.buffer) + + # Verify no orphaned buffers (all buffers are referenced) + for i in range(len(self.fb_model.buffers)): + self.assertIn( + i, referenced_buffers, + f"Buffer {i} is orphaned (not referenced by any tensor or metadata)") + + # Verify shared buffer deduplication: two tensors share one buffer + tensor1_fb = next(t for t in fb_sg.tensors + if t.name == b"shared_buf_tensor1") + tensor2_fb = next(t for t in fb_sg.tensors + if t.name == b"shared_buf_tensor2") + + # Both tensors should point to the same buffer index + self.assertEqual(tensor1_fb.buffer, tensor2_fb.buffer) + self.assertNotEqual(tensor1_fb.buffer, 0) + + # Verify loopback preserves shared buffer (same Buffer object) + tensor1_loopback = next(t for t in loopback_sg.tensors + if t.name == "shared_buf_tensor1") + tensor2_loopback = next(t for t in loopback_sg.tensors + if t.name == "shared_buf_tensor2") + + self.assertIs(tensor1_loopback.buffer, tensor2_loopback.buffer) + self.assertEqual(bytes(tensor1_loopback.buffer.data), + self.shared_buffer_data.tobytes()) + self.assertEqual(bytes(tensor2_loopback.buffer.data), + self.shared_buffer_data.tobytes()) + + +class TestQuantization(unittest.TestCase): + """Test per-tensor and per-channel quantization parameters.""" + + @classmethod + def setUpClass(cls): + """Build model once for all tests in this class.""" + # Per-channel quantization parameters + cls.per_channel_scales = [0.1, 0.2, 0.3, 0.4] + cls.per_channel_zeros = [0, 1, 2, 3] + + cls.model = Model( + description="Quantization test model", + subgraphs=[ + Subgraph(tensors=[ + # Per-tensor quantized tensor (single scale/zero_point) + Tensor(shape=(1, 10), + dtype=tflite.TensorType.INT8, + data=np.ones((1, 10), dtype=np.int8), + name="per_tensor", + quantization=Quantization(scales=0.5, zero_points=10)), + # Per-channel quantized tensor (array of scales/zero_points, axis) + Tensor(shape=(4, 10), + dtype=tflite.TensorType.INT8, + data=np.ones((4, 10), dtype=np.int8), + name="per_channel", + quantization=Quantization( + scales=cls.per_channel_scales, + zero_points=cls.per_channel_zeros, + axis=0)) + ]) + ]) + + fb = cls.model.build() + cls.loopback_model = model_editor.read(fb) + cls.fb_model = tflite.ModelT.InitFromPackedBuf(fb, 0) + + def test_per_tensor_quantization_flatbuffer(self): + """Verify per-tensor quantization in flatbuffer encoding.""" + fb_sg = self.fb_model.subgraphs[0] + + tensor = next(t for t in fb_sg.tensors if t.name == b"per_tensor") + self.assertIsNotNone(tensor.quantization) + + # Scale and zero_point encoded as single-element arrays + self.assertIsNotNone(tensor.quantization.scale) + self.assertEqual(len(tensor.quantization.scale), 1) + self.assertEqual(tensor.quantization.scale[0], 0.5) + + self.assertIsNotNone(tensor.quantization.zeroPoint) + self.assertEqual(len(tensor.quantization.zeroPoint), 1) + self.assertEqual(tensor.quantization.zeroPoint[0], 10) + + def test_per_tensor_quantization_loopback(self): + """Verify per-tensor quantization in loopback model.""" + loopback_sg = self.loopback_model.subgraphs[0] + + tensor = next(t for t in loopback_sg.tensors if t.name == "per_tensor") + self.assertIsNotNone(tensor.quantization) + + # Read back as lists + self.assertEqual(tensor.quantization.scales, [0.5]) + self.assertEqual(tensor.quantization.zero_points, [10]) + self.assertIsNone(tensor.quantization.axis) + + def test_per_channel_quantization_flatbuffer(self): + """Verify per-channel quantization in flatbuffer encoding.""" + fb_sg = self.fb_model.subgraphs[0] + + tensor = next(t for t in fb_sg.tensors if t.name == b"per_channel") + self.assertIsNotNone(tensor.quantization) + + # All scales encoded + self.assertIsNotNone(tensor.quantization.scale) + self.assertEqual(len(tensor.quantization.scale), 4) + self.assertEqual(list(tensor.quantization.scale), self.per_channel_scales) + + # All zero_points encoded + self.assertIsNotNone(tensor.quantization.zeroPoint) + self.assertEqual(len(tensor.quantization.zeroPoint), 4) + self.assertEqual(list(tensor.quantization.zeroPoint), + self.per_channel_zeros) + + # Axis encoded as quantizedDimension + self.assertEqual(tensor.quantization.quantizedDimension, 0) + + def test_per_channel_quantization_loopback(self): + """Verify per-channel quantization in loopback model.""" + loopback_sg = self.loopback_model.subgraphs[0] + + tensor = next(t for t in loopback_sg.tensors if t.name == "per_channel") + self.assertIsNotNone(tensor.quantization) + + # Read back as lists + self.assertEqual(tensor.quantization.scales, self.per_channel_scales) + self.assertEqual(tensor.quantization.zero_points, self.per_channel_zeros) + self.assertEqual(tensor.quantization.axis, 0) + + +class TestReadModifyWrite(unittest.TestCase): + """Test read-modify-write workflows.""" + + @classmethod + def setUpClass(cls): + """Create a simple base model for modification tests.""" + cls.original_data = np.array([[1, 2, 3]], dtype=np.int8) + cls.model = Model( + description="Base model", + metadata={"original": b"metadata"}, + subgraphs=[ + Subgraph(tensors=[ + Tensor(shape=(1, 3), + dtype=tflite.TensorType.INT8, + data=cls.original_data, + name="weights"), + Tensor( + shape=(1, 3), dtype=tflite.TensorType.INT8, name="input"), + Tensor( + shape=(1, 3), dtype=tflite.TensorType.INT8, name="output") + ]) + ]) + + cls.fb = cls.model.build() + + def test_modify_tensor_data(self): + """Read model, modify tensor data, write back, verify.""" + # Read the model + model2 = model_editor.read(self.fb) + + # Modify tensor data using array setter (high-level API) + weights_tensor = next(t for t in model2.subgraphs[0].tensors + if t.name == "weights") + new_data = np.array([[10, 20, 30]], dtype=np.int8) + weights_tensor.array = new_data # Uses array setter + + # Build modified model + fb2 = model2.build() + + # Read back and verify modification + model3 = model_editor.read(fb2) + modified_weights = next(t for t in model3.subgraphs[0].tensors + if t.name == "weights") + np.testing.assert_array_equal(modified_weights.array, new_data) + + # Verify other tensors unchanged + self.assertEqual(len(model3.subgraphs[0].tensors), 3) + + def test_add_tensor_and_operator(self): + """Read model, add new tensor and operator, write back, verify.""" + # Read the model + model2 = model_editor.read(self.fb) + sg = model2.subgraphs[0] + + # Get existing tensors + input_tensor = next(t for t in sg.tensors if t.name == "input") + output_tensor = next(t for t in sg.tensors if t.name == "output") + + # Add new tensor using imperative API + new_weights = np.array([[5, 10, 15]], dtype=np.int8) + new_weights_tensor = sg.add_tensor(shape=(1, 3), + dtype=tflite.TensorType.INT8, + data=new_weights, + name="new_weights") + + # Add new operator using imperative API + sg.add_operator(opcode=tflite.BuiltinOperator.ADD, + inputs=[input_tensor, new_weights_tensor], + outputs=[output_tensor]) + + # Build modified model + fb2 = model2.build() + + # Read back and verify additions + model3 = model_editor.read(fb2) + sg3 = model3.subgraphs[0] + + # Verify tensor was added + self.assertEqual(len(sg3.tensors), 4) + added_tensor = next(t for t in sg3.tensors if t.name == "new_weights") + self.assertIsNotNone(added_tensor) + np.testing.assert_array_equal(added_tensor.array, new_weights) + + # Verify operator was added + self.assertEqual(len(sg3.operators), 1) + added_op = sg3.operators[0] + self.assertEqual([t.name for t in added_op.inputs], + ["input", "new_weights"]) + self.assertEqual([t.name for t in added_op.outputs], ["output"]) + + def test_modify_metadata(self): + """Read model, modify metadata, write back, verify.""" + # Read the model + model2 = model_editor.read(self.fb) + + # Modify existing metadata + model2.metadata["original"] = b"modified_metadata" + + # Add new metadata + model2.metadata["new_key"] = b"new_value" + + # Build modified model + fb2 = model2.build() + + # Read back and verify modifications + model3 = model_editor.read(fb2) + + self.assertEqual(len(model3.metadata), 2) + self.assertEqual(model3.metadata["original"], b"modified_metadata") + self.assertEqual(model3.metadata["new_key"], b"new_value") + + +class TestSubgraphInputsOutputs(unittest.TestCase): + """Test subgraph inputs and outputs are set correctly.""" + + def test_subgraph_inputs_outputs_set(self): + """Verify subgraph inputs/outputs are set in the flatbuffer.""" + input_t = Tensor(shape=(1, 4), dtype=tflite.TensorType.INT8, name="input") + output_t = Tensor(shape=(1, 4), + dtype=tflite.TensorType.INT8, + name="output") + weights = Tensor( + shape=(4, 4), + dtype=tflite.TensorType.INT8, + data=np.array([[1, 2, 3, 4]] * 4, dtype=np.int8), + name="weights", + ) + + model = Model(subgraphs=[ + Subgraph( + tensors=[weights], + inputs=[input_t], + outputs=[output_t], + operators=[ + Operator( + opcode=tflite.BuiltinOperator.FULLY_CONNECTED, + inputs=[input_t, weights], + outputs=[output_t], + ) + ], + ) + ]) + + fb = model.build() + fb_model = tflite.ModelT.InitFromPackedBuf(fb, 0) + fb_sg = fb_model.subgraphs[0] + + # Verify inputs/outputs are set (as tensor indices) + self.assertEqual(len(fb_sg.inputs), 1) + self.assertEqual(len(fb_sg.outputs), 1) + + # Verify indices point to correct tensors + input_idx = fb_sg.inputs[0] + output_idx = fb_sg.outputs[0] + self.assertEqual(fb_sg.tensors[input_idx].name, b"input") + self.assertEqual(fb_sg.tensors[output_idx].name, b"output") + + def test_subgraph_inputs_outputs_loopback(self): + """Verify inputs/outputs survive read/build loopback.""" + input_t = Tensor(shape=(1, 4), dtype=tflite.TensorType.INT8, name="input") + output_t = Tensor(shape=(1, 4), + dtype=tflite.TensorType.INT8, + name="output") + weights = Tensor( + shape=(4, 4), + dtype=tflite.TensorType.INT8, + data=np.array([[1, 2, 3, 4]] * 4, dtype=np.int8), + name="weights", + ) + + model = Model(subgraphs=[ + Subgraph( + tensors=[weights], + inputs=[input_t], + outputs=[output_t], + operators=[ + Operator( + opcode=tflite.BuiltinOperator.FULLY_CONNECTED, + inputs=[input_t, weights], + outputs=[output_t], + ) + ], + ) + ]) + + fb = model.build() + loopback = model_editor.read(fb) + sg = loopback.subgraphs[0] + + # Verify high-level inputs/outputs are populated + self.assertEqual(len(sg.inputs), 1) + self.assertEqual(len(sg.outputs), 1) + self.assertEqual(sg.inputs[0].name, "input") + self.assertEqual(sg.outputs[0].name, "output") + + def test_tensor_by_name_not_found_raises(self): + """tensor_by_name raises KeyError when name not found.""" + model = Model(subgraphs=[ + Subgraph(tensors=[ + Tensor(shape=(4, ), dtype=tflite.TensorType.INT8, name="exists") + ]) + ]) + + with self.assertRaises(KeyError): + model.subgraphs[0].tensor_by_name("nonexistent") + + +class TestReadEdgeCases(unittest.TestCase): + """Test model_editor.read() with edge cases from real-world models. + + These tests construct models using the low-level TFLite schema to create + edge cases that may not be producible via model_editor.build(), but can + appear in models from other sources (e.g., TFLite converter). + """ + + def _build_model_with_schema(self, model_t): + """Build a flatbuffer from a ModelT using the low-level schema.""" + import flatbuffers + builder = flatbuffers.Builder(1024) + builder.Finish(model_t.Pack(builder)) + return bytes(builder.Output()) + + def test_read_scalar_tensor(self): + """Verify read() handles tensors with None shape (scalars). + + Some TFLite models have scalar tensors where shape is None rather than + an empty list. This can occur with constant scalars produced by certain + converters. + """ + # Build a minimal model with a scalar tensor (shape=None) + model_t = tflite.ModelT() + model_t.version = 3 + + # Buffer 0 is always empty, buffer 1 holds scalar data + buf0 = tflite.BufferT() + buf0.data = [] + buf1 = tflite.BufferT() + buf1.data = [42] # Single byte scalar value + + model_t.buffers = [buf0, buf1] + + # Create operator code + opcode = tflite.OperatorCodeT() + opcode.builtinCode = tflite.BuiltinOperator.ADD + model_t.operatorCodes = [opcode] + + # Create subgraph with scalar tensor + sg = tflite.SubGraphT() + + # Tensor with shape=None (scalar) + scalar_tensor = tflite.TensorT() + scalar_tensor.name = b"scalar" + scalar_tensor.type = tflite.TensorType.INT8 + scalar_tensor.buffer = 1 + scalar_tensor.shape = None # This is the edge case + + # Normal tensor for comparison + normal_tensor = tflite.TensorT() + normal_tensor.name = b"normal" + normal_tensor.type = tflite.TensorType.INT8 + normal_tensor.buffer = 0 + normal_tensor.shape = [1, 4] + + sg.tensors = [scalar_tensor, normal_tensor] + sg.inputs = [1] + sg.outputs = [1] + sg.operators = [] + + model_t.subgraphs = [sg] + + # Build and read + fb = self._build_model_with_schema(model_t) + model = model_editor.read(fb) + + # Verify scalar tensor was read with empty shape tuple + self.assertEqual(model.subgraphs[0].tensors[0].shape, ()) + self.assertEqual(model.subgraphs[0].tensors[0].name, "scalar") + + # Verify normal tensor shape is preserved + self.assertEqual(model.subgraphs[0].tensors[1].shape, (1, 4)) + + def test_read_operator_with_empty_inputs(self): + """Verify read() handles operators with None inputs/outputs. + + Some operators (e.g., certain control flow or custom ops) may have + empty input or output lists represented as None in the flatbuffer. + """ + model_t = tflite.ModelT() + model_t.version = 3 + + buf0 = tflite.BufferT() + buf0.data = [] + model_t.buffers = [buf0] + + # Custom op that might have unusual input/output patterns + opcode = tflite.OperatorCodeT() + opcode.builtinCode = tflite.BuiltinOperator.CUSTOM + opcode.customCode = b"NoInputOp" + model_t.operatorCodes = [opcode] + + sg = tflite.SubGraphT() + + # Single output tensor + output_tensor = tflite.TensorT() + output_tensor.name = b"output" + output_tensor.type = tflite.TensorType.INT8 + output_tensor.buffer = 0 + output_tensor.shape = [1] + + sg.tensors = [output_tensor] + sg.inputs = [] + sg.outputs = [0] + + # Operator with None inputs (edge case) + op = tflite.OperatorT() + op.opcodeIndex = 0 + op.inputs = None # This is the edge case + op.outputs = [0] + + sg.operators = [op] + model_t.subgraphs = [sg] + + # Build and read + fb = self._build_model_with_schema(model_t) + model = model_editor.read(fb) + + # Verify operator was read with empty inputs list + self.assertEqual(len(model.subgraphs[0].operators), 1) + self.assertEqual(model.subgraphs[0].operators[0].inputs, []) + self.assertEqual(len(model.subgraphs[0].operators[0].outputs), 1) + + def test_read_operator_with_empty_outputs(self): + """Verify read() handles operators with None outputs. + + Similar to empty inputs, some operators may have None outputs. + """ + model_t = tflite.ModelT() + model_t.version = 3 + + buf0 = tflite.BufferT() + buf0.data = [] + model_t.buffers = [buf0] + + opcode = tflite.OperatorCodeT() + opcode.builtinCode = tflite.BuiltinOperator.CUSTOM + opcode.customCode = b"NoOutputOp" + model_t.operatorCodes = [opcode] + + sg = tflite.SubGraphT() + + input_tensor = tflite.TensorT() + input_tensor.name = b"input" + input_tensor.type = tflite.TensorType.INT8 + input_tensor.buffer = 0 + input_tensor.shape = [1] + + sg.tensors = [input_tensor] + sg.inputs = [0] + sg.outputs = [] + + # Operator with None outputs (edge case) + op = tflite.OperatorT() + op.opcodeIndex = 0 + op.inputs = [0] + op.outputs = None # This is the edge case + + sg.operators = [op] + model_t.subgraphs = [sg] + + fb = self._build_model_with_schema(model_t) + model = model_editor.read(fb) + + self.assertEqual(len(model.subgraphs[0].operators), 1) + self.assertEqual(len(model.subgraphs[0].operators[0].inputs), 1) + self.assertEqual(model.subgraphs[0].operators[0].outputs, []) + + def test_int64_tensor(self): + """Verify INT64 tensors are correctly handled.""" + model_t = tflite.ModelT() + model_t.version = 3 + + buf0 = tflite.BufferT() + buf0.data = [] + buf1 = tflite.BufferT() + # INT64 data: [1, 2, 3, 4] as little-endian 8-byte values + int64_data = np.array([1, 2, 3, 4], dtype=np.int64) + buf1.data = list(int64_data.tobytes()) + + model_t.buffers = [buf0, buf1] + + opcode = tflite.OperatorCodeT() + opcode.builtinCode = tflite.BuiltinOperator.ADD + model_t.operatorCodes = [opcode] + + sg = tflite.SubGraphT() + tensor = tflite.TensorT() + tensor.name = b"int64_tensor" + tensor.type = tflite.TensorType.INT64 + tensor.buffer = 1 + tensor.shape = [4] + + sg.tensors = [tensor] + sg.inputs = [0] + sg.outputs = [0] + sg.operators = [] + model_t.subgraphs = [sg] + + fb = self._build_model_with_schema(model_t) + model = model_editor.read(fb) + + t = model.subgraphs[0].tensors[0] + self.assertEqual(t.dtype, tflite.TensorType.INT64) + np.testing.assert_array_equal(t.array, int64_data) + + +class TestFieldPreservation(unittest.TestCase): + """Test that schema fields are preserved during read-modify-write. + + These tests verify that fields not explicitly handled by model_editor + are still preserved when reading a model, modifying it, and writing + it back. This catches regressions where adding wrapper classes might + accidentally drop fields. + """ + + def _build_model_with_schema(self, model_t): + """Build a flatbuffer from a ModelT using the low-level schema.""" + import flatbuffers + builder = flatbuffers.Builder(1024) + builder.Finish(model_t.Pack(builder)) + return bytes(builder.Output()) + + def _create_base_model(self): + """Create a minimal valid model for testing.""" + model_t = tflite.ModelT() + model_t.version = 3 + model_t.description = b"test" + + buf0 = tflite.BufferT() + buf0.data = [] + buf1 = tflite.BufferT() + buf1.data = [1, 2, 3, 4] + model_t.buffers = [buf0, buf1] + + opcode = tflite.OperatorCodeT() + opcode.builtinCode = tflite.BuiltinOperator.ADD + model_t.operatorCodes = [opcode] + + sg = tflite.SubGraphT() + + t0 = tflite.TensorT() + t0.name = b"input" + t0.type = tflite.TensorType.INT8 + t0.buffer = 1 + t0.shape = [4] + + t1 = tflite.TensorT() + t1.name = b"output" + t1.type = tflite.TensorType.INT8 + t1.buffer = 0 + t1.shape = [4] + + sg.tensors = [t0, t1] + sg.inputs = [0] + sg.outputs = [1] + + op = tflite.OperatorT() + op.opcodeIndex = 0 + op.inputs = [0] + op.outputs = [1] + sg.operators = [op] + + model_t.subgraphs = [sg] + return model_t + + def test_tensor_is_variable_preserved(self): + """Verify Tensor.isVariable is preserved through read-modify-write.""" + model_t = self._create_base_model() + model_t.subgraphs[0].tensors[0].isVariable = True + + fb = self._build_model_with_schema(model_t) + + # Read, modify, write + model = model_editor.read(fb) + model.description = "modified" + fb2 = model.build() + + # Verify field preserved + model_t2 = tflite.ModelT.InitFromPackedBuf(fb2, 0) + self.assertTrue(model_t2.subgraphs[0].tensors[0].isVariable) + + def test_tensor_shape_signature_preserved(self): + """Verify Tensor.shapeSignature is preserved through read-modify-write.""" + model_t = self._create_base_model() + model_t.subgraphs[0].tensors[0].shapeSignature = [-1, 4] + + fb = self._build_model_with_schema(model_t) + + model = model_editor.read(fb) + model.description = "modified" + fb2 = model.build() + + model_t2 = tflite.ModelT.InitFromPackedBuf(fb2, 0) + self.assertEqual(list(model_t2.subgraphs[0].tensors[0].shapeSignature), + [-1, 4]) + + def test_operator_builtin_options_preserved(self): + """Verify Operator.builtinOptions is preserved through read-modify-write.""" + model_t = self._create_base_model() + + # Use ADD operator with AddOptions (must also set builtinOptionsType for union) + add_options = tflite.AddOptionsT() + add_options.fusedActivationFunction = tflite.ActivationFunctionType.RELU + model_t.subgraphs[0].operators[0].builtinOptions = add_options + model_t.subgraphs[0].operators[ + 0].builtinOptionsType = tflite.BuiltinOptions.AddOptions + + fb = self._build_model_with_schema(model_t) + + model = model_editor.read(fb) + model.description = "modified" + fb2 = model.build() + + model_t2 = tflite.ModelT.InitFromPackedBuf(fb2, 0) + self.assertIsNotNone(model_t2.subgraphs[0].operators[0].builtinOptions) + self.assertEqual( + model_t2.subgraphs[0].operators[0].builtinOptions. + fusedActivationFunction, tflite.ActivationFunctionType.RELU) + + def test_operator_custom_options_preserved(self): + """Verify Operator.customOptions is preserved through read-modify-write.""" + model_t = self._create_base_model() + model_t.subgraphs[0].operators[0].customOptions = [0xDE, 0xAD, 0xBE, 0xEF] + + fb = self._build_model_with_schema(model_t) + + model = model_editor.read(fb) + model.description = "modified" + fb2 = model.build() + + model_t2 = tflite.ModelT.InitFromPackedBuf(fb2, 0) + self.assertEqual(list(model_t2.subgraphs[0].operators[0].customOptions), + [0xDE, 0xAD, 0xBE, 0xEF]) + + def test_operator_intermediates_preserved(self): + """Verify Operator.intermediates is preserved through read-modify-write.""" + model_t = self._create_base_model() + model_t.subgraphs[0].operators[0].intermediates = [0, 1] + + fb = self._build_model_with_schema(model_t) + + model = model_editor.read(fb) + model.description = "modified" + fb2 = model.build() + + model_t2 = tflite.ModelT.InitFromPackedBuf(fb2, 0) + self.assertEqual(list(model_t2.subgraphs[0].operators[0].intermediates), + [0, 1]) + + def test_operator_debug_metadata_index_preserved(self): + """Verify Operator.debugMetadataIndex is preserved through read-modify-write.""" + model_t = self._create_base_model() + model_t.subgraphs[0].operators[0].debugMetadataIndex = 7 + + fb = self._build_model_with_schema(model_t) + + model = model_editor.read(fb) + model.description = "modified" + fb2 = model.build() + + model_t2 = tflite.ModelT.InitFromPackedBuf(fb2, 0) + self.assertEqual(model_t2.subgraphs[0].operators[0].debugMetadataIndex, 7) + + def test_operator_code_deprecated_builtin_code_preserved(self): + """Verify OperatorCode.deprecatedBuiltinCode is preserved.""" + model_t = self._create_base_model() + # Set deprecated code to a value different from the new builtin code + model_t.operatorCodes[0].deprecatedBuiltinCode = 42 + + fb = self._build_model_with_schema(model_t) + + model = model_editor.read(fb) + model.description = "modified" + fb2 = model.build() + + model_t2 = tflite.ModelT.InitFromPackedBuf(fb2, 0) + self.assertEqual(model_t2.operatorCodes[0].deprecatedBuiltinCode, 42) + + def test_subgraph_debug_metadata_index_preserved(self): + """Verify SubGraph.debugMetadataIndex is preserved.""" + model_t = self._create_base_model() + model_t.subgraphs[0].debugMetadataIndex = 5 + + fb = self._build_model_with_schema(model_t) + + model = model_editor.read(fb) + model.description = "modified" + fb2 = model.build() + + model_t2 = tflite.ModelT.InitFromPackedBuf(fb2, 0) + self.assertEqual(model_t2.subgraphs[0].debugMetadataIndex, 5) + + def test_model_version_preserved(self): + """Verify Model.version is preserved (not hardcoded to 3).""" + model_t = self._create_base_model() + model_t.version = 42 + + fb = self._build_model_with_schema(model_t) + + model = model_editor.read(fb) + model.description = "modified" + fb2 = model.build() + + model_t2 = tflite.ModelT.InitFromPackedBuf(fb2, 0) + self.assertEqual(model_t2.version, 42) + + def test_model_signature_defs_preserved(self): + """Verify Model.signatureDefs is preserved.""" + model_t = self._create_base_model() + + sig_def = tflite.SignatureDefT() + sig_def.signatureKey = b"serving_default" + sig_def.subgraphIndex = 0 + model_t.signatureDefs = [sig_def] + + fb = self._build_model_with_schema(model_t) + + model = model_editor.read(fb) + model.description = "modified" + fb2 = model.build() + + model_t2 = tflite.ModelT.InitFromPackedBuf(fb2, 0) + self.assertIsNotNone(model_t2.signatureDefs) + self.assertEqual(len(model_t2.signatureDefs), 1) + self.assertEqual(model_t2.signatureDefs[0].signatureKey, + b"serving_default") + + def test_quantization_min_max_preserved(self): + """Verify QuantizationParameters.min/max are preserved.""" + model_t = self._create_base_model() + + quant = tflite.QuantizationParametersT() + quant.scale = [0.5] + quant.zeroPoint = [128] + quant.min = [0.0] + quant.max = [1.0] + model_t.subgraphs[0].tensors[0].quantization = quant + + fb = self._build_model_with_schema(model_t) + + model = model_editor.read(fb) + model.description = "modified" + fb2 = model.build() + + model_t2 = tflite.ModelT.InitFromPackedBuf(fb2, 0) + quant2 = model_t2.subgraphs[0].tensors[0].quantization + self.assertIsNotNone(quant2) + self.assertEqual(list(quant2.min), [0.0]) + self.assertEqual(list(quant2.max), [1.0]) + + def test_tensor_sparsity_preserved(self): + """Verify Tensor.sparsity is preserved through read-modify-write.""" + model_t = self._create_base_model() + + sparsity = tflite.SparsityParametersT() + sparsity.traversalOrder = [0, 1] + sparsity.blockMap = [0] + model_t.subgraphs[0].tensors[0].sparsity = sparsity + + fb = self._build_model_with_schema(model_t) + + model = model_editor.read(fb) + model.description = "modified" + fb2 = model.build() + + model_t2 = tflite.ModelT.InitFromPackedBuf(fb2, 0) + sparsity2 = model_t2.subgraphs[0].tensors[0].sparsity + self.assertIsNotNone(sparsity2) + self.assertEqual(list(sparsity2.traversalOrder), [0, 1]) + self.assertEqual(list(sparsity2.blockMap), [0]) + + def test_tensor_has_rank_preserved(self): + """Verify Tensor.hasRank is preserved through read-modify-write.""" + model_t = self._create_base_model() + model_t.subgraphs[0].tensors[0].hasRank = True + + fb = self._build_model_with_schema(model_t) + + model = model_editor.read(fb) + model.description = "modified" + fb2 = model.build() + + model_t2 = tflite.ModelT.InitFromPackedBuf(fb2, 0) + self.assertTrue(model_t2.subgraphs[0].tensors[0].hasRank) + + def test_operator_builtin_options_2_preserved(self): + """Verify Operator.builtinOptions2 is preserved through read-modify-write.""" + model_t = self._create_base_model() + + options2 = tflite.StablehloConcatenateOptionsT() + options2.dimension = 42 + model_t.subgraphs[0].operators[0].builtinOptions2 = options2 + model_t.subgraphs[0].operators[0].builtinOptions2Type = ( + tflite.BuiltinOptions2.StablehloConcatenateOptions) + + fb = self._build_model_with_schema(model_t) + + model = model_editor.read(fb) + model.description = "modified" + fb2 = model.build() + + model_t2 = tflite.ModelT.InitFromPackedBuf(fb2, 0) + options2_out = model_t2.subgraphs[0].operators[0].builtinOptions2 + self.assertIsNotNone(options2_out) + self.assertEqual(options2_out.dimension, 42) + + def test_quantization_axis_preserved(self): + """Verify QuantizationParameters.quantizedDimension is preserved.""" + model_t = self._create_base_model() + + quant = tflite.QuantizationParametersT() + quant.scale = [0.5, 0.25] + quant.zeroPoint = [0, 0] + quant.quantizedDimension = 1 + model_t.subgraphs[0].tensors[0].quantization = quant + + fb = self._build_model_with_schema(model_t) + + model = model_editor.read(fb) + model.description = "modified" + fb2 = model.build() + + model_t2 = tflite.ModelT.InitFromPackedBuf(fb2, 0) + quant2 = model_t2.subgraphs[0].tensors[0].quantization + self.assertIsNotNone(quant2) + self.assertEqual(quant2.quantizedDimension, 1) + + def test_quantization_zero_point_preserved(self): + """Verify QuantizationParameters.zeroPoint is preserved.""" + model_t = self._create_base_model() + + quant = tflite.QuantizationParametersT() + quant.scale = [0.5] + quant.zeroPoint = [128] + model_t.subgraphs[0].tensors[0].quantization = quant + + fb = self._build_model_with_schema(model_t) + + model = model_editor.read(fb) + model.description = "modified" + fb2 = model.build() + + model_t2 = tflite.ModelT.InitFromPackedBuf(fb2, 0) + quant2 = model_t2.subgraphs[0].tensors[0].quantization + self.assertIsNotNone(quant2) + self.assertEqual(list(quant2.zeroPoint), [128]) + + def test_quantization_zero_point_not_expanded(self): + """Single zeroPoint with multiple scales is preserved as-is. + + TFLite converter optimizes by storing single zeroPoint when all channels + have the same zero point. This must be preserved, not expanded. + """ + model_t = self._create_base_model() + + quant = tflite.QuantizationParametersT() + quant.scale = [0.5, 0.25, 0.125, 0.0625] # 4 scales + quant.zeroPoint = [128] # Single zero point (converter optimization) + quant.quantizedDimension = 0 + model_t.subgraphs[0].tensors[0].quantization = quant + + fb = self._build_model_with_schema(model_t) + + model = model_editor.read(fb) + model.description = "modified" + fb2 = model.build() + + model_t2 = tflite.ModelT.InitFromPackedBuf(fb2, 0) + quant2 = model_t2.subgraphs[0].tensors[0].quantization + self.assertIsNotNone(quant2) + # Should still be single-element, not expanded to 4 + self.assertEqual(len(quant2.zeroPoint), 1) + self.assertEqual(quant2.zeroPoint[0], 128) + + +if __name__ == "__main__": + unittest.main() From 8518b5e544170b63ec144763fad09a360c984bdc Mon Sep 17 00:00:00 2001 From: Ryan Kuester Date: Sun, 24 May 2026 23:09:24 -0500 Subject: [PATCH 02/21] refactor(compression): migrate compress.py from model_facade to model_editor Replace model_facade with model_editor in compress.py and tests. model_editor provides a cleaner API with better buffer and metadata handling. Update BUILD dependencies accordingly. BUG=part of #3256 --- tensorflow/lite/micro/compression/BUILD | 4 +-- tensorflow/lite/micro/compression/compress.py | 36 ++++++++++--------- .../lite/micro/compression/compress_test.py | 19 +++++----- 3 files changed, 32 insertions(+), 27 deletions(-) diff --git a/tensorflow/lite/micro/compression/BUILD b/tensorflow/lite/micro/compression/BUILD index 23e4fe93ca5..3e0141925d6 100644 --- a/tensorflow/lite/micro/compression/BUILD +++ b/tensorflow/lite/micro/compression/BUILD @@ -124,7 +124,7 @@ py_library( ], deps = [ ":metadata_py", - ":model_facade", + ":model_editor", ":spec", "//tensorflow/lite/micro/tools:tflite_flatbuffer_align", requirement("absl_py"), @@ -160,7 +160,7 @@ py_test( deps = [ ":compress", ":metadata_py", - ":model_facade", + ":model_editor", ":spec", ":test_models", "//tensorflow/lite/python:schema_py", diff --git a/tensorflow/lite/micro/compression/compress.py b/tensorflow/lite/micro/compression/compress.py index bd67bf5637b..b6d5aef4435 100644 --- a/tensorflow/lite/micro/compression/compress.py +++ b/tensorflow/lite/micro/compression/compress.py @@ -29,7 +29,7 @@ import flatbuffers import numpy as np -from tflite_micro.tensorflow.lite.micro.compression import model_facade +from tflite_micro.tensorflow.lite.micro.compression import model_editor from tflite_micro.tensorflow.lite.micro.compression import spec from tflite_micro.tensorflow.lite.micro.compression import metadata_py_generated as schema from tflite_micro.tensorflow.lite.micro.tools import tflite_flatbuffer_align_wrapper @@ -177,7 +177,7 @@ def _check_lut_compression(compression) -> spec.LookUpTableCompression: return compression[0] -def _identify_compression_axis(tensor: model_facade._Tensor) -> Optional[int]: +def _identify_compression_axis(tensor: model_editor.Tensor) -> Optional[int]: """Determines the axis along which to compress. The axis along which to compress is inferred from the tensor's quantization @@ -191,16 +191,18 @@ def _identify_compression_axis(tensor: model_facade._Tensor) -> Optional[int]: CompressionError: If the axis cannot be determined. """ q = tensor.quantization - if q is not None \ - and q.scale is not None \ - and q.quantizedDimension < len(tensor.shape): - quantization_channels = len(q.scale) + if q is not None: + # model_editor wraps quantization, access scales/axis from wrapper + scales = q.scales if isinstance(q.scales, list) else [q.scales] + quantization_channels = len(scales) + if quantization_channels == 1: # Use one value table for the entire tensor return None - if quantization_channels == tensor.shape[q.quantizedDimension]: - return q.quantizedDimension + if q.axis is not None and q.axis < len(tensor.shape): + if quantization_channels == tensor.shape[q.axis]: + return q.axis raise CompressionError( f"Invalid or no quanitzation parameters from which to " @@ -300,7 +302,7 @@ def compress(model_in: ByteString, specs: Iterable[spec.Tensor]) -> bytearray: Returns: A compressed flatbuffer. """ - model = model_facade.read(model_in) + model = model_editor.read(model_in) metadata = _MetadataBuilder() for spec in specs: @@ -316,12 +318,14 @@ def compress(model_in: ByteString, specs: Iterable[spec.Tensor]) -> bytearray: tensor.buffer.data = _pack_indices(compressed.indices, spec_bitwidth) # write value buffer - value_buffer = model.add_buffer() - value_buffer.data = _pack_lookup_tables(compressed.lookup_tables, + value_buffer_data = _pack_lookup_tables(compressed.lookup_tables, 2**spec_bitwidth) + value_buffer = model_editor.Buffer(data=value_buffer_data) + model.buffers.append(value_buffer) # Auto-sets value_buffer.index + # add compression metadata for tensor - lut_tensor = metadata.add_lut_tensor(subgraph_id=tensor.subgraph.index) - lut_tensor.tensor = tensor.index + lut_tensor = metadata.add_lut_tensor(subgraph_id=spec.subgraph) + lut_tensor.tensor = spec.tensor lut_tensor.valueBuffer = value_buffer.index lut_tensor.indexBitwidth = spec_bitwidth @@ -329,10 +333,10 @@ def compress(model_in: ByteString, specs: Iterable[spec.Tensor]) -> bytearray: raise CompressionError(f"error compressing {spec}") from e # add compression metadata to model - model.add_metadata(TFLITE_METADATA_KEY, metadata.compile()) + model.metadata[TFLITE_METADATA_KEY] = metadata.compile() - # Compile the model and apply proper alignment - unaligned_model = model.compile() + # Build the model and apply proper alignment + unaligned_model = model.build() return _apply_flatbuffer_alignment(unaligned_model) diff --git a/tensorflow/lite/micro/compression/compress_test.py b/tensorflow/lite/micro/compression/compress_test.py index 012957acc90..ee10a75f36d 100644 --- a/tensorflow/lite/micro/compression/compress_test.py +++ b/tensorflow/lite/micro/compression/compress_test.py @@ -19,7 +19,7 @@ from tflite_micro.tensorflow.lite.micro.compression import compress from tflite_micro.tensorflow.lite.micro.compression import metadata_py_generated as schema -from tflite_micro.tensorflow.lite.micro.compression import model_facade +from tflite_micro.tensorflow.lite.micro.compression import model_editor from tflite_micro.tensorflow.lite.micro.compression import spec from tflite_micro.tensorflow.lite.micro.compression import test_models from tflite_micro.tensorflow.lite.python import schema_py_generated as tflite @@ -368,12 +368,12 @@ class TestsCompression(unittest.TestCase): def setUpClass(cls): super().setUpClass() cls.flatbuffer = test_models.build(TEST_MODEL) - cls.uncompressed = model_facade.read(cls.flatbuffer) + cls.uncompressed = model_editor.read(cls.flatbuffer) def test_compression_metadata(self): """The compressed model has compression metadata.""" compressed = compress.compress(self.flatbuffer, TEST_COMPRESSION_SPEC) - model = model_facade.read(compressed) + model = model_editor.read(compressed) self.assertIn("metadata0", self.uncompressed.metadata) self.assertIn(compress.TFLITE_METADATA_KEY, model.metadata) @@ -461,16 +461,17 @@ def setUpClass(cls): super().setUpClass() # Create a model uncompressed_fb = test_models.build(TEST_MODEL) - cls.uncompressed = model_facade.read(uncompressed_fb) + cls.uncompressed = model_editor.read(uncompressed_fb) # Compress the model compressed_fb = compress.compress(uncompressed_fb, TEST_COMPRESSION_SPEC) - cls.compressed = model_facade.read(compressed_fb) + cls.compressed = model_editor.read(compressed_fb) # Extract the compression metadata - metadata_flatbuffer = cls.compressed.metadata[compress.TFLITE_METADATA_KEY] - cls.metadata = schema.MetadataT.InitFromPackedBuf(metadata_flatbuffer.data, - 0) + metadata_flatbuffer_bytes = cls.compressed.metadata[ + compress.TFLITE_METADATA_KEY] + cls.metadata = schema.MetadataT.InitFromPackedBuf( + metadata_flatbuffer_bytes, 0) def test_uncompressed_tensors(self): """Tensors not in compression spec are not compressed. @@ -515,7 +516,7 @@ def _get_compressed( indices = indices[:n_indices * bitwidth] # trim possible padding value_buffer = self.compressed.buffers[lut_tensor.valueBuffer] - values = np.frombuffer(value_buffer.data, dtype=tensor_obj.dtype) + values = np.frombuffer(value_buffer.data, dtype=tensor_obj.numpy_dtype) return bitwidth, indices, values From e1afb8c89655317dbe568feb927a53da2b5d3970 Mon Sep 17 00:00:00 2001 From: Ryan Kuester Date: Sun, 24 May 2026 23:14:16 -0500 Subject: [PATCH 03/21] chore(compression): remove model_facade.py Remove model_facade module and its tests, now superseded by model_editor. BUG=part of #3256 --- tensorflow/lite/micro/compression/BUILD | 20 -- .../lite/micro/compression/model_facade.py | 276 ------------------ .../micro/compression/model_facade_test.py | 144 --------- 3 files changed, 440 deletions(-) delete mode 100644 tensorflow/lite/micro/compression/model_facade.py delete mode 100644 tensorflow/lite/micro/compression/model_facade_test.py diff --git a/tensorflow/lite/micro/compression/BUILD b/tensorflow/lite/micro/compression/BUILD index 3e0141925d6..0f12cb15010 100644 --- a/tensorflow/lite/micro/compression/BUILD +++ b/tensorflow/lite/micro/compression/BUILD @@ -169,26 +169,6 @@ py_test( ], ) -tflm_py_library( - name = "model_facade", - srcs = ["model_facade.py"], - deps = [ - "//tensorflow/lite/python:schema_py", - requirement("flatbuffers"), - ], -) - -py_test( - name = "model_facade_test", - size = "small", - srcs = ["model_facade_test.py"], - target_compatible_with = INCOMPATIBLE_WITH_WINDOWS, - deps = [ - ":model_facade", - ":test_models", - ], -) - tflm_py_library( name = "spec", srcs = ["spec.py"], diff --git a/tensorflow/lite/micro/compression/model_facade.py b/tensorflow/lite/micro/compression/model_facade.py deleted file mode 100644 index 2e58d8080f1..00000000000 --- a/tensorflow/lite/micro/compression/model_facade.py +++ /dev/null @@ -1,276 +0,0 @@ -# Copyright 2024 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -"""A facade for working with tflite.Model. - -This module provides convenient navigation, data type conversions, and -utilities for working with a tflite.Model, which can be tedious and verbose to -work with directly. - -Usage: - model = model_facade.read(flatbuffer) - # manipulate - new_flatbuffer = model.compile() -""" - -from __future__ import annotations - -import flatbuffers -import numpy as np -from numpy.typing import NDArray -from tflite_micro.tensorflow.lite.python import schema_py_generated as tflite -from typing import ByteString, Generic, TypeVar - -_IteratorTo = TypeVar("_IteratorTo") - - -class _Iterator(Generic[_IteratorTo]): - - def __init__(self, sequence, cls, parent): - self._sequence = sequence - self._cls = cls - self._index = 0 - self._parent = parent - - def __getitem__(self, key) -> _IteratorTo: - return self._cls(self._sequence[key], key, self._parent) - - def __len__(self): - return len(self._sequence) - - def __iter__(self): - self._index = 0 - return self - - def __next__(self): - try: - result = self[self._index] - self._index += 1 - return result - except IndexError: - raise StopIteration - - -class _IndirectIterator(Generic[_IteratorTo]): - - def __init__(self, indices, sequence): - self._indices = indices - self._index = 0 - self._sequence = sequence - - def __getitem__(self, key) -> _IteratorTo: - index = self._indices[key] - return self._sequence[index] - - def __len__(self): - return len(self._indices) - - def __iter__(self): - self._index = 0 - return self - - def __next__(self): - try: - result = self[self._index] - self._index += 1 - return result - except IndexError: - raise StopIteration - - -class _Operator: - - def __init__(self, operator, index, subgraph): - self.operator = operator - self.index = index - self.subgraph = subgraph - - @property - def opcode(self) -> tflite.OperatorCodeT: - return self.subgraph.model.operatorCodes[self.operator.opcodeIndex] - - @property - def inputs(self): - return _IndirectIterator(self.operator.inputs, self.subgraph.tensors) - - -_NP_DTYPES = { - tflite.TensorType.FLOAT16: np.dtype(" _Buffer: - return self.subgraph.model.buffers[self._tensor_t.buffer] - - @property - def data(self) -> bytes: - return self.buffer.data - - @property - def dtype(self) -> np.dtype: - return _NP_DTYPES[self._tensor_t.type] - - @property - def array(self) -> np.ndarray: - """Returns an array created from the Tensor's data, type, and shape. - - Note the bytes in the data buffer and the Tensor's type and shape may be - inconsistent, and thus the returned array invalid, if the data buffer has - been altered according to the compression schema, in which the data buffer - is an array of fixed-width, integer fields. - """ - return np.frombuffer(self.data, - dtype=self.dtype).reshape(self._tensor_t.shape) - - @property - def quantization(self) -> tflite.QuantizationParametersT | None: - return self._tensor_t.quantization - - -class _Buffer: - - def __init__(self, buffer_t: tflite.BufferT, index, model): - self._buffer_t = buffer_t - self.index = index - self.model = model - - @property - def data(self) -> bytes: - return bytes(self._buffer_t.data) - - @data.setter - def data(self, value: ByteString): - self._buffer_t.data = list(value) - - def extend(self, values: NDArray): - self._buffer_t.data.extend(values.tobytes()) - - -class _Subgraph: - - def __init__(self, subgraph_t: tflite.SubGraphT, index: int, model: _Model): - self._subgraph_t = subgraph_t - self.index = index - self.model = model - - @property - def operators(self) -> _Iterator[_Operator]: - return _Iterator(self._subgraph_t.operators, _Operator, parent=self) - - @property - def tensors(self) -> _Iterator[_Tensor]: - return _Iterator(self._subgraph_t.tensors, _Tensor, parent=self) - - -class _Model: - """A facade for manipulating tflite.Model. - """ - - def __init__(self, model_t: tflite.ModelT): - self._model_t = model_t - - def compile(self) -> bytearray: - """Returns a tflite.Model flatbuffer. - """ - size_hint = 4 * 2**10 - builder = flatbuffers.Builder(size_hint) - builder.Finish(self._model_t.Pack(builder)) - return builder.Output() - - def add_buffer(self) -> _Buffer: - """Adds a buffer to the model. - """ - buffer = tflite.BufferT() - buffer.data = [] - self._model_t.buffers.append(buffer) - index = len(self._model_t.buffers) - 1 - return _Buffer(buffer, index, self._model_t) - - def add_metadata(self, key, value): - """Adds a key-value pair, writing value to a newly created buffer. - """ - metadata = tflite.MetadataT() - metadata.name = key - buffer = self.add_buffer() - buffer.data = value - metadata.buffer = buffer.index - self._model_t.metadata.append(metadata) - - @property - def metadata(self) -> dict[str, _Buffer]: - """Returns the model's metadata as a dictionary to Buffer objects. - """ - result = {} - for m in self._model_t.metadata: - name = m.name.decode("utf-8") # type: ignore (fb library is wrong) - buffer = _Buffer(self._model_t.buffers[m.buffer], m.buffer, - self._model_t) - result[name] = buffer - - return result - - @property - def operatorCodes(self): - return self._model_t.operatorCodes - - @property - def subgraphs(self) -> _Iterator[_Subgraph]: - return _Iterator(self._model_t.subgraphs, _Subgraph, parent=self) - - @property - def buffers(self) -> _Iterator[_Buffer]: - return _Iterator(self._model_t.buffers, _Buffer, parent=self) - - -def read(buffer: ByteString): - """Reads a tflite.Model and returns a model facade. - """ - schema_model = tflite.ModelT.InitFromPackedBuf(buffer, 0) - return _Model(schema_model) diff --git a/tensorflow/lite/micro/compression/model_facade_test.py b/tensorflow/lite/micro/compression/model_facade_test.py deleted file mode 100644 index 87e71fa968b..00000000000 --- a/tensorflow/lite/micro/compression/model_facade_test.py +++ /dev/null @@ -1,144 +0,0 @@ -# Copyright 2024 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np -import unittest -from tflite_micro.tensorflow.lite.python import schema_py_generated as tflite -from tflite_micro.tensorflow.lite.micro.compression import model_facade -from tflite_micro.tensorflow.lite.micro.compression import test_models - -TEST_MODEL = { - "operator_codes": { - 0: { - "builtin_code": tflite.BuiltinOperator.FULLY_CONNECTED, - }, - 1: { - "builtin_code": tflite.BuiltinOperator.ADD, - }, - }, - "metadata": { - 0: { - "name": "metadata0", - "buffer": 0 - }, - 1: { - "name": "metadata1", - "buffer": 0 - }, - }, - "subgraphs": { - 0: { - "operators": { - 0: { - "opcode_index": 1, # ADD - "inputs": ( - 1, - 2, - ), - "outputs": (3, ), - }, - 1: { - "opcode_index": 0, # FULLY_CONNECTED - "inputs": ( - 3, - 4, - 5, - ), - "outputs": (6, ), - }, - }, - "tensors": { - 0: { - "name": "tensor0", - "shape": (16, 1), - "type": tflite.TensorType.INT8, - "buffer": 1, - }, - 1: { - "name": "tensor1", - "shape": (8, 1), - "type": tflite.TensorType.INT16, - "buffer": 2, - }, - 2: { - "name": "tensor2", - "shape": (4, 1), - "type": tflite.TensorType.INT32, - "buffer": 3, - }, - 3: { - "name": "tensor3", - "shape": (2, 1), - "type": tflite.TensorType.INT64, - "buffer": 4, - }, - }, - }, - }, - "buffers": { - 0: None, - 1: np.array(range(16), dtype=np.dtype(" Date: Sun, 24 May 2026 23:16:20 -0500 Subject: [PATCH 04/21] refactor(compression): replace test_models with model_editor in compress_test Replace dictionary-based test_models.build() with model_editor's declarative API for building test models. BUG=part of #3256 --- tensorflow/lite/micro/compression/BUILD | 1 - .../lite/micro/compression/compress_test.py | 218 ++++++------------ 2 files changed, 66 insertions(+), 153 deletions(-) diff --git a/tensorflow/lite/micro/compression/BUILD b/tensorflow/lite/micro/compression/BUILD index 0f12cb15010..b098e360433 100644 --- a/tensorflow/lite/micro/compression/BUILD +++ b/tensorflow/lite/micro/compression/BUILD @@ -162,7 +162,6 @@ py_test( ":metadata_py", ":model_editor", ":spec", - ":test_models", "//tensorflow/lite/python:schema_py", requirement("bitarray"), requirement("numpy"), diff --git a/tensorflow/lite/micro/compression/compress_test.py b/tensorflow/lite/micro/compression/compress_test.py index ee10a75f36d..81bbdab3293 100644 --- a/tensorflow/lite/micro/compression/compress_test.py +++ b/tensorflow/lite/micro/compression/compress_test.py @@ -21,7 +21,6 @@ from tflite_micro.tensorflow.lite.micro.compression import metadata_py_generated as schema from tflite_micro.tensorflow.lite.micro.compression import model_editor from tflite_micro.tensorflow.lite.micro.compression import spec -from tflite_micro.tensorflow.lite.micro.compression import test_models from tflite_micro.tensorflow.lite.python import schema_py_generated as tflite @@ -170,153 +169,70 @@ def test_multiple_tables_with_padding(self): self.assertEqual(result, expected_output) -# yapf: disable -TEST_MODEL = { - "operator_codes": { - 0: { - "builtin_code": tflite.BuiltinOperator.ADD, - }, - }, - "metadata": { - 0: { - "name": "metadata0", - "buffer": 0 - }, - }, - "subgraphs": { - 0: { - "operators": { - 0: { - "opcode_index": 0, - "inputs": ( - 0, - 1, - ), - "outputs": (2, ), - }, - }, - "tensors": { - 0: { - "shape": (16, 1), - "type": tflite.TensorType.UINT8, - "buffer": 1, - "quantization": { - "quantized_dimension": 1, - "scale": (1,), - "zero_point": (0,), - }, - }, - 1: { - "shape": (16, 1), - "type": tflite.TensorType.INT8, - "buffer": 2, - "quantization": { - "quantized_dimension": 1, - "scale": (1,), - "zero_point": (0,), - }, - }, - 2: { - "shape": (16, 1), - "type": tflite.TensorType.INT16, - "buffer": 3, - "quantization": { - "quantized_dimension": 1, - "scale": (1,), - "zero_point": (0,), - }, - }, - 3: { - "shape": (16, 1), - "type": tflite.TensorType.INT32, - "buffer": 4, - "quantization": { - "quantized_dimension": 1, - "scale": (1,), - "zero_point": (0,), - }, - }, - 4: { - "shape": (16, 1), - "type": tflite.TensorType.INT32, - "buffer": 5, - "quantization": { - "quantized_dimension": 1, - "scale": (1,), - "zero_point": (0,), - }, - }, - 5: { - "shape": (4, 5), - "type": tflite.TensorType.INT16, - "buffer": 6, - "quantization": { - "quantized_dimension": 1, - "scale": (1, 1, 1, 1, 1), - "zero_point": (0, 0, 0, 0, 0), - }, - }, - 6: { - "shape": (5, 4), - "type": tflite.TensorType.INT16, - "buffer": 7, - "quantization": { - "quantized_dimension": 0, - "scale": (1, 1, 1, 1, 1), - "zero_point": (0, 0, 0, 0, 0), - }, - }, - 7: { - "shape": (5, 4), - "type": tflite.TensorType.INT16, - "buffer": 8, - "quantization": { - "quantized_dimension": 0, - "scale": (1,), - "zero_point": (0,), - }, - }, - 8: { - "shape": (16, 1), - "type": tflite.TensorType.UINT8, - "buffer": 9, - }, - }, - }, - }, - "buffers": { - 0: None, - - 1: np.array(range(16), dtype=np.dtype(" Date: Sun, 24 May 2026 23:17:18 -0500 Subject: [PATCH 05/21] chore(compression): remove test_models.py Remove test_models module and its tests, now superseded by model_editor. BUG=part of #3256 --- tensorflow/lite/micro/compression/BUILD | 21 -- .../lite/micro/compression/test_models.py | 190 ------------------ .../micro/compression/test_models_test.py | 32 --- 3 files changed, 243 deletions(-) delete mode 100644 tensorflow/lite/micro/compression/test_models.py delete mode 100644 tensorflow/lite/micro/compression/test_models_test.py diff --git a/tensorflow/lite/micro/compression/BUILD b/tensorflow/lite/micro/compression/BUILD index b098e360433..ca168798525 100644 --- a/tensorflow/lite/micro/compression/BUILD +++ b/tensorflow/lite/micro/compression/BUILD @@ -205,27 +205,6 @@ py_test( ], ) -tflm_py_library( - name = "test_models", - srcs = ["test_models.py"], - deps = [ - "//tensorflow/lite/python:schema_py", - requirement("flatbuffers"), - requirement("numpy"), - ], -) - -py_test( - name = "test_models_test", - size = "small", - srcs = ["test_models_test.py"], - target_compatible_with = INCOMPATIBLE_WITH_WINDOWS, - deps = [ - ":test_models", - "//tensorflow/lite/python:schema_py", - ], -) - tflm_py_library( name = "model_editor", srcs = ["model_editor.py"], diff --git a/tensorflow/lite/micro/compression/test_models.py b/tensorflow/lite/micro/compression/test_models.py deleted file mode 100644 index 80286d17359..00000000000 --- a/tensorflow/lite/micro/compression/test_models.py +++ /dev/null @@ -1,190 +0,0 @@ -# Copyright 2024 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -"""Tools for constructing flatbuffers for testing. - -This module provides tools for constructing .tflite flatbuffers from a Python -dictionary representation of a model, a prototype of which can be found in -EXAMPLE_MODEL. - -Example usage: - model_definition = {...} # use EXAMPLE_MODEL as prototype - flatbuffer: bytearray = test_models.build(model_definition) -""" - -# This module must remain low-level and independent from any helpers in this -# project which make constructing model and flatbuffers easier, because this -# module is used to define tests for those helpers. - -import flatbuffers -import numpy as np -from tflite_micro.tensorflow.lite.python import schema_py_generated as tflite - -EXAMPLE_MODEL = { - "operator_codes": { - 0: { - "builtin_code": tflite.BuiltinOperator.FULLY_CONNECTED, - }, - 1: { - "builtin_code": tflite.BuiltinOperator.ADD, - }, - }, - "metadata": { - 0: { - "name": "metadata0", - "buffer": 0 - }, - }, - "subgraphs": { - 0: { - "operators": { - 0: { - "opcode_index": 1, - "inputs": ( - 0, - 1, - ), - "outputs": (3, ), - }, - 1: { - "opcode_index": 0, - "inputs": ( - 3, - 2, - ), - "outputs": (4, ), - }, - }, - "tensors": { - 0: { - "shape": (16, 1), - "type": tflite.TensorType.INT8, - "buffer": 1, - }, - 1: { - "shape": (16, 1), - "type": tflite.TensorType.INT8, - "buffer": 1, - }, - 2: { - "shape": (16, 1), - "type": tflite.TensorType.INT8, - "buffer": 1, - }, - 3: { - "shape": (16, 1), - "type": tflite.TensorType.INT8, - "buffer": 1, - "quantization": { - "quantized_dimension": 0, - }, - }, - }, - }, - }, - "buffers": { - 0: None, - 1: np.array(range(16), dtype=np.dtype(" bytearray: - """Builds a .tflite flatbuffer from a model definition. - - Args: - model_definition: A dictionary representation of the model, a prototype of - which can be found in the EXAMPLE_MODEL attribute of this module. - - Returns: - A tflite flatbuffer. - """ - root = tflite.ModelT() - description = model_definition.get("description") - if description is not None: - root.description = description - - root.operatorCodes = [] - for id, operator_code in model_definition["operator_codes"].items(): - assert id == len(root.operatorCodes) - opcode_t = tflite.OperatorCodeT() - root.operatorCodes.append(opcode_t) - opcode_t.builtinCode = operator_code["builtin_code"] - - root.metadata = [] - if "metadata" in model_definition: - for _, metadata in model_definition["metadata"].items(): - metadata_t = tflite.MetadataT() - metadata_t.name = metadata["name"] - metadata_t.buffer = metadata["buffer"] - root.metadata.append(metadata_t) - - root.subgraphs = [] - for id, subgraph in model_definition["subgraphs"].items(): - assert id == len(root.subgraphs) - subgraph_t = tflite.SubGraphT() - root.subgraphs.append(subgraph_t) - - subgraph_t.operators = [] - for id, operator in subgraph["operators"].items(): - assert id == len(subgraph_t.operators) - operator_t = tflite.OperatorT() - operator_t.opcodeIndex = operator["opcode_index"] - operator_t.inputs = operator["inputs"] - operator_t.outputs = operator["outputs"] - subgraph_t.operators.append(operator_t) - - subgraph_t.tensors = [] - for id, tensor in subgraph["tensors"].items(): - assert id == len(subgraph_t.tensors) - tensor_t = tflite.TensorT() - tensor_t.name = tensor.get("name", None) - tensor_t.shape = tensor["shape"] - tensor_t.type = tensor["type"] - tensor_t.buffer = tensor["buffer"] - - if "quantization" in tensor: - tensor_t.quantization = tflite.QuantizationParametersT() - tensor_t.quantization.quantizedDimension = \ - tensor["quantization"].get("quantized_dimension", None) - tensor_t.quantization.scale = \ - tensor["quantization"].get("scale", None) - tensor_t.quantization.zeroPoint = \ - tensor["quantization"].get("zero_point", None) - - subgraph_t.tensors.append(tensor_t) - - root.buffers = [] - for id, data in model_definition["buffers"].items(): - assert id == len(root.buffers) - buffer_t = tflite.BufferT() - - if data is None: - buffer_t.data = [] - elif isinstance(data, np.ndarray): - array = data.astype(data.dtype.newbyteorder("<")) # ensure little-endian - buffer_t.data = list(array.tobytes()) - else: - raise TypeError(f"buffer_id {id} must be None or an np.ndarray") - - root.buffers.append(buffer_t) - - size_hint = 1 * 2**20 - builder = flatbuffers.Builder(size_hint) - builder.Finish(root.Pack(builder)) - flatbuffer = builder.Output() - return flatbuffer diff --git a/tensorflow/lite/micro/compression/test_models_test.py b/tensorflow/lite/micro/compression/test_models_test.py deleted file mode 100644 index d7e961c2dd9..00000000000 --- a/tensorflow/lite/micro/compression/test_models_test.py +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright 2024 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest -from tflite_micro.tensorflow.lite.micro.compression import test_models -from tflite_micro.tensorflow.lite.python import schema_py_generated as tflite - - -class TestBuild(unittest.TestCase): - - def setUp(self): - self.flatbuffer = test_models.build(test_models.EXAMPLE_MODEL) - - def testNotDegenerate(self): - model = tflite.ModelT.InitFromPackedBuf(self.flatbuffer, 0) - self.assertEqual(model.operatorCodes[0].builtinCode, - tflite.BuiltinOperator.FULLY_CONNECTED) - - -if __name__ == "__main__": - unittest.main() From 31434d40ba14378fc73c7de82cdd43861798636b Mon Sep 17 00:00:00 2001 From: Ryan Kuester Date: Sun, 24 May 2026 23:19:41 -0500 Subject: [PATCH 06/21] feat(compression): add DECODE operator types and metadata Add decode module with DecodeType constants and DecodeCommonMetadata, per the TFLM DECODE Operator Design document. BUG=part of #3256 --- tensorflow/lite/micro/compression/BUILD | 14 + tensorflow/lite/micro/compression/decode.py | 240 ++++++++++++++++++ .../lite/micro/compression/decode_test.py | 155 +++++++++++ 3 files changed, 409 insertions(+) create mode 100644 tensorflow/lite/micro/compression/decode.py create mode 100644 tensorflow/lite/micro/compression/decode_test.py diff --git a/tensorflow/lite/micro/compression/BUILD b/tensorflow/lite/micro/compression/BUILD index ca168798525..7080cc35006 100644 --- a/tensorflow/lite/micro/compression/BUILD +++ b/tensorflow/lite/micro/compression/BUILD @@ -226,6 +226,20 @@ tflm_py_test( ], ) +tflm_py_library( + name = "decode", + srcs = ["decode.py"], +) + +tflm_py_test( + name = "decode_test", + size = "small", + srcs = ["decode_test.py"], + deps = [ + ":decode", + ], +) + tflm_py_binary( name = "view", srcs = [ diff --git a/tensorflow/lite/micro/compression/decode.py b/tensorflow/lite/micro/compression/decode.py new file mode 100644 index 00000000000..df8428310a3 --- /dev/null +++ b/tensorflow/lite/micro/compression/decode.py @@ -0,0 +1,240 @@ +# Copyright 2026 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""DECODE compression module.""" + +# Implements the DECODE operator compression scheme described in the +# "TFLM DECODE Operator Design" document, revised May 20, 2025. +# +# The DECODE operator transforms an encoded tensor, alongside a paired +# ancillary data tensor, into a tensor ready for use as input to any +# operator. For example, an encoded tensor might contain compressed +# data, while the paired ancillary data tensor holds the information +# necessary for decompression. The DECODE operator's output is a fully +# decompressed tensor. +# +# DECODE operators are inserted into the TfLite model subgraph +# immediately before each operation that uses a decodable tensor as +# input. +# +# Ancillary Data Tensor +# +# The ancillary data tensor contains the information necessary for +# decoding. It begins with a 16-byte DECODE Common Metadata (DCM) +# header, followed by decode-type-specific ancillary data. +# +# DECODE Common Metadata (DCM) +# +# Byte 0: Decode type +# 0-127: TFLM-supported decode operations (see below) +# 128-255: Custom operations requiring application-registered +# handlers +# +# Supported decode types: +# +# 0: LUT decompression +# All TFLM tensor types supported in reference and optimized +# code. +# +# 1: Huffman decompression using Xtensa format decode tables +# INT8 and INT16 tensor types only, in reference and optimized +# code. +# +# 2: Pruning decompression +# All TFLM tensor types supported in reference and optimized +# code. +# +# 3-127: Reserved +# +# 128-255: Custom decode types +# Requires user-supplied encoding module and decoding ancillary +# data. +# +# Byte 1: DCM version (currently 1) +# +# Bytes 2-3: Reserved +# +# Bytes 4-15: User-defined +# Used by TFLM decode types to avoid requiring additional alignment +# of metadata or ancillary data. +# +# The 16-byte DCM size ensures that subsequent metadata and ancillary +# data are 128-bit aligned, which is required for some optimized +# decoding operations such as Xtensa LUT decompression. +# +# For TFLM decode types, ancillary data starts immediately after the +# DCM. For custom decode types, the location is determined by +# user-defined metadata. + +from dataclasses import dataclass +from typing import Protocol + + +class DecodeType: + """Decode operation type (0-255). + + Use predefined constants for built-in types or DecodeType.custom() + for custom types: + DecodeType.LUT # 0 + DecodeType.HUFFMAN # 1 + DecodeType.PRUNING # 2 + DecodeType.custom(200) # Custom type 128-255 + """ + + # Built-in decode types (class variables set after class definition) + LUT: 'DecodeType' + HUFFMAN: 'DecodeType' + PRUNING: 'DecodeType' + + def __init__(self, code: int, name: str = None): + """Initialize DecodeType. + + Args: + code: Integer code 0-255 + name: Optional name for the type. If not provided: + - Codes 0-127: Named "TYPE_{code}" + - Codes 128-255: Named "CUSTOM_{code}" + """ + if not 0 <= code <= 255: + raise ValueError(f"Decode type must be 0-255, got {code}") + self.code = code + + # Auto-generate name if not provided + if name is None: + self.name = f"CUSTOM_{code}" if code >= 128 else f"TYPE_{code}" + else: + self.name = name + + self._is_custom = code >= 128 + + @property + def is_custom(self) -> bool: + """True if this is a custom decode type (128-255).""" + return self._is_custom + + @classmethod + def custom(cls, code: int) -> 'DecodeType': + """Create custom decode type (128-255). + + Args: + code: Integer code 128-255 + + Returns: + DecodeType with name CUSTOM_{code} + """ + if not 128 <= code <= 255: + raise ValueError(f"Custom decode type must be 128-255, got {code}") + return cls(code) + + def __int__(self): + """Convert to integer for serialization.""" + return self.code + + def __eq__(self, other): + if isinstance(other, DecodeType): + return self.code == other.code + return self.code == other + + def __repr__(self): + return f"DecodeType.{self.name}({self.code})" + + +# Define built-in decode type constants +DecodeType.LUT = DecodeType(0, "LUT") +DecodeType.HUFFMAN = DecodeType(1, "HUFFMAN") +DecodeType.PRUNING = DecodeType(2, "PRUNING") + + +@dataclass +class DecodeCommonMetadata: + """16-byte DECODE Common Metadata (DCM) header. + + Attributes: + decode_type: Decode operation type. Use DecodeType constants or + DecodeType.custom(code) for custom types. + version: DCM version (currently 1). + user_data: 12 bytes of user-defined data (bytes 4-15 of DCM). Used by TFLM + decode types to avoid requiring additional alignment of metadata + or ancillary data. + """ + decode_type: DecodeType + version: int = 1 + user_data: bytes = b'\x00' * 12 + + def to_bytes(self) -> bytes: + """Serialize DCM to 16-byte sequence.""" + decode_code = int(self.decode_type) + if not 0 <= self.version <= 255: + raise ValueError(f"version must be 0-255, got {self.version}") + if len(self.user_data) < 12: + # Pad with zeros if user_data is too short + user_data = self.user_data + b'\x00' * (12 - len(self.user_data)) + else: + user_data = self.user_data[:12] + + result = bytearray(16) + result[0] = decode_code + result[1] = self.version + # bytes 2-3 remain zero (reserved) + result[4:16] = user_data + return bytes(result) + + +class AncillaryDataSerializer(Protocol): + """Protocol for objects that can serialize ancillary data.""" + + def to_bytes(self) -> bytes: + ... + + +@dataclass +class AncillaryDataTensor: + """Complete Ancillary Data Tensor (ADT): DCM + decode-type-specific data. + + The ADT is stored as a buffer in the TFLite model. It begins with a 16-byte + DCM header, followed by decode-type-specific ancillary data. + + Attributes: + dcm: The DECODE Common Metadata header. + ancillary_data: The decode-type-specific ancillary data, either as raw bytes + or as an object implementing the AncillaryDataSerializer + protocol. May be None if only the DCM is needed. + """ + dcm: DecodeCommonMetadata + ancillary_data: AncillaryDataSerializer | bytes | None = None + + def with_ancillary_data( + self, data: AncillaryDataSerializer | bytes) -> 'AncillaryDataTensor': + """Create new ADT with ancillary data added. + + Args: + data: Ancillary data to add, either as raw bytes or as an object + implementing AncillaryDataSerializer. + + Returns: + New AncillaryDataTensor with the specified ancillary data. + """ + return AncillaryDataTensor(self.dcm, data) + + def to_bytes(self) -> bytes: + """Serialize entire ADT to bytes. + + Returns: + Byte sequence containing DCM followed by ancillary data (if present). + """ + dcm_bytes = self.dcm.to_bytes() + if self.ancillary_data is None: + return dcm_bytes + if isinstance(self.ancillary_data, bytes): + return dcm_bytes + self.ancillary_data + return dcm_bytes + self.ancillary_data.to_bytes() diff --git a/tensorflow/lite/micro/compression/decode_test.py b/tensorflow/lite/micro/compression/decode_test.py new file mode 100644 index 00000000000..eca3b42b2b4 --- /dev/null +++ b/tensorflow/lite/micro/compression/decode_test.py @@ -0,0 +1,155 @@ +# Copyright 2026 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from tflite_micro.tensorflow.lite.micro.compression import decode + + +class TestDecodeCommonMetadata(unittest.TestCase): + + def testBasicSerialization(self): + dcm = decode.DecodeCommonMetadata(decode_type=decode.DecodeType.LUT) + result = dcm.to_bytes() + + # Should be exactly 16 bytes + self.assertEqual(len(result), 16) + + # Byte 0: decode_type + self.assertEqual(result[0], 0) + + # Byte 1: version (default 1) + self.assertEqual(result[1], 1) + + # Bytes 2-3: reserved (should be zero) + self.assertEqual(result[2], 0) + self.assertEqual(result[3], 0) + + # Bytes 4-15: user_data (default all zeros) + self.assertEqual(result[4:16], b'\x00' * 12) + + def testCustomVersion(self): + dcm = decode.DecodeCommonMetadata(decode_type=1, version=2) + result = dcm.to_bytes() + + self.assertEqual(result[0], 1) + self.assertEqual(result[1], 2) + + def testUserData(self): + user_data = b'\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c' + dcm = decode.DecodeCommonMetadata(decode_type=0, user_data=user_data) + result = dcm.to_bytes() + + self.assertEqual(result[4:16], user_data) + + def testUserDataPadding(self): + # User data shorter than 12 bytes should be padded with zeros + user_data = b'\x01\x02\x03' + dcm = decode.DecodeCommonMetadata(decode_type=0, user_data=user_data) + result = dcm.to_bytes() + + expected = b'\x01\x02\x03' + b'\x00' * 9 + self.assertEqual(result[4:16], expected) + + def testUserDataTruncation(self): + # User data longer than 12 bytes should be truncated + user_data = b'\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f' + dcm = decode.DecodeCommonMetadata(decode_type=0, user_data=user_data) + result = dcm.to_bytes() + + self.assertEqual(result[4:16], user_data[:12]) + + def testDecodeTypeRange(self): + # Valid decode types: 0-255 + decode.DecodeCommonMetadata(decode_type=decode.DecodeType.LUT).to_bytes() + decode.DecodeCommonMetadata(decode_type=decode.DecodeType(127)).to_bytes() + decode.DecodeCommonMetadata( + decode_type=decode.DecodeType.custom(255)).to_bytes() + + # Invalid decode types should raise ValueError + with self.assertRaises(ValueError): + decode.DecodeCommonMetadata(decode_type=decode.DecodeType(-1)).to_bytes() + with self.assertRaises(ValueError): + decode.DecodeCommonMetadata( + decode_type=decode.DecodeType(256)).to_bytes() + + def testVersionRange(self): + # Valid versions: 0-255 + decode.DecodeCommonMetadata(decode_type=0, version=0).to_bytes() + decode.DecodeCommonMetadata(decode_type=0, version=255).to_bytes() + + # Invalid versions should raise ValueError + with self.assertRaises(ValueError): + decode.DecodeCommonMetadata(decode_type=0, version=-1).to_bytes() + with self.assertRaises(ValueError): + decode.DecodeCommonMetadata(decode_type=0, version=256).to_bytes() + + +class TestAncillaryDataTensor(unittest.TestCase): + + def testDcmOnly(self): + dcm = decode.DecodeCommonMetadata(decode_type=decode.DecodeType.LUT) + adt = decode.AncillaryDataTensor(dcm) + result = adt.to_bytes() + + # Should be just the 16-byte DCM + self.assertEqual(len(result), 16) + self.assertEqual(result, dcm.to_bytes()) + + def testWithBytesAncillaryData(self): + dcm = decode.DecodeCommonMetadata(decode_type=decode.DecodeType.HUFFMAN) + ancillary = b'\xaa\xbb\xcc\xdd' + adt = decode.AncillaryDataTensor(dcm, ancillary) + result = adt.to_bytes() + + # Should be DCM + ancillary data + self.assertEqual(len(result), 20) + self.assertEqual(result[:16], dcm.to_bytes()) + self.assertEqual(result[16:], ancillary) + + def testWithAncillaryDataMethod(self): + dcm = decode.DecodeCommonMetadata(decode_type=decode.DecodeType.PRUNING) + adt = decode.AncillaryDataTensor(dcm) + + ancillary = b'\x11\x22\x33\x44' + adt_with_data = adt.with_ancillary_data(ancillary) + result = adt_with_data.to_bytes() + + # Original ADT should be unchanged + self.assertEqual(adt.to_bytes(), dcm.to_bytes()) + + # New ADT should have ancillary data + self.assertEqual(len(result), 20) + self.assertEqual(result[:16], dcm.to_bytes()) + self.assertEqual(result[16:], ancillary) + + def testWithSerializerProtocol(self): + # Test with an object that implements AncillaryDataSerializer + class MockSerializer: + + def to_bytes(self): + return b'\xff\xee\xdd\xcc' + + dcm = decode.DecodeCommonMetadata(decode_type=decode.DecodeType(3)) + serializer = MockSerializer() + adt = decode.AncillaryDataTensor(dcm, serializer) + result = adt.to_bytes() + + self.assertEqual(len(result), 20) + self.assertEqual(result[:16], dcm.to_bytes()) + self.assertEqual(result[16:], b'\xff\xee\xdd\xcc') + + +if __name__ == '__main__': + unittest.main() From 5be61ceaf064c443c96b365e64f3454b2238de2f Mon Sep 17 00:00:00 2001 From: Ryan Kuester Date: Sun, 24 May 2026 23:20:50 -0500 Subject: [PATCH 07/21] feat(compression): add Compressor protocol Define the plugin interface for compression methods. Each compressor implements the Compressor protocol with a compress() method that returns encoded data and ancillary data. BUG=part of #3256 --- tensorflow/lite/micro/compression/BUILD | 10 +++ .../lite/micro/compression/compressor.py | 80 +++++++++++++++++++ 2 files changed, 90 insertions(+) create mode 100644 tensorflow/lite/micro/compression/compressor.py diff --git a/tensorflow/lite/micro/compression/BUILD b/tensorflow/lite/micro/compression/BUILD index 7080cc35006..adf04e5ad35 100644 --- a/tensorflow/lite/micro/compression/BUILD +++ b/tensorflow/lite/micro/compression/BUILD @@ -240,6 +240,16 @@ tflm_py_test( ], ) +tflm_py_library( + name = "compressor", + srcs = ["compressor.py"], + deps = [ + ":decode", + ":model_editor", + ":spec", + ], +) + tflm_py_binary( name = "view", srcs = [ diff --git a/tensorflow/lite/micro/compression/compressor.py b/tensorflow/lite/micro/compression/compressor.py new file mode 100644 index 00000000000..3d5a635eb09 --- /dev/null +++ b/tensorflow/lite/micro/compression/compressor.py @@ -0,0 +1,80 @@ +# Copyright 2026 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Compression plugin interface.""" + +from dataclasses import dataclass +from typing import Protocol + +from tflite_micro.tensorflow.lite.micro.compression import decode +from tflite_micro.tensorflow.lite.micro.compression import model_editor +from tflite_micro.tensorflow.lite.micro.compression import spec + + +class CompressionError(Exception): + """Raised when compression fails for the reason documented in the message.""" + + def __init__(self, message, wrapped_exception=None): + if wrapped_exception: + super().__init__(f"{message}: {str(wrapped_exception)}") + else: + super().__init__(message) + self.original_exception = wrapped_exception + + +@dataclass +class CompressionResult: + """Result of compressing a tensor. + + Attributes: + encoded_data: The compressed tensor data (e.g., packed indices for LUT). + ancillary_data: The complete ancillary data tensor bytes (DCM + type-specific + data). This is the full buffer contents for the ancillary + tensor. + """ + encoded_data: bytes + ancillary_data: bytes + + +class Compressor(Protocol): + """Protocol that compression plugins must implement. + + Each compression method (LUT, Huffman, Pruning) provides a class implementing + this protocol. The compress() function uses duck typing to call the plugin. + """ + + @property + def decode_type(self) -> decode.DecodeType: + """The DecodeType constant for this compression method.""" + ... + + def compress( + self, + tensor: model_editor.Tensor, + method: spec.CompressionMethod, + ) -> CompressionResult: + """Compress a tensor according to the specified method. + + Args: + tensor: The tensor to compress. Must have data (tensor.array is not None) + and quantization parameters for axis inference. + method: The compression method spec (e.g., LookUpTableCompression). + + Returns: + CompressionResult with encoded tensor data and ancillary data bytes. + + Raises: + CompressionError: If compression fails (e.g., too many unique values + for specified bitwidth, missing quantization, etc.). + """ + ... From 40f28c52ff117b0587957b1d6e2a271601c63091 Mon Sep 17 00:00:00 2001 From: Ryan Kuester Date: Sun, 24 May 2026 23:22:36 -0500 Subject: [PATCH 08/21] feat(compression): add LUT compression plugin Implement LutCompressor using the Compressor protocol. Lookup table compression replaces tensor values with indices into a table of unique values, producing packed indices and ancillary data in the format expected by the TFLM DECODE kernel. Supports per-tensor and per-channel compression, sizes value tables to actual unique count, and handles unquantized tensors. BUG=part of #3256 --- tensorflow/lite/micro/compression/BUILD | 33 ++ tensorflow/lite/micro/compression/lut.py | 318 ++++++++++++++ tensorflow/lite/micro/compression/lut_test.py | 405 ++++++++++++++++++ 3 files changed, 756 insertions(+) create mode 100644 tensorflow/lite/micro/compression/lut.py create mode 100644 tensorflow/lite/micro/compression/lut_test.py diff --git a/tensorflow/lite/micro/compression/BUILD b/tensorflow/lite/micro/compression/BUILD index adf04e5ad35..5a01c2ef474 100644 --- a/tensorflow/lite/micro/compression/BUILD +++ b/tensorflow/lite/micro/compression/BUILD @@ -250,6 +250,39 @@ tflm_py_library( ], ) +tflm_py_library( + name = "lut", + srcs = ["lut.py"], + deps = [ + ":compressor", + ":decode", + ":model_editor", + ":spec", + requirement("bitarray"), + requirement("numpy"), + ], +) + +tflm_py_test( + name = "lut_test", + size = "small", + srcs = ["lut_test.py"], + tags = [ + "noasan", + "nomsan", + "noubsan", + ], + deps = [ + ":compressor", + ":decode", + ":lut", + ":model_editor", + ":spec", + "//tensorflow/lite/python:schema_py", + requirement("numpy"), + ], +) + tflm_py_binary( name = "view", srcs = [ diff --git a/tensorflow/lite/micro/compression/lut.py b/tensorflow/lite/micro/compression/lut.py new file mode 100644 index 00000000000..def34059ac5 --- /dev/null +++ b/tensorflow/lite/micro/compression/lut.py @@ -0,0 +1,318 @@ +# Copyright 2026 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""LUT (Look-Up Table) compression plugin.""" + +import sys +from dataclasses import dataclass, field +from typing import Optional + +import bitarray +import bitarray.util +import numpy as np + +from tflite_micro.tensorflow.lite.micro.compression import compressor +from tflite_micro.tensorflow.lite.micro.compression import decode +from tflite_micro.tensorflow.lite.micro.compression import model_editor +from tflite_micro.tensorflow.lite.micro.compression import spec + + +@dataclass +class LutCompressedArray: + """Intermediate representation of LUT-compressed data. + + Attributes: + compression_axis: The axis along which compression was performed, or None + for per-tensor compression. + lookup_tables: List of value lookup tables. One table for per-tensor + compression, or one per channel for per-channel compression. + indices: Array of indices into the lookup tables, same shape as original. + """ + compression_axis: Optional[int] = None + lookup_tables: list[np.ndarray] = field(default_factory=list) + indices: np.ndarray = field(default_factory=lambda: np.array([])) + + @property + def index_bitwidth(self) -> int: + """Returns the number of bits required to encode the indices.""" + if self.indices is None or self.indices.size == 0: + raise ValueError("No indices to compute bitwidth from") + max_index = int(np.max(self.indices)) + return max_index.bit_length() or 1 + + +@dataclass +class LutAncillaryData: + """LUT-specific ancillary data matching C++ decode_state_lut.cc format. + + The LUT ancillary data uses the DCM user_data bytes (4-15) plus value tables: + - Byte 4: LUT version (currently 1) + - Byte 5: Params (lower 3 bits = bitwidth, 1-7) + - Byte 6: Value table channel stride (elements per channel) + - Bytes 7-15: Reserved (zeros) + - Bytes 16+: Value tables (concatenated, stride elements per channel) + + Attributes: + lut_version: LUT format version (currently 1). + bitwidth: Number of bits per index (1-7). + value_table_stride: Number of elements per channel in value tables. + value_tables: Packed value table data following the DCM. + """ + lut_version: int = 1 + bitwidth: int = 4 + value_table_stride: int = 16 + value_tables: bytes = b'' + + def __post_init__(self): + if not 1 <= self.bitwidth <= 7: + raise ValueError(f"bitwidth must be 1-7, got {self.bitwidth}") + if not 0 <= self.value_table_stride <= 128: + raise ValueError( + f"value_table_stride must be 0-128, got {self.value_table_stride}") + + def to_user_data(self) -> bytes: + """Serialize to 12-byte user_data for DCM bytes 4-15.""" + user_data = bytearray(12) + user_data[0] = self.lut_version + user_data[1] = self.bitwidth & 0x07 + user_data[2] = self.value_table_stride + # Bytes 3-11 (DCM bytes 7-15) remain zero (reserved) + return bytes(user_data) + + def to_bytes(self) -> bytes: + """Serialize for use as AncillaryDataTensor.ancillary_data.""" + # This returns the type-specific data that follows the DCM header. + # For LUT, that's just the value tables since user_data is in DCM. + return self.value_tables + + +def compress_array(tensor: np.ndarray, + axis: Optional[int]) -> LutCompressedArray: + """Compresses the given tensor using lookup tables. + + Args: + tensor: The tensor to be compressed. + axis: The axis along which to compress. If an axis is given, a lookup table + is created for each slice along the axis. If axis is None, a single + lookup table is used for the entire tensor. + + Compressing a tensor with a lookup table per slice along a particular + axis is analogous to quantizing a tensor with different quantization + parameters per slice along a particular axis (dimension). + + Returns: + LutCompressedArray containing lookup tables and indices. + """ + compressed = LutCompressedArray() + compressed.compression_axis = axis + + if axis is None: + # Compute unique values and indices for the entire tensor + values, indices = np.unique(tensor, return_inverse=True) + compressed.lookup_tables.append(values) + compressed.indices = indices.reshape(tensor.shape) + else: + # Iterate over slices along the compression axis + slice_indices = [] + for slice in np.moveaxis(tensor, axis, 0): + values, indices = np.unique(slice, return_inverse=True) + compressed.lookup_tables.append(values) + indices = indices.reshape(slice.shape) + slice_indices.append(indices) + + # Reconstruct a tensor of indices from the slices + stacked = np.stack(slice_indices, axis=0) + compressed.indices = np.moveaxis(stacked, 0, axis) + + return compressed + + +def identify_compression_axis(tensor: model_editor.Tensor) -> Optional[int]: + """Determines the axis along which to compress. + + The axis along which to compress is inferred from the tensor's quantization + parameters. Unquantized tensors use per-tensor compression. + + Args: + tensor: The tensor to analyze. + + Returns: + The axis along which to compress, or None to indicate one value table for + the entire tensor. + + Raises: + CompressionError: If the axis cannot be determined from quantization. + """ + q = tensor.quantization + if q is None: + return None + + # model_editor wraps quantization, access scales/axis from wrapper + scales = q.scales if isinstance(q.scales, list) else [q.scales] + quantization_channels = len(scales) + + if quantization_channels == 1: + return None + + if q.axis is not None and q.axis < len(tensor.shape): + if quantization_channels == tensor.shape[q.axis]: + return q.axis + + raise compressor.CompressionError( + "Invalid or no quantization parameters from which to " + "infer the axis along which tensor should be compressed.") + + +def check_bitwidth(compressed: int, specified: int, tensor_spec: spec.Tensor): + """Validates that the specified bitwidth is sufficient. + + It is an error if the bitwidth required to compress a tensor exceeds the + specified bitwith, and a warning if the tensor can be compressed in less than + the specified bitwidth. The latter is allowed, and is not an error, to permit + testing with larger bitwidths without re-binning a model. + + Args: + compressed: The bitwidth required by the compressed data. + specified: The bitwidth specified in the compression spec. + tensor_spec: The tensor spec, for error messages. + + Raises: + CompressionError: If specified bitwidth is too small. + """ + if compressed > specified: + raise compressor.CompressionError( + f"index_bitwidth too small: {compressed} bits needed to " + f"enumerate unique values in tensor specified in {tensor_spec}") + elif compressed < specified: + print( + f"warning: index_bitwidth too large: only {compressed} " + f"bits needed to enumerate unique values in tensor specified in " + f"{tensor_spec}", + file=sys.stderr) + + +def pack_indices(indices: np.ndarray, bitwidth: int) -> bytes: + """Packs indices into a bytearray using bitwidth-sized fields. + + Args: + indices: Array of indices to pack. + bitwidth: Number of bits per index. + + Returns: + Packed bytes with indices in big-endian bit order. + """ + endianness = "big" + bits = bitarray.bitarray(endian=endianness) + for i in indices.ravel(): + bits.extend( + bitarray.util.int2ba(int(i), length=bitwidth, endian=endianness)) + return bits.tobytes() + + +def pack_lookup_tables(tables: list[np.ndarray], table_len: int) -> bytes: + """Packs the value tables of a LutCompressedArray. + + Pack the value tables of a LutCompressedArray into a bytes object in the + format writable to a value_table buffer in the .tflite flatbuffer. The + tables are concatenated. + + Args: + tables: List of numpy arrays containing lookup table values. + table_len: Length to pad each table to (typically 2**bitwidth). + + Returns: + Packed bytes containing all tables concatenated. + """ + buffer = bytearray() + for t in tables: + padding_needed = table_len - len(t) + padded = np.pad(t, (0, padding_needed), mode='constant', constant_values=0) + buffer.extend(padded.tobytes()) + return bytes(buffer) + + +class LutCompressor: + """LUT compression plugin implementing the Compressor protocol.""" + + @property + def decode_type(self) -> decode.DecodeType: + """Returns DecodeType.LUT.""" + return decode.DecodeType.LUT + + def compress( + self, + tensor: model_editor.Tensor, + method: spec.CompressionMethod, + ) -> compressor.CompressionResult: + """Compress a tensor using LUT compression. + + Args: + tensor: The tensor to compress. + method: Must be a LookUpTableCompression instance. + + Returns: + CompressionResult with packed indices and ancillary data. + + Raises: + CompressionError: If compression fails. + """ + if not isinstance(method, spec.LookUpTableCompression): + raise compressor.CompressionError( + f"LutCompressor requires LookUpTableCompression, got {type(method)}") + + if tensor.array is None: + raise compressor.CompressionError("Tensor has no data to compress") + + spec_bitwidth = method.index_bitwidth + axis = identify_compression_axis(tensor) + compressed = compress_array(tensor.array, axis) + # Note: check_bitwidth requires a spec.Tensor but we don't have it here. + # We'll do a simpler check. + actual_bitwidth = compressed.index_bitwidth + if actual_bitwidth > spec_bitwidth: + raise compressor.CompressionError( + f"index_bitwidth too small: {actual_bitwidth} bits needed, " + f"but only {spec_bitwidth} specified") + elif actual_bitwidth < spec_bitwidth: + print( + f"warning: index_bitwidth larger than necessary: only " + f"{actual_bitwidth} bits needed, but {spec_bitwidth} specified", + file=sys.stderr) + + # Pack indices into bytes + encoded_data = pack_indices(compressed.indices, spec_bitwidth) + + # Pack value tables + table_len = max(len(t) for t in compressed.lookup_tables) + value_tables_bytes = pack_lookup_tables(compressed.lookup_tables, + table_len) + + # Build ancillary data + lut_data = LutAncillaryData( + lut_version=1, + bitwidth=spec_bitwidth, + value_table_stride=table_len, + value_tables=value_tables_bytes, + ) + + # Build complete ancillary data tensor bytes: DCM header + value tables + dcm = decode.DecodeCommonMetadata( + decode_type=self.decode_type, + user_data=lut_data.to_user_data(), + ) + ancillary_data = dcm.to_bytes() + lut_data.to_bytes() + + return compressor.CompressionResult( + encoded_data=encoded_data, + ancillary_data=ancillary_data, + ) diff --git a/tensorflow/lite/micro/compression/lut_test.py b/tensorflow/lite/micro/compression/lut_test.py new file mode 100644 index 00000000000..d01dcfd4260 --- /dev/null +++ b/tensorflow/lite/micro/compression/lut_test.py @@ -0,0 +1,405 @@ +# Copyright 2026 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit tests for LUT compression plugin.""" + +import numpy as np +import unittest + +from tflite_micro.tensorflow.lite.micro.compression import compressor +from tflite_micro.tensorflow.lite.micro.compression import decode +from tflite_micro.tensorflow.lite.micro.compression import lut +from tflite_micro.tensorflow.lite.micro.compression import model_editor +from tflite_micro.tensorflow.lite.micro.compression import spec +from tflite_micro.tensorflow.lite.python import schema_py_generated as tflite + + +class TestCompressArray(unittest.TestCase): + """Tests for the compress_array function.""" + + def test_per_tensor_basic(self): + """Per-tensor compression extracts unique values.""" + array = np.array([1, 2, 1, 2, 3, 3], dtype=np.int8) + compressed = lut.compress_array(array, axis=None) + + self.assertIsNone(compressed.compression_axis) + self.assertEqual(len(compressed.lookup_tables), 1) + np.testing.assert_array_equal(compressed.lookup_tables[0], [1, 2, 3]) + # Indices should map back to original values + reconstructed = compressed.lookup_tables[0][compressed.indices] + np.testing.assert_array_equal(reconstructed, array) + + def test_per_tensor_preserves_shape(self): + """Indices array has same shape as input.""" + # yapf: disable + array = np.array([[1, 2], + [3, 1], + [2, 3]], dtype=np.int8) + # yapf: enable + compressed = lut.compress_array(array, axis=None) + + self.assertEqual(compressed.indices.shape, array.shape) + + def test_per_channel_axis0(self): + """Per-channel compression along axis 0.""" + # Each row gets its own value table + # yapf: disable + array = np.array([[1, 1, 1], + [5, 5, 5], + [9, 9, 9]], dtype=np.int8) + # yapf: enable + compressed = lut.compress_array(array, axis=0) + + self.assertEqual(compressed.compression_axis, 0) + self.assertEqual(len(compressed.lookup_tables), 3) + np.testing.assert_array_equal(compressed.lookup_tables[0], [1]) + np.testing.assert_array_equal(compressed.lookup_tables[1], [5]) + np.testing.assert_array_equal(compressed.lookup_tables[2], [9]) + + def test_per_channel_axis1(self): + """Per-channel compression along axis 1.""" + # Each column gets its own value table + # yapf: disable + array = np.array([[1, 5], + [1, 5], + [1, 5]], dtype=np.int8) + # yapf: enable + compressed = lut.compress_array(array, axis=1) + + self.assertEqual(compressed.compression_axis, 1) + self.assertEqual(len(compressed.lookup_tables), 2) + np.testing.assert_array_equal(compressed.lookup_tables[0], [1]) + np.testing.assert_array_equal(compressed.lookup_tables[1], [5]) + + def test_single_value(self): + """Array with single unique value.""" + array = np.array([7, 7, 7, 7], dtype=np.int8) + compressed = lut.compress_array(array, axis=None) + + self.assertEqual(len(compressed.lookup_tables), 1) + np.testing.assert_array_equal(compressed.lookup_tables[0], [7]) + np.testing.assert_array_equal(compressed.indices, [0, 0, 0, 0]) + + def test_bitwidth_calculation(self): + """Index bitwidth is computed correctly.""" + # 3 unique values -> 2 bits needed + array = np.array([0, 1, 2], dtype=np.int8) + compressed = lut.compress_array(array, axis=None) + self.assertEqual(compressed.index_bitwidth, 2) + + # 4 unique values -> 2 bits needed + array = np.array([0, 1, 2, 3], dtype=np.int8) + compressed = lut.compress_array(array, axis=None) + self.assertEqual(compressed.index_bitwidth, 2) + + # 5 unique values -> 3 bits needed + array = np.array([0, 1, 2, 3, 4], dtype=np.int8) + compressed = lut.compress_array(array, axis=None) + self.assertEqual(compressed.index_bitwidth, 3) + + def test_bitwidth_single_value(self): + """Single unique value requires 1 bit.""" + array = np.array([42, 42, 42], dtype=np.int8) + compressed = lut.compress_array(array, axis=None) + self.assertEqual(compressed.index_bitwidth, 1) + + +class TestPackIndices(unittest.TestCase): + """Tests for the pack_indices function.""" + + def test_4bit_packing(self): + """Pack indices into 4-bit fields.""" + indices = np.array([1, 2, 3, 0]) + result = lut.pack_indices(indices, bitwidth=4) + # Big-endian: 0001 0010 | 0011 0000 = 0x12 0x30 + self.assertEqual(result, bytes([0x12, 0x30])) + + def test_2bit_packing(self): + """Pack indices into 2-bit fields.""" + indices = np.array([0, 1, 2, 3]) + result = lut.pack_indices(indices, bitwidth=2) + # Big-endian: 00 01 10 11 = 0x1B + self.assertEqual(result, bytes([0x1B])) + + def test_3bit_packing(self): + """Pack indices into 3-bit fields.""" + indices = np.array([0, 1, 2, 3, 4, 5, 6, 7]) + result = lut.pack_indices(indices, bitwidth=3) + # 000 001 010 011 | 100 101 110 111 + # 00000101 | 00111001 | 01110111 = 0x05 0x39 0x77 + self.assertEqual(result, bytes([0x05, 0x39, 0x77])) + + def test_1bit_packing(self): + """Pack indices into 1-bit fields.""" + indices = np.array([0, 1, 0, 1, 1, 0, 1, 0]) + result = lut.pack_indices(indices, bitwidth=1) + # 0 1 0 1 1 0 1 0 = 0x5A + self.assertEqual(result, bytes([0x5A])) + + def test_multidimensional_flattens(self): + """Multidimensional indices are flattened row-major.""" + # yapf: disable + indices = np.array([[0, 1], + [2, 3]]) + # yapf: enable + result = lut.pack_indices(indices, bitwidth=4) + # 0000 0001 | 0010 0011 = 0x01 0x23 + self.assertEqual(result, bytes([0x01, 0x23])) + + +class TestPackLookupTables(unittest.TestCase): + """Tests for the pack_lookup_tables function.""" + + def test_single_table_int8(self): + """Pack single INT8 lookup table.""" + tables = [np.array([10, 20, 30], dtype=np.int8)] + result = lut.pack_lookup_tables(tables, table_len=4) + # Values: 10, 20, 30, 0 (padding) + self.assertEqual(result, bytes([10, 20, 30, 0])) + + def test_multiple_tables(self): + """Pack multiple lookup tables.""" + tables = [ + np.array([1, 2], dtype=np.int8), + np.array([3, 4], dtype=np.int8), + ] + result = lut.pack_lookup_tables(tables, table_len=4) + # Table 1: 1, 2, 0, 0 | Table 2: 3, 4, 0, 0 + self.assertEqual(result, bytes([1, 2, 0, 0, 3, 4, 0, 0])) + + def test_int16_little_endian(self): + """INT16 values are packed in native byte order.""" + tables = [np.array([0x1234, 0x5678], dtype=' Date: Sun, 24 May 2026 23:24:04 -0500 Subject: [PATCH 09/21] feat(compression): add Huffman and Pruning compression support Add spec types, YAML parser support, and plugin stubs for Huffman and Pruning compression methods. The plugins raise CompressionError when invoked, to be replaced with working implementations later. BUG=part of #3256 --- tensorflow/lite/micro/compression/BUILD | 22 +++++++ tensorflow/lite/micro/compression/huffman.py | 60 ++++++++++++++++++++ tensorflow/lite/micro/compression/pruning.py | 59 +++++++++++++++++++ tensorflow/lite/micro/compression/spec.py | 51 +++++++++++++++-- 4 files changed, 186 insertions(+), 6 deletions(-) create mode 100644 tensorflow/lite/micro/compression/huffman.py create mode 100644 tensorflow/lite/micro/compression/pruning.py diff --git a/tensorflow/lite/micro/compression/BUILD b/tensorflow/lite/micro/compression/BUILD index 5a01c2ef474..f40b530034d 100644 --- a/tensorflow/lite/micro/compression/BUILD +++ b/tensorflow/lite/micro/compression/BUILD @@ -283,6 +283,28 @@ tflm_py_test( ], ) +tflm_py_library( + name = "huffman", + srcs = ["huffman.py"], + deps = [ + ":compressor", + ":decode", + ":model_editor", + ":spec", + ], +) + +tflm_py_library( + name = "pruning", + srcs = ["pruning.py"], + deps = [ + ":compressor", + ":decode", + ":model_editor", + ":spec", + ], +) + tflm_py_binary( name = "view", srcs = [ diff --git a/tensorflow/lite/micro/compression/huffman.py b/tensorflow/lite/micro/compression/huffman.py new file mode 100644 index 00000000000..40d0be9284a --- /dev/null +++ b/tensorflow/lite/micro/compression/huffman.py @@ -0,0 +1,60 @@ +# Copyright 2026 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Huffman compression plugin (stub). + +This module provides a placeholder for Huffman compression using Xtensa-format +decode tables. The actual implementation is not yet available. + +Supported tensor types (when implemented): INT8, INT16 +""" + +from tflite_micro.tensorflow.lite.micro.compression import compressor +from tflite_micro.tensorflow.lite.micro.compression import decode +from tflite_micro.tensorflow.lite.micro.compression import model_editor +from tflite_micro.tensorflow.lite.micro.compression import spec + + +class HuffmanCompressor: + """Huffman compression plugin (stub). + + This stub exists to validate the plugin architecture. The actual Huffman + compression algorithm using Xtensa-format decode tables is not yet + implemented. + """ + + @property + def decode_type(self) -> decode.DecodeType: + """Returns DecodeType.HUFFMAN.""" + return decode.DecodeType.HUFFMAN + + def compress( + self, + tensor: model_editor.Tensor, + method: spec.CompressionMethod, + ) -> compressor.CompressionResult: + """Compress a tensor using Huffman encoding. + + Args: + tensor: The tensor to compress. + method: Must be a HuffmanCompression instance. + + Returns: + CompressionResult (not implemented). + + Raises: + CompressionError: Always, since this is a stub. + """ + raise compressor.CompressionError( + "Huffman compression not yet implemented. " + "This stub exists to validate the plugin architecture.") diff --git a/tensorflow/lite/micro/compression/pruning.py b/tensorflow/lite/micro/compression/pruning.py new file mode 100644 index 00000000000..2181b73e34a --- /dev/null +++ b/tensorflow/lite/micro/compression/pruning.py @@ -0,0 +1,59 @@ +# Copyright 2026 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Pruning compression plugin (stub). + +This module provides a placeholder for pruning (sparsity) compression. +The actual implementation is not yet available. + +Supported tensor types (when implemented): All TFLM tensor types +""" + +from tflite_micro.tensorflow.lite.micro.compression import compressor +from tflite_micro.tensorflow.lite.micro.compression import decode +from tflite_micro.tensorflow.lite.micro.compression import model_editor +from tflite_micro.tensorflow.lite.micro.compression import spec + + +class PruningCompressor: + """Pruning compression plugin (stub). + + This stub exists to validate the plugin architecture. The actual pruning + compression algorithm for sparse tensors is not yet implemented. + """ + + @property + def decode_type(self) -> decode.DecodeType: + """Returns DecodeType.PRUNING.""" + return decode.DecodeType.PRUNING + + def compress( + self, + tensor: model_editor.Tensor, + method: spec.CompressionMethod, + ) -> compressor.CompressionResult: + """Compress a tensor using pruning (sparsity) encoding. + + Args: + tensor: The tensor to compress. + method: Must be a PruningCompression instance. + + Returns: + CompressionResult (not implemented). + + Raises: + CompressionError: Always, since this is a stub. + """ + raise compressor.CompressionError( + "Pruning compression not yet implemented. " + "This stub exists to validate the plugin architecture.") diff --git a/tensorflow/lite/micro/compression/spec.py b/tensorflow/lite/micro/compression/spec.py index 6f782e92d7a..5c0f81885bc 100644 --- a/tensorflow/lite/micro/compression/spec.py +++ b/tensorflow/lite/micro/compression/spec.py @@ -58,10 +58,32 @@ class Tensor: @dataclass class LookUpTableCompression(CompressionMethod): + """LUT compression using lookup tables. + Attributes: + index_bitwidth: Number of bits per index (1-7). + """ index_bitwidth: int +@dataclass +class HuffmanCompression(CompressionMethod): + """Huffman compression using Xtensa-format decode tables. + + Supported tensor types: INT8, INT16 only. + """ + pass + + +@dataclass +class PruningCompression(CompressionMethod): + """Pruning (sparsity) compression. + + Supported tensor types: All TFLM tensor types. + """ + pass + + class ParseError(Exception): "Raised when the spec string cannot be parsed." @@ -70,6 +92,18 @@ def __init__(self, message="error parsing spec", wrapped_exception=None): self.original_exception = wrapped_exception +def _parse_compression_method(comp: dict) -> CompressionMethod: + """Parse a single compression method from YAML dict.""" + if "lut" in comp: + return LookUpTableCompression(index_bitwidth=comp["lut"]["index_bitwidth"]) + elif "huffman" in comp: + return HuffmanCompression() + elif "pruning" in comp: + return PruningCompression() + else: + raise ParseError(f"Unknown compression method: {list(comp.keys())}") + + def parse_yaml(y: str) -> list[Tensor]: "Parses a compression spec in a YAML string into its Python representation." try: @@ -77,14 +111,19 @@ def parse_yaml(y: str) -> list[Tensor]: tensors = [] for item in config["tensors"]: - bitwidth = item["compression"][0]["lut"]["index_bitwidth"] - tensor = Tensor(subgraph=item["subgraph"], - tensor=item["tensor"], - compression=[ - LookUpTableCompression(index_bitwidth=bitwidth), - ]) + methods = [] + for comp in item["compression"]: + methods.append(_parse_compression_method(comp)) + + tensor = Tensor( + subgraph=item["subgraph"], + tensor=item["tensor"], + compression=methods, + ) tensors.append(tensor) + except ParseError: + raise except Exception as e: raise ParseError() from e From 963ce9782c96f8e9ae2f3e3faf25b6ac62dfbf6b Mon Sep 17 00:00:00 2001 From: Ryan Kuester Date: Sun, 24 May 2026 23:26:33 -0500 Subject: [PATCH 10/21] feat(python): add alt decompression memory parameter to interpreter Add alt_decompression_memory_size parameter to the Python interpreter API. When non-zero, allocates a separate memory region for DECODE operator outputs and calls SetDecompressionMemory before AllocateTensors. BUG=part of #3256 --- python/tflite_micro/_runtime.cc | 9 +++++---- python/tflite_micro/interpreter_wrapper.cc | 18 ++++++++++++++++-- python/tflite_micro/interpreter_wrapper.h | 6 +++++- python/tflite_micro/python_ops_resolver.cc | 2 ++ python/tflite_micro/runtime.py | 12 ++++++++++++ 5 files changed, 40 insertions(+), 7 deletions(-) diff --git a/python/tflite_micro/_runtime.cc b/python/tflite_micro/_runtime.cc index 246545fd016..53825f14f0d 100644 --- a/python/tflite_micro/_runtime.cc +++ b/python/tflite_micro/_runtime.cc @@ -33,10 +33,11 @@ PYBIND11_MODULE(_runtime, m) { .def(py::init([](const py::bytes& data, const std::vector& registerers_by_name, size_t arena_size, int num_resource_variables, - tflite::InterpreterConfig config) { - return std::unique_ptr( - new InterpreterWrapper(data.ptr(), registerers_by_name, arena_size, - num_resource_variables, config)); + tflite::InterpreterConfig config, + size_t alt_decompression_memory_size) { + return std::unique_ptr(new InterpreterWrapper( + data.ptr(), registerers_by_name, arena_size, num_resource_variables, + config, alt_decompression_memory_size)); })) .def("PrintAllocations", &InterpreterWrapper::PrintAllocations) .def("Invoke", &InterpreterWrapper::Invoke) diff --git a/python/tflite_micro/interpreter_wrapper.cc b/python/tflite_micro/interpreter_wrapper.cc index 669589890ad..c74ab84736b 100644 --- a/python/tflite_micro/interpreter_wrapper.cc +++ b/python/tflite_micro/interpreter_wrapper.cc @@ -238,7 +238,14 @@ InterpreterWrapper::~InterpreterWrapper() { InterpreterWrapper::InterpreterWrapper( PyObject* model_data, const std::vector& registerers_by_name, - size_t arena_size, int num_resource_variables, InterpreterConfig config) { + size_t arena_size, int num_resource_variables, InterpreterConfig config, + size_t alt_decompression_memory_size) + : memory_arena_(new uint8_t[arena_size]), + alt_decompression_memory_(alt_decompression_memory_size > 0 + ? new uint8_t[alt_decompression_memory_size] + : nullptr), + alt_decompression_region_{alt_decompression_memory_.get(), + alt_decompression_memory_size} { interpreter_ = nullptr; // `model_data` is used as a raw pointer beyond the scope of this @@ -266,7 +273,6 @@ InterpreterWrapper::InterpreterWrapper( "--//:with_compression=true to enable compression support."); } - memory_arena_ = std::unique_ptr(new uint8_t[arena_size]); for (const std::string& registerer : registerers_by_name) { if (!AddCustomOpRegistererByName(registerer.c_str(), &python_ops_resolver_)) { @@ -296,6 +302,14 @@ InterpreterWrapper::InterpreterWrapper( interpreter_ = new MicroInterpreter(model, python_ops_resolver_, allocator_, resource_variables_); + if (alt_decompression_memory_size > 0) { + TfLiteStatus status = + interpreter_->SetDecompressionMemory(&alt_decompression_region_, 1); + if (status != kTfLiteOk) { + ThrowRuntimeError("TFLM failed to set decompression memory"); + } + } + TfLiteStatus status = interpreter_->AllocateTensors(); if (status != kTfLiteOk) { ThrowRuntimeError("TFLM failed to allocate tensors"); diff --git a/python/tflite_micro/interpreter_wrapper.h b/python/tflite_micro/interpreter_wrapper.h index 9bb31b067fe..d3a156b337a 100644 --- a/python/tflite_micro/interpreter_wrapper.h +++ b/python/tflite_micro/interpreter_wrapper.h @@ -19,6 +19,7 @@ limitations under the License. #include "python/tflite_micro/python_ops_resolver.h" #include "tensorflow/lite/micro/micro_allocator.h" +#include "tensorflow/lite/micro/micro_context.h" #include "tensorflow/lite/micro/micro_interpreter.h" #include "tensorflow/lite/micro/recording_micro_allocator.h" @@ -40,7 +41,8 @@ class InterpreterWrapper { InterpreterWrapper( PyObject* model_data, const std::vector& registerers_by_name, size_t arena_size, int num_resource_variables, - InterpreterConfig config = InterpreterConfig::kAllocationRecording); + InterpreterConfig config = InterpreterConfig::kAllocationRecording, + size_t alt_decompression_memory_size = 0); ~InterpreterWrapper(); void PrintAllocations(); @@ -57,6 +59,8 @@ class InterpreterWrapper { tflite::RecordingMicroAllocator* recording_allocator_ = nullptr; const PyObject* model_; std::unique_ptr memory_arena_; + std::unique_ptr alt_decompression_memory_; + tflite::MicroContext::AlternateMemoryRegion alt_decompression_region_; tflite::PythonOpsResolver python_ops_resolver_; tflite::MicroInterpreter* interpreter_; }; diff --git a/python/tflite_micro/python_ops_resolver.cc b/python/tflite_micro/python_ops_resolver.cc index 5f7d40fb74e..19f324bdf2f 100644 --- a/python/tflite_micro/python_ops_resolver.cc +++ b/python/tflite_micro/python_ops_resolver.cc @@ -40,7 +40,9 @@ PythonOpsResolver::PythonOpsResolver() { AddConv2D(); AddCos(); AddCumSum(); +#ifdef USE_TFLM_COMPRESSION AddDecode(); +#endif AddDelay(); AddDepthToSpace(); AddDepthwiseConv2D(); diff --git a/python/tflite_micro/runtime.py b/python/tflite_micro/runtime.py index d895f8c4993..7052972b4a6 100644 --- a/python/tflite_micro/runtime.py +++ b/python/tflite_micro/runtime.py @@ -100,6 +100,7 @@ def __init__( custom_op_registerers, arena_size, intrepreter_config=InterpreterConfig.kAllocationRecording, + alt_decompression_memory_size=0, ): if model_data is None: raise ValueError("Model must not be None") @@ -122,6 +123,7 @@ def __init__( arena_size, num_resource_variables, _ENUM_TRANSLATOR[intrepreter_config], + alt_decompression_memory_size, ) @classmethod @@ -131,6 +133,7 @@ def from_file( custom_op_registerers=[], arena_size=None, intrepreter_config=InterpreterConfig.kAllocationRecording, + alt_decompression_memory_size=0, ): """Instantiates a TFLM interpreter from a model .tflite filepath. @@ -140,6 +143,9 @@ def from_file( custom OP registerer arena_size: Tensor arena size in bytes. If unused, tensor arena size will default to 10 times the model size. + alt_decompression_memory_size: Size in bytes of alternate decompression + memory. If non-zero, DECODE operators will use this memory instead of + the main arena for decompressed tensor outputs. Returns: An Interpreter instance @@ -155,6 +161,7 @@ def from_file( custom_op_registerers, arena_size, intrepreter_config, + alt_decompression_memory_size, ) @classmethod @@ -164,6 +171,7 @@ def from_bytes( custom_op_registerers=[], arena_size=None, intrepreter_config=InterpreterConfig.kAllocationRecording, + alt_decompression_memory_size=0, ): """Instantiates a TFLM interpreter from a model in byte array. @@ -173,6 +181,9 @@ def from_bytes( custom OP registerer arena_size: Tensor arena size in bytes. If unused, tensor arena size will default to 10 times the model size. + alt_decompression_memory_size: Size in bytes of alternate decompression + memory. If non-zero, DECODE operators will use this memory instead of + the main arena for decompressed tensor outputs. Returns: An Interpreter instance @@ -183,6 +194,7 @@ def from_bytes( custom_op_registerers, arena_size, intrepreter_config, + alt_decompression_memory_size, ) def print_allocations(self): From 702dfb5970745d57ebbf7990f29a934af9014b83 Mon Sep 17 00:00:00 2001 From: Ryan Kuester Date: Sun, 24 May 2026 23:28:09 -0500 Subject: [PATCH 11/21] feat(compression): add DECODE operator insertion Insert DECODE operators before consumers of compressed tensors. Each consumer gets its own DECODE operator to support alternate decompression memory, which resets allocations between DECODE invocations. After insertion, compressed tensors are rewritten to hold encoded data as UINT8 with shape matching byte count. BUG=part of #3256 --- tensorflow/lite/micro/compression/BUILD | 29 ++ .../lite/micro/compression/decode_insert.py | 268 +++++++++++ .../micro/compression/decode_insert_test.py | 417 ++++++++++++++++++ 3 files changed, 714 insertions(+) create mode 100644 tensorflow/lite/micro/compression/decode_insert.py create mode 100644 tensorflow/lite/micro/compression/decode_insert_test.py diff --git a/tensorflow/lite/micro/compression/BUILD b/tensorflow/lite/micro/compression/BUILD index f40b530034d..c5a0b6c2753 100644 --- a/tensorflow/lite/micro/compression/BUILD +++ b/tensorflow/lite/micro/compression/BUILD @@ -305,6 +305,35 @@ tflm_py_library( ], ) +tflm_py_library( + name = "decode_insert", + srcs = ["decode_insert.py"], + deps = [ + ":compressor", + ":model_editor", + "//tensorflow/lite/python:schema_py", + ], +) + +tflm_py_test( + name = "decode_insert_test", + size = "small", + srcs = ["decode_insert_test.py"], + tags = [ + "noasan", + "nomsan", + "noubsan", + ], + deps = [ + ":compressor", + ":decode", + ":decode_insert", + ":model_editor", + "//tensorflow/lite/python:schema_py", + requirement("numpy"), + ], +) + tflm_py_binary( name = "view", srcs = [ diff --git a/tensorflow/lite/micro/compression/decode_insert.py b/tensorflow/lite/micro/compression/decode_insert.py new file mode 100644 index 00000000000..43dffce46f0 --- /dev/null +++ b/tensorflow/lite/micro/compression/decode_insert.py @@ -0,0 +1,268 @@ +# Copyright 2026 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""DECODE operator insertion into TFLite model graphs. + +This module inserts DECODE operators into a compressed model. DECODE operators +transform encoded tensors (with their paired ancillary data tensors) into +tensors ready for use by downstream operators. + +The DECODE operator is registered as a custom operator named "TFLM_DECODE". +Each DECODE output requires two inputs: the encoded tensor and the ancillary +data tensor (containing the DCM header and decode-type-specific data). +""" + +import warnings +from collections import defaultdict +from dataclasses import dataclass +from typing import Optional + +from tflite_micro.tensorflow.lite.micro.compression import compressor +from tflite_micro.tensorflow.lite.micro.compression import model_editor +from tflite_micro.tensorflow.lite.python import schema_py_generated as tflite + +# Custom operator name for DECODE +DECODE_CUSTOM_OP_NAME = "TFLM_DECODE" + + +@dataclass +class _CompressedTensorInfo: + """Information about a compressed tensor for DECODE insertion.""" + subgraph_idx: int + tensor_idx: int + tensor: model_editor.Tensor + encoded_data: bytes + ancillary_data: bytes + consumers: list[model_editor.Operator] + + +def _find_tensor_consumers( + subgraph: model_editor.Subgraph, + tensor: model_editor.Tensor, +) -> list[model_editor.Operator]: + """Find all operators in subgraph that use tensor as an input.""" + consumers = [] + for op in subgraph.operators: + if tensor in op.inputs: + consumers.append(op) + return consumers + + +def _create_ancillary_tensor( + ancillary_data: bytes, + original_tensor: model_editor.Tensor, +) -> model_editor.Tensor: + """Create an ancillary data tensor for a compressed tensor. + + Args: + ancillary_data: The complete ancillary data (DCM + type-specific data). + original_tensor: The original tensor being decoded, for naming. + + Returns: + A new Tensor containing the ancillary data. + """ + name = None + if original_tensor.name: + name = f"{original_tensor.name}_ancillary" + + return model_editor.Tensor( + shape=(len(ancillary_data), ), + dtype=tflite.TensorType.UINT8, + data=ancillary_data, + name=name, + ) + + +def _create_output_tensor( + original_tensor: model_editor.Tensor, ) -> model_editor.Tensor: + """Create the output tensor for a DECODE operator. + + The output tensor has the same shape, dtype, and quantization as the + original tensor would have when decoded. It has no data---the DECODE + operator produces its values at runtime. + + Args: + original_tensor: The original tensor being decoded. + + Returns: + A new Tensor for the DECODE output. + """ + name = None + if original_tensor.name: + name = f"{original_tensor.name}_decoded" + + return model_editor.Tensor( + shape=original_tensor.shape, + dtype=original_tensor.dtype, + quantization=original_tensor.quantization, + name=name, + ) + + +def _rewire_consumers( + consumers: list[model_editor.Operator], + old_tensor: model_editor.Tensor, + new_tensor: model_editor.Tensor, +) -> None: + """Replace old_tensor with new_tensor in all consumer inputs.""" + for consumer in consumers: + consumer.inputs = [ + new_tensor if t is old_tensor else t for t in consumer.inputs + ] + + +def _rewrite_encoded_tensor( + tensor: model_editor.Tensor, + encoded_data: bytes, +) -> None: + """Rewrite a compressed tensor to hold encoded data. + + The original tensor contained uncompressed values with quantization. After + compression, it holds packed indices (or other encoded form) as raw bytes. + This function updates the tensor in place to reflect its new role. + + Args: + tensor: The tensor to rewrite. + encoded_data: The compressed/encoded data bytes. + """ + tensor.shape = (len(encoded_data), ) + tensor.dtype = tflite.TensorType.UINT8 + tensor.quantization = None + tensor.buffer.data = encoded_data + + +def insert_decode_operators( + model: model_editor.Model, + compression_results: dict[tuple[int, int], compressor.CompressionResult], +) -> None: + """Insert DECODE operators for all compressed tensors. + + This function modifies the model in-place, inserting DECODE operators + before any operator that uses a compressed tensor as input. + + A separate DECODE is inserted before each consumer, rather than sharing one + DECODE output among all consumers. This is required because the interpreter's + alternate decompression memory resets its allocation offset for each DECODE's + Prepare, causing all DECODE outputs to be allocated at the same address. If + two consumers share one DECODE and another DECODE runs between them, the + intervening DECODE overwrites the shared output, corrupting data for the + second consumer. + + For each consumer of a compressed tensor: + 1. Create an ancillary data tensor containing DCM + type-specific data + 2. Create an output tensor with the same shape/dtype as the decoded tensor + 3. Insert a DECODE operator immediately before the consumer + 4. Rewire the consumer to use the DECODE output + + Args: + model: The model to modify in-place. + compression_results: Map from (subgraph_idx, tensor_idx) to the + CompressionResult containing ancillary_data. + """ + # Group compressed tensors by subgraph + by_subgraph: dict[int, list[_CompressedTensorInfo]] = defaultdict(list) + + for (sg_idx, tensor_idx), result in compression_results.items(): + subgraph = model.subgraphs[sg_idx] + tensor = subgraph.tensors[tensor_idx] + consumers = _find_tensor_consumers(subgraph, tensor) + + if not consumers: + # Check if tensor is a subgraph output + is_output = tensor in subgraph.outputs + if is_output: + # TODO: Handle compressed tensors that are subgraph outputs. + # This occurs in multi-subgraph models using IF/WHILE where a + # compressed tensor flows out of a subgraph. + raise NotImplementedError( + f"Compressed tensor {tensor.name!r} (subgraph {sg_idx}, " + f"tensor {tensor_idx}) is a subgraph output with no consumers. " + "Compressed subgraph outputs are not yet supported.") + else: + warnings.warn( + f"Compressed tensor {tensor.name!r} (subgraph {sg_idx}, " + f"tensor {tensor_idx}) has no consumers and is not a subgraph " + "output. No DECODE operator will be inserted.", + stacklevel=2) + continue + + info = _CompressedTensorInfo( + subgraph_idx=sg_idx, + tensor_idx=tensor_idx, + tensor=tensor, + encoded_data=result.encoded_data, + ancillary_data=result.ancillary_data, + consumers=consumers, + ) + by_subgraph[sg_idx].append(info) + + # Process each subgraph + for sg_idx, tensor_infos in by_subgraph.items(): + subgraph = model.subgraphs[sg_idx] + + # Collect all (consumer, tensor_info) pairs and sort by consumer position + # in reverse order so insertions don't invalidate positions + consumer_pairs = [] + for info in tensor_infos: + for consumer in info.consumers: + consumer_pairs.append((consumer, info)) + + consumer_pairs.sort( + key=lambda pair: subgraph.operators.index(pair[0]), + reverse=True, + ) + + # Cache ancillary tensors by original tensor to avoid duplicates. Each + # DECODE needs its own output tensor, but ancillary data is identical for + # all DECODEs of the same compressed tensor. + ancillary_cache: dict[model_editor.Tensor, model_editor.Tensor] = {} + + # Track tensors to rewrite after all output tensors are created, since + # _create_output_tensor reads the original tensor's shape/dtype/quantization. + tensors_to_rewrite: dict[model_editor.Tensor, bytes] = {} + + for consumer, info in consumer_pairs: + # Reuse or create ancillary data tensor + if info.tensor not in ancillary_cache: + ancillary_tensor = _create_ancillary_tensor( + info.ancillary_data, + info.tensor, + ) + subgraph.tensors.append(ancillary_tensor) + ancillary_cache[info.tensor] = ancillary_tensor + tensors_to_rewrite[info.tensor] = info.encoded_data + else: + ancillary_tensor = ancillary_cache[info.tensor] + + # Create output tensor (one per DECODE) + output_tensor = _create_output_tensor(info.tensor) + subgraph.tensors.append(output_tensor) + + # Create DECODE operator + decode_op = model_editor.Operator( + opcode=tflite.BuiltinOperator.CUSTOM, + custom_code=DECODE_CUSTOM_OP_NAME, + inputs=[info.tensor, ancillary_tensor], + outputs=[output_tensor], + ) + + # Insert DECODE immediately before this consumer + insert_pos = subgraph.operators.index(consumer) + subgraph.operators.insert(insert_pos, decode_op) + + # Rewire only this consumer to use the decoded output + _rewire_consumers([consumer], info.tensor, output_tensor) + + # Rewrite encoded tensors after all output tensors are created + for tensor, encoded_data in tensors_to_rewrite.items(): + _rewrite_encoded_tensor(tensor, encoded_data) diff --git a/tensorflow/lite/micro/compression/decode_insert_test.py b/tensorflow/lite/micro/compression/decode_insert_test.py new file mode 100644 index 00000000000..11be81963d9 --- /dev/null +++ b/tensorflow/lite/micro/compression/decode_insert_test.py @@ -0,0 +1,417 @@ +# Copyright 2026 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit tests for DECODE operator insertion.""" + +import warnings + +import numpy as np +import unittest + +from tflite_micro.tensorflow.lite.micro.compression import compressor +from tflite_micro.tensorflow.lite.micro.compression import decode +from tflite_micro.tensorflow.lite.micro.compression import decode_insert +from tflite_micro.tensorflow.lite.micro.compression import model_editor +from tflite_micro.tensorflow.lite.python import schema_py_generated as tflite + + +def _build_simple_fc_model(): + """Build a simple model with one FC operator and compressible weights.""" + # yapf: disable + weights = model_editor.Tensor( + shape=(4, 4), + dtype=tflite.TensorType.INT8, + data=np.array([[1, 2, 1, 2], + [3, 4, 3, 4], + [1, 2, 1, 2], + [3, 4, 3, 4]], dtype=np.int8), + name="weights", + quantization=model_editor.Quantization(scales=0.5, zero_points=0), + ) + # yapf: enable + input_t = model_editor.Tensor( + shape=(1, 4), + dtype=tflite.TensorType.INT8, + name="input", + ) + output_t = model_editor.Tensor( + shape=(1, 4), + dtype=tflite.TensorType.INT8, + name="output", + ) + + model = model_editor.Model(subgraphs=[ + model_editor.Subgraph( + tensors=[weights], + operators=[ + model_editor.Operator( + opcode=tflite.BuiltinOperator.FULLY_CONNECTED, + inputs=[input_t, weights], + outputs=[output_t], + ) + ], + ) + ]) + return model + + +def _build_shared_weights_model(): + """Build model where one tensor is used by multiple operators.""" + weights = model_editor.Tensor( + shape=(4, 4), + dtype=tflite.TensorType.INT8, + data=np.ones((4, 4), dtype=np.int8), + name="shared_weights", + quantization=model_editor.Quantization(scales=0.5, zero_points=0), + ) + input1 = model_editor.Tensor( + shape=(1, 4), + dtype=tflite.TensorType.INT8, + name="input1", + ) + input2 = model_editor.Tensor( + shape=(1, 4), + dtype=tflite.TensorType.INT8, + name="input2", + ) + output1 = model_editor.Tensor( + shape=(1, 4), + dtype=tflite.TensorType.INT8, + name="output1", + ) + output2 = model_editor.Tensor( + shape=(1, 4), + dtype=tflite.TensorType.INT8, + name="output2", + ) + + model = model_editor.Model(subgraphs=[ + model_editor.Subgraph( + tensors=[weights], + operators=[ + model_editor.Operator( + opcode=tflite.BuiltinOperator.FULLY_CONNECTED, + inputs=[input1, weights], + outputs=[output1], + ), + model_editor.Operator( + opcode=tflite.BuiltinOperator.FULLY_CONNECTED, + inputs=[input2, weights], + outputs=[output2], + ), + ], + ) + ]) + return model + + +def _make_dummy_ancillary_data() -> bytes: + """Create dummy ancillary data for testing.""" + dcm = decode.DecodeCommonMetadata( + decode_type=decode.DecodeType.LUT, + user_data=b'\x01\x04\x10' + b'\x00' * 9, # lut_version, bitwidth, stride + ) + value_tables = bytes([1, 2, 3, 4] + [0] * 12) # 16-byte padded table + return dcm.to_bytes() + value_tables + + +class TestDecodeInsertion(unittest.TestCase): + """Tests for insert_decode_operators function.""" + + def test_insert_single_decode_operator(self): + """DECODE operator inserted before FC that uses compressed weights.""" + model = _build_simple_fc_model() + weights_tensor = model.subgraphs[0].tensor_by_name("weights") + + # Create compression result + compression_results = { + (0, 0): + compressor.CompressionResult( + encoded_data=b'\x00\x00', + ancillary_data=_make_dummy_ancillary_data(), + ) + } + + # Insert DECODE operators + decode_insert.insert_decode_operators(model, compression_results) + + sg = model.subgraphs[0] + + # Should have 2 operators: DECODE then FC + self.assertEqual(len(sg.operators), 2) + self.assertEqual(sg.operators[0].opcode, tflite.BuiltinOperator.CUSTOM) + self.assertEqual(sg.operators[0].custom_code, + decode_insert.DECODE_CUSTOM_OP_NAME) + self.assertEqual(sg.operators[1].opcode, + tflite.BuiltinOperator.FULLY_CONNECTED) + + def test_decode_inputs_structure(self): + """DECODE operator has correct inputs: encoded tensor + ancillary.""" + model = _build_simple_fc_model() + weights_tensor = model.subgraphs[0].tensor_by_name("weights") + + compression_results = { + (0, 0): + compressor.CompressionResult( + encoded_data=b'\x00\x00', + ancillary_data=_make_dummy_ancillary_data(), + ) + } + + decode_insert.insert_decode_operators(model, compression_results) + + decode_op = model.subgraphs[0].operators[0] + + # DECODE has 2 inputs + self.assertEqual(len(decode_op.inputs), 2) + # First input is the encoded tensor (original weights) + self.assertIs(decode_op.inputs[0], weights_tensor) + # Second input is ancillary tensor + self.assertEqual(decode_op.inputs[1].dtype, tflite.TensorType.UINT8) + + def test_decode_output_structure(self): + """DECODE operator output has correct shape and dtype.""" + model = _build_simple_fc_model() + weights_tensor = model.subgraphs[0].tensor_by_name("weights") + + # Save original properties before rewrite + original_shape = weights_tensor.shape + original_dtype = weights_tensor.dtype + + compression_results = { + (0, 0): + compressor.CompressionResult( + encoded_data=b'\x00\x00', + ancillary_data=_make_dummy_ancillary_data(), + ) + } + + decode_insert.insert_decode_operators(model, compression_results) + + decode_op = model.subgraphs[0].operators[0] + output = decode_op.outputs[0] + + # Output matches original (pre-rewrite) tensor shape and dtype + self.assertEqual(output.shape, original_shape) + self.assertEqual(output.dtype, original_dtype) + + def test_consumer_rewired_to_decode_output(self): + """FC operator input rewired to use DECODE output.""" + model = _build_simple_fc_model() + weights_tensor = model.subgraphs[0].tensor_by_name("weights") + + compression_results = { + (0, 0): + compressor.CompressionResult( + encoded_data=b'\x00\x00', + ancillary_data=_make_dummy_ancillary_data(), + ) + } + + decode_insert.insert_decode_operators(model, compression_results) + + decode_op = model.subgraphs[0].operators[0] + fc_op = model.subgraphs[0].operators[1] + + # FC's second input (weights) should now be DECODE's output + self.assertIs(fc_op.inputs[1], decode_op.outputs[0]) + # Original weights tensor should NOT be in FC inputs + self.assertNotIn(weights_tensor, fc_op.inputs) + + def test_shared_tensor_decode_per_consumer(self): + """Tensor used by multiple ops gets separate DECODE for each consumer.""" + model = _build_shared_weights_model() + weights_tensor = model.subgraphs[0].tensor_by_name("shared_weights") + + compression_results = { + (0, 0): + compressor.CompressionResult( + encoded_data=b'\x00\x00', + ancillary_data=_make_dummy_ancillary_data(), + ) + } + + decode_insert.insert_decode_operators(model, compression_results) + + sg = model.subgraphs[0] + + # Should have 4 operators: 2 DECODEs + 2 FCs (DECODE before each FC) + self.assertEqual(len(sg.operators), 4) + self.assertEqual(sg.operators[0].opcode, tflite.BuiltinOperator.CUSTOM) + self.assertEqual(sg.operators[1].opcode, + tflite.BuiltinOperator.FULLY_CONNECTED) + self.assertEqual(sg.operators[2].opcode, tflite.BuiltinOperator.CUSTOM) + self.assertEqual(sg.operators[3].opcode, + tflite.BuiltinOperator.FULLY_CONNECTED) + + decode_op1 = sg.operators[0] + fc_op1 = sg.operators[1] + decode_op2 = sg.operators[2] + fc_op2 = sg.operators[3] + + # Each FC should use its own DECODE's output + self.assertIs(fc_op1.inputs[1], decode_op1.outputs[0]) + self.assertIs(fc_op2.inputs[1], decode_op2.outputs[0]) + # The two DECODEs should have different outputs + self.assertIsNot(decode_op1.outputs[0], decode_op2.outputs[0]) + # The two DECODEs should share the same ancillary tensor + self.assertIs(decode_op1.inputs[1], decode_op2.inputs[1]) + + def test_ancillary_tensor_contains_dcm(self): + """Ancillary tensor data contains valid DCM header.""" + model = _build_simple_fc_model() + + ancillary_data = _make_dummy_ancillary_data() + compression_results = { + (0, 0): + compressor.CompressionResult( + encoded_data=b'\x00\x00', + ancillary_data=ancillary_data, + ) + } + + decode_insert.insert_decode_operators(model, compression_results) + + decode_op = model.subgraphs[0].operators[0] + ancillary_tensor = decode_op.inputs[1] + + # Ancillary tensor data should match what we provided + self.assertEqual(bytes(ancillary_tensor.array), ancillary_data) + + # Verify DCM header + dcm_bytes = ancillary_tensor.array[:16] + self.assertEqual(dcm_bytes[0], 0) # decode_type = LUT + self.assertEqual(dcm_bytes[1], 1) # DCM version + + def test_no_consumers_no_decode(self): + """Tensor with no consumers gets no DECODE operator and emits warning.""" + # Create model where compressed tensor is not used as input + unused_tensor = model_editor.Tensor( + shape=(4, 4), + dtype=tflite.TensorType.INT8, + data=np.ones((4, 4), dtype=np.int8), + name="unused", + quantization=model_editor.Quantization(scales=0.5, zero_points=0), + ) + input_t = model_editor.Tensor( + shape=(1, 4), + dtype=tflite.TensorType.INT8, + name="input", + ) + output_t = model_editor.Tensor( + shape=(1, 4), + dtype=tflite.TensorType.INT8, + name="output", + ) + other_weights = model_editor.Tensor( + shape=(4, 4), + dtype=tflite.TensorType.INT8, + data=np.ones((4, 4), dtype=np.int8), + name="other_weights", + quantization=model_editor.Quantization(scales=0.5, zero_points=0), + ) + + model = model_editor.Model(subgraphs=[ + model_editor.Subgraph( + tensors=[unused_tensor, other_weights], + operators=[ + model_editor.Operator( + opcode=tflite.BuiltinOperator.FULLY_CONNECTED, + inputs=[input_t, other_weights], + outputs=[output_t], + ) + ], + ) + ]) + + # Compress the unused tensor + compression_results = { + (0, 0): + compressor.CompressionResult( + encoded_data=b'\x00\x00', + ancillary_data=_make_dummy_ancillary_data(), + ) + } + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + decode_insert.insert_decode_operators(model, compression_results) + + # Should emit a warning about no consumers + self.assertEqual(len(w), 1) + self.assertIn("no consumers", str(w[0].message)) + self.assertIn("unused", str(w[0].message)) + + # Should still have just 1 operator (no DECODE inserted) + self.assertEqual(len(model.subgraphs[0].operators), 1) + + def test_tensor_naming(self): + """Output and ancillary tensors get appropriate names.""" + model = _build_simple_fc_model() + + compression_results = { + (0, 0): + compressor.CompressionResult( + encoded_data=b'\x00\x00', + ancillary_data=_make_dummy_ancillary_data(), + ) + } + + decode_insert.insert_decode_operators(model, compression_results) + + decode_op = model.subgraphs[0].operators[0] + ancillary = decode_op.inputs[1] + output = decode_op.outputs[0] + + self.assertEqual(ancillary.name, "weights_ancillary") + self.assertEqual(output.name, "weights_decoded") + + def test_encoded_tensor_rewritten(self): + """Compressed tensor is rewritten with encoded data, UINT8 type, no quant.""" + model = _build_simple_fc_model() + weights_tensor = model.subgraphs[0].tensor_by_name("weights") + + encoded_data = b'\xAB\xCD\xEF' + compression_results = { + (0, 0): + compressor.CompressionResult( + encoded_data=encoded_data, + ancillary_data=_make_dummy_ancillary_data(), + ) + } + + decode_insert.insert_decode_operators(model, compression_results) + + # Original tensor should be rewritten + self.assertEqual(weights_tensor.shape, (len(encoded_data), )) + self.assertEqual(weights_tensor.dtype, tflite.TensorType.UINT8) + self.assertIsNone(weights_tensor.quantization) + self.assertEqual(weights_tensor.buffer.data, encoded_data) + + +class TestHelperFunctions(unittest.TestCase): + """Tests for internal helper functions.""" + + def test_find_tensor_consumers(self): + """_find_tensor_consumers finds all ops using a tensor.""" + model = _build_shared_weights_model() + sg = model.subgraphs[0] + weights = sg.tensor_by_name("shared_weights") + + consumers = decode_insert._find_tensor_consumers(sg, weights) + + self.assertEqual(len(consumers), 2) + + +if __name__ == "__main__": + unittest.main() From b10c147d4767a24a6e2cd89a3ca2b44cadac436c Mon Sep 17 00:00:00 2001 From: Ryan Kuester Date: Sun, 24 May 2026 23:30:48 -0500 Subject: [PATCH 12/21] refactor(compression): use plugin architecture in compress.py Replace monolithic compression logic with a dispatch table that routes compression requests to plugin modules based on the spec's compression method type. After compressing tensors, insert DECODE operators into the model graph. Warn when compression expands data, helping users identify tensors that don't benefit from compression. BUG=part of #3256 --- tensorflow/lite/micro/compression/BUILD | 13 +- tensorflow/lite/micro/compression/compress.py | 310 ++------ .../lite/micro/compression/compress_test.py | 700 ++++++++---------- 3 files changed, 389 insertions(+), 634 deletions(-) diff --git a/tensorflow/lite/micro/compression/BUILD b/tensorflow/lite/micro/compression/BUILD index c5a0b6c2753..09a77ce407a 100644 --- a/tensorflow/lite/micro/compression/BUILD +++ b/tensorflow/lite/micro/compression/BUILD @@ -123,14 +123,15 @@ py_library( "compress.py", ], deps = [ - ":metadata_py", + ":compressor", + ":decode_insert", + ":huffman", + ":lut", ":model_editor", + ":pruning", ":spec", "//tensorflow/lite/micro/tools:tflite_flatbuffer_align", requirement("absl_py"), - requirement("flatbuffers"), - requirement("bitarray"), - requirement("numpy"), ], ) @@ -159,11 +160,11 @@ py_test( target_compatible_with = INCOMPATIBLE_WITH_WINDOWS, deps = [ ":compress", - ":metadata_py", + ":compressor", + ":decode_insert", ":model_editor", ":spec", "//tensorflow/lite/python:schema_py", - requirement("bitarray"), requirement("numpy"), ], ) diff --git a/tensorflow/lite/micro/compression/compress.py b/tensorflow/lite/micro/compression/compress.py index b6d5aef4435..270951fecf8 100644 --- a/tensorflow/lite/micro/compression/compress.py +++ b/tensorflow/lite/micro/compression/compress.py @@ -16,22 +16,22 @@ See USAGE. """ -import bitarray -import bitarray.util -from dataclasses import dataclass, field import os import sys import tempfile -from typing import ByteString, Iterable, Optional +import warnings +from typing import ByteString, Iterable, Type import absl.app import absl.flags -import flatbuffers -import numpy as np +from tflite_micro.tensorflow.lite.micro.compression import compressor +from tflite_micro.tensorflow.lite.micro.compression import decode_insert +from tflite_micro.tensorflow.lite.micro.compression import huffman +from tflite_micro.tensorflow.lite.micro.compression import lut from tflite_micro.tensorflow.lite.micro.compression import model_editor +from tflite_micro.tensorflow.lite.micro.compression import pruning from tflite_micro.tensorflow.lite.micro.compression import spec -from tflite_micro.tensorflow.lite.micro.compression import metadata_py_generated as schema from tflite_micro.tensorflow.lite.micro.tools import tflite_flatbuffer_align_wrapper USAGE = f"""\ @@ -49,221 +49,48 @@ {spec.EXAMPLE_YAML_SPEC} --- -The only compression method currently implemented is "lut", i.e., -Look-Up-Table. This method requires the tensor in the input model to have a -small number of unique values, fewer than or equal to 2**index_bitwidth. LUT -compression collects these values into a lookup table, and rewrites the tensor -as bitwidth-wide integer indices into that lookup table. Presumably, the input -model has been trained or preprocessed in a way that the tensor values -are binned into a meaningful, limited set. -""" - -# A compressed model augments the usual .tflite flatbuffer with a flatbuffer of -# its own containing compression metadata, stored at the buffer index stored at -# the following key in the .tflite flatbuffer's metadata map. -TFLITE_METADATA_KEY = "COMPRESSION_METADATA" - - -class CompressionError(Exception): - """Raised when compression fails for the reason documented in the message.""" - - def __init__(self, message, wrapped_exception=None): - super().__init__(f"{message}: {str(wrapped_exception)}") - self.original_exception = wrapped_exception - - -class _MetadataBuilder: - """Builder for the compression metadata flatbuffer.""" - - def __init__(self): - self._metadata = schema.MetadataT() - self._metadata.subgraphs = [] - - def compile(self) -> bytearray: - """Packs the metadata into a binary array and returns it. - """ - builder = flatbuffers.Builder(1 * 2**10) - root = self._metadata.Pack(builder) - builder.Finish(root) - return builder.Output() - - def subgraph(self, index: int): - """Return subgraph at index, adding subgraphs if necessary. - """ - while len(self._metadata.subgraphs) <= index: - self._add_subgraph() - return self._metadata.subgraphs[index] - - def add_lut_tensor(self, subgraph_id: int): - """Add LUT tensor to the given subgraph and return it. - """ - tensor = schema.LutTensorT() - self.subgraph(subgraph_id).lutTensors.append(tensor) - return tensor - - def _add_subgraph(self): - subgraph = schema.SubgraphT() - subgraph.lutTensors = [] - self._metadata.subgraphs.append(subgraph) - return subgraph - - -@dataclass -class _LutCompressedArray: - compression_axis: Optional[int] = None - lookup_tables: list[np.ndarray] = field(default_factory=list) - indices: np.ndarray = field(default_factory=lambda: np.array([])) - - @property - def index_bitwidth(self) -> int: - """Returns the number of bits required to encode the indices.""" - if self.indices is None: - raise ValueError - - max_index = int(np.max(self.indices)) - return max_index.bit_length() or 1 - - -def _lut_compress_array(tensor: np.ndarray, - axis: Optional[int]) -> _LutCompressedArray: - """Compresses the given tensor using lookup tables. - - Args: - tensor (np.ndarray): The tensor to be compressed. - - axis (Optional[int]): The axis along which to compress the tensor. If an - axis is given, a lookup table is created for each slice along the - axis. If axis is None, a single lookup table is used for the entire - tensor. - - Compressing a tensor with a lookup table per slice along a - particular axis is analogous to quantizing a tensor with different - quantization parameters per slice along a particular axis (dimension). - - Returns: - _LutCompressedArray: An object containing the compressed tensor data, - including the lookup tables and indices. - """ - compressed = _LutCompressedArray() - compressed.compression_axis = axis - - if axis is None: - # Compute unique values and indices for the entire tensor - values, indices = np.unique(tensor, return_inverse=True) - compressed.lookup_tables.append(values) - compressed.indices = indices.reshape(tensor.shape) - else: - # Iterate over slices along the compression axis - slice_indices = [] - for slice in np.moveaxis(tensor, axis, 0): - values, indices = np.unique(slice, return_inverse=True) - compressed.lookup_tables.append(values) - indices = indices.reshape(slice.shape) - slice_indices.append(indices) - - # Reconstruct a tensor of indices from the slices - stacked = np.stack(slice_indices, axis=0) - compressed.indices = np.moveaxis(stacked, 0, axis) - - return compressed - - -def _check_lut_compression(compression) -> spec.LookUpTableCompression: - if len(compression) != 1: - raise CompressionError("Each tensor must have exactly one compression") - if not isinstance(compression[0], spec.LookUpTableCompression): - raise CompressionError('Only "lut" compression may be specified') - - return compression[0] - - -def _identify_compression_axis(tensor: model_editor.Tensor) -> Optional[int]: - """Determines the axis along which to compress. - - The axis along which to compress is inferred from the tensor's quantization - parameters. - - Returns: - The axis along which to compress, or None to indicate one value table for - the entire tensor. - - Raises: - CompressionError: If the axis cannot be determined. - """ - q = tensor.quantization - if q is not None: - # model_editor wraps quantization, access scales/axis from wrapper - scales = q.scales if isinstance(q.scales, list) else [q.scales] - quantization_channels = len(scales) +Supported compression methods: - if quantization_channels == 1: - # Use one value table for the entire tensor - return None + lut: Look-Up-Table compression. Requires the tensor to have a small number of + unique values, fewer than or equal to 2**index_bitwidth. LUT compression + collects these values into a lookup table, and rewrites the tensor as + bitwidth-wide integer indices into that lookup table. - if q.axis is not None and q.axis < len(tensor.shape): - if quantization_channels == tensor.shape[q.axis]: - return q.axis + huffman: Huffman compression using Xtensa-format decode tables. (Not yet + implemented.) - raise CompressionError( - f"Invalid or no quanitzation parameters from which to " - f"infer the axis along which tensor should be compressed.") - - -def _check_bitwidth(compressed: int, specified: int, spec: spec.Tensor): - """Applies business logic regarding specified bitwidth. - - It is an error if the bitwidth required to compress a tensor exceeds the - specified bitwith, and a warning if the tensor can be compressed in less than - the specified bitwidth. The latter is allowed, and is not an error, to permit - testing with larger bitwidths without re-binning a model. - """ - if compressed > specified: - raise CompressionError( - f"index_bitwidth too small: {compressed} bits needed to " - f"enumerate unique values in tensor specified in {spec}") - elif compressed < specified: - print( - f"warning: index_bitwidth too large: only {compressed} " - f"bits needed to enumerate unique values in tensor specified in {spec}", - file=sys.stderr) - - -def _pack_indices(indices: np.ndarray, bitwidth: int) -> bytes: - """Packs indices into a bytearray using bitwidth-sized fields. - """ - endianness = "big" - bits = bitarray.bitarray(endian=endianness) - for i in indices.ravel(): - bits.extend( - bitarray.util.int2ba(int(i), length=bitwidth, endian=endianness)) - return bits.tobytes() + pruning: Pruning (sparsity) compression for sparse tensors. (Not yet + implemented.) +Compressed models use DECODE operators to decompress tensors at runtime. +""" -def _pack_lookup_tables(tables: list[np.ndarray], table_len: int) -> bytearray: - """Packs the value tables of a LutCompressedArray. +# Plugin dispatch table: maps CompressionMethod subclasses to compressor instances +_COMPRESSORS: dict[Type[spec.CompressionMethod], compressor.Compressor] = { + spec.LookUpTableCompression: lut.LutCompressor(), + spec.HuffmanCompression: huffman.HuffmanCompressor(), + spec.PruningCompression: pruning.PruningCompressor(), +} - Pack the value tables of a LutCompressedArray into a bytes object in the - format writable to a value_table buffer in the .tflite flatbuffer. The - tables are concatenated. - """ - buffer = bytearray() - for t in tables: - padding_needed = table_len - len(t) - padded = np.pad(t, (0, padding_needed), mode='constant', constant_values=0) - buffer.extend(padded.tobytes()) - return buffer +def _get_compressor(method: spec.CompressionMethod) -> compressor.Compressor: + """Get the compressor plugin for a given compression method.""" + compressor_instance = _COMPRESSORS.get(type(method)) + if compressor_instance is None: + raise compressor.CompressionError( + f"No compressor registered for {type(method).__name__}") + return compressor_instance def _apply_flatbuffer_alignment(model_bytes: bytearray) -> bytearray: """Applies proper FlatBuffer alignment to a model. - + The Python flatbuffers library doesn't respect `force_align` schema attributes, so we use the C++ wrapper which properly handles alignment requirements. - + Args: model_bytes: The model flatbuffer to align - + Returns: The properly aligned model flatbuffer """ @@ -295,45 +122,58 @@ def _apply_flatbuffer_alignment(model_bytes: bytearray) -> bytearray: def compress(model_in: ByteString, specs: Iterable[spec.Tensor]) -> bytearray: """Compresses a model .tflite flatbuffer. + Compresses tensors according to the given specs and inserts DECODE operators + to decompress them at runtime. + Args: model_in: the original, uncompressed .tflite flatbuffer specs: an iterable of compression specs, see module spec.py Returns: - A compressed flatbuffer. + A compressed flatbuffer with DECODE operators inserted. """ model = model_editor.read(model_in) - metadata = _MetadataBuilder() + compression_results: dict[tuple[int, int], compressor.CompressionResult] = {} - for spec in specs: + for tensor_spec in specs: try: - tensor = model.subgraphs[spec.subgraph].tensors[spec.tensor] - lut_compression = _check_lut_compression(spec.compression) - spec_bitwidth = lut_compression.index_bitwidth - axis = _identify_compression_axis(tensor) - compressed = _lut_compress_array(tensor.array, axis) - _check_bitwidth(compressed.index_bitwidth, spec_bitwidth, spec) - - # overwrite tensor data with indices - tensor.buffer.data = _pack_indices(compressed.indices, spec_bitwidth) - - # write value buffer - value_buffer_data = _pack_lookup_tables(compressed.lookup_tables, - 2**spec_bitwidth) - value_buffer = model_editor.Buffer(data=value_buffer_data) - model.buffers.append(value_buffer) # Auto-sets value_buffer.index - - # add compression metadata for tensor - lut_tensor = metadata.add_lut_tensor(subgraph_id=spec.subgraph) - lut_tensor.tensor = spec.tensor - lut_tensor.valueBuffer = value_buffer.index - lut_tensor.indexBitwidth = spec_bitwidth - + tensor = model.subgraphs[tensor_spec.subgraph].tensors[ + tensor_spec.tensor] + + # Currently only one compression method per tensor + if len(tensor_spec.compression) != 1: + raise compressor.CompressionError( + "Each tensor must have exactly one compression method") + + method = tensor_spec.compression[0] + plugin = _get_compressor(method) + original_size = len(tensor.buffer.data) if tensor.buffer.data else 0 + result = plugin.compress(tensor, method) + + compressed_size = len(result.encoded_data) + len(result.ancillary_data) + if compressed_size > original_size: + warnings.warn( + f"Compression of tensor {tensor.name!r} (subgraph " + f"{tensor_spec.subgraph}, tensor {tensor_spec.tensor}) resulted " + f"in expansion: {original_size} bytes -> {compressed_size} bytes " + f"(encoded: {len(result.encoded_data)}, " + f"ancillary: {len(result.ancillary_data)})", + stacklevel=2) + + # Replace tensor data with encoded data + tensor.buffer.data = result.encoded_data + + # Store result for DECODE insertion + compression_results[(tensor_spec.subgraph, tensor_spec.tensor)] = result + + except compressor.CompressionError: + raise except Exception as e: - raise CompressionError(f"error compressing {spec}") from e + raise compressor.CompressionError( + f"error compressing {tensor_spec}") from e - # add compression metadata to model - model.metadata[TFLITE_METADATA_KEY] = metadata.compile() + # Insert DECODE operators into the graph + decode_insert.insert_decode_operators(model, compression_results) # Build the model and apply proper alignment unaligned_model = model.build() diff --git a/tensorflow/lite/micro/compression/compress_test.py b/tensorflow/lite/micro/compression/compress_test.py index 81bbdab3293..cb241c2c62f 100644 --- a/tensorflow/lite/micro/compression/compress_test.py +++ b/tensorflow/lite/micro/compression/compress_test.py @@ -11,164 +11,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""Integration tests for the compression system.""" + +import warnings -import bitarray -import bitarray.util import numpy as np import unittest from tflite_micro.tensorflow.lite.micro.compression import compress -from tflite_micro.tensorflow.lite.micro.compression import metadata_py_generated as schema +from tflite_micro.tensorflow.lite.micro.compression import compressor +from tflite_micro.tensorflow.lite.micro.compression import decode_insert from tflite_micro.tensorflow.lite.micro.compression import model_editor from tflite_micro.tensorflow.lite.micro.compression import spec from tflite_micro.tensorflow.lite.python import schema_py_generated as tflite -class TestPackIndices(unittest.TestCase): - - def test_basic_case(self): - indices = np.array([1, 2, 3]) - bitwidth = 4 - result = compress._pack_indices(indices, bitwidth) - expected_bytes = bytes([0b0001_0010, 0b0011_0000]) - self.assertEqual(result, expected_bytes) - - def test_single_element(self): - indices = np.array([10]) - bitwidth = 8 - result = compress._pack_indices(indices, bitwidth) - expected_bytes = bytes([0b0000_1010]) - self.assertEqual(result, expected_bytes) - - def test_different_bitwidth(self): - indices = np.array([1, 2, 3]) - bitwidth = 8 - result = compress._pack_indices(indices, bitwidth) - expected_bytes = bytes([0b0000_0001, 0b0000_0010, 0b0000_0011]) - self.assertEqual(result, expected_bytes) - - def test_large_numbers(self): - indices = np.array([255, 128, 64]) - bitwidth = 8 - result = compress._pack_indices(indices, bitwidth) - expected_bytes = bytes([0b1111_1111, 0b1000_0000, 0b0100_0000]) - self.assertEqual(result, expected_bytes) - - def test_multidimensional_array(self): - indices = np.array([[1, 2], [3, 4]]) - bitwidth = 4 - result = compress._pack_indices(indices, bitwidth) - expected_bytes = bytes([0b0001_0010, 0b0011_0100]) - self.assertEqual(result, expected_bytes) - - def test_zero_bitwidth(self): - indices = np.array([0, 1, 2]) - bitwidth = 0 - with self.assertRaises(ValueError): - compress._pack_indices(indices, bitwidth) - - def test_empty_array(self): - indices = np.array([]) - bitwidth = 4 - result = compress._pack_indices(indices, bitwidth) - expected_bytes = b"" - self.assertEqual(result, expected_bytes) - - def test_bitwidth_1(self): - indices = np.array([1, 0, 1, 1, 0, 1]) - bitwidth = 1 - result = compress._pack_indices(indices, bitwidth) - expected_bytes = bytes([0b101101_00]) - self.assertEqual(result, expected_bytes) - - def test_bitwidth_2(self): - indices = np.array([1, 2, 3, 0]) - bitwidth = 2 - result = compress._pack_indices(indices, bitwidth) - expected_bytes = bytes([0b01_10_11_00]) - self.assertEqual(result, expected_bytes) - - def test_bitwidth_3(self): - indices = np.array([1, 3, 5, 7]) - bitwidth = 3 - result = compress._pack_indices(indices, bitwidth) - expected_bytes = bytes([0b001_011_10, 0b1_111_0000]) - self.assertEqual(result, expected_bytes) - - def test_bitwidth_5(self): - indices = np.array([1, 2, 16, 31]) - bitwidth = 5 - result = compress._pack_indices(indices, bitwidth) - expected_bytes = bytes([0b00001_000, 0b10_10000_1, 0b1111_0000]) - self.assertEqual(result, expected_bytes) - - def test_bitwidth_7(self): - indices = np.array([1, 64, 127, 32]) - bitwidth = 7 - result = compress._pack_indices(indices, bitwidth) - expected_bytes = bytes( - [0b0000001_1, 0b000000_11, 0b11111_010, 0b0000_0000]) - self.assertEqual(result, expected_bytes) - - -class TestPackLookupTables(unittest.TestCase): - - def test_int16_positive(self): - tables = [np.array([0x1234, 0x5678], dtype=' tuple[int, bitarray.bitarray, np.ndarray]: - """Helper: extracts the compressed tensor parts for a given spec. - - Returns: - bitwidth - indices - values - """ - subgraph_obj = self.compressed.subgraphs[subgraph] - tensor_obj = subgraph_obj.tensors[tensor] - lut_tensors = self.metadata.subgraphs[subgraph_obj.index].lutTensors - lut_tensor = next(t for t in lut_tensors if t.tensor == tensor_obj.index) - bitwidth = lut_tensor.indexBitwidth - - indices = bitarray.bitarray(buffer=tensor_obj.buffer.data, endian="big") - n_indices = np.prod(tensor_obj.shape) - indices = indices[:n_indices * bitwidth] # trim possible padding - - value_buffer = self.compressed.buffers[lut_tensor.valueBuffer] - values = np.frombuffer(value_buffer.data, dtype=tensor_obj.numpy_dtype) - - return bitwidth, indices, values - - def _make_indices(self, s: str) -> bitarray.bitarray: - """Helper: makes indices from "01" strings for use as expected values.""" - return bitarray.bitarray(s, endian="big") - - def test_compressed_uint8(self): - bitwidth, indices, values = self._get_compressed(subgraph=0, tensor=0) - self.assertEqual(bitwidth, 4) - - # yapf: disable - expected_indices = self._make_indices(""" - 0000 0001 0010 0011 - 0100 0101 0110 0111 - 1000 1001 1010 1011 - 1100 1101 1110 1111 - """) - # yapf: enable - self.assertEqual(indices, expected_indices) - - expected_values = np.array(range(16), dtype=" Date: Sun, 24 May 2026 23:33:10 -0500 Subject: [PATCH 13/21] test(compression): add integration tests with TFLM interpreter Add tests that compress models with LUT compression, run them through the TFLM Python interpreter, and verify outputs match uncompressed originals. Cover per-tensor and per-channel quantization, various index bitwidths, unquantized weights, and alternate decompression memory. BUG=part of #3256 --- tensorflow/lite/micro/compression/BUILD | 25 + .../compression_integration_test.py | 505 ++++++++++++++++++ 2 files changed, 530 insertions(+) create mode 100644 tensorflow/lite/micro/compression/compression_integration_test.py diff --git a/tensorflow/lite/micro/compression/BUILD b/tensorflow/lite/micro/compression/BUILD index 09a77ce407a..c7262c6f366 100644 --- a/tensorflow/lite/micro/compression/BUILD +++ b/tensorflow/lite/micro/compression/BUILD @@ -169,6 +169,31 @@ py_test( ], ) +tflm_py_test( + name = "compression_integration_test", + size = "small", + srcs = ["compression_integration_test.py"], + tags = [ + "noasan", + "nomsan", + "noubsan", + ], + # Only run when compression IS enabled + target_compatible_with = select({ + "//:with_compression_enabled": [], + "//conditions:default": ["@platforms//:incompatible"], + }), + deps = [ + ":compress_lib", + ":decode_insert", + ":model_editor", + ":spec", + "//python/tflite_micro:runtime", + "//tensorflow/lite/python:schema_py", + requirement("numpy"), + ], +) + tflm_py_library( name = "spec", srcs = ["spec.py"], diff --git a/tensorflow/lite/micro/compression/compression_integration_test.py b/tensorflow/lite/micro/compression/compression_integration_test.py new file mode 100644 index 00000000000..0e92a527f6a --- /dev/null +++ b/tensorflow/lite/micro/compression/compression_integration_test.py @@ -0,0 +1,505 @@ +# Copyright 2026 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Integration tests for compression with TFLM interpreter. + +These tests verify that compressed models produce correct inference results +when run through the TFLM Python interpreter. Tests compress models and +compare outputs against uncompressed originals. + +These tests only run when compression is enabled (--//:with_compression). +""" + +import os +import unittest +import numpy as np + +from tflite_micro.python.tflite_micro import runtime +from tflite_micro.tensorflow.lite.micro.compression import compress +from tflite_micro.tensorflow.lite.micro.compression import decode_insert +from tflite_micro.tensorflow.lite.micro.compression import model_editor +from tflite_micro.tensorflow.lite.micro.compression import spec +from tflite_micro.tensorflow.lite.python import schema_py_generated as tflite + + +def _build_compressible_model(weight_shape=(4, 4), + index_bitwidth=2, + per_channel=False, + unquantized=False): + """Build a model with clustered weights for compression testing. + + Args: + weight_shape: Shape of the weight tensor as (rows, cols). + index_bitwidth: Bits per index. Determines unique value count (2^bitwidth). + per_channel: If True, use per-channel quantization (one scale per row). + unquantized: If True, omit quantization from weights. + + Returns: + A TFLite flatbuffer (bytes) containing a simple FULLY_CONNECTED model + with weights that have limited unique values per channel. + """ + rows, cols = weight_shape + unique_count = 2**index_bitwidth + + # Create weights with limited unique values per channel + pattern = np.arange(1, unique_count + 1, dtype=np.int8) + weight_data = np.resize(pattern, (rows, cols)) + + if unquantized: + quantization = None + elif per_channel: + # Per-channel: one scale per output channel (row in FC weights) + scales = [0.5 + 0.1 * i for i in range(rows)] + zero_points = [0] * rows + quantization = model_editor.Quantization( + scales=scales, + zero_points=zero_points, + axis=0, + ) + else: + quantization = model_editor.Quantization(scales=0.5, zero_points=0) + + weights = model_editor.Tensor( + shape=weight_shape, + dtype=tflite.TensorType.INT8, + data=weight_data, + name="weights", + quantization=quantization, + ) + + input_t = model_editor.Tensor( + shape=(1, cols), + dtype=tflite.TensorType.INT8, + name="input", + ) + output_t = model_editor.Tensor( + shape=(1, rows), + dtype=tflite.TensorType.INT8, + name="output", + ) + + model = model_editor.Model(subgraphs=[ + model_editor.Subgraph( + tensors=[weights], + inputs=[input_t], + outputs=[output_t], + operators=[ + model_editor.Operator( + opcode=tflite.BuiltinOperator.FULLY_CONNECTED, + inputs=[input_t, weights], + outputs=[output_t], + ) + ], + ) + ]) + return model.build() + + +class LutCompressionTest(unittest.TestCase): + """Integration tests for LUT (lookup table) compression.""" + + def test_lut_compressed_model_matches_uncompressed(self): + """LUT-compressed model produces same outputs as uncompressed.""" + flatbuffer = _build_compressible_model() + + # Create compression spec for weights tensor (index 0 in tensors list) + specs = [ + spec.Tensor( + subgraph=0, + tensor=0, + compression=[spec.LookUpTableCompression(index_bitwidth=2)], + ) + ] + + # Compress + compressed_fb = compress.compress(flatbuffer, specs) + + # Run inference on both (convert bytearray to bytes for interpreter) + uncompressed_interp = runtime.Interpreter.from_bytes(bytes(flatbuffer)) + compressed_interp = runtime.Interpreter.from_bytes(bytes(compressed_fb)) + + # Test with multiple random inputs + np.random.seed(42) + for _ in range(10): + test_input = np.random.randint(-128, 127, (1, 4), dtype=np.int8) + + uncompressed_interp.set_input(test_input, 0) + uncompressed_interp.invoke() + expected = uncompressed_interp.get_output(0) + + compressed_interp.set_input(test_input, 0) + compressed_interp.invoke() + actual = compressed_interp.get_output(0) + + np.testing.assert_array_equal(expected, actual) + + def test_lut_decode_operators_present(self): + """DECODE operators are inserted for LUT-compressed tensors.""" + flatbuffer = _build_compressible_model() + + specs = [ + spec.Tensor( + subgraph=0, + tensor=0, + compression=[spec.LookUpTableCompression(index_bitwidth=2)], + ) + ] + + compressed_fb = compress.compress(flatbuffer, specs) + model = model_editor.read(compressed_fb) + sg = model.subgraphs[0] + + # Find DECODE operators + decode_ops = [ + op for op in sg.operators if op.opcode == tflite.BuiltinOperator.CUSTOM + and op.custom_code == decode_insert.DECODE_CUSTOM_OP_NAME + ] + + self.assertEqual(len(decode_ops), 1) + + def test_lut_compressed_model_is_smaller(self): + """LUT-compressed model is smaller than original. + + Uses a large enough weight tensor (64x64 = 4096 bytes) that compression + savings outweigh the overhead from lookup tables and DECODE operators. + With 2-bit indices, 4096 bytes becomes 1024 bytes of indices. + """ + flatbuffer = _build_compressible_model(weight_shape=(64, 64)) + + specs = [ + spec.Tensor( + subgraph=0, + tensor=0, + compression=[spec.LookUpTableCompression(index_bitwidth=2)], + ) + ] + + compressed_fb = compress.compress(flatbuffer, specs) + + original_size = len(flatbuffer) + compressed_size = len(compressed_fb) + + self.assertLess( + compressed_size, original_size, + f"Compressed model ({compressed_size} bytes) should be smaller than " + f"original ({original_size} bytes)") + + def test_lut_4bit_compression(self): + """4-bit LUT compression produces correct inference results.""" + flatbuffer = _build_compressible_model(index_bitwidth=4) + + specs = [ + spec.Tensor( + subgraph=0, + tensor=0, + compression=[spec.LookUpTableCompression(index_bitwidth=4)], + ) + ] + + compressed_fb = compress.compress(flatbuffer, specs) + + uncompressed_interp = runtime.Interpreter.from_bytes(bytes(flatbuffer)) + compressed_interp = runtime.Interpreter.from_bytes(bytes(compressed_fb)) + + test_input = np.array([[1, 2, 3, 4]], dtype=np.int8) + + uncompressed_interp.set_input(test_input, 0) + uncompressed_interp.invoke() + expected = uncompressed_interp.get_output(0) + + compressed_interp.set_input(test_input, 0) + compressed_interp.invoke() + actual = compressed_interp.get_output(0) + + np.testing.assert_array_equal(expected, actual) + + def test_lut_per_channel_quantization(self): + """Per-channel quantized weights compress and decompress correctly.""" + flatbuffer = _build_compressible_model(per_channel=True) + + specs = [ + spec.Tensor( + subgraph=0, + tensor=0, + compression=[spec.LookUpTableCompression(index_bitwidth=2)], + ) + ] + + compressed_fb = compress.compress(flatbuffer, specs) + + uncompressed_interp = runtime.Interpreter.from_bytes(bytes(flatbuffer)) + compressed_interp = runtime.Interpreter.from_bytes(bytes(compressed_fb)) + + test_input = np.array([[1, 2, 3, 4]], dtype=np.int8) + + uncompressed_interp.set_input(test_input, 0) + uncompressed_interp.invoke() + expected = uncompressed_interp.get_output(0) + + compressed_interp.set_input(test_input, 0) + compressed_interp.invoke() + actual = compressed_interp.get_output(0) + + np.testing.assert_array_equal(expected, actual) + + def test_lut_unquantized_weights(self): + """Unquantized weights compress and decompress correctly.""" + flatbuffer = _build_compressible_model(unquantized=True) + + specs = [ + spec.Tensor( + subgraph=0, + tensor=0, + compression=[spec.LookUpTableCompression(index_bitwidth=2)], + ) + ] + + compressed_fb = compress.compress(flatbuffer, specs) + + uncompressed_interp = runtime.Interpreter.from_bytes(bytes(flatbuffer)) + compressed_interp = runtime.Interpreter.from_bytes(bytes(compressed_fb)) + + test_input = np.array([[1, 2, 3, 4]], dtype=np.int8) + + uncompressed_interp.set_input(test_input, 0) + uncompressed_interp.invoke() + expected = uncompressed_interp.get_output(0) + + compressed_interp.set_input(test_input, 0) + compressed_interp.invoke() + actual = compressed_interp.get_output(0) + + np.testing.assert_array_equal(expected, actual) + + +def _build_shared_weights_model(): + """Build a model where one compressed tensor is shared between two operators. + + Model structure: + input1 -> [FC1 with weights1] -> output1 + input2 -> [FC2 with weights2] -> intermediate -> [FC3 with weights1] -> output2 + + weights1 is shared between FC1 and FC3. weights2 is used only by FC2, which + runs between the two consumers of weights1. + """ + # 4 unique values per tensor for 2-bit LUT compression. Small values avoid + # saturation in chained layers. Different row sums produce varied outputs. + weights1_data = np.array([ + [-1, 0, 0, 1], + [-1, 0, 1, 1], + [-1, 1, 1, 1], + [0, 1, 1, 1], + ], + dtype=np.int8) + weights1 = model_editor.Tensor( + shape=(4, 4), + dtype=tflite.TensorType.INT8, + data=weights1_data, + name="weights1", + quantization=model_editor.Quantization(scales=1.0, zero_points=0), + ) + + weights2_data = np.array([ + [1, 1, 1, 1], + [1, 1, 2, 2], + [1, 2, 2, 3], + [2, 2, 3, 3], + ], + dtype=np.int8) + weights2 = model_editor.Tensor( + shape=(4, 4), + dtype=tflite.TensorType.INT8, + data=weights2_data, + name="weights2", + quantization=model_editor.Quantization(scales=1.0, zero_points=0), + ) + + # All tensors need matching quantization for FULLY_CONNECTED + quant = model_editor.Quantization(scales=1.0, zero_points=0) + + input1 = model_editor.Tensor( + shape=(1, 4), + dtype=tflite.TensorType.INT8, + name="input1", + quantization=quant, + ) + input2 = model_editor.Tensor( + shape=(1, 4), + dtype=tflite.TensorType.INT8, + name="input2", + quantization=quant, + ) + output1 = model_editor.Tensor( + shape=(1, 4), + dtype=tflite.TensorType.INT8, + name="output1", + quantization=quant, + ) + intermediate = model_editor.Tensor( + shape=(1, 4), + dtype=tflite.TensorType.INT8, + name="intermediate", + quantization=quant, + ) + output2 = model_editor.Tensor( + shape=(1, 4), + dtype=tflite.TensorType.INT8, + name="output2", + quantization=quant, + ) + + model = model_editor.Model(subgraphs=[ + model_editor.Subgraph( + tensors=[weights1, weights2], + inputs=[input1, input2], + outputs=[output1, output2], + operators=[ + # FC1: uses weights1 + model_editor.Operator( + opcode=tflite.BuiltinOperator.FULLY_CONNECTED, + inputs=[input1, weights1], + outputs=[output1], + ), + # FC2: uses weights2 (runs between FC1 and FC3) + model_editor.Operator( + opcode=tflite.BuiltinOperator.FULLY_CONNECTED, + inputs=[input2, weights2], + outputs=[intermediate], + ), + # FC3: uses weights1 (second consumer, after DECODE(weights2)) + model_editor.Operator( + opcode=tflite.BuiltinOperator.FULLY_CONNECTED, + inputs=[intermediate, weights1], + outputs=[output2], + ), + ], + ) + ]) + return model.build() + + +class AltDecompressionMemoryTest(unittest.TestCase): + """Tests for alternate decompression memory with shared compressed tensors. + + These tests verify correct behavior when compressed tensors are shared + between multiple operators and alternate decompression memory is enabled. + """ + + def test_shared_compressed_tensor_with_alt_memory(self): + """Verify correct results when a shared compressed tensor is used with alt + decompression memory. + + This test uses a graph where a compressed tensor (weights1) is consumed by + two operators (FC1 and FC3), with an intervening DECODE of a different + compressed tensor (weights2) between them. + + The interpreter's alternate decompression memory has a limitation: each + DECODE's Prepare resets the allocation offset to zero. This means all + DECODE outputs are allocated at the same address, so they overwrite each + other. A DECODE output can only be used until the next DECODE runs. + + To work around this limitation, the DECODE insertion code inserts a + separate DECODE immediately before each consumer of a compressed tensor, + rather than sharing one DECODE output among all consumers. + """ + flatbuffer = _build_shared_weights_model() + + specs = [ + spec.Tensor( + subgraph=0, + tensor=0, # weights1 + compression=[spec.LookUpTableCompression(index_bitwidth=2)], + ), + spec.Tensor( + subgraph=0, + tensor=1, # weights2 + compression=[spec.LookUpTableCompression(index_bitwidth=2)], + ), + ] + + compressed_fb = compress.compress(flatbuffer, specs) + + # Run without alt decompression memory (baseline) + interp_no_alt = runtime.Interpreter.from_bytes(bytes(compressed_fb)) + + # Run with alt decompression memory + interp_with_alt = runtime.Interpreter.from_bytes( + bytes(compressed_fb), + alt_decompression_memory_size=256, + ) + + test_input1 = np.array([[1, 1, 1, 1]], dtype=np.int8) + test_input2 = np.array([[1, 1, 1, 1]], dtype=np.int8) + + interp_no_alt.set_input(test_input1, 0) + interp_no_alt.set_input(test_input2, 1) + interp_no_alt.invoke() + expected1 = interp_no_alt.get_output(0) + expected2 = interp_no_alt.get_output(1) + + interp_with_alt.set_input(test_input1, 0) + interp_with_alt.set_input(test_input2, 1) + interp_with_alt.invoke() + actual1 = interp_with_alt.get_output(0) + actual2 = interp_with_alt.get_output(1) + + np.testing.assert_array_equal( + expected1, actual1, "Output 1 mismatch with alt decompression memory") + np.testing.assert_array_equal( + expected2, actual2, "Output 2 mismatch with alt decompression memory") + + +class HuffmanCompressionTest(unittest.TestCase): + """Integration tests for Huffman compression.""" + + @unittest.skip("Huffman compression not yet implemented") + def test_huffman_compressed_model_matches_uncompressed(self): + """Huffman-compressed model produces same outputs as uncompressed.""" + pass + + @unittest.skip("Huffman compression not yet implemented") + def test_huffman_decode_operators_present(self): + """DECODE operators are inserted for Huffman-compressed tensors.""" + pass + + @unittest.skip("Huffman compression not yet implemented") + def test_huffman_compressed_model_is_smaller(self): + """Huffman-compressed model is smaller than original.""" + pass + + +class PruningCompressionTest(unittest.TestCase): + """Integration tests for pruning compression.""" + + @unittest.skip("Pruning compression not yet implemented") + def test_pruning_compressed_model_matches_uncompressed(self): + """Pruning-compressed model produces same outputs as uncompressed.""" + pass + + @unittest.skip("Pruning compression not yet implemented") + def test_pruning_decode_operators_present(self): + """DECODE operators are inserted for pruning-compressed tensors.""" + pass + + @unittest.skip("Pruning compression not yet implemented") + def test_pruning_compressed_model_is_smaller(self): + """Pruning-compressed model is smaller than original.""" + pass + + +if __name__ == "__main__": + # Suppress TF C++ info/debug logs (0=DEBUG, 1=INFO, 2=WARNING, 3=ERROR) + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" + # Disable oneDNN to avoid non-deterministic floating point results + os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" + unittest.main() From 6791fba701851666ef6fa763a5bc29ba1ae3a0a5 Mon Sep 17 00:00:00 2001 From: Ryan Kuester Date: Sun, 24 May 2026 23:36:40 -0500 Subject: [PATCH 14/21] test(compression): add proprietary model integration test Add a manual test for verifying compression on proprietary models that can't be checked into the repository. See the module docstring for usage instructions. BUG=part of #3256 --- tensorflow/lite/micro/compression/BUILD | 24 ++ .../proprietary_integration_test.py | 211 ++++++++++++++++++ 2 files changed, 235 insertions(+) create mode 100644 tensorflow/lite/micro/compression/proprietary_integration_test.py diff --git a/tensorflow/lite/micro/compression/BUILD b/tensorflow/lite/micro/compression/BUILD index c7262c6f366..8a518e0fc8e 100644 --- a/tensorflow/lite/micro/compression/BUILD +++ b/tensorflow/lite/micro/compression/BUILD @@ -194,6 +194,30 @@ tflm_py_test( ], ) +tflm_py_test( + name = "proprietary_integration_test", + size = "small", + srcs = ["proprietary_integration_test.py"], + tags = [ + "manual", + "noasan", + "nomsan", + "noubsan", + ], + target_compatible_with = select({ + "//:with_compression_enabled": [], + "//conditions:default": ["@platforms//:incompatible"], + }), + deps = [ + ":compress_lib", + ":model_editor", + ":spec", + "//python/tflite_micro:runtime", + "//tensorflow/lite/python:schema_py", + requirement("numpy"), + ], +) + tflm_py_library( name = "spec", srcs = ["spec.py"], diff --git a/tensorflow/lite/micro/compression/proprietary_integration_test.py b/tensorflow/lite/micro/compression/proprietary_integration_test.py new file mode 100644 index 00000000000..684805d0f56 --- /dev/null +++ b/tensorflow/lite/micro/compression/proprietary_integration_test.py @@ -0,0 +1,211 @@ +# Copyright 2026 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Integration tests for compression using proprietary models. + +These tests verify that compressed models produce correct inference results +when run through the TFLM Python interpreter. Tests compress models and +compare outputs against uncompressed originals using random inputs. + +This test is tagged `manual` and requires a path to a directory containing +.tflite model files. + +Usage: + bazel test //tensorflow/lite/micro/compression:proprietary_integration_test \ + --//:with_compression \ + --test_arg=/path/to/models + +Required files: + Each model requires a compression spec file: + model.spec.yaml (replacing .tflite extension) + + See spec.py for the YAML format. Example: + tensors: + - subgraph: 0 + tensor: 2 + compression: + - lut: + index_bitwidth: 4 + +Optional files: + model.config.json (replacing .tflite extension) + Tolerance overrides: {"rtol": 1e-5, "atol": 1e-6} + Default is exact match (rtol=0, atol=0). +""" + +import glob +import json +import os +import sys +import unittest + +import numpy as np + +from tflite_micro.python.tflite_micro import runtime +from tflite_micro.tensorflow.lite.micro.compression import compress +from tflite_micro.tensorflow.lite.micro.compression import model_editor +from tflite_micro.tensorflow.lite.micro.compression import spec +from tflite_micro.tensorflow.lite.python import schema_py_generated as tflite + + +def _dtype_to_numpy(dtype: tflite.TensorType) -> np.dtype: + """Convert TFLite dtype to numpy dtype.""" + type_map = { + tflite.TensorType.INT8: np.int8, + tflite.TensorType.INT16: np.int16, + tflite.TensorType.INT32: np.int32, + tflite.TensorType.INT64: np.int64, + tflite.TensorType.UINT8: np.uint8, + tflite.TensorType.UINT16: np.uint16, + tflite.TensorType.UINT32: np.uint32, + tflite.TensorType.FLOAT16: np.float16, + tflite.TensorType.FLOAT32: np.float32, + tflite.TensorType.FLOAT64: np.float64, + tflite.TensorType.BOOL: np.bool_, + } + return type_map.get(dtype, np.uint8) + + +class ProprietaryModelTest(unittest.TestCase): + """Integration tests using proprietary models.""" + + # Parsed from command line in main() + models_dir = None + + @classmethod + def setUpClass(cls): + if not cls.models_dir: + raise unittest.SkipTest( + "No models directory provided. " + "Usage: bazel test ... --test_arg=/path/to/models") + + cls.model_paths = sorted( + glob.glob(os.path.join(cls.models_dir, '*.tflite'))) + if not cls.model_paths: + raise unittest.SkipTest(f"No .tflite files found in {cls.models_dir}") + + def test_all_models(self): + """Run compression test on each discovered model.""" + for model_path in self.model_paths: + with self.subTest(model=os.path.basename(model_path)): + self._test_model_compression(model_path) + + def _test_model_compression(self, model_path): + """Test that a compressed model produces same outputs as original.""" + with open(model_path, 'rb') as f: + flatbuffer = f.read() + + # Load compression spec from sidecar file + specs = self._load_compression_spec(model_path) + + # Load tolerance config + rtol, atol = self._load_tolerance(model_path) + + # Compress the model + compressed_fb = compress.compress(flatbuffer, specs) + + # Create interpreters + original_interp = runtime.Interpreter.from_bytes(bytes(flatbuffer)) + compressed_interp = runtime.Interpreter.from_bytes(bytes(compressed_fb)) + + # Generate random inputs and compare outputs + np.random.seed(42) + model = model_editor.read(flatbuffer) + sg = model.subgraphs[0] + + for trial in range(5): + # Set inputs + for i, input_tensor in enumerate(sg.inputs): + test_input = self._generate_input(input_tensor) + original_interp.set_input(test_input, i) + compressed_interp.set_input(test_input, i) + + # Run inference + original_interp.invoke() + compressed_interp.invoke() + + # Compare outputs + for i in range(len(sg.outputs)): + expected = original_interp.get_output(i) + actual = compressed_interp.get_output(i) + self._compare_outputs(expected, actual, rtol, atol, + f"trial {trial}, output {i}") + + def _generate_input(self, tensor): + """Generate random input respecting tensor dtype.""" + shape = tensor.shape + dtype = _dtype_to_numpy(tensor.dtype) + + if np.issubdtype(dtype, np.floating): + return np.random.uniform(-1.0, 1.0, shape).astype(dtype) + elif np.issubdtype(dtype, np.integer): + info = np.iinfo(dtype) + return np.random.randint(info.min, info.max + 1, shape, dtype=dtype) + elif dtype == np.bool_: + return np.random.choice([False, True], shape) + return np.zeros(shape, dtype=dtype) + + def _load_compression_spec(self, model_path): + """Load compression spec from sidecar YAML file. + + Raises: + FileNotFoundError: If no spec file is found. + """ + spec_path = model_path.replace('.tflite', '.spec.yaml') + if os.path.exists(spec_path): + with open(spec_path) as f: + return spec.parse_yaml(f.read()) + + raise FileNotFoundError( + f"No compression spec file found for {model_path}. " + f"Expected: {spec_path}") + + def _load_tolerance(self, model_path): + """Load tolerance from sidecar config if present. + + Returns (0, 0) for exact match if no config file exists. + """ + config_path = model_path.replace('.tflite', '.config.json') + if os.path.exists(config_path): + with open(config_path) as f: + config = json.load(f) + return config.get('rtol', 0), config.get('atol', 0) + return 0, 0 + + def _compare_outputs(self, expected, actual, rtol, atol, context=""): + """Compare outputs with optional tolerance.""" + msg = f"Output mismatch ({context})" if context else "Output mismatch" + if rtol == 0 and atol == 0: + np.testing.assert_array_equal(expected, actual, err_msg=msg) + else: + np.testing.assert_allclose(expected, + actual, + rtol=rtol, + atol=atol, + err_msg=msg) + + +if __name__ == "__main__": + # Suppress TF C++ info/debug logs (0=DEBUG, 1=INFO, 2=WARNING, 3=ERROR) + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" + # Disable oneDNN to avoid non-deterministic floating point results + os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" + + # Parse models directory from args, then strip it so tf.test doesn't see it + for arg in sys.argv[1:]: + if not arg.startswith('-') and os.path.isdir(arg): + ProprietaryModelTest.models_dir = arg + sys.argv.remove(arg) + break + + unittest.main() From b264245686ed9b10b1e57974544c2e484721bcba Mon Sep 17 00:00:00 2001 From: Ryan Kuester Date: Sun, 24 May 2026 23:38:01 -0500 Subject: [PATCH 15/21] refactor(compression): compressors inherit from Compressor protocol Explicit inheritance from Protocol enables static type checking at definition time and makes the interface self-documenting. BUG=part of #3256 --- tensorflow/lite/micro/compression/huffman.py | 2 +- tensorflow/lite/micro/compression/lut.py | 2 +- tensorflow/lite/micro/compression/pruning.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow/lite/micro/compression/huffman.py b/tensorflow/lite/micro/compression/huffman.py index 40d0be9284a..e539827eae4 100644 --- a/tensorflow/lite/micro/compression/huffman.py +++ b/tensorflow/lite/micro/compression/huffman.py @@ -25,7 +25,7 @@ from tflite_micro.tensorflow.lite.micro.compression import spec -class HuffmanCompressor: +class HuffmanCompressor(compressor.Compressor): """Huffman compression plugin (stub). This stub exists to validate the plugin architecture. The actual Huffman diff --git a/tensorflow/lite/micro/compression/lut.py b/tensorflow/lite/micro/compression/lut.py index def34059ac5..991288f54cc 100644 --- a/tensorflow/lite/micro/compression/lut.py +++ b/tensorflow/lite/micro/compression/lut.py @@ -241,7 +241,7 @@ def pack_lookup_tables(tables: list[np.ndarray], table_len: int) -> bytes: return bytes(buffer) -class LutCompressor: +class LutCompressor(compressor.Compressor): """LUT compression plugin implementing the Compressor protocol.""" @property diff --git a/tensorflow/lite/micro/compression/pruning.py b/tensorflow/lite/micro/compression/pruning.py index 2181b73e34a..5c95e3e87e9 100644 --- a/tensorflow/lite/micro/compression/pruning.py +++ b/tensorflow/lite/micro/compression/pruning.py @@ -25,7 +25,7 @@ from tflite_micro.tensorflow.lite.micro.compression import spec -class PruningCompressor: +class PruningCompressor(compressor.Compressor): """Pruning compression plugin (stub). This stub exists to validate the plugin architecture. The actual pruning From 1d74e402854bad577dd00fa237bc4268492af5b9 Mon Sep 17 00:00:00 2001 From: Ryan Kuester Date: Mon, 25 May 2026 00:05:51 -0500 Subject: [PATCH 16/21] test(python): rewrite unsupported-compression test for legacy path An upcoming change registers the DECODE operator unconditionally in the Python ops resolver, after which compress() emits DECODE-based models that load successfully. That breaks this test's original approach, which ran a model through compress() and expected the load to fail. Rewrite it to instead inject a raw COMPRESSION_METADATA entry into the flatbuffer via model_editor, directly exercising the HasCompressionMetadata() detection path for legacy-compressed models. Decoupling the test from compress() output lets it verify the legacy-rejection behavior independently of whether the DECODE operator is registered, so it passes both before and after that upcoming change. BUG=part of #3256 --- python/tflite_micro/BUILD | 2 +- .../test_compression_unsupported.py | 96 +++++++++---------- 2 files changed, 48 insertions(+), 50 deletions(-) diff --git a/python/tflite_micro/BUILD b/python/tflite_micro/BUILD index b358fd12adc..812cf7092fd 100644 --- a/python/tflite_micro/BUILD +++ b/python/tflite_micro/BUILD @@ -125,7 +125,7 @@ py_test( ":runtime", requirement("numpy"), requirement("tensorflow"), - "//tensorflow/lite/micro/compression", + "//tensorflow/lite/micro/compression:model_editor", ], ) diff --git a/python/tflite_micro/test_compression_unsupported.py b/python/tflite_micro/test_compression_unsupported.py index 3692dd0a43a..edd47808298 100644 --- a/python/tflite_micro/test_compression_unsupported.py +++ b/python/tflite_micro/test_compression_unsupported.py @@ -12,84 +12,82 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Test compression metadata detection when compression is disabled.""" +"""Test legacy compression metadata detection when compression is disabled.""" import os import numpy as np import tensorflow as tf from tflite_micro.python.tflite_micro import runtime -from tflite_micro.tensorflow.lite.micro import compression +from tflite_micro.tensorflow.lite.micro.compression import model_editor -class CompressionDetectionTest(tf.test.TestCase): - """Test compression metadata detection when compression is disabled.""" +def _create_test_model(): + """Create a simple quantized model for testing.""" + model = tf.keras.Sequential([ + tf.keras.layers.Dense(10, input_shape=(5, ), activation='relu'), + tf.keras.layers.Dense(5, activation='softmax') + ]) + model.compile(optimizer='adam', loss='sparse_categorical_crossentropy') - def _create_test_model(self): - """Create a simple quantized model for testing.""" - model = tf.keras.Sequential([ - tf.keras.layers.Dense(10, input_shape=(5, ), activation='relu'), - tf.keras.layers.Dense(5, activation='softmax') - ]) - model.compile(optimizer='adam', loss='sparse_categorical_crossentropy') + converter = tf.lite.TFLiteConverter.from_keras_model(model) + converter.optimizations = [tf.lite.Optimize.DEFAULT] - # Convert to quantized TFLite - converter = tf.lite.TFLiteConverter.from_keras_model(model) - converter.optimizations = [tf.lite.Optimize.DEFAULT] + def representative_dataset(): + for _ in range(10): + yield [np.random.randn(1, 5).astype(np.float32)] - def representative_dataset(): - for _ in range(10): - yield [np.random.randn(1, 5).astype(np.float32)] + converter.representative_dataset = representative_dataset + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.uint8 + converter.inference_output_type = tf.uint8 - converter.representative_dataset = representative_dataset - converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] - converter.inference_input_type = tf.uint8 - converter.inference_output_type = tf.uint8 + tflite_model = converter.convert() + return bytes(tflite_model) if isinstance(tflite_model, + bytearray) else tflite_model - tflite_model = converter.convert() - return bytes(tflite_model) if isinstance(tflite_model, - bytearray) else tflite_model + +def _inject_compression_metadata(model_data): + """Inject raw COMPRESSION_METADATA into a model's flatbuffer metadata. + + This simulates a legacy-compressed model (one that uses the + COMPRESSION_METADATA metadata entry and kernel-level decompression) without + going through compress(), which now produces DECODE-based output. + """ + model = model_editor.read(model_data) + model.metadata["COMPRESSION_METADATA"] = b"\x00" + return bytes(model.build()) + + +class LegacyCompressionDetectionTest(tf.test.TestCase): + """Test that legacy COMPRESSION_METADATA is rejected without the flag.""" def test_regular_model_loads_successfully(self): """Non-compressed models should load without issues.""" - model_data = self._create_test_model() + model_data = _create_test_model() interpreter = runtime.Interpreter.from_bytes(model_data) self.assertIsNotNone(interpreter) - def test_compressed_model_raises_runtime_error(self): - """Compressed models should raise RuntimeError when compression is disabled.""" - # Create and compress a model - model_data = self._create_test_model() + def test_legacy_compressed_model_raises_runtime_error(self): + """Models with COMPRESSION_METADATA should raise RuntimeError.""" + model_data = _create_test_model() + legacy_model = _inject_compression_metadata(model_data) - spec = (compression.SpecBuilder().add_tensor( - subgraph=0, tensor=1).with_lut(index_bitwidth=4).build()) - - compressed_model = compression.compress(model_data, spec) - if isinstance(compressed_model, bytearray): - compressed_model = bytes(compressed_model) - - # Should raise RuntimeError with self.assertRaises(RuntimeError): - runtime.Interpreter.from_bytes(compressed_model) - - def test_can_load_regular_after_compressed_failure(self): - """Verify we can still load regular models after compressed model fails.""" - model_data = self._create_test_model() + runtime.Interpreter.from_bytes(legacy_model) - # First try compressed model (should fail) - spec = (compression.SpecBuilder().add_tensor( - subgraph=0, tensor=1).with_lut(index_bitwidth=4).build()) - compressed_model = compression.compress(model_data, spec) + def test_can_load_regular_after_legacy_failure(self): + """Verify regular models still load after a legacy-compressed failure.""" + model_data = _create_test_model() + legacy_model = _inject_compression_metadata(model_data) with self.assertRaises(RuntimeError): - runtime.Interpreter.from_bytes(bytes(compressed_model)) + runtime.Interpreter.from_bytes(legacy_model) - # Then load regular model (should succeed) interpreter = runtime.Interpreter.from_bytes(model_data) self.assertIsNotNone(interpreter) if __name__ == '__main__': - # Set TF environment variables to suppress warnings os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0' tf.test.main() From 594e149d85402229372f29514ebac9da97ed21e4 Mon Sep 17 00:00:00 2001 From: Ryan Kuester Date: Mon, 25 May 2026 00:06:40 -0500 Subject: [PATCH 17/21] feat(python): register DECODE op unconditionally The DECODE kernel and its dependencies are already compiled unconditionally -- none are guarded by USE_TFLM_COMPRESSION. Remove the #ifdef around AddDecode() in PythonOpsResolver so DECODE-based compressed models work in a default Python build. Remove the with_compression_enabled gating from compression and proprietary integration tests, since they use DECODE-based models that no longer require the flag. BUG=part of #3256 --- python/tflite_micro/python_ops_resolver.cc | 2 -- tensorflow/lite/micro/compression/BUILD | 11 ++--------- 2 files changed, 2 insertions(+), 11 deletions(-) diff --git a/python/tflite_micro/python_ops_resolver.cc b/python/tflite_micro/python_ops_resolver.cc index 19f324bdf2f..5f7d40fb74e 100644 --- a/python/tflite_micro/python_ops_resolver.cc +++ b/python/tflite_micro/python_ops_resolver.cc @@ -40,9 +40,7 @@ PythonOpsResolver::PythonOpsResolver() { AddConv2D(); AddCos(); AddCumSum(); -#ifdef USE_TFLM_COMPRESSION AddDecode(); -#endif AddDelay(); AddDepthToSpace(); AddDepthwiseConv2D(); diff --git a/tensorflow/lite/micro/compression/BUILD b/tensorflow/lite/micro/compression/BUILD index 8a518e0fc8e..f85977a0d70 100644 --- a/tensorflow/lite/micro/compression/BUILD +++ b/tensorflow/lite/micro/compression/BUILD @@ -178,11 +178,7 @@ tflm_py_test( "nomsan", "noubsan", ], - # Only run when compression IS enabled - target_compatible_with = select({ - "//:with_compression_enabled": [], - "//conditions:default": ["@platforms//:incompatible"], - }), + target_compatible_with = INCOMPATIBLE_WITH_WINDOWS, deps = [ ":compress_lib", ":decode_insert", @@ -204,10 +200,7 @@ tflm_py_test( "nomsan", "noubsan", ], - target_compatible_with = select({ - "//:with_compression_enabled": [], - "//conditions:default": ["@platforms//:incompatible"], - }), + target_compatible_with = INCOMPATIBLE_WITH_WINDOWS, deps = [ ":compress_lib", ":model_editor", From 79fdf0066fc151d84c297d9acadd02a89455efbc Mon Sep 17 00:00:00 2001 From: Ryan Kuester Date: Mon, 25 May 2026 00:09:21 -0500 Subject: [PATCH 18/21] test(compression): add tests for batched DECODE insertion Add test_multiple_compressed_inputs_batched: a CONCATENATION with two compressed tensor inputs, each with a different bitwidth, should produce a single DECODE with 4 inputs and 2 outputs, each ancillary tensor carrying its own distinct data. Marked expectedFailure until the implementation lands. Add test_mixed_compressed_and_uncompressed_inputs: a CONCATENATION with one compressed and one plain input leaves the plain input untouched. This already passes with the current code. BUG=part of #3256 --- tensorflow/lite/micro/compression/BUILD | 1 + .../micro/compression/decode_insert_test.py | 153 +++++++++++++++++- 2 files changed, 149 insertions(+), 5 deletions(-) diff --git a/tensorflow/lite/micro/compression/BUILD b/tensorflow/lite/micro/compression/BUILD index f85977a0d70..ef1880f5abd 100644 --- a/tensorflow/lite/micro/compression/BUILD +++ b/tensorflow/lite/micro/compression/BUILD @@ -371,6 +371,7 @@ tflm_py_test( ":compressor", ":decode", ":decode_insert", + ":lut", ":model_editor", "//tensorflow/lite/python:schema_py", requirement("numpy"), diff --git a/tensorflow/lite/micro/compression/decode_insert_test.py b/tensorflow/lite/micro/compression/decode_insert_test.py index 11be81963d9..a7e1fb25e8d 100644 --- a/tensorflow/lite/micro/compression/decode_insert_test.py +++ b/tensorflow/lite/micro/compression/decode_insert_test.py @@ -13,14 +13,15 @@ # limitations under the License. """Unit tests for DECODE operator insertion.""" +import unittest import warnings import numpy as np -import unittest from tflite_micro.tensorflow.lite.micro.compression import compressor from tflite_micro.tensorflow.lite.micro.compression import decode from tflite_micro.tensorflow.lite.micro.compression import decode_insert +from tflite_micro.tensorflow.lite.micro.compression import lut from tflite_micro.tensorflow.lite.micro.compression import model_editor from tflite_micro.tensorflow.lite.python import schema_py_generated as tflite @@ -115,14 +116,22 @@ def _build_shared_weights_model(): return model -def _make_dummy_ancillary_data() -> bytes: +def _make_dummy_ancillary_data(bitwidth=4) -> bytes: """Create dummy ancillary data for testing.""" + n_entries = 1 << bitwidth + value_tables = bytes(range(1, n_entries + 1)) + value_tables += b'\x00' * ((-len(value_tables)) % 16) + + lut_data = lut.LutAncillaryData( + bitwidth=bitwidth, + value_table_stride=n_entries, + value_tables=value_tables, + ) dcm = decode.DecodeCommonMetadata( decode_type=decode.DecodeType.LUT, - user_data=b'\x01\x04\x10' + b'\x00' * 9, # lut_version, bitwidth, stride + user_data=lut_data.to_user_data(), ) - value_tables = bytes([1, 2, 3, 4] + [0] * 12) # 16-byte padded table - return dcm.to_bytes() + value_tables + return dcm.to_bytes() + lut_data.to_bytes() class TestDecodeInsertion(unittest.TestCase): @@ -376,6 +385,140 @@ def test_tensor_naming(self): self.assertEqual(ancillary.name, "weights_ancillary") self.assertEqual(output.name, "weights_decoded") + @unittest.expectedFailure + def test_multiple_compressed_inputs_batched(self): + """CONCATENATION with two compressed inputs gets one batched DECODE.""" + weights_a = model_editor.Tensor( + shape=(4, 4), + dtype=tflite.TensorType.INT8, + data=np.ones((4, 4), dtype=np.int8), + name="weights_a", + quantization=model_editor.Quantization(scales=0.5, zero_points=0), + ) + weights_b = model_editor.Tensor( + shape=(4, 4), + dtype=tflite.TensorType.INT8, + data=np.ones((4, 4), dtype=np.int8), + name="weights_b", + quantization=model_editor.Quantization(scales=0.25, zero_points=0), + ) + output_t = model_editor.Tensor( + shape=(4, 8), + dtype=tflite.TensorType.INT8, + name="output", + ) + + concat_op = model_editor.Operator( + opcode=tflite.BuiltinOperator.CONCATENATION, + inputs=[weights_a, weights_b], + outputs=[output_t], + ) + + model = model_editor.Model(subgraphs=[ + model_editor.Subgraph( + tensors=[weights_a, weights_b], + operators=[concat_op], + ) + ]) + + ancillary_a = _make_dummy_ancillary_data(bitwidth=2) + ancillary_b = _make_dummy_ancillary_data(bitwidth=4) + compression_results = { + (0, 0): + compressor.CompressionResult( + encoded_data=b'\x00\x01', + ancillary_data=ancillary_a, + ), + (0, 1): + compressor.CompressionResult( + encoded_data=b'\x02\x03', + ancillary_data=ancillary_b, + ), + } + + decode_insert.insert_decode_operators(model, compression_results) + + sg = model.subgraphs[0] + + # One DECODE + one CONCATENATION + self.assertEqual(len(sg.operators), 2) + decode_op = sg.operators[0] + self.assertEqual(decode_op.opcode, tflite.BuiltinOperator.CUSTOM) + self.assertEqual(decode_op.custom_code, + decode_insert.DECODE_CUSTOM_OP_NAME) + + # DECODE has 4 inputs (enc_a, anc_a, enc_b, anc_b) and 2 outputs + self.assertEqual(len(decode_op.inputs), 4) + self.assertEqual(len(decode_op.outputs), 2) + + # Each ancillary tensor carries its own distinct data + self.assertNotEqual(ancillary_a, ancillary_b) + self.assertEqual(bytes(decode_op.inputs[1].array), ancillary_a) + self.assertEqual(bytes(decode_op.inputs[3].array), ancillary_b) + + # CONCATENATION rewired to DECODE outputs + self.assertIs(sg.operators[1].inputs[0], decode_op.outputs[0]) + self.assertIs(sg.operators[1].inputs[1], decode_op.outputs[1]) + + def test_mixed_compressed_and_uncompressed_inputs(self): + """CONCATENATION with one compressed and one plain input.""" + weights = model_editor.Tensor( + shape=(4, 4), + dtype=tflite.TensorType.INT8, + data=np.ones((4, 4), dtype=np.int8), + name="weights", + quantization=model_editor.Quantization(scales=0.5, zero_points=0), + ) + plain = model_editor.Tensor( + shape=(4, 4), + dtype=tflite.TensorType.INT8, + data=np.zeros((4, 4), dtype=np.int8), + name="plain", + ) + output_t = model_editor.Tensor( + shape=(4, 8), + dtype=tflite.TensorType.INT8, + name="output", + ) + + concat_op = model_editor.Operator( + opcode=tflite.BuiltinOperator.CONCATENATION, + inputs=[weights, plain], + outputs=[output_t], + ) + + model = model_editor.Model(subgraphs=[ + model_editor.Subgraph( + tensors=[weights, plain], + operators=[concat_op], + ) + ]) + + # Only compress weights, not plain + compression_results = { + (0, 0): + compressor.CompressionResult( + encoded_data=b'\x00\x01', + ancillary_data=_make_dummy_ancillary_data(), + ), + } + + decode_insert.insert_decode_operators(model, compression_results) + + sg = model.subgraphs[0] + + # One DECODE + one CONCATENATION + self.assertEqual(len(sg.operators), 2) + decode_op = sg.operators[0] + + # DECODE has 2 inputs and 1 output (only the compressed tensor) + self.assertEqual(len(decode_op.inputs), 2) + self.assertEqual(len(decode_op.outputs), 1) + + # CONCATENATION: first input rewired to DECODE output, second unchanged + self.assertIs(sg.operators[1].inputs[0], decode_op.outputs[0]) + self.assertIs(sg.operators[1].inputs[1], plain) + def test_encoded_tensor_rewritten(self): """Compressed tensor is rewritten with encoded data, UINT8 type, no quant.""" model = _build_simple_fc_model() From c147e7c5c770a64214203a3fa5b35a9a5dbdd2cb Mon Sep 17 00:00:00 2001 From: Ryan Kuester Date: Mon, 25 May 2026 00:39:13 -0500 Subject: [PATCH 19/21] feat(compression): batch multiple compressed tensors per DECODE When a single operator (e.g., CONCATENATION) has multiple compressed tensor inputs, group them into one DECODE instead of creating a separate DECODE for each. Grouping is per-consumer, so a tensor shared across different consumers still gets a separate DECODE before each one to avoid clobbering the alternate decompression memory. BUG=part of #3256 --- .../lite/micro/compression/decode_insert.py | 72 +++++++++++-------- .../micro/compression/decode_insert_test.py | 1 - 2 files changed, 42 insertions(+), 31 deletions(-) diff --git a/tensorflow/lite/micro/compression/decode_insert.py b/tensorflow/lite/micro/compression/decode_insert.py index 43dffce46f0..fa91896e538 100644 --- a/tensorflow/lite/micro/compression/decode_insert.py +++ b/tensorflow/lite/micro/compression/decode_insert.py @@ -210,15 +210,20 @@ def insert_decode_operators( for sg_idx, tensor_infos in by_subgraph.items(): subgraph = model.subgraphs[sg_idx] - # Collect all (consumer, tensor_info) pairs and sort by consumer position - # in reverse order so insertions don't invalidate positions - consumer_pairs = [] + # Group tensor infos by consumer so multiple compressed inputs to the + # same operator get batched into a single DECODE. + consumer_to_infos: dict[model_editor.Operator, list[_CompressedTensorInfo]] + consumer_to_infos = defaultdict(list) for info in tensor_infos: for consumer in info.consumers: - consumer_pairs.append((consumer, info)) - - consumer_pairs.sort( - key=lambda pair: subgraph.operators.index(pair[0]), + if info not in consumer_to_infos[consumer]: + consumer_to_infos[consumer].append(info) + + # Sort consumers by position in reverse so insertions don't invalidate + # earlier positions. + sorted_consumers = sorted( + consumer_to_infos.keys(), + key=lambda op: subgraph.operators.index(op), reverse=True, ) @@ -231,38 +236,45 @@ def insert_decode_operators( # _create_output_tensor reads the original tensor's shape/dtype/quantization. tensors_to_rewrite: dict[model_editor.Tensor, bytes] = {} - for consumer, info in consumer_pairs: - # Reuse or create ancillary data tensor - if info.tensor not in ancillary_cache: - ancillary_tensor = _create_ancillary_tensor( - info.ancillary_data, - info.tensor, - ) - subgraph.tensors.append(ancillary_tensor) - ancillary_cache[info.tensor] = ancillary_tensor - tensors_to_rewrite[info.tensor] = info.encoded_data - else: - ancillary_tensor = ancillary_cache[info.tensor] - - # Create output tensor (one per DECODE) - output_tensor = _create_output_tensor(info.tensor) - subgraph.tensors.append(output_tensor) - - # Create DECODE operator + for consumer in sorted_consumers: + decode_inputs = [] + decode_outputs = [] + + for info in consumer_to_infos[consumer]: + # Reuse or create ancillary data tensor + if info.tensor not in ancillary_cache: + ancillary_tensor = _create_ancillary_tensor( + info.ancillary_data, + info.tensor, + ) + subgraph.tensors.append(ancillary_tensor) + ancillary_cache[info.tensor] = ancillary_tensor + tensors_to_rewrite[info.tensor] = info.encoded_data + else: + ancillary_tensor = ancillary_cache[info.tensor] + + # Create output tensor (one per compressed input) + output_tensor = _create_output_tensor(info.tensor) + subgraph.tensors.append(output_tensor) + + decode_inputs.extend([info.tensor, ancillary_tensor]) + decode_outputs.append(output_tensor) + + # Rewire this consumer to use the decoded output + _rewire_consumers([consumer], info.tensor, output_tensor) + + # Create single DECODE operator for all compressed inputs decode_op = model_editor.Operator( opcode=tflite.BuiltinOperator.CUSTOM, custom_code=DECODE_CUSTOM_OP_NAME, - inputs=[info.tensor, ancillary_tensor], - outputs=[output_tensor], + inputs=decode_inputs, + outputs=decode_outputs, ) # Insert DECODE immediately before this consumer insert_pos = subgraph.operators.index(consumer) subgraph.operators.insert(insert_pos, decode_op) - # Rewire only this consumer to use the decoded output - _rewire_consumers([consumer], info.tensor, output_tensor) - # Rewrite encoded tensors after all output tensors are created for tensor, encoded_data in tensors_to_rewrite.items(): _rewrite_encoded_tensor(tensor, encoded_data) diff --git a/tensorflow/lite/micro/compression/decode_insert_test.py b/tensorflow/lite/micro/compression/decode_insert_test.py index a7e1fb25e8d..60965b46676 100644 --- a/tensorflow/lite/micro/compression/decode_insert_test.py +++ b/tensorflow/lite/micro/compression/decode_insert_test.py @@ -385,7 +385,6 @@ def test_tensor_naming(self): self.assertEqual(ancillary.name, "weights_ancillary") self.assertEqual(output.name, "weights_decoded") - @unittest.expectedFailure def test_multiple_compressed_inputs_batched(self): """CONCATENATION with two compressed inputs gets one batched DECODE.""" weights_a = model_editor.Tensor( From 92300fafef38e82eb906be27ea679e38e1df8c87 Mon Sep 17 00:00:00 2001 From: Ryan Kuester Date: Mon, 25 May 2026 00:40:14 -0500 Subject: [PATCH 20/21] feat(compression): reject empty compression spec An empty spec list passed to compress() previously returned an unmodified model silently. Fail early with a clear error instead, since an empty spec is almost certainly a mistake. BUG=part of #3256 --- tensorflow/lite/micro/compression/compress.py | 5 +++++ tensorflow/lite/micro/compression/compress_test.py | 5 +++++ 2 files changed, 10 insertions(+) diff --git a/tensorflow/lite/micro/compression/compress.py b/tensorflow/lite/micro/compression/compress.py index 270951fecf8..96b55d94fd7 100644 --- a/tensorflow/lite/micro/compression/compress.py +++ b/tensorflow/lite/micro/compression/compress.py @@ -132,6 +132,11 @@ def compress(model_in: ByteString, specs: Iterable[spec.Tensor]) -> bytearray: Returns: A compressed flatbuffer with DECODE operators inserted. """ + specs = list(specs) + if not specs: + raise compressor.CompressionError( + "Compression spec is empty; no tensors to compress") + model = model_editor.read(model_in) compression_results: dict[tuple[int, int], compressor.CompressionResult] = {} diff --git a/tensorflow/lite/micro/compression/compress_test.py b/tensorflow/lite/micro/compression/compress_test.py index cb241c2c62f..6ee80f200d5 100644 --- a/tensorflow/lite/micro/compression/compress_test.py +++ b/tensorflow/lite/micro/compression/compress_test.py @@ -313,6 +313,11 @@ def test_ancillary_data_format(self): self.assertEqual(dcm_bytes[5] & 0x07, 4) # bitwidth = 4 self.assertEqual(dcm_bytes[6], 4) # stride = num unique values + def test_empty_spec_raises(self): + """Empty compression spec is an error, not a silent no-op.""" + self.assertRaisesRegex(compressor.CompressionError, "empty", + lambda: compress.compress(self.flatbuffer, [])) + def test_smaller_bitwidth_raises(self): """Specifying LUT compression with too small a bitwidth fails.""" specs = [ From 8001af06ea770ef4f25c352d70e4708b9cdb5f6c Mon Sep 17 00:00:00 2001 From: Ryan Kuester Date: Mon, 25 May 2026 00:41:20 -0500 Subject: [PATCH 21/21] docs(python): explain env vars in test runner --- python/tflite_micro/test_compression_unsupported.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/tflite_micro/test_compression_unsupported.py b/python/tflite_micro/test_compression_unsupported.py index edd47808298..01c598374ce 100644 --- a/python/tflite_micro/test_compression_unsupported.py +++ b/python/tflite_micro/test_compression_unsupported.py @@ -88,6 +88,8 @@ def test_can_load_regular_after_legacy_failure(self): if __name__ == '__main__': + # Suppress TF C++ info/debug logs (0=DEBUG, 1=INFO, 2=WARNING, 3=ERROR) os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' + # Disable oneDNN to avoid non-deterministic floating point results os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0' tf.test.main()