Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 24 additions & 34 deletions python/tflite_micro/runtime_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,13 @@
import numpy as np
import tensorflow as tf

from tensorflow.python.framework import test_util
from tensorflow.python.platform import test
import unittest
from tflite_micro.python.tflite_micro import runtime
from tflite_micro.tensorflow.lite.micro.examples.recipes import add_four_numbers
from tflite_micro.tensorflow.lite.micro.testing import generate_test_models


class PeserveAllTensorsTest(test_util.TensorFlowTestCase):
class PeserveAllTensorsTest(unittest.TestCase):

def AddFourNumbersTestInterpreterMaker(self, inputs):
"""Returns a tflm interpreter with a simple model that loads 4 numbers loaded
Expand Down Expand Up @@ -95,14 +94,13 @@ def testGetTensorAllUniqueTensors(self):
self.assertEqual(len(set(tensors)), 7)


class ConvModelTests(test_util.TensorFlowTestCase):
class ConvModelTests(unittest.TestCase):
filename = "/tmp/interpreter_test_conv_model.tflite"
input_shape = (1, 16, 16, 1)
output_shape = (1, 10)

def testInitErrorHandling(self):
with self.assertRaisesWithPredicateMatch(ValueError,
"Invalid model file path"):
with self.assertRaisesRegex(ValueError, "Invalid model file path"):
runtime.Interpreter.from_file("wrong.tflite")

def testInput(self):
Expand All @@ -114,7 +112,7 @@ def testInput(self):

# Test input tensor details
input_details = tflm_interpreter.get_input_details(0)
self.assertAllEqual(input_details["shape"], self.input_shape)
np.testing.assert_array_equal(input_details["shape"], self.input_shape)
# Single channel int8 quantization
self.assertEqual(input_details["dtype"], np.int8)
self.assertEqual(len(input_details["quantization_parameters"]["scales"]),
Expand All @@ -134,26 +132,22 @@ def testInputErrorHandling(self):

data_x = np.random.randint(-127, 127, self.input_shape, dtype=np.int8)
# Try to access out of bound data
with self.assertRaisesWithPredicateMatch(IndexError,
"Tensor is out of bound"):
with self.assertRaisesRegex(IndexError, "Tensor is out of bound"):
tflm_interpreter.set_input(data_x, 1)
# Pass data with wrong dimension
with self.assertRaisesWithPredicateMatch(ValueError,
"Dimension mismatch."):
with self.assertRaisesRegex(ValueError, "Dimension mismatch."):
reshaped_data = data_x.reshape((1, 16, 16, 1, 1))
tflm_interpreter.set_input(reshaped_data, 0)
# Pass data with wrong dimension in one axis
with self.assertRaisesWithPredicateMatch(ValueError,
"Dimension mismatch."):
with self.assertRaisesRegex(ValueError, "Dimension mismatch."):
reshaped_data = data_x.reshape((1, 2, 128, 1))
tflm_interpreter.set_input(reshaped_data, 0)
# Pass data with wrong type
with self.assertRaisesWithPredicateMatch(ValueError, "Got value of type"):
with self.assertRaisesRegex(ValueError, "Got value of type"):
float_data = data_x.astype(np.float32)
tflm_interpreter.set_input(float_data, 0)
# Reach wrong details
with self.assertRaisesWithPredicateMatch(IndexError,
"Tensor is out of bound"):
with self.assertRaisesRegex(IndexError, "Tensor is out of bound"):
tflm_interpreter.get_input_details(1)

def testOutput(self):
Expand All @@ -162,7 +156,7 @@ def testOutput(self):

# Test the output tensor details
output_details = tflm_interpreter.get_output_details(0)
self.assertAllEqual(output_details["shape"], self.output_shape)
np.testing.assert_array_equal(output_details["shape"], self.output_shape)
# Single channel int8 quantization
self.assertEqual(output_details["dtype"], np.int8)
self.assertEqual(len(output_details["quantization_parameters"]["scales"]),
Expand All @@ -180,11 +174,9 @@ def testOutputErrorHandling(self):
model_data = generate_test_models.generate_conv_model(True, self.filename)
tflm_interpreter = runtime.Interpreter.from_bytes(model_data)
# Try to access out of bound data
with self.assertRaisesWithPredicateMatch(IndexError,
"Tensor is out of bound"):
with self.assertRaisesRegex(IndexError, "Tensor is out of bound"):
tflm_interpreter.get_output(1)
with self.assertRaisesWithPredicateMatch(IndexError,
"Tensor is out of bound"):
with self.assertRaisesRegex(IndexError, "Tensor is out of bound"):
tflm_interpreter.get_output_details(1)

def testCompareWithTFLite(self):
Expand Down Expand Up @@ -219,9 +211,9 @@ def testCompareWithTFLite(self):
tflm_output = tflm_interpreter.get_output(0)

# Check that TFLM output has correct metadata
self.assertDTypeEqual(tflm_output, np.int8)
self.assertEqual(tflm_output.dtype, np.int8)
self.assertEqual(tflm_output.shape, self.output_shape)
self.assertAllEqual(tflite_output, tflm_output)
np.testing.assert_allclose(tflite_output, tflm_output, atol=1)

def _helperModelFromFileAndBufferEqual(self):
model_data = generate_test_models.generate_conv_model(True, self.filename)
Expand All @@ -241,12 +233,12 @@ def _helperModelFromFileAndBufferEqual(self):
bytes_interpreter.invoke()
bytes_output = bytes_interpreter.get_output(0)

self.assertDTypeEqual(file_output, np.int8)
self.assertEqual(file_output.dtype, np.int8)
self.assertEqual(file_output.shape, self.output_shape)
self.assertDTypeEqual(bytes_output, np.int8)
self.assertEqual(bytes_output.dtype, np.int8)
self.assertEqual(bytes_output.shape, self.output_shape)
# Same interpreter and model, should expect all equal
self.assertAllEqual(file_output, bytes_output)
np.testing.assert_array_equal(file_output, bytes_output)

def testModelFromFileAndBufferEqual(self):
self._helperModelFromFileAndBufferEqual()
Expand All @@ -270,9 +262,9 @@ def testMultipleInterpreters(self):
if prev_output is None:
prev_output = output

self.assertDTypeEqual(output, np.int8)
self.assertEqual(output.dtype, np.int8)
self.assertEqual(output.shape, self.output_shape)
self.assertAllEqual(output, prev_output)
np.testing.assert_array_equal(output, prev_output)

def _helperNoop(self):
pass
Expand Down Expand Up @@ -305,25 +297,23 @@ def testOutputTensorMemoryLeak(self):
def testMalformedCustomOps(self):
model_data = generate_test_models.generate_conv_model(False)
custom_op_registerers = [("wrong", "format")]
with self.assertRaisesWithPredicateMatch(ValueError,
"must be a list of strings"):
with self.assertRaisesRegex(ValueError, "must be a list of strings"):
interpreter = runtime.Interpreter.from_bytes(model_data,
custom_op_registerers)

custom_op_registerers = "WrongFormat"
with self.assertRaisesWithPredicateMatch(ValueError,
"must be a list of strings"):
with self.assertRaisesRegex(ValueError, "must be a list of strings"):
interpreter = runtime.Interpreter.from_bytes(model_data,
custom_op_registerers)

def testNonExistentCustomOps(self):
model_data = generate_test_models.generate_conv_model(False)
custom_op_registerers = ["SomeRandomOp"]
with self.assertRaisesWithPredicateMatch(
with self.assertRaisesRegex(
RuntimeError, "TFLM could not register custom op via SomeRandomOp"):
interpreter = runtime.Interpreter.from_bytes(model_data,
custom_op_registerers)


if __name__ == "__main__":
test.main()
unittest.main()
1 change: 0 additions & 1 deletion python/tflite_micro/signal/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +0,0 @@
# Empty file required by setuptools.find_packages to recognize this as a package
1 change: 0 additions & 1 deletion python/tflite_micro/signal/ops/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +0,0 @@
# Empty file required by setuptools.find_packages to recognize this as a package
16 changes: 7 additions & 9 deletions python/tflite_micro/signal/ops/delay_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@

import numpy as np
import tensorflow as tf
import unittest

from tflite_micro.python.tflite_micro.signal.ops import delay_op
from tflite_micro.python.tflite_micro.signal.utils import util


class DelayOpTest(tf.test.TestCase):
class DelayOpTest(unittest.TestCase):

def TestHelper(self, input_signal, delay_length, frame_size):
inner_dim_size = input_signal.shape[-1]
Expand All @@ -48,17 +49,14 @@ def TestHelper(self, input_signal, delay_length, frame_size):
interpreter = util.get_tflm_interpreter(concrete_function, func)

for i in range(frame_num):
in_frame = input_signal_padded[..., i * frame_size:(i + 1) * frame_size]
in_frame = np.copy(input_signal_padded[..., i * frame_size:(i + 1) *
frame_size])
# TFLM
interpreter.set_input(in_frame, 0)
interpreter.invoke()
out_frame_tflm = interpreter.get_output(0)
# TF
out_frame = self.evaluate(
delay_op.delay(in_frame, delay_length=delay_length))
delay_out[..., i * frame_size:(i + 1) * frame_size] = out_frame
self.assertAllEqual(out_frame, out_frame_tflm)
self.assertAllEqual(delay_out, delay_exp)
delay_out[..., i * frame_size:(i + 1) * frame_size] = out_frame_tflm
np.testing.assert_array_equal(delay_out, delay_exp)

def testFrameLargerThanDelay(self):
self.TestHelper(np.arange(0, 30, dtype=np.int16), 7, 10)
Expand All @@ -82,4 +80,4 @@ def testMultiDimensionalDelay(self):


if __name__ == '__main__':
tf.test.main()
unittest.main()
14 changes: 4 additions & 10 deletions python/tflite_micro/signal/ops/energy_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,16 @@

import numpy as np
import tensorflow as tf
import unittest

from tensorflow.python.platform import resource_loader
from tflite_micro.python.tflite_micro.signal.ops import energy_op
from tflite_micro.python.tflite_micro.signal.utils import util


class EnergyOpTest(tf.test.TestCase):
class EnergyOpTest(unittest.TestCase):

_PREFIX_PATH = resource_loader.get_path_to_datafile('')
_PREFIX_PATH = os.path.dirname(__file__)

def GetResource(self, filepath):
full_path = os.path.join(self._PREFIX_PATH, filepath)
Expand Down Expand Up @@ -56,13 +57,6 @@ def SingleEnergyTest(self, filename):
interpreter.set_input(in_frame, 0)
interpreter.invoke()
out_frame = interpreter.get_output(0)
for j in range(start_index, end_index):
self.assertEqual(out_frame_exp[j], out_frame[j])
# TF
out_frame = self.evaluate(
energy_op.energy(in_frame,
start_index=start_index,
end_index=end_index))
for j in range(start_index, end_index):
self.assertEqual(out_frame_exp[j], out_frame[j])
i += 2
Expand Down Expand Up @@ -134,4 +128,4 @@ def testEnergy(self):


if __name__ == '__main__':
tf.test.main()
unittest.main()
Loading
Loading