Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
4e5bf4d
feat(compression): implement model_editor for TFLite model manipulation
rkuester Feb 2, 2026
1581db6
refactor(compression): migrate compress.py from model_facade to model…
rkuester Feb 2, 2026
1367adf
chore(compression): remove model_facade.py
rkuester Feb 2, 2026
890c89c
refactor(compression): replace test_models with model_editor in compr…
rkuester Feb 2, 2026
97b89d8
chore(compression): remove test_models.py
rkuester Feb 2, 2026
b06d064
feat(compression): add DECODE operator types and metadata
rkuester Feb 2, 2026
4786bb4
feat(compression): add Compressor protocol
rkuester Feb 2, 2026
9d2c66e
feat(compression): add LUT compression plugin
rkuester Feb 2, 2026
0286a21
feat(compression): add Huffman and Pruning compression support
rkuester Feb 2, 2026
279067d
feat(python): add alt decompression memory parameter to interpreter
rkuester Feb 2, 2026
eb54f0e
feat(compression): add DECODE operator insertion
rkuester Feb 2, 2026
1182972
refactor(compression): use plugin architecture in compress.py
rkuester Feb 2, 2026
049a5ee
test(compression): add integration tests with TFLM interpreter
rkuester Feb 2, 2026
1e5651f
test(compression): add proprietary model integration test
rkuester Feb 2, 2026
66fae7c
refactor(compression): compressors inherit from Compressor protocol
rkuester Feb 3, 2026
d2ac3ce
feat(python): register DECODE op unconditionally
rkuester Feb 10, 2026
84df17e
test(python): rewrite unsupported-compression test for legacy path
rkuester Feb 10, 2026
6005ce4
test(compression): add tests for batched DECODE insertion
rkuester Feb 10, 2026
5bfdd9c
feat(compression): batch multiple compressed tensors per DECODE
rkuester Feb 10, 2026
aa5679d
feat(compression): reject empty compression spec
rkuester Feb 10, 2026
d40a84e
docs(python): explain env vars in test runner
rkuester Feb 10, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/tflite_micro/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ py_test(
":runtime",
requirement("numpy"),
requirement("tensorflow"),
"//tensorflow/lite/micro/compression",
"//tensorflow/lite/micro/compression:model_editor",
],
)

Expand Down
9 changes: 5 additions & 4 deletions python/tflite_micro/_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,11 @@ PYBIND11_MODULE(_runtime, m) {
.def(py::init([](const py::bytes& data,
const std::vector<std::string>& registerers_by_name,
size_t arena_size, int num_resource_variables,
tflite::InterpreterConfig config) {
return std::unique_ptr<InterpreterWrapper>(
new InterpreterWrapper(data.ptr(), registerers_by_name, arena_size,
num_resource_variables, config));
tflite::InterpreterConfig config,
size_t alt_decompression_memory_size) {
return std::unique_ptr<InterpreterWrapper>(new InterpreterWrapper(
data.ptr(), registerers_by_name, arena_size, num_resource_variables,
config, alt_decompression_memory_size));
}))
.def("PrintAllocations", &InterpreterWrapper::PrintAllocations)
.def("Invoke", &InterpreterWrapper::Invoke)
Expand Down
18 changes: 16 additions & 2 deletions python/tflite_micro/interpreter_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,14 @@ InterpreterWrapper::~InterpreterWrapper() {

InterpreterWrapper::InterpreterWrapper(
PyObject* model_data, const std::vector<std::string>& registerers_by_name,
size_t arena_size, int num_resource_variables, InterpreterConfig config) {
size_t arena_size, int num_resource_variables, InterpreterConfig config,
size_t alt_decompression_memory_size)
: memory_arena_(new uint8_t[arena_size]),
alt_decompression_memory_(alt_decompression_memory_size > 0
? new uint8_t[alt_decompression_memory_size]
: nullptr),
alt_decompression_region_{alt_decompression_memory_.get(),
alt_decompression_memory_size} {
interpreter_ = nullptr;

// `model_data` is used as a raw pointer beyond the scope of this
Expand Down Expand Up @@ -266,7 +273,6 @@ InterpreterWrapper::InterpreterWrapper(
"--//:with_compression=true to enable compression support.");
}

memory_arena_ = std::unique_ptr<uint8_t[]>(new uint8_t[arena_size]);
for (const std::string& registerer : registerers_by_name) {
if (!AddCustomOpRegistererByName(registerer.c_str(),
&python_ops_resolver_)) {
Expand Down Expand Up @@ -296,6 +302,14 @@ InterpreterWrapper::InterpreterWrapper(
interpreter_ = new MicroInterpreter(model, python_ops_resolver_, allocator_,
resource_variables_);

if (alt_decompression_memory_size > 0) {
TfLiteStatus status =
interpreter_->SetDecompressionMemory(&alt_decompression_region_, 1);
if (status != kTfLiteOk) {
ThrowRuntimeError("TFLM failed to set decompression memory");
}
}

TfLiteStatus status = interpreter_->AllocateTensors();
if (status != kTfLiteOk) {
ThrowRuntimeError("TFLM failed to allocate tensors");
Expand Down
6 changes: 5 additions & 1 deletion python/tflite_micro/interpreter_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.

#include "python/tflite_micro/python_ops_resolver.h"
#include "tensorflow/lite/micro/micro_allocator.h"
#include "tensorflow/lite/micro/micro_context.h"
#include "tensorflow/lite/micro/micro_interpreter.h"
#include "tensorflow/lite/micro/recording_micro_allocator.h"

Expand All @@ -40,7 +41,8 @@ class InterpreterWrapper {
InterpreterWrapper(
PyObject* model_data, const std::vector<std::string>& registerers_by_name,
size_t arena_size, int num_resource_variables,
InterpreterConfig config = InterpreterConfig::kAllocationRecording);
InterpreterConfig config = InterpreterConfig::kAllocationRecording,
size_t alt_decompression_memory_size = 0);
~InterpreterWrapper();

void PrintAllocations();
Expand All @@ -57,6 +59,8 @@ class InterpreterWrapper {
tflite::RecordingMicroAllocator* recording_allocator_ = nullptr;
const PyObject* model_;
std::unique_ptr<uint8_t[]> memory_arena_;
std::unique_ptr<uint8_t[]> alt_decompression_memory_;
tflite::MicroContext::AlternateMemoryRegion alt_decompression_region_;
tflite::PythonOpsResolver python_ops_resolver_;
tflite::MicroInterpreter* interpreter_;
};
Expand Down
12 changes: 12 additions & 0 deletions python/tflite_micro/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def __init__(
custom_op_registerers,
arena_size,
intrepreter_config=InterpreterConfig.kAllocationRecording,
alt_decompression_memory_size=0,
):
if model_data is None:
raise ValueError("Model must not be None")
Expand All @@ -94,6 +95,7 @@ def __init__(
arena_size,
num_resource_variables,
_ENUM_TRANSLATOR[intrepreter_config],
alt_decompression_memory_size,
)

@classmethod
Expand All @@ -103,6 +105,7 @@ def from_file(
custom_op_registerers=[],
arena_size=None,
intrepreter_config=InterpreterConfig.kAllocationRecording,
alt_decompression_memory_size=0,
):
"""Instantiates a TFLM interpreter from a model .tflite filepath.

Expand All @@ -112,6 +115,9 @@ def from_file(
custom OP registerer
arena_size: Tensor arena size in bytes. If unused, tensor arena size will
default to 10 times the model size.
alt_decompression_memory_size: Size in bytes of alternate decompression
memory. If non-zero, DECODE operators will use this memory instead of
the main arena for decompressed tensor outputs.

Returns:
An Interpreter instance
Expand All @@ -127,6 +133,7 @@ def from_file(
custom_op_registerers,
arena_size,
intrepreter_config,
alt_decompression_memory_size,
)

@classmethod
Expand All @@ -136,6 +143,7 @@ def from_bytes(
custom_op_registerers=[],
arena_size=None,
intrepreter_config=InterpreterConfig.kAllocationRecording,
alt_decompression_memory_size=0,
):
"""Instantiates a TFLM interpreter from a model in byte array.

Expand All @@ -145,6 +153,9 @@ def from_bytes(
custom OP registerer
arena_size: Tensor arena size in bytes. If unused, tensor arena size will
default to 10 times the model size.
alt_decompression_memory_size: Size in bytes of alternate decompression
memory. If non-zero, DECODE operators will use this memory instead of
the main arena for decompressed tensor outputs.

Returns:
An Interpreter instance
Expand All @@ -155,6 +166,7 @@ def from_bytes(
custom_op_registerers,
arena_size,
intrepreter_config,
alt_decompression_memory_size,
)

def print_allocations(self):
Expand Down
98 changes: 49 additions & 49 deletions python/tflite_micro/test_compression_unsupported.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,84 +12,84 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test compression metadata detection when compression is disabled."""
"""Test legacy compression metadata detection when compression is disabled."""

import os
import numpy as np
import tensorflow as tf
from tflite_micro.python.tflite_micro import runtime
from tflite_micro.tensorflow.lite.micro import compression
from tflite_micro.tensorflow.lite.micro.compression import model_editor


class CompressionDetectionTest(tf.test.TestCase):
"""Test compression metadata detection when compression is disabled."""
def _create_test_model():
"""Create a simple quantized model for testing."""
model = tf.keras.Sequential([
tf.keras.layers.Dense(10, input_shape=(5,), activation='relu'),
tf.keras.layers.Dense(5, activation='softmax')
])
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy')

def _create_test_model(self):
"""Create a simple quantized model for testing."""
model = tf.keras.Sequential([
tf.keras.layers.Dense(10, input_shape=(5, ), activation='relu'),
tf.keras.layers.Dense(5, activation='softmax')
])
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy')
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]

# Convert to quantized TFLite
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
def representative_dataset():
for _ in range(10):
yield [np.random.randn(1, 5).astype(np.float32)]

def representative_dataset():
for _ in range(10):
yield [np.random.randn(1, 5).astype(np.float32)]
converter.representative_dataset = representative_dataset
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8

converter.representative_dataset = representative_dataset
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8
tflite_model = converter.convert()
return bytes(tflite_model) if isinstance(tflite_model,
bytearray) else tflite_model

tflite_model = converter.convert()
return bytes(tflite_model) if isinstance(tflite_model,
bytearray) else tflite_model

def _inject_compression_metadata(model_data):
"""Inject raw COMPRESSION_METADATA into a model's flatbuffer metadata.

This simulates a legacy-compressed model (one that uses the
COMPRESSION_METADATA metadata entry and kernel-level decompression) without
going through compress(), which now produces DECODE-based output.
"""
model = model_editor.read(model_data)
model.metadata["COMPRESSION_METADATA"] = b"\x00"
return bytes(model.build())


class LegacyCompressionDetectionTest(tf.test.TestCase):
"""Test that legacy COMPRESSION_METADATA is rejected without the flag."""

def test_regular_model_loads_successfully(self):
"""Non-compressed models should load without issues."""
model_data = self._create_test_model()
model_data = _create_test_model()
interpreter = runtime.Interpreter.from_bytes(model_data)
self.assertIsNotNone(interpreter)

def test_compressed_model_raises_runtime_error(self):
"""Compressed models should raise RuntimeError when compression is disabled."""
# Create and compress a model
model_data = self._create_test_model()
def test_legacy_compressed_model_raises_runtime_error(self):
"""Models with COMPRESSION_METADATA should raise RuntimeError."""
model_data = _create_test_model()
legacy_model = _inject_compression_metadata(model_data)

spec = (compression.SpecBuilder().add_tensor(
subgraph=0, tensor=1).with_lut(index_bitwidth=4).build())

compressed_model = compression.compress(model_data, spec)
if isinstance(compressed_model, bytearray):
compressed_model = bytes(compressed_model)

# Should raise RuntimeError
with self.assertRaises(RuntimeError):
runtime.Interpreter.from_bytes(compressed_model)

def test_can_load_regular_after_compressed_failure(self):
"""Verify we can still load regular models after compressed model fails."""
model_data = self._create_test_model()
runtime.Interpreter.from_bytes(legacy_model)

# First try compressed model (should fail)
spec = (compression.SpecBuilder().add_tensor(
subgraph=0, tensor=1).with_lut(index_bitwidth=4).build())
compressed_model = compression.compress(model_data, spec)
def test_can_load_regular_after_legacy_failure(self):
"""Verify regular models still load after a legacy-compressed failure."""
model_data = _create_test_model()
legacy_model = _inject_compression_metadata(model_data)

with self.assertRaises(RuntimeError):
runtime.Interpreter.from_bytes(bytes(compressed_model))
runtime.Interpreter.from_bytes(legacy_model)

# Then load regular model (should succeed)
interpreter = runtime.Interpreter.from_bytes(model_data)
self.assertIsNotNone(interpreter)


if __name__ == '__main__':
# Set TF environment variables to suppress warnings
# Suppress TF C++ info/debug logs (0=DEBUG, 1=INFO, 2=WARNING, 3=ERROR)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
# Disable oneDNN to avoid non-deterministic floating point results
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
Comment on lines 92 to 94
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should add a comment on the meaning of these environment vars

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in d40a84e.

tf.test.main()
Loading