Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
f537167
feat(compression): implement model_editor for TFLite model manipulation
rkuester May 25, 2026
8518b5e
refactor(compression): migrate compress.py from model_facade to model…
rkuester May 25, 2026
e1afb8c
chore(compression): remove model_facade.py
rkuester May 25, 2026
f384f98
refactor(compression): replace test_models with model_editor in compr…
rkuester May 25, 2026
2901f55
chore(compression): remove test_models.py
rkuester May 25, 2026
31434d4
feat(compression): add DECODE operator types and metadata
rkuester May 25, 2026
5be61ce
feat(compression): add Compressor protocol
rkuester May 25, 2026
40f28c5
feat(compression): add LUT compression plugin
rkuester May 25, 2026
b842fbb
feat(compression): add Huffman and Pruning compression support
rkuester May 25, 2026
963ce97
feat(python): add alt decompression memory parameter to interpreter
rkuester May 25, 2026
702dfb5
feat(compression): add DECODE operator insertion
rkuester May 25, 2026
b10c147
refactor(compression): use plugin architecture in compress.py
rkuester May 25, 2026
cd4b0b3
test(compression): add integration tests with TFLM interpreter
rkuester May 25, 2026
6791fba
test(compression): add proprietary model integration test
rkuester May 25, 2026
b264245
refactor(compression): compressors inherit from Compressor protocol
rkuester May 25, 2026
1d74e40
test(python): rewrite unsupported-compression test for legacy path
rkuester May 25, 2026
594e149
feat(python): register DECODE op unconditionally
rkuester May 25, 2026
79fdf00
test(compression): add tests for batched DECODE insertion
rkuester May 25, 2026
c147e7c
feat(compression): batch multiple compressed tensors per DECODE
rkuester May 25, 2026
92300fa
feat(compression): reject empty compression spec
rkuester May 25, 2026
8001af0
docs(python): explain env vars in test runner
rkuester May 25, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/tflite_micro/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ py_test(
":runtime",
requirement("numpy"),
requirement("tensorflow"),
"//tensorflow/lite/micro/compression",
"//tensorflow/lite/micro/compression:model_editor",
],
)

Expand Down
9 changes: 5 additions & 4 deletions python/tflite_micro/_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,11 @@ PYBIND11_MODULE(_runtime, m) {
.def(py::init([](const py::bytes& data,
const std::vector<std::string>& registerers_by_name,
size_t arena_size, int num_resource_variables,
tflite::InterpreterConfig config) {
return std::unique_ptr<InterpreterWrapper>(
new InterpreterWrapper(data.ptr(), registerers_by_name, arena_size,
num_resource_variables, config));
tflite::InterpreterConfig config,
size_t alt_decompression_memory_size) {
return std::unique_ptr<InterpreterWrapper>(new InterpreterWrapper(
data.ptr(), registerers_by_name, arena_size, num_resource_variables,
config, alt_decompression_memory_size));
}))
.def("PrintAllocations", &InterpreterWrapper::PrintAllocations)
.def("Invoke", &InterpreterWrapper::Invoke)
Expand Down
18 changes: 16 additions & 2 deletions python/tflite_micro/interpreter_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,14 @@ InterpreterWrapper::~InterpreterWrapper() {

InterpreterWrapper::InterpreterWrapper(
PyObject* model_data, const std::vector<std::string>& registerers_by_name,
size_t arena_size, int num_resource_variables, InterpreterConfig config) {
size_t arena_size, int num_resource_variables, InterpreterConfig config,
size_t alt_decompression_memory_size)
: memory_arena_(new uint8_t[arena_size]),
alt_decompression_memory_(alt_decompression_memory_size > 0
? new uint8_t[alt_decompression_memory_size]
: nullptr),
alt_decompression_region_{alt_decompression_memory_.get(),
alt_decompression_memory_size} {
interpreter_ = nullptr;

// `model_data` is used as a raw pointer beyond the scope of this
Expand Down Expand Up @@ -266,7 +273,6 @@ InterpreterWrapper::InterpreterWrapper(
"--//:with_compression=true to enable compression support.");
}

memory_arena_ = std::unique_ptr<uint8_t[]>(new uint8_t[arena_size]);
for (const std::string& registerer : registerers_by_name) {
if (!AddCustomOpRegistererByName(registerer.c_str(),
&python_ops_resolver_)) {
Expand Down Expand Up @@ -296,6 +302,14 @@ InterpreterWrapper::InterpreterWrapper(
interpreter_ = new MicroInterpreter(model, python_ops_resolver_, allocator_,
resource_variables_);

if (alt_decompression_memory_size > 0) {
TfLiteStatus status =
interpreter_->SetDecompressionMemory(&alt_decompression_region_, 1);
if (status != kTfLiteOk) {
ThrowRuntimeError("TFLM failed to set decompression memory");
}
}

TfLiteStatus status = interpreter_->AllocateTensors();
if (status != kTfLiteOk) {
ThrowRuntimeError("TFLM failed to allocate tensors");
Expand Down
6 changes: 5 additions & 1 deletion python/tflite_micro/interpreter_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.

#include "python/tflite_micro/python_ops_resolver.h"
#include "tensorflow/lite/micro/micro_allocator.h"
#include "tensorflow/lite/micro/micro_context.h"
#include "tensorflow/lite/micro/micro_interpreter.h"
#include "tensorflow/lite/micro/recording_micro_allocator.h"

Expand All @@ -40,7 +41,8 @@ class InterpreterWrapper {
InterpreterWrapper(
PyObject* model_data, const std::vector<std::string>& registerers_by_name,
size_t arena_size, int num_resource_variables,
InterpreterConfig config = InterpreterConfig::kAllocationRecording);
InterpreterConfig config = InterpreterConfig::kAllocationRecording,
size_t alt_decompression_memory_size = 0);
~InterpreterWrapper();

void PrintAllocations();
Expand All @@ -57,6 +59,8 @@ class InterpreterWrapper {
tflite::RecordingMicroAllocator* recording_allocator_ = nullptr;
const PyObject* model_;
std::unique_ptr<uint8_t[]> memory_arena_;
std::unique_ptr<uint8_t[]> alt_decompression_memory_;
tflite::MicroContext::AlternateMemoryRegion alt_decompression_region_;
tflite::PythonOpsResolver python_ops_resolver_;
tflite::MicroInterpreter* interpreter_;
};
Expand Down
12 changes: 12 additions & 0 deletions python/tflite_micro/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def __init__(
custom_op_registerers,
arena_size,
intrepreter_config=InterpreterConfig.kAllocationRecording,
alt_decompression_memory_size=0,
):
if model_data is None:
raise ValueError("Model must not be None")
Expand All @@ -122,6 +123,7 @@ def __init__(
arena_size,
num_resource_variables,
_ENUM_TRANSLATOR[intrepreter_config],
alt_decompression_memory_size,
)

@classmethod
Expand All @@ -131,6 +133,7 @@ def from_file(
custom_op_registerers=[],
arena_size=None,
intrepreter_config=InterpreterConfig.kAllocationRecording,
alt_decompression_memory_size=0,
):
"""Instantiates a TFLM interpreter from a model .tflite filepath.

Expand All @@ -140,6 +143,9 @@ def from_file(
custom OP registerer
arena_size: Tensor arena size in bytes. If unused, tensor arena size will
default to 10 times the model size.
alt_decompression_memory_size: Size in bytes of alternate decompression
memory. If non-zero, DECODE operators will use this memory instead of
the main arena for decompressed tensor outputs.

Returns:
An Interpreter instance
Expand All @@ -155,6 +161,7 @@ def from_file(
custom_op_registerers,
arena_size,
intrepreter_config,
alt_decompression_memory_size,
)

@classmethod
Expand All @@ -164,6 +171,7 @@ def from_bytes(
custom_op_registerers=[],
arena_size=None,
intrepreter_config=InterpreterConfig.kAllocationRecording,
alt_decompression_memory_size=0,
):
"""Instantiates a TFLM interpreter from a model in byte array.

Expand All @@ -173,6 +181,9 @@ def from_bytes(
custom OP registerer
arena_size: Tensor arena size in bytes. If unused, tensor arena size will
default to 10 times the model size.
alt_decompression_memory_size: Size in bytes of alternate decompression
memory. If non-zero, DECODE operators will use this memory instead of
the main arena for decompressed tensor outputs.

Returns:
An Interpreter instance
Expand All @@ -183,6 +194,7 @@ def from_bytes(
custom_op_registerers,
arena_size,
intrepreter_config,
alt_decompression_memory_size,
)

def print_allocations(self):
Expand Down
98 changes: 49 additions & 49 deletions python/tflite_micro/test_compression_unsupported.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,84 +12,84 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test compression metadata detection when compression is disabled."""
"""Test legacy compression metadata detection when compression is disabled."""

import os
import numpy as np
import tensorflow as tf
from tflite_micro.python.tflite_micro import runtime
from tflite_micro.tensorflow.lite.micro import compression
from tflite_micro.tensorflow.lite.micro.compression import model_editor


class CompressionDetectionTest(tf.test.TestCase):
"""Test compression metadata detection when compression is disabled."""
def _create_test_model():
"""Create a simple quantized model for testing."""
model = tf.keras.Sequential([
tf.keras.layers.Dense(10, input_shape=(5, ), activation='relu'),
tf.keras.layers.Dense(5, activation='softmax')
])
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy')

def _create_test_model(self):
"""Create a simple quantized model for testing."""
model = tf.keras.Sequential([
tf.keras.layers.Dense(10, input_shape=(5, ), activation='relu'),
tf.keras.layers.Dense(5, activation='softmax')
])
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy')
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]

# Convert to quantized TFLite
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
def representative_dataset():
for _ in range(10):
yield [np.random.randn(1, 5).astype(np.float32)]

def representative_dataset():
for _ in range(10):
yield [np.random.randn(1, 5).astype(np.float32)]
converter.representative_dataset = representative_dataset
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8

converter.representative_dataset = representative_dataset
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8
tflite_model = converter.convert()
return bytes(tflite_model) if isinstance(tflite_model,
bytearray) else tflite_model

tflite_model = converter.convert()
return bytes(tflite_model) if isinstance(tflite_model,
bytearray) else tflite_model

def _inject_compression_metadata(model_data):
"""Inject raw COMPRESSION_METADATA into a model's flatbuffer metadata.

This simulates a legacy-compressed model (one that uses the
COMPRESSION_METADATA metadata entry and kernel-level decompression) without
going through compress(), which now produces DECODE-based output.
"""
model = model_editor.read(model_data)
model.metadata["COMPRESSION_METADATA"] = b"\x00"
return bytes(model.build())


class LegacyCompressionDetectionTest(tf.test.TestCase):
"""Test that legacy COMPRESSION_METADATA is rejected without the flag."""

def test_regular_model_loads_successfully(self):
"""Non-compressed models should load without issues."""
model_data = self._create_test_model()
model_data = _create_test_model()
interpreter = runtime.Interpreter.from_bytes(model_data)
self.assertIsNotNone(interpreter)

def test_compressed_model_raises_runtime_error(self):
"""Compressed models should raise RuntimeError when compression is disabled."""
# Create and compress a model
model_data = self._create_test_model()
def test_legacy_compressed_model_raises_runtime_error(self):
"""Models with COMPRESSION_METADATA should raise RuntimeError."""
model_data = _create_test_model()
legacy_model = _inject_compression_metadata(model_data)

spec = (compression.SpecBuilder().add_tensor(
subgraph=0, tensor=1).with_lut(index_bitwidth=4).build())

compressed_model = compression.compress(model_data, spec)
if isinstance(compressed_model, bytearray):
compressed_model = bytes(compressed_model)

# Should raise RuntimeError
with self.assertRaises(RuntimeError):
runtime.Interpreter.from_bytes(compressed_model)

def test_can_load_regular_after_compressed_failure(self):
"""Verify we can still load regular models after compressed model fails."""
model_data = self._create_test_model()
runtime.Interpreter.from_bytes(legacy_model)

# First try compressed model (should fail)
spec = (compression.SpecBuilder().add_tensor(
subgraph=0, tensor=1).with_lut(index_bitwidth=4).build())
compressed_model = compression.compress(model_data, spec)
def test_can_load_regular_after_legacy_failure(self):
"""Verify regular models still load after a legacy-compressed failure."""
model_data = _create_test_model()
legacy_model = _inject_compression_metadata(model_data)

with self.assertRaises(RuntimeError):
runtime.Interpreter.from_bytes(bytes(compressed_model))
runtime.Interpreter.from_bytes(legacy_model)

# Then load regular model (should succeed)
interpreter = runtime.Interpreter.from_bytes(model_data)
self.assertIsNotNone(interpreter)


if __name__ == '__main__':
# Set TF environment variables to suppress warnings
# Suppress TF C++ info/debug logs (0=DEBUG, 1=INFO, 2=WARNING, 3=ERROR)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
# Disable oneDNN to avoid non-deterministic floating point results
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
tf.test.main()
Loading