Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tensorflow/lite/micro/compression/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ py_library(
],
deps = [
":metadata_py",
":model_facade",
":model_editor",
":spec",
"//tensorflow/lite/micro/tools:tflite_flatbuffer_align",
requirement("absl_py"),
Expand Down Expand Up @@ -160,7 +160,7 @@ py_test(
deps = [
":compress",
":metadata_py",
":model_facade",
":model_editor",
":spec",
":test_models",
"//tensorflow/lite/python:schema_py",
Expand Down
36 changes: 20 additions & 16 deletions tensorflow/lite/micro/compression/compress.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import flatbuffers
import numpy as np

from tflite_micro.tensorflow.lite.micro.compression import model_facade
from tflite_micro.tensorflow.lite.micro.compression import model_editor
from tflite_micro.tensorflow.lite.micro.compression import spec
from tflite_micro.tensorflow.lite.micro.compression import metadata_py_generated as schema
from tflite_micro.tensorflow.lite.micro.tools import tflite_flatbuffer_align_wrapper
Expand Down Expand Up @@ -177,7 +177,7 @@ def _check_lut_compression(compression) -> spec.LookUpTableCompression:
return compression[0]


def _identify_compression_axis(tensor: model_facade._Tensor) -> Optional[int]:
def _identify_compression_axis(tensor: model_editor.Tensor) -> Optional[int]:
"""Determines the axis along which to compress.

The axis along which to compress is inferred from the tensor's quantization
Expand All @@ -191,16 +191,18 @@ def _identify_compression_axis(tensor: model_facade._Tensor) -> Optional[int]:
CompressionError: If the axis cannot be determined.
"""
q = tensor.quantization
if q is not None \
and q.scale is not None \
and q.quantizedDimension < len(tensor.shape):
quantization_channels = len(q.scale)
if q is not None:
# model_editor wraps quantization, access scales/axis from wrapper
scales = q.scales if isinstance(q.scales, list) else [q.scales]
quantization_channels = len(scales)

if quantization_channels == 1:
# Use one value table for the entire tensor
return None

if quantization_channels == tensor.shape[q.quantizedDimension]:
return q.quantizedDimension
if q.axis is not None and q.axis < len(tensor.shape):
if quantization_channels == tensor.shape[q.axis]:
return q.axis

raise CompressionError(
f"Invalid or no quanitzation parameters from which to "
Expand Down Expand Up @@ -300,7 +302,7 @@ def compress(model_in: ByteString, specs: Iterable[spec.Tensor]) -> bytearray:
Returns:
A compressed flatbuffer.
"""
model = model_facade.read(model_in)
model = model_editor.read(model_in)
metadata = _MetadataBuilder()

for spec in specs:
Expand All @@ -316,23 +318,25 @@ def compress(model_in: ByteString, specs: Iterable[spec.Tensor]) -> bytearray:
tensor.buffer.data = _pack_indices(compressed.indices, spec_bitwidth)

# write value buffer
value_buffer = model.add_buffer()
value_buffer.data = _pack_lookup_tables(compressed.lookup_tables,
value_buffer_data = _pack_lookup_tables(compressed.lookup_tables,
2**spec_bitwidth)
value_buffer = model_editor.Buffer(data=value_buffer_data)
model.buffers.append(value_buffer) # Auto-sets value_buffer.index

# add compression metadata for tensor
lut_tensor = metadata.add_lut_tensor(subgraph_id=tensor.subgraph.index)
lut_tensor.tensor = tensor.index
lut_tensor = metadata.add_lut_tensor(subgraph_id=spec.subgraph)
lut_tensor.tensor = spec.tensor
lut_tensor.valueBuffer = value_buffer.index
lut_tensor.indexBitwidth = spec_bitwidth

except Exception as e:
raise CompressionError(f"error compressing {spec}") from e

# add compression metadata to model
model.add_metadata(TFLITE_METADATA_KEY, metadata.compile())
model.metadata[TFLITE_METADATA_KEY] = metadata.compile()

# Compile the model and apply proper alignment
unaligned_model = model.compile()
# Build the model and apply proper alignment
unaligned_model = model.build()
return _apply_flatbuffer_alignment(unaligned_model)


Expand Down
19 changes: 10 additions & 9 deletions tensorflow/lite/micro/compression/compress_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from tflite_micro.tensorflow.lite.micro.compression import compress
from tflite_micro.tensorflow.lite.micro.compression import metadata_py_generated as schema
from tflite_micro.tensorflow.lite.micro.compression import model_facade
from tflite_micro.tensorflow.lite.micro.compression import model_editor
from tflite_micro.tensorflow.lite.micro.compression import spec
from tflite_micro.tensorflow.lite.micro.compression import test_models
from tflite_micro.tensorflow.lite.python import schema_py_generated as tflite
Expand Down Expand Up @@ -368,12 +368,12 @@ class TestsCompression(unittest.TestCase):
def setUpClass(cls):
super().setUpClass()
cls.flatbuffer = test_models.build(TEST_MODEL)
cls.uncompressed = model_facade.read(cls.flatbuffer)
cls.uncompressed = model_editor.read(cls.flatbuffer)

def test_compression_metadata(self):
"""The compressed model has compression metadata."""
compressed = compress.compress(self.flatbuffer, TEST_COMPRESSION_SPEC)
model = model_facade.read(compressed)
model = model_editor.read(compressed)
self.assertIn("metadata0", self.uncompressed.metadata)
self.assertIn(compress.TFLITE_METADATA_KEY, model.metadata)

Expand Down Expand Up @@ -461,16 +461,17 @@ def setUpClass(cls):
super().setUpClass()
# Create a model
uncompressed_fb = test_models.build(TEST_MODEL)
cls.uncompressed = model_facade.read(uncompressed_fb)
cls.uncompressed = model_editor.read(uncompressed_fb)

# Compress the model
compressed_fb = compress.compress(uncompressed_fb, TEST_COMPRESSION_SPEC)
cls.compressed = model_facade.read(compressed_fb)
cls.compressed = model_editor.read(compressed_fb)

# Extract the compression metadata
metadata_flatbuffer = cls.compressed.metadata[compress.TFLITE_METADATA_KEY]
cls.metadata = schema.MetadataT.InitFromPackedBuf(metadata_flatbuffer.data,
0)
metadata_flatbuffer_bytes = cls.compressed.metadata[
compress.TFLITE_METADATA_KEY]
cls.metadata = schema.MetadataT.InitFromPackedBuf(
metadata_flatbuffer_bytes, 0)

def test_uncompressed_tensors(self):
"""Tensors not in compression spec are not compressed.
Expand Down Expand Up @@ -515,7 +516,7 @@ def _get_compressed(
indices = indices[:n_indices * bitwidth] # trim possible padding

value_buffer = self.compressed.buffers[lut_tensor.valueBuffer]
values = np.frombuffer(value_buffer.data, dtype=tensor_obj.dtype)
values = np.frombuffer(value_buffer.data, dtype=tensor_obj.numpy_dtype)

return bitwidth, indices, values

Expand Down