-
Notifications
You must be signed in to change notification settings - Fork 3.9k
[Tests][TFLite] Add SQUEEZE, REVERSE_SEQUENCE, UNPACK, and ZEROS_LIKE coverage #19814
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -1456,6 +1456,29 @@ def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float3 | |||||||||||||||||||||
| verify(ReverseV2, Expected) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def test_reverse_sequence(): | ||||||||||||||||||||||
| mod = _load_model_from_buffer(_build_tflite_reverse_sequence_model()) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| @I.ir_module | ||||||||||||||||||||||
| class Expected: | ||||||||||||||||||||||
| @R.function | ||||||||||||||||||||||
| def main( | ||||||||||||||||||||||
| tvmgen_tensor_0: R.Tensor((2, 4, 3), dtype="float32"), | ||||||||||||||||||||||
| tvmgen_tensor_1: R.Tensor((2,), dtype="int32"), | ||||||||||||||||||||||
| ) -> R.Tensor((2, 4, 3), dtype="float32"): | ||||||||||||||||||||||
| R.func_attr({"num_input": 2}) | ||||||||||||||||||||||
| with R.dataflow(): | ||||||||||||||||||||||
| gv: R.Tensor((2, 4, 3), dtype="float32") = R.call_dps_packed( | ||||||||||||||||||||||
| "topi.reverse_sequence", | ||||||||||||||||||||||
| (tvmgen_tensor_0, tvmgen_tensor_1, 1), | ||||||||||||||||||||||
| out_sinfo=R.Tensor((2, 4, 3), dtype="float32"), | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
|
Comment on lines
+1471
to
+1475
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Update the expected IR module to match the explicit passing of
Suggested change
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as above. The current |
||||||||||||||||||||||
| R.output(gv) | ||||||||||||||||||||||
| return gv | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| tvm.ir.assert_structural_equal(mod, Expected) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def test_gather(): | ||||||||||||||||||||||
| class Gather(tf.Module): | ||||||||||||||||||||||
| @tf.function( | ||||||||||||||||||||||
|
|
@@ -1513,6 +1536,73 @@ def main( | |||||||||||||||||||||
| verify(GatherND, Expected) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def test_squeeze(): | ||||||||||||||||||||||
| mod = _load_model_from_buffer(_build_tflite_squeeze_model()) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| @I.ir_module | ||||||||||||||||||||||
| class Expected: | ||||||||||||||||||||||
| @R.function | ||||||||||||||||||||||
| def main(tvmgen_tensor_0: R.Tensor((1, 2, 1, 3), dtype="float32")) -> R.Tensor( | ||||||||||||||||||||||
| (2, 3), dtype="float32" | ||||||||||||||||||||||
| ): | ||||||||||||||||||||||
| R.func_attr({"num_input": 1}) | ||||||||||||||||||||||
| with R.dataflow(): | ||||||||||||||||||||||
| gv: R.Tensor((2, 3), dtype="float32") = R.squeeze(tvmgen_tensor_0, axis=[0, 2]) | ||||||||||||||||||||||
| R.output(gv) | ||||||||||||||||||||||
| return gv | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| tvm.ir.assert_structural_equal(mod, Expected) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def test_unpack(): | ||||||||||||||||||||||
| mod = _load_model_from_buffer(_build_tflite_unpack_model()) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| @I.ir_module | ||||||||||||||||||||||
| class Expected: | ||||||||||||||||||||||
| @R.function | ||||||||||||||||||||||
| def main(tvmgen_tensor_0: R.Tensor((2, 3, 4), dtype="float32")) -> R.Tuple( | ||||||||||||||||||||||
| R.Tensor((2, 4), dtype="float32"), | ||||||||||||||||||||||
| R.Tensor((2, 4), dtype="float32"), | ||||||||||||||||||||||
| R.Tensor((2, 4), dtype="float32"), | ||||||||||||||||||||||
| ): | ||||||||||||||||||||||
| R.func_attr({"num_input": 1}) | ||||||||||||||||||||||
| with R.dataflow(): | ||||||||||||||||||||||
| lv: R.Tuple( | ||||||||||||||||||||||
| R.Tensor((2, 1, 4), dtype="float32"), | ||||||||||||||||||||||
| R.Tensor((2, 1, 4), dtype="float32"), | ||||||||||||||||||||||
| R.Tensor((2, 1, 4), dtype="float32"), | ||||||||||||||||||||||
| ) = R.split(tvmgen_tensor_0, indices_or_sections=3, axis=1) | ||||||||||||||||||||||
| lv1: R.Tensor((2, 1, 4), dtype="float32") = lv[0] | ||||||||||||||||||||||
| lv2: R.Tensor((2, 4), dtype="float32") = R.squeeze(lv1, axis=[1]) | ||||||||||||||||||||||
| lv3: R.Tensor((2, 1, 4), dtype="float32") = lv[1] | ||||||||||||||||||||||
| lv4: R.Tensor((2, 4), dtype="float32") = R.squeeze(lv3, axis=[1]) | ||||||||||||||||||||||
| lv5: R.Tensor((2, 1, 4), dtype="float32") = lv[2] | ||||||||||||||||||||||
| lv6: R.Tensor((2, 4), dtype="float32") = R.squeeze(lv5, axis=[1]) | ||||||||||||||||||||||
| gv = (lv2, lv4, lv6) | ||||||||||||||||||||||
| R.output(gv) | ||||||||||||||||||||||
| return gv | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| tvm.ir.assert_structural_equal(mod, Expected) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def test_zeros_like(): | ||||||||||||||||||||||
| mod = _load_model_from_buffer(_build_tflite_zeros_like_model()) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| @I.ir_module | ||||||||||||||||||||||
| class Expected: | ||||||||||||||||||||||
| @R.function | ||||||||||||||||||||||
| def main(tvmgen_tensor_0: R.Tensor((2, 3), dtype="float32")) -> R.Tensor( | ||||||||||||||||||||||
| (2, 3), dtype="float32" | ||||||||||||||||||||||
| ): | ||||||||||||||||||||||
| R.func_attr({"num_input": 1}) | ||||||||||||||||||||||
| with R.dataflow(): | ||||||||||||||||||||||
| gv: R.Tensor((2, 3), dtype="float32") = R.zeros_like(tvmgen_tensor_0) | ||||||||||||||||||||||
| R.output(gv) | ||||||||||||||||||||||
| return gv | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| tvm.ir.assert_structural_equal(mod, Expected) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def _make_conv2d_module(data_shape, kernel_shape, data_format, strides, padding): | ||||||||||||||||||||||
| class Conv2DModule(tf.Module): | ||||||||||||||||||||||
| @tf.function( | ||||||||||||||||||||||
|
|
@@ -3711,7 +3801,11 @@ def _get_tflite_schema_enum(enum_name): | |||||||||||||||||||||
| _tfl_sparsity_parameters = _get_tflite_schema_module("SparsityParameters") | ||||||||||||||||||||||
| _tfl_subgraph = _get_tflite_schema_module("SubGraph") | ||||||||||||||||||||||
| _tfl_tensor = _get_tflite_schema_module("Tensor") | ||||||||||||||||||||||
| _tfl_reverse_sequence_options = _get_tflite_schema_module("ReverseSequenceOptions") | ||||||||||||||||||||||
| _tfl_squeeze_options = _get_tflite_schema_module("SqueezeOptions") | ||||||||||||||||||||||
| _tfl_unpack_options = _get_tflite_schema_module("UnpackOptions") | ||||||||||||||||||||||
| _tfl_while_options = _get_tflite_schema_module("WhileOptions") | ||||||||||||||||||||||
| _tfl_zeros_like_options = _get_tflite_schema_module("ZerosLikeOptions") | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| _tfl_builtin_operator = _get_tflite_schema_enum("BuiltinOperator") | ||||||||||||||||||||||
| _tfl_builtin_options = _get_tflite_schema_enum("BuiltinOptions") | ||||||||||||||||||||||
|
|
@@ -3967,6 +4061,31 @@ def _build_call_once_options(builder, init_subgraph_index): | |||||||||||||||||||||
| return _tfl_call_once_options.CallOnceOptionsEnd(builder) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def _build_squeeze_options(builder, squeeze_dims): | ||||||||||||||||||||||
| squeeze_dims_vec = _tflite_int32_vector( | ||||||||||||||||||||||
| builder, | ||||||||||||||||||||||
| _tfl_squeeze_options.SqueezeOptionsStartSqueezeDimsVector, | ||||||||||||||||||||||
| squeeze_dims, | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| _tfl_squeeze_options.SqueezeOptionsStart(builder) | ||||||||||||||||||||||
| _tfl_squeeze_options.SqueezeOptionsAddSqueezeDims(builder, squeeze_dims_vec) | ||||||||||||||||||||||
| return _tfl_squeeze_options.SqueezeOptionsEnd(builder) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def _build_reverse_sequence_options(builder, seq_dim, batch_dim): | ||||||||||||||||||||||
| _tfl_reverse_sequence_options.ReverseSequenceOptionsStart(builder) | ||||||||||||||||||||||
| _tfl_reverse_sequence_options.ReverseSequenceOptionsAddSeqDim(builder, seq_dim) | ||||||||||||||||||||||
| _tfl_reverse_sequence_options.ReverseSequenceOptionsAddBatchDim(builder, batch_dim) | ||||||||||||||||||||||
| return _tfl_reverse_sequence_options.ReverseSequenceOptionsEnd(builder) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def _build_unpack_options(builder, num, axis): | ||||||||||||||||||||||
| _tfl_unpack_options.UnpackOptionsStart(builder) | ||||||||||||||||||||||
| _tfl_unpack_options.UnpackOptionsAddNum(builder, num) | ||||||||||||||||||||||
| _tfl_unpack_options.UnpackOptionsAddAxis(builder, axis) | ||||||||||||||||||||||
| return _tfl_unpack_options.UnpackOptionsEnd(builder) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def _get_builtin_options_type(options_name): | ||||||||||||||||||||||
| if not hasattr(_tfl_builtin_options, options_name): | ||||||||||||||||||||||
| pytest.skip(f"TFLite schema does not provide BuiltinOptions.{options_name}") | ||||||||||||||||||||||
|
|
@@ -4066,6 +4185,146 @@ def _get_builtin_operator(builtin_name): | |||||||||||||||||||||
| return getattr(_tfl_builtin_operator, builtin_name) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def _build_tflite_squeeze_model(): | ||||||||||||||||||||||
| builder = flatbuffers.Builder(1024) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| squeeze_opts = _build_squeeze_options(builder, [0, 2]) | ||||||||||||||||||||||
| squeeze_op_code = _build_operator_code(builder, _tfl_builtin_operator.SQUEEZE) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| tensors = [ | ||||||||||||||||||||||
| _build_tensor(builder, 0, [1, 2, 1, 3]), | ||||||||||||||||||||||
| _build_tensor(builder, 0, [2, 3]), | ||||||||||||||||||||||
| ] | ||||||||||||||||||||||
| squeeze_op = _build_operator( | ||||||||||||||||||||||
| builder, | ||||||||||||||||||||||
| 0, | ||||||||||||||||||||||
| [0], | ||||||||||||||||||||||
| [1], | ||||||||||||||||||||||
| builtin_options_type=_tfl_builtin_options.SqueezeOptions, | ||||||||||||||||||||||
| builtin_options=squeeze_opts, | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| subgraph = _build_subgraph( | ||||||||||||||||||||||
| builder, | ||||||||||||||||||||||
| tensors=tensors, | ||||||||||||||||||||||
| operators=[squeeze_op], | ||||||||||||||||||||||
| inputs=[0], | ||||||||||||||||||||||
| outputs=[1], | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| buffers = [_build_buffer(builder)] | ||||||||||||||||||||||
| return _finish_tflite_model( | ||||||||||||||||||||||
| builder, | ||||||||||||||||||||||
| subgraph=subgraph, | ||||||||||||||||||||||
| operator_codes=[squeeze_op_code], | ||||||||||||||||||||||
| buffers=buffers, | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def _build_tflite_reverse_sequence_model(): | ||||||||||||||||||||||
| builder = flatbuffers.Builder(1024) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| reverse_sequence_opts = _build_reverse_sequence_options(builder, seq_dim=1, batch_dim=0) | ||||||||||||||||||||||
| reverse_sequence_op_code = _build_operator_code(builder, _tfl_builtin_operator.REVERSE_SEQUENCE) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| tensors = [ | ||||||||||||||||||||||
| _build_tensor(builder, 0, [2, 4, 3]), | ||||||||||||||||||||||
| _build_tensor(builder, 0, [2], tensor_type=_tfl_tensor_type.INT32), | ||||||||||||||||||||||
| _build_tensor(builder, 0, [2, 4, 3]), | ||||||||||||||||||||||
| ] | ||||||||||||||||||||||
| reverse_sequence_op = _build_operator( | ||||||||||||||||||||||
| builder, | ||||||||||||||||||||||
| 0, | ||||||||||||||||||||||
| [0, 1], | ||||||||||||||||||||||
| [2], | ||||||||||||||||||||||
| builtin_options_type=_tfl_builtin_options.ReverseSequenceOptions, | ||||||||||||||||||||||
| builtin_options=reverse_sequence_opts, | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| subgraph = _build_subgraph( | ||||||||||||||||||||||
| builder, | ||||||||||||||||||||||
| tensors=tensors, | ||||||||||||||||||||||
| operators=[reverse_sequence_op], | ||||||||||||||||||||||
| inputs=[0, 1], | ||||||||||||||||||||||
| outputs=[2], | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| buffers = [_build_buffer(builder)] | ||||||||||||||||||||||
| return _finish_tflite_model( | ||||||||||||||||||||||
| builder, | ||||||||||||||||||||||
| subgraph=subgraph, | ||||||||||||||||||||||
| operator_codes=[reverse_sequence_op_code], | ||||||||||||||||||||||
| buffers=buffers, | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def _build_tflite_unpack_model(): | ||||||||||||||||||||||
| builder = flatbuffers.Builder(1024) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| unpack_opts = _build_unpack_options(builder, num=3, axis=1) | ||||||||||||||||||||||
| unpack_op_code = _build_operator_code(builder, _tfl_builtin_operator.UNPACK) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| tensors = [ | ||||||||||||||||||||||
| _build_tensor(builder, 0, [2, 3, 4]), | ||||||||||||||||||||||
| _build_tensor(builder, 0, [2, 4]), | ||||||||||||||||||||||
| _build_tensor(builder, 0, [2, 4]), | ||||||||||||||||||||||
| _build_tensor(builder, 0, [2, 4]), | ||||||||||||||||||||||
| ] | ||||||||||||||||||||||
| unpack_op = _build_operator( | ||||||||||||||||||||||
| builder, | ||||||||||||||||||||||
| 0, | ||||||||||||||||||||||
| [0], | ||||||||||||||||||||||
| [1, 2, 3], | ||||||||||||||||||||||
| builtin_options_type=_tfl_builtin_options.UnpackOptions, | ||||||||||||||||||||||
| builtin_options=unpack_opts, | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| subgraph = _build_subgraph( | ||||||||||||||||||||||
| builder, | ||||||||||||||||||||||
| tensors=tensors, | ||||||||||||||||||||||
| operators=[unpack_op], | ||||||||||||||||||||||
| inputs=[0], | ||||||||||||||||||||||
| outputs=[1, 2, 3], | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| buffers = [_build_buffer(builder)] | ||||||||||||||||||||||
| return _finish_tflite_model( | ||||||||||||||||||||||
| builder, | ||||||||||||||||||||||
| subgraph=subgraph, | ||||||||||||||||||||||
| operator_codes=[unpack_op_code], | ||||||||||||||||||||||
| buffers=buffers, | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def _build_tflite_zeros_like_model(): | ||||||||||||||||||||||
| builder = flatbuffers.Builder(1024) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| _tfl_zeros_like_options.ZerosLikeOptionsStart(builder) | ||||||||||||||||||||||
| zeros_like_opts = _tfl_zeros_like_options.ZerosLikeOptionsEnd(builder) | ||||||||||||||||||||||
| zeros_like_op_code = _build_operator_code(builder, _tfl_builtin_operator.ZEROS_LIKE) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| tensors = [ | ||||||||||||||||||||||
| _build_tensor(builder, 0, [2, 3]), | ||||||||||||||||||||||
| _build_tensor(builder, 0, [2, 3]), | ||||||||||||||||||||||
| ] | ||||||||||||||||||||||
| zeros_like_op = _build_operator( | ||||||||||||||||||||||
| builder, | ||||||||||||||||||||||
| 0, | ||||||||||||||||||||||
| [0], | ||||||||||||||||||||||
| [1], | ||||||||||||||||||||||
| builtin_options_type=_tfl_builtin_options.ZerosLikeOptions, | ||||||||||||||||||||||
| builtin_options=zeros_like_opts, | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| subgraph = _build_subgraph( | ||||||||||||||||||||||
| builder, | ||||||||||||||||||||||
| tensors=tensors, | ||||||||||||||||||||||
| operators=[zeros_like_op], | ||||||||||||||||||||||
| inputs=[0], | ||||||||||||||||||||||
| outputs=[1], | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| buffers = [_build_buffer(builder)] | ||||||||||||||||||||||
| return _finish_tflite_model( | ||||||||||||||||||||||
| builder, | ||||||||||||||||||||||
| subgraph=subgraph, | ||||||||||||||||||||||
| operator_codes=[zeros_like_op_code], | ||||||||||||||||||||||
| buffers=buffers, | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def _run_module(mod, *inputs): | ||||||||||||||||||||||
| tgt = tvm.target.Target("c") | ||||||||||||||||||||||
| ex = tvm.compile(mod, tgt) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
topi.reverse_sequenceoperator expects 4 arguments:(data, seq_lengths, seq_axis, batch_axis). Omittingbatch_axiscan lead to runtime errors (such as arity mismatch) when calling the packed function, especially in non-Python runtimes or environments where default arguments are not automatically resolved. Sincebatch_axisis validated to be0, we should explicitly pass it as the fourth argument.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the careful check. I looked at the current packed registration in
src/topi/transform.cc, andtopi.reverse_sequenceis registered as a3-argument packed function:
So passing
(data, seq_lengths, seq_axis, batch_axis)from Relax would not matchthe current packed API. The frontend already rejects
batch_dim != 0before thecall, and the 3-argument packed wrapper therefore intentionally uses TOPI's
default
batch_axis=0.I think the current 3-argument call and the matching Expected IR are the safer
form for this PR. If we want to support non-zero
batch_dimlater, the rightfollow-up would be to first extend the
topi.reverse_sequencepackedregistration to consume
args[3], and then update the TFLite frontend and testexpectation to pass
batch_axisexplicitly.