diff --git a/tensorflow/lite/micro/compression/BUILD b/tensorflow/lite/micro/compression/BUILD index 2352b4b508a..36725fac63c 100644 --- a/tensorflow/lite/micro/compression/BUILD +++ b/tensorflow/lite/micro/compression/BUILD @@ -124,7 +124,7 @@ py_library( ], deps = [ ":metadata_py", - ":model_facade", + ":model_editor", ":spec", "//tensorflow/lite/micro/tools:tflite_flatbuffer_align", requirement("absl_py"), @@ -160,7 +160,7 @@ py_test( deps = [ ":compress", ":metadata_py", - ":model_facade", + ":model_editor", ":spec", ":test_models", "//tensorflow/lite/python:schema_py", diff --git a/tensorflow/lite/micro/compression/compress.py b/tensorflow/lite/micro/compression/compress.py index bd67bf5637b..b6d5aef4435 100644 --- a/tensorflow/lite/micro/compression/compress.py +++ b/tensorflow/lite/micro/compression/compress.py @@ -29,7 +29,7 @@ import flatbuffers import numpy as np -from tflite_micro.tensorflow.lite.micro.compression import model_facade +from tflite_micro.tensorflow.lite.micro.compression import model_editor from tflite_micro.tensorflow.lite.micro.compression import spec from tflite_micro.tensorflow.lite.micro.compression import metadata_py_generated as schema from tflite_micro.tensorflow.lite.micro.tools import tflite_flatbuffer_align_wrapper @@ -177,7 +177,7 @@ def _check_lut_compression(compression) -> spec.LookUpTableCompression: return compression[0] -def _identify_compression_axis(tensor: model_facade._Tensor) -> Optional[int]: +def _identify_compression_axis(tensor: model_editor.Tensor) -> Optional[int]: """Determines the axis along which to compress. The axis along which to compress is inferred from the tensor's quantization @@ -191,16 +191,18 @@ def _identify_compression_axis(tensor: model_facade._Tensor) -> Optional[int]: CompressionError: If the axis cannot be determined. """ q = tensor.quantization - if q is not None \ - and q.scale is not None \ - and q.quantizedDimension < len(tensor.shape): - quantization_channels = len(q.scale) + if q is not None: + # model_editor wraps quantization, access scales/axis from wrapper + scales = q.scales if isinstance(q.scales, list) else [q.scales] + quantization_channels = len(scales) + if quantization_channels == 1: # Use one value table for the entire tensor return None - if quantization_channels == tensor.shape[q.quantizedDimension]: - return q.quantizedDimension + if q.axis is not None and q.axis < len(tensor.shape): + if quantization_channels == tensor.shape[q.axis]: + return q.axis raise CompressionError( f"Invalid or no quanitzation parameters from which to " @@ -300,7 +302,7 @@ def compress(model_in: ByteString, specs: Iterable[spec.Tensor]) -> bytearray: Returns: A compressed flatbuffer. """ - model = model_facade.read(model_in) + model = model_editor.read(model_in) metadata = _MetadataBuilder() for spec in specs: @@ -316,12 +318,14 @@ def compress(model_in: ByteString, specs: Iterable[spec.Tensor]) -> bytearray: tensor.buffer.data = _pack_indices(compressed.indices, spec_bitwidth) # write value buffer - value_buffer = model.add_buffer() - value_buffer.data = _pack_lookup_tables(compressed.lookup_tables, + value_buffer_data = _pack_lookup_tables(compressed.lookup_tables, 2**spec_bitwidth) + value_buffer = model_editor.Buffer(data=value_buffer_data) + model.buffers.append(value_buffer) # Auto-sets value_buffer.index + # add compression metadata for tensor - lut_tensor = metadata.add_lut_tensor(subgraph_id=tensor.subgraph.index) - lut_tensor.tensor = tensor.index + lut_tensor = metadata.add_lut_tensor(subgraph_id=spec.subgraph) + lut_tensor.tensor = spec.tensor lut_tensor.valueBuffer = value_buffer.index lut_tensor.indexBitwidth = spec_bitwidth @@ -329,10 +333,10 @@ def compress(model_in: ByteString, specs: Iterable[spec.Tensor]) -> bytearray: raise CompressionError(f"error compressing {spec}") from e # add compression metadata to model - model.add_metadata(TFLITE_METADATA_KEY, metadata.compile()) + model.metadata[TFLITE_METADATA_KEY] = metadata.compile() - # Compile the model and apply proper alignment - unaligned_model = model.compile() + # Build the model and apply proper alignment + unaligned_model = model.build() return _apply_flatbuffer_alignment(unaligned_model) diff --git a/tensorflow/lite/micro/compression/compress_test.py b/tensorflow/lite/micro/compression/compress_test.py index 012957acc90..ee10a75f36d 100644 --- a/tensorflow/lite/micro/compression/compress_test.py +++ b/tensorflow/lite/micro/compression/compress_test.py @@ -19,7 +19,7 @@ from tflite_micro.tensorflow.lite.micro.compression import compress from tflite_micro.tensorflow.lite.micro.compression import metadata_py_generated as schema -from tflite_micro.tensorflow.lite.micro.compression import model_facade +from tflite_micro.tensorflow.lite.micro.compression import model_editor from tflite_micro.tensorflow.lite.micro.compression import spec from tflite_micro.tensorflow.lite.micro.compression import test_models from tflite_micro.tensorflow.lite.python import schema_py_generated as tflite @@ -368,12 +368,12 @@ class TestsCompression(unittest.TestCase): def setUpClass(cls): super().setUpClass() cls.flatbuffer = test_models.build(TEST_MODEL) - cls.uncompressed = model_facade.read(cls.flatbuffer) + cls.uncompressed = model_editor.read(cls.flatbuffer) def test_compression_metadata(self): """The compressed model has compression metadata.""" compressed = compress.compress(self.flatbuffer, TEST_COMPRESSION_SPEC) - model = model_facade.read(compressed) + model = model_editor.read(compressed) self.assertIn("metadata0", self.uncompressed.metadata) self.assertIn(compress.TFLITE_METADATA_KEY, model.metadata) @@ -461,16 +461,17 @@ def setUpClass(cls): super().setUpClass() # Create a model uncompressed_fb = test_models.build(TEST_MODEL) - cls.uncompressed = model_facade.read(uncompressed_fb) + cls.uncompressed = model_editor.read(uncompressed_fb) # Compress the model compressed_fb = compress.compress(uncompressed_fb, TEST_COMPRESSION_SPEC) - cls.compressed = model_facade.read(compressed_fb) + cls.compressed = model_editor.read(compressed_fb) # Extract the compression metadata - metadata_flatbuffer = cls.compressed.metadata[compress.TFLITE_METADATA_KEY] - cls.metadata = schema.MetadataT.InitFromPackedBuf(metadata_flatbuffer.data, - 0) + metadata_flatbuffer_bytes = cls.compressed.metadata[ + compress.TFLITE_METADATA_KEY] + cls.metadata = schema.MetadataT.InitFromPackedBuf( + metadata_flatbuffer_bytes, 0) def test_uncompressed_tensors(self): """Tensors not in compression spec are not compressed. @@ -515,7 +516,7 @@ def _get_compressed( indices = indices[:n_indices * bitwidth] # trim possible padding value_buffer = self.compressed.buffers[lut_tensor.valueBuffer] - values = np.frombuffer(value_buffer.data, dtype=tensor_obj.dtype) + values = np.frombuffer(value_buffer.data, dtype=tensor_obj.numpy_dtype) return bitwidth, indices, values