From e737c57ffa3c72764d0398f596dc6e4326ab0caa Mon Sep 17 00:00:00 2001 From: Esun Kim Date: Thu, 21 May 2026 11:29:56 -0700 Subject: [PATCH 1/3] Fixed LSTM --- Include/arm_nn_types.h | 1 + .../arm_lstm_unidirectional_s16.c | 58 +++++++++++++++---- .../arm_lstm_unidirectional_s8.c | 58 +++++++++++++++---- 3 files changed, 93 insertions(+), 24 deletions(-) diff --git a/Include/arm_nn_types.h b/Include/arm_nn_types.h index 2846e621..64bfa3e3 100644 --- a/Include/arm_nn_types.h +++ b/Include/arm_nn_types.h @@ -272,6 +272,7 @@ typedef struct void *temp1; void *temp2; void *cell_state; + void *hidden_state; } cmsis_nn_lstm_context; /** diff --git a/Source/LSTMFunctions/arm_lstm_unidirectional_s16.c b/Source/LSTMFunctions/arm_lstm_unidirectional_s16.c index 4d4ed021..f7910ffe 100644 --- a/Source/LSTMFunctions/arm_lstm_unidirectional_s16.c +++ b/Source/LSTMFunctions/arm_lstm_unidirectional_s16.c @@ -52,8 +52,13 @@ arm_cmsis_nn_status arm_lstm_unidirectional_s16(const int16_t *input, cmsis_nn_lstm_context *buffers) { - int16_t *hidden_in = NULL; - memset(buffers->cell_state, 0, params->batch_size * params->hidden_size * sizeof(int16_t)); + int16_t *hidden_in = (int16_t *)buffers->hidden_state; + + if (buffers->hidden_state == NULL) + { + memset(buffers->cell_state, 0, params->batch_size * params->hidden_size * sizeof(int16_t)); + } + if (params->time_major) { // First dimension is time, input/output for each time step is stored continously in memory @@ -69,22 +74,51 @@ arm_cmsis_nn_status arm_lstm_unidirectional_s16(const int16_t *input, // Output is used as recurrent input/hidden state for the next timestep. hidden_in = &hidden_out[0]; } + + if (buffers->hidden_state != NULL && params->time_steps > 0) + { + memcpy(buffers->hidden_state, hidden_in, params->batch_size * params->hidden_size * sizeof(int16_t)); + } } else { - // First dimension is time, add batch_offset to jump in memory for each batch - for (int t = 0; t < params->time_steps; t++) + // Batch major: [batch, time, size] + // arm_nn_lstm_step_s16 expects data_in and hidden_in to have the same batch_offset. + // Since the initial hidden_state is contiguous, we must process one batch at a time. + cmsis_nn_lstm_params step_params = *params; + step_params.batch_size = 1; + + for (int b = 0; b < params->batch_size; b++) { - const int16_t *data_in = input + (t * params->input_size); - int16_t *hidden_out = output + (t * params->hidden_size); - arm_cmsis_nn_status status = - arm_nn_lstm_step_s16(data_in, hidden_in, hidden_out, params, buffers, params->time_steps); - if (status != ARM_CMSIS_NN_SUCCESS) + int16_t *step_hidden_in = (buffers->hidden_state != NULL) + ? ((int16_t *)buffers->hidden_state + b * params->hidden_size) + : NULL; + + cmsis_nn_lstm_context step_buffers = *buffers; + step_buffers.cell_state = (int16_t *)buffers->cell_state + b * params->hidden_size; + + for (int t = 0; t < params->time_steps; t++) { - return status; + const int16_t *data_in = input + (b * params->time_steps + t) * params->input_size; + int16_t *hidden_out = output + (b * params->time_steps + t) * params->hidden_size; + + arm_cmsis_nn_status status = arm_nn_lstm_step_s16( + data_in, step_hidden_in, hidden_out, &step_params, &step_buffers, 1); + + if (status != ARM_CMSIS_NN_SUCCESS) + { + return status; + } + + step_hidden_in = hidden_out; + } + + if (buffers->hidden_state != NULL && params->time_steps > 0) + { + memcpy((int16_t *)buffers->hidden_state + b * params->hidden_size, + step_hidden_in, + params->hidden_size * sizeof(int16_t)); } - // Output is used as recurrent input/hidden state for the next timestep. - hidden_in = &hidden_out[0]; } } return ARM_CMSIS_NN_SUCCESS; diff --git a/Source/LSTMFunctions/arm_lstm_unidirectional_s8.c b/Source/LSTMFunctions/arm_lstm_unidirectional_s8.c index 86f404e9..d87c0862 100644 --- a/Source/LSTMFunctions/arm_lstm_unidirectional_s8.c +++ b/Source/LSTMFunctions/arm_lstm_unidirectional_s8.c @@ -52,8 +52,13 @@ arm_cmsis_nn_status arm_lstm_unidirectional_s8(const int8_t *input, cmsis_nn_lstm_context *buffers) { - int8_t *hidden_in = NULL; - memset(buffers->cell_state, 0, params->batch_size * params->hidden_size * sizeof(int16_t)); + int8_t *hidden_in = (int8_t *)buffers->hidden_state; + + if (buffers->hidden_state == NULL) + { + memset(buffers->cell_state, 0, params->batch_size * params->hidden_size * sizeof(int16_t)); + } + if (params->time_major) { // First dimension is time, input/output for each time step is stored continously in memory @@ -69,22 +74,51 @@ arm_cmsis_nn_status arm_lstm_unidirectional_s8(const int8_t *input, // Output is used as recurrent input/hidden state for the next timestep. hidden_in = &hidden_out[0]; } + + if (buffers->hidden_state != NULL && params->time_steps > 0) + { + memcpy(buffers->hidden_state, hidden_in, params->batch_size * params->hidden_size * sizeof(int8_t)); + } } else { - // First dimension is time, add batch_offset to jump in memory for each batch - for (int t = 0; t < params->time_steps; t++) + // Batch major: [batch, time, size] + // arm_nn_lstm_step_s8 expects data_in and hidden_in to have the same batch_offset. + // Since the initial hidden_state is contiguous, we must process one batch at a time. + cmsis_nn_lstm_params step_params = *params; + step_params.batch_size = 1; + + for (int b = 0; b < params->batch_size; b++) { - const int8_t *data_in = input + (t * params->input_size); - int8_t *hidden_out = output + (t * params->hidden_size); - arm_cmsis_nn_status status = - arm_nn_lstm_step_s8(data_in, hidden_in, hidden_out, params, buffers, params->time_steps); - if (status != ARM_CMSIS_NN_SUCCESS) + int8_t *step_hidden_in = (buffers->hidden_state != NULL) + ? ((int8_t *)buffers->hidden_state + b * params->hidden_size) + : NULL; + + cmsis_nn_lstm_context step_buffers = *buffers; + step_buffers.cell_state = (int16_t *)buffers->cell_state + b * params->hidden_size; + + for (int t = 0; t < params->time_steps; t++) { - return status; + const int8_t *data_in = input + (b * params->time_steps + t) * params->input_size; + int8_t *hidden_out = output + (b * params->time_steps + t) * params->hidden_size; + + arm_cmsis_nn_status status = arm_nn_lstm_step_s8( + data_in, step_hidden_in, hidden_out, &step_params, &step_buffers, 1); + + if (status != ARM_CMSIS_NN_SUCCESS) + { + return status; + } + + step_hidden_in = hidden_out; + } + + if (buffers->hidden_state != NULL && params->time_steps > 0) + { + memcpy((int8_t *)buffers->hidden_state + b * params->hidden_size, + step_hidden_in, + params->hidden_size * sizeof(int8_t)); } - // Output is used as recurrent input/hidden state for the next timestep. - hidden_in = &hidden_out[0]; } } return ARM_CMSIS_NN_SUCCESS; From d1da44fdedc6d727731562d807a8c0f0b44762ae Mon Sep 17 00:00:00 2001 From: Esun Kim Date: Tue, 26 May 2026 15:51:51 -0700 Subject: [PATCH 2/3] Added date/revision --- Include/arm_nn_types.h | 4 ++-- Source/LSTMFunctions/arm_lstm_unidirectional_s16.c | 4 ++-- Source/LSTMFunctions/arm_lstm_unidirectional_s8.c | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/Include/arm_nn_types.h b/Include/arm_nn_types.h index 64bfa3e3..f6192a44 100644 --- a/Include/arm_nn_types.h +++ b/Include/arm_nn_types.h @@ -22,8 +22,8 @@ * Description: Public header file to contain the CMSIS-NN structs for the * TensorFlowLite micro compliant functions * - * $Date: 21 Oct 2024 - * $Revision: V.3.5.0 + * $Date: 21 May 2026 + * $Revision: V.3.5.1 * * Target : Arm(R) M-Profile Architecture * -------------------------------------------------------------------- */ diff --git a/Source/LSTMFunctions/arm_lstm_unidirectional_s16.c b/Source/LSTMFunctions/arm_lstm_unidirectional_s16.c index f7910ffe..230e7bfb 100644 --- a/Source/LSTMFunctions/arm_lstm_unidirectional_s16.c +++ b/Source/LSTMFunctions/arm_lstm_unidirectional_s16.c @@ -21,8 +21,8 @@ * Title: arm_lstm_unidirectional_s16.c * Description: S16 LSTM function with S16 gate output * - * $Date: 26 March 2024 - * $Revision: V.1.0.0 + * $Date: 21 May 2026 + * $Revision: V.1.0.1 * * Target Processor: Cortex-M processors * diff --git a/Source/LSTMFunctions/arm_lstm_unidirectional_s8.c b/Source/LSTMFunctions/arm_lstm_unidirectional_s8.c index d87c0862..97f48490 100644 --- a/Source/LSTMFunctions/arm_lstm_unidirectional_s8.c +++ b/Source/LSTMFunctions/arm_lstm_unidirectional_s8.c @@ -21,8 +21,8 @@ * Title: arm_lstm_unidirectional_s8.c * Description: S8 LSTM function with S16 gate output * - * $Date: 08 February 2024 - * $Revision: V.1.1.0 + * $Date: 21 May 2026 + * $Revision: V.1.1.1 * * Target Processor: Cortex-M processors * From c11b2eeb73f0678fe1cb6b6d7a6b5a84037ec30a Mon Sep 17 00:00:00 2001 From: Esun Kim Date: Wed, 27 May 2026 15:32:09 -0700 Subject: [PATCH 3/3] Unit tests --- .../UnitTest/RefactoredTestGen/test_plan.json | 15 ++ .../cell_gate_bias.h | 6 + .../cell_gate_hidden_weights.h | 9 + .../cell_gate_input_weights.h | 8 + .../config_data.h | 35 ++++ .../forget_gate_bias.h | 6 + .../forget_gate_hidden_weights.h | 9 + .../forget_gate_input_weights.h | 8 + .../input_gate_bias.h | 6 + .../input_gate_hidden_weights.h | 9 + .../input_gate_input_weights.h | 8 + .../input_tensor.h | 7 + .../output.h | 8 + .../output_gate_bias.h | 6 + .../output_gate_hidden_weights.h | 9 + .../output_gate_input_weights.h | 8 + .../test_data.h | 15 ++ .../cell_gate_bias.h | 7 + .../cell_gate_hidden_weights.h | 9 + .../cell_gate_input_weights.h | 8 + .../config_data.h | 35 ++++ .../forget_gate_bias.h | 7 + .../forget_gate_hidden_weights.h | 9 + .../forget_gate_input_weights.h | 8 + .../input_gate_bias.h | 7 + .../input_gate_hidden_weights.h | 9 + .../input_gate_input_weights.h | 8 + .../input_tensor.h | 8 + .../output.h | 8 + .../output_gate_bias.h | 7 + .../output_gate_hidden_weights.h | 9 + .../output_gate_input_weights.h | 8 + .../test_data.h | 17 ++ .../unity_test_arm_lstm_unidirectional_s16.c | 1 + .../test_arm_lstm_unidirectional_s16.c | 158 +++++++++++++++++ .../unity_test_arm_lstm_unidirectional_s8.c | 1 + .../test_arm_lstm_unidirectional_s8.c | 166 ++++++++++++++++++ Tests/UnitTest/build_and_run_tests.sh | 7 +- 38 files changed, 667 insertions(+), 2 deletions(-) create mode 100644 Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch/cell_gate_bias.h create mode 100644 Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch/cell_gate_hidden_weights.h create mode 100644 Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch/cell_gate_input_weights.h create mode 100644 Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch/config_data.h create mode 100644 Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch/forget_gate_bias.h create mode 100644 Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch/forget_gate_hidden_weights.h create mode 100644 Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch/forget_gate_input_weights.h create mode 100644 Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch/input_gate_bias.h create mode 100644 Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch/input_gate_hidden_weights.h create mode 100644 Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch/input_gate_input_weights.h create mode 100644 Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch/input_tensor.h create mode 100644 Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch/output.h create mode 100644 Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch/output_gate_bias.h create mode 100644 Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch/output_gate_hidden_weights.h create mode 100644 Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch/output_gate_input_weights.h create mode 100644 Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch/test_data.h create mode 100644 Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch_s16/cell_gate_bias.h create mode 100644 Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch_s16/cell_gate_hidden_weights.h create mode 100644 Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch_s16/cell_gate_input_weights.h create mode 100644 Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch_s16/config_data.h create mode 100644 Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch_s16/forget_gate_bias.h create mode 100644 Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch_s16/forget_gate_hidden_weights.h create mode 100644 Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch_s16/forget_gate_input_weights.h create mode 100644 Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch_s16/input_gate_bias.h create mode 100644 Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch_s16/input_gate_hidden_weights.h create mode 100644 Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch_s16/input_gate_input_weights.h create mode 100644 Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch_s16/input_tensor.h create mode 100644 Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch_s16/output.h create mode 100644 Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch_s16/output_gate_bias.h create mode 100644 Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch_s16/output_gate_hidden_weights.h create mode 100644 Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch_s16/output_gate_input_weights.h create mode 100644 Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch_s16/test_data.h diff --git a/Tests/UnitTest/RefactoredTestGen/test_plan.json b/Tests/UnitTest/RefactoredTestGen/test_plan.json index 6b0cdb0b..0dae76ec 100644 --- a/Tests/UnitTest/RefactoredTestGen/test_plan.json +++ b/Tests/UnitTest/RefactoredTestGen/test_plan.json @@ -542,6 +542,14 @@ "input_size" : 22, "hidden_size" : 3, "json_template": "lstm_s16.json" + }, + {"name" : "lstm_stateful_batch_major_multibatch_s16", + "time_major" : false, + "batch_size" : 2, + "time_steps" : 2, + "input_size" : 6, + "hidden_size" : 7, + "json_template": "lstm_s16.json" } ] }, @@ -574,6 +582,13 @@ "time_steps" : 1, "input_size" : 22, "hidden_size" : 3 + }, + {"name" : "lstm_stateful_batch_major_multibatch", + "time_major" : false, + "batch_size" : 2, + "time_steps" : 2, + "input_size" : 6, + "hidden_size" : 7 } ] }, diff --git a/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch/cell_gate_bias.h b/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch/cell_gate_bias.h new file mode 100644 index 00000000..bba3b929 --- /dev/null +++ b/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch/cell_gate_bias.h @@ -0,0 +1,6 @@ +// Generated by generate_test_data.py using tensorflow version 2.18.0 (Keras version 3.14.1). +// Interpreter from tflite_micro runtime version 0.dev20260203175027-g88f8587. +#pragma once +#include + +const int32_t lstm_stateful_batch_major_multibatch_cell_gate_bias[7] = {0, 0, 0, 0, 0, 0, 0}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch/cell_gate_hidden_weights.h b/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch/cell_gate_hidden_weights.h new file mode 100644 index 00000000..8cb34220 --- /dev/null +++ b/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch/cell_gate_hidden_weights.h @@ -0,0 +1,9 @@ +// Generated by generate_test_data.py using tensorflow version 2.18.0 (Keras version 3.14.1). +// Interpreter from tflite_micro runtime version 0.dev20260203175027-g88f8587. +#pragma once +#include + +const int8_t lstm_stateful_batch_major_multibatch_cell_gate_hidden_weights[49] = { + -20, -88, 87, 87, 109, -54, -12, 21, 2, -112, 44, -79, -97, -15, 123, 105, -122, + -29, -83, 36, 58, 33, 59, -115, 127, -106, 101, -57, -97, -64, -39, 71, 4, -114, + -94, 74, 34, -12, -118, -64, 104, 102, -36, -114, 117, 95, -1, 67, 81}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch/cell_gate_input_weights.h b/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch/cell_gate_input_weights.h new file mode 100644 index 00000000..15bf1045 --- /dev/null +++ b/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch/cell_gate_input_weights.h @@ -0,0 +1,8 @@ +// Generated by generate_test_data.py using tensorflow version 2.18.0 (Keras version 3.14.1). +// Interpreter from tflite_micro runtime version 0.dev20260203175027-g88f8587. +#pragma once +#include + +const int8_t lstm_stateful_batch_major_multibatch_cell_gate_input_weights[42] = { + 92, 100, -29, 127, 76, -74, 115, -81, 63, 4, 69, -81, 8, -25, 42, 99, 44, 101, 12, -25, 99, + -70, -88, -41, 107, 65, 67, -31, 87, -54, -104, 95, 35, 21, 125, -87, 27, 78, -113, -114, 61, -101}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch/config_data.h b/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch/config_data.h new file mode 100644 index 00000000..3b53ecf3 --- /dev/null +++ b/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch/config_data.h @@ -0,0 +1,35 @@ +// Generated by generate_test_data.py using tensorflow version 2.18.0 (Keras version 3.14.1). +// Interpreter from tflite_micro runtime version 0.dev20260203175027-g88f8587. +#pragma once + +#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_TIME_MAJOR false +#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_BATCH_SIZE 2 +#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_TIME_STEPS 2 +#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_INPUT_SIZE 6 +#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_HIDDEN_SIZE 7 +#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_CELL_SCALE_POWER -15 +#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_INPUT_ZERO_POINT 128 +#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_OUTPUT_ZERO_POINT 4 +#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_CELL_CLIP 32767 +#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_FORGET_TO_CELL_MULTIPLIER 1073741824 +#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_FORGET_TO_CELL_SHIFT -14 +#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_INPUT_TO_CELL_MULTIPLIER 1073741824 +#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_INPUT_TO_CELL_SHIFT -14 +#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_OUTPUT_MULTIPLIER 1993694592 +#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_OUTPUT_SHIFT -21 +#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_OUTPUT_GATE_HIDDEN_MULTIPLIER 1143723136 +#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_OUTPUT_GATE_HIDDEN_SHIFT -3 +#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_CELL_GATE_HIDDEN_MULTIPLIER 1164696448 +#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_CELL_GATE_HIDDEN_SHIFT -3 +#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_FORGET_GATE_HIDDEN_MULTIPLIER 1134438656 +#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_FORGET_GATE_HIDDEN_SHIFT -3 +#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_INPUT_GATE_HIDDEN_MULTIPLIER 2044599040 +#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_INPUT_GATE_HIDDEN_SHIFT -4 +#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_OUTPUT_GATE_INPUT_MULTIPLIER 2130456576 +#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_OUTPUT_GATE_INPUT_SHIFT -3 +#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_CELL_GATE_INPUT_MULTIPLIER 2096726656 +#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_CELL_GATE_INPUT_SHIFT -3 +#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_FORGET_GATE_INPUT_MULTIPLIER 2025295488 +#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_FORGET_GATE_INPUT_SHIFT -3 +#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_INPUT_GATE_INPUT_MULTIPLIER 2146124928 +#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_INPUT_GATE_INPUT_SHIFT -3 diff --git a/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch/forget_gate_bias.h b/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch/forget_gate_bias.h new file mode 100644 index 00000000..f28cbb4f --- /dev/null +++ b/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch/forget_gate_bias.h @@ -0,0 +1,6 @@ +// Generated by generate_test_data.py using tensorflow version 2.18.0 (Keras version 3.14.1). +// Interpreter from tflite_micro runtime version 0.dev20260203175027-g88f8587. +#pragma once +#include + +const int32_t lstm_stateful_batch_major_multibatch_forget_gate_bias[7] = {0, 0, 0, 0, 0, 0, 0}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch/forget_gate_hidden_weights.h b/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch/forget_gate_hidden_weights.h new file mode 100644 index 00000000..a1d396f4 --- /dev/null +++ b/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch/forget_gate_hidden_weights.h @@ -0,0 +1,9 @@ +// Generated by generate_test_data.py using tensorflow version 2.18.0 (Keras version 3.14.1). +// Interpreter from tflite_micro runtime version 0.dev20260203175027-g88f8587. +#pragma once +#include + +const int8_t lstm_stateful_batch_major_multibatch_forget_gate_hidden_weights[49] = { + 8, -96, 106, 66, -101, 22, -70, 86, -37, -1, -127, 52, 9, 79, 111, -94, 126, + -21, -25, -79, 42, -57, -42, -3, 126, 51, -49, -28, -10, 50, -104, -48, -11, -78, + 36, 121, 6, -56, -24, -75, -104, -103, -119, -63, -69, -51, 11, -43, 19}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch/forget_gate_input_weights.h b/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch/forget_gate_input_weights.h new file mode 100644 index 00000000..f5593522 --- /dev/null +++ b/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch/forget_gate_input_weights.h @@ -0,0 +1,8 @@ +// Generated by generate_test_data.py using tensorflow version 2.18.0 (Keras version 3.14.1). +// Interpreter from tflite_micro runtime version 0.dev20260203175027-g88f8587. +#pragma once +#include + +const int8_t lstm_stateful_batch_major_multibatch_forget_gate_input_weights[42] = { + 58, -87, 92, 94, -41, -61, 31, -86, 72, 3, 116, 68, -81, -55, -44, -122, -105, -87, -23, 94, -113, + -68, 23, 45, -96, 13, 103, -116, 87, 88, -2, 6, 58, -126, 33, 28, -113, 68, -30, -84, 127, -33}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch/input_gate_bias.h b/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch/input_gate_bias.h new file mode 100644 index 00000000..a83c2251 --- /dev/null +++ b/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch/input_gate_bias.h @@ -0,0 +1,6 @@ +// Generated by generate_test_data.py using tensorflow version 2.18.0 (Keras version 3.14.1). +// Interpreter from tflite_micro runtime version 0.dev20260203175027-g88f8587. +#pragma once +#include + +const int32_t lstm_stateful_batch_major_multibatch_input_gate_bias[7] = {0, 0, 0, 0, 0, 0, 0}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch/input_gate_hidden_weights.h b/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch/input_gate_hidden_weights.h new file mode 100644 index 00000000..9be21756 --- /dev/null +++ b/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch/input_gate_hidden_weights.h @@ -0,0 +1,9 @@ +// Generated by generate_test_data.py using tensorflow version 2.18.0 (Keras version 3.14.1). +// Interpreter from tflite_micro runtime version 0.dev20260203175027-g88f8587. +#pragma once +#include + +const int8_t lstm_stateful_batch_major_multibatch_input_gate_hidden_weights[49] = { + -120, -121, -18, 3, 74, 27, -107, 84, 43, 127, 91, 116, 77, -17, -8, 32, 89, + 111, 85, -8, 9, -115, -36, 71, 70, -10, 70, -106, -9, -71, -83, -79, -73, -88, + 39, -124, -1, -2, 108, 117, -27, 61, -90, -32, -105, 97, 41, -98, 57}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch/input_gate_input_weights.h b/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch/input_gate_input_weights.h new file mode 100644 index 00000000..ba0d1204 --- /dev/null +++ b/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch/input_gate_input_weights.h @@ -0,0 +1,8 @@ +// Generated by generate_test_data.py using tensorflow version 2.18.0 (Keras version 3.14.1). +// Interpreter from tflite_micro runtime version 0.dev20260203175027-g88f8587. +#pragma once +#include + +const int8_t lstm_stateful_batch_major_multibatch_input_gate_input_weights[42] = { + -21, -53, 56, 127, -88, 66, -111, 94, -19, 90, 124, -67, 15, 3, 85, -34, 37, 65, 93, -86, -72, + -30, -121, -122, 5, -54, -83, -111, -6, 49, 99, -67, 0, 45, 85, -16, 54, 120, 36, 79, -119, -4}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch/input_tensor.h b/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch/input_tensor.h new file mode 100644 index 00000000..c5d89f8d --- /dev/null +++ b/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch/input_tensor.h @@ -0,0 +1,7 @@ +// Generated by generate_test_data.py using tensorflow version 2.18.0 (Keras version 3.14.1). +// Interpreter from tflite_micro runtime version 0.dev20260203175027-g88f8587. +#pragma once +#include + +const int8_t lstm_stateful_batch_major_multibatch_input_tensor[24] = { + -62, -21, 64, 39, 55, -93, -122, -119, 14, -86, -125, 79, -67, -83, 65, -111, 59, -111, 66, -115, -86, 3, 118, 75}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch/output.h b/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch/output.h new file mode 100644 index 00000000..a7788696 --- /dev/null +++ b/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch/output.h @@ -0,0 +1,8 @@ +// Generated by generate_test_data.py using tensorflow version 2.18.0 (Keras version 3.14.1). +// Interpreter from tflite_micro runtime version 0.dev20260203175027-g88f8587. +#pragma once +#include + +const int8_t lstm_stateful_batch_major_multibatch_output[28] = {72, 68, 89, -16, 29, 94, -68, 7, 15, 102, + 7, 2, 30, -105, 39, 90, 65, 5, 52, 115, + -27, 54, 105, 98, -48, 32, 78, -23}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch/output_gate_bias.h b/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch/output_gate_bias.h new file mode 100644 index 00000000..77688096 --- /dev/null +++ b/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch/output_gate_bias.h @@ -0,0 +1,6 @@ +// Generated by generate_test_data.py using tensorflow version 2.18.0 (Keras version 3.14.1). +// Interpreter from tflite_micro runtime version 0.dev20260203175027-g88f8587. +#pragma once +#include + +const int32_t lstm_stateful_batch_major_multibatch_output_gate_bias[7] = {0, 0, 0, 0, 0, 0, 0}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch/output_gate_hidden_weights.h b/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch/output_gate_hidden_weights.h new file mode 100644 index 00000000..6ef4f153 --- /dev/null +++ b/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch/output_gate_hidden_weights.h @@ -0,0 +1,9 @@ +// Generated by generate_test_data.py using tensorflow version 2.18.0 (Keras version 3.14.1). +// Interpreter from tflite_micro runtime version 0.dev20260203175027-g88f8587. +#pragma once +#include + +const int8_t lstm_stateful_batch_major_multibatch_output_gate_hidden_weights[49] = { + 29, -109, 63, -104, -1, -60, 94, -95, -65, 46, 52, -10, 50, -59, -29, 30, -106, + 108, 7, 88, -76, 5, 85, -121, 127, 35, 63, -97, -12, 84, -112, 109, -89, 19, + -41, 74, 83, 56, 53, 93, -127, -25, -104, 37, -75, 7, -10, 72, -111}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch/output_gate_input_weights.h b/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch/output_gate_input_weights.h new file mode 100644 index 00000000..e1b362dc --- /dev/null +++ b/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch/output_gate_input_weights.h @@ -0,0 +1,8 @@ +// Generated by generate_test_data.py using tensorflow version 2.18.0 (Keras version 3.14.1). +// Interpreter from tflite_micro runtime version 0.dev20260203175027-g88f8587. +#pragma once +#include + +const int8_t lstm_stateful_batch_major_multibatch_output_gate_input_weights[42] = { + -54, 67, -56, -37, -72, -6, -36, 1, 30, -117, -51, 94, -11, -95, -16, 52, -37, -99, -1, 66, -31, + 73, 52, 11, -93, 34, 9, -84, -77, -1, 94, -34, -61, -122, 115, -107, 2, -35, 127, -76, -75, -79}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch/test_data.h b/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch/test_data.h new file mode 100644 index 00000000..ff07f2a6 --- /dev/null +++ b/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch/test_data.h @@ -0,0 +1,15 @@ +#include "cell_gate_bias.h" +#include "cell_gate_hidden_weights.h" +#include "cell_gate_input_weights.h" +#include "config_data.h" +#include "forget_gate_bias.h" +#include "forget_gate_hidden_weights.h" +#include "forget_gate_input_weights.h" +#include "input_gate_bias.h" +#include "input_gate_hidden_weights.h" +#include "input_gate_input_weights.h" +#include "input_tensor.h" +#include "output.h" +#include "output_gate_bias.h" +#include "output_gate_hidden_weights.h" +#include "output_gate_input_weights.h" diff --git a/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch_s16/cell_gate_bias.h b/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch_s16/cell_gate_bias.h new file mode 100644 index 00000000..24656b7d --- /dev/null +++ b/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch_s16/cell_gate_bias.h @@ -0,0 +1,7 @@ +// Generated by generate_test_data.py using flatc version 25.9.23 +// Interpreter from tflite_micro runtime version 0.dev20260203175027-g88f8587. +#pragma once +#include + +const int64_t lstm_stateful_batch_major_multibatch_s16_cell_gate_bias[7] = + {13325, 17119, 32248, 26169, 8164, 22567, 4211}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch_s16/cell_gate_hidden_weights.h b/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch_s16/cell_gate_hidden_weights.h new file mode 100644 index 00000000..992d18ba --- /dev/null +++ b/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch_s16/cell_gate_hidden_weights.h @@ -0,0 +1,9 @@ +// Generated by generate_test_data.py using flatc version 25.9.23 +// Interpreter from tflite_micro runtime version 0.dev20260203175027-g88f8587. +#pragma once +#include + +const int8_t lstm_stateful_batch_major_multibatch_s16_cell_gate_hidden_weights[49] = { + -6, -40, 107, -122, -109, -62, 103, 112, -125, 57, -80, -126, -93, -77, 30, 2, -32, + 30, 35, 37, 10, 45, -3, -96, 20, 31, 110, 121, -74, -10, -82, 76, -63, 117, + -111, 24, -23, -81, 49, 102, -45, 97, 48, -9, -75, -90, 5, 88, -50}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch_s16/cell_gate_input_weights.h b/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch_s16/cell_gate_input_weights.h new file mode 100644 index 00000000..41438865 --- /dev/null +++ b/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch_s16/cell_gate_input_weights.h @@ -0,0 +1,8 @@ +// Generated by generate_test_data.py using flatc version 25.9.23 +// Interpreter from tflite_micro runtime version 0.dev20260203175027-g88f8587. +#pragma once +#include + +const int8_t lstm_stateful_batch_major_multibatch_s16_cell_gate_input_weights[42] = { + 49, 34, 68, -90, -112, -1, 22, -86, -44, -96, 84, 73, 118, 23, -40, 88, -83, 85, -61, 70, -32, + -4, -4, 109, -25, -77, -9, -122, -84, 122, 44, -15, -43, -60, -37, 64, -85, 73, 43, 74, 23, 25}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch_s16/config_data.h b/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch_s16/config_data.h new file mode 100644 index 00000000..07f67581 --- /dev/null +++ b/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch_s16/config_data.h @@ -0,0 +1,35 @@ +// Generated by generate_test_data.py using flatc version 25.9.23 +// Interpreter from tflite_micro runtime version 0.dev20260203175027-g88f8587. +#pragma once + +#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_TIME_MAJOR false +#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_BATCH_SIZE 2 +#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_TIME_STEPS 2 +#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_INPUT_SIZE 6 +#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_HIDDEN_SIZE 7 +#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_CELL_SCALE_POWER -9 +#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_OUTPUT_ZERO_POINT 0 +#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_INPUT_ZERO_POINT 0 +#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_CELL_CLIP 32767 +#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_FORGET_TO_CELL_MULTIPLIER 1073741824 +#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_FORGET_TO_CELL_SHIFT -14 +#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_INPUT_TO_CELL_MULTIPLIER 1308266999 +#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_INPUT_TO_CELL_SHIFT -20 +#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_OUTPUT_MULTIPLIER 1236528302 +#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_OUTPUT_SHIFT -19 +#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_OUTPUT_GATE_HIDDEN_MULTIPLIER 1604002043 +#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_OUTPUT_GATE_HIDDEN_SHIFT -9 +#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_CELL_GATE_HIDDEN_MULTIPLIER 1741487933 +#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_CELL_GATE_HIDDEN_SHIFT -11 +#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_FORGET_GATE_HIDDEN_MULTIPLIER 1405411314 +#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_FORGET_GATE_HIDDEN_SHIFT -8 +#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_INPUT_GATE_HIDDEN_MULTIPLIER 1789226089 +#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_INPUT_GATE_HIDDEN_SHIFT -8 +#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_OUTPUT_GATE_INPUT_MULTIPLIER 1209657853 +#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_OUTPUT_GATE_INPUT_SHIFT -9 +#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_CELL_GATE_INPUT_MULTIPLIER 1465381247 +#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_CELL_GATE_INPUT_SHIFT -10 +#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_FORGET_GATE_INPUT_MULTIPLIER 1683751785 +#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_FORGET_GATE_INPUT_SHIFT -11 +#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_INPUT_GATE_INPUT_MULTIPLIER 1384928943 +#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_INPUT_GATE_INPUT_SHIFT -9 diff --git a/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch_s16/forget_gate_bias.h b/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch_s16/forget_gate_bias.h new file mode 100644 index 00000000..6c313f64 --- /dev/null +++ b/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch_s16/forget_gate_bias.h @@ -0,0 +1,7 @@ +// Generated by generate_test_data.py using flatc version 25.9.23 +// Interpreter from tflite_micro runtime version 0.dev20260203175027-g88f8587. +#pragma once +#include + +const int64_t lstm_stateful_batch_major_multibatch_s16_forget_gate_bias[7] = + {20637, 23729, 19724, 13194, 29244, 9521, 32113}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch_s16/forget_gate_hidden_weights.h b/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch_s16/forget_gate_hidden_weights.h new file mode 100644 index 00000000..35183e70 --- /dev/null +++ b/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch_s16/forget_gate_hidden_weights.h @@ -0,0 +1,9 @@ +// Generated by generate_test_data.py using flatc version 25.9.23 +// Interpreter from tflite_micro runtime version 0.dev20260203175027-g88f8587. +#pragma once +#include + +const int8_t lstm_stateful_batch_major_multibatch_s16_forget_gate_hidden_weights[49] = { + -96, 89, 37, -113, -123, 81, -28, 111, -38, 88, 78, -56, 113, -96, 44, -85, -13, + -41, 76, -126, 8, -10, 48, -84, 96, 40, -78, -40, -43, 49, -94, -122, -15, 76, + 47, 41, -10, -17, -4, 33, 9, -57, -24, -98, -23, -100, 112, -51, -74}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch_s16/forget_gate_input_weights.h b/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch_s16/forget_gate_input_weights.h new file mode 100644 index 00000000..eadc768b --- /dev/null +++ b/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch_s16/forget_gate_input_weights.h @@ -0,0 +1,8 @@ +// Generated by generate_test_data.py using flatc version 25.9.23 +// Interpreter from tflite_micro runtime version 0.dev20260203175027-g88f8587. +#pragma once +#include + +const int8_t lstm_stateful_batch_major_multibatch_s16_forget_gate_input_weights[42] = { + -57, -94, 41, 110, 49, -128, 110, -120, 61, -17, 26, -68, 121, -74, 76, -29, 59, -32, -67, 13, 14, + 67, 116, -31, 94, -101, -114, 60, 120, -18, -105, 6, 123, 97, -95, -100, -108, -11, 116, 73, -61, -92}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch_s16/input_gate_bias.h b/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch_s16/input_gate_bias.h new file mode 100644 index 00000000..2e158e34 --- /dev/null +++ b/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch_s16/input_gate_bias.h @@ -0,0 +1,7 @@ +// Generated by generate_test_data.py using flatc version 25.9.23 +// Interpreter from tflite_micro runtime version 0.dev20260203175027-g88f8587. +#pragma once +#include + +const int64_t lstm_stateful_batch_major_multibatch_s16_input_gate_bias[7] = + {22079, 6235, 31102, 7436, 17647, 2634, 7285}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch_s16/input_gate_hidden_weights.h b/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch_s16/input_gate_hidden_weights.h new file mode 100644 index 00000000..080b603a --- /dev/null +++ b/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch_s16/input_gate_hidden_weights.h @@ -0,0 +1,9 @@ +// Generated by generate_test_data.py using flatc version 25.9.23 +// Interpreter from tflite_micro runtime version 0.dev20260203175027-g88f8587. +#pragma once +#include + +const int8_t lstm_stateful_batch_major_multibatch_s16_input_gate_hidden_weights[49] = { + 81, -28, -21, 121, 117, 24, 47, -1, 110, 76, -71, -77, 66, 42, 45, 74, 90, + 56, -92, -40, -25, -74, -55, -3, -65, 39, 22, -10, 63, 123, 100, 56, -94, -123, + -18, -91, 55, -98, 8, -44, 22, -71, -11, -60, 81, 45, 55, 119, 3}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch_s16/input_gate_input_weights.h b/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch_s16/input_gate_input_weights.h new file mode 100644 index 00000000..a957a769 --- /dev/null +++ b/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch_s16/input_gate_input_weights.h @@ -0,0 +1,8 @@ +// Generated by generate_test_data.py using flatc version 25.9.23 +// Interpreter from tflite_micro runtime version 0.dev20260203175027-g88f8587. +#pragma once +#include + +const int8_t lstm_stateful_batch_major_multibatch_s16_input_gate_input_weights[42] = { + 38, -47, -111, -26, 124, -69, -37, 116, 53, 5, -106, 38, 77, 14, 35, 38, -53, -44, 76, -26, 17, + 49, -85, -120, 80, 58, -56, -51, -126, 26, 17, -57, -78, -74, 18, 57, 88, -121, -25, -40, -89, -3}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch_s16/input_tensor.h b/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch_s16/input_tensor.h new file mode 100644 index 00000000..d12fd4a8 --- /dev/null +++ b/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch_s16/input_tensor.h @@ -0,0 +1,8 @@ +// Generated by generate_test_data.py using flatc version 25.9.23 +// Interpreter from tflite_micro runtime version 0.dev20260203175027-g88f8587. +#pragma once +#include + +const int16_t lstm_stateful_batch_major_multibatch_s16_input_tensor[24] = { + 28057, 25392, -13007, -4876, -9976, -24994, -29540, 16696, 12431, -17737, -20149, 5250, + 3699, 21061, -6754, 31374, 4217, 7014, 16679, -4841, 12023, 24303, 26425, -26243}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch_s16/output.h b/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch_s16/output.h new file mode 100644 index 00000000..721db8df --- /dev/null +++ b/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch_s16/output.h @@ -0,0 +1,8 @@ +// Generated by generate_test_data.py using flatc version 25.9.23 +// Interpreter from tflite_micro runtime version 0.dev20260203175027-g88f8587. +#pragma once +#include + +const int16_t lstm_stateful_batch_major_multibatch_s16_output[28] = { + 191, -140, 117, -230, -470, 40, -90, 203, -123, -37, 29, 98, 11, 5, + -89, -284, 239, 160, -320, -51, 99, -242, -313, 62, -275, -292, -170, 53}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch_s16/output_gate_bias.h b/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch_s16/output_gate_bias.h new file mode 100644 index 00000000..9be4404d --- /dev/null +++ b/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch_s16/output_gate_bias.h @@ -0,0 +1,7 @@ +// Generated by generate_test_data.py using flatc version 25.9.23 +// Interpreter from tflite_micro runtime version 0.dev20260203175027-g88f8587. +#pragma once +#include + +const int64_t lstm_stateful_batch_major_multibatch_s16_output_gate_bias[7] = + {823, 25824, 10847, 19946, 24942, 28970, 26737}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch_s16/output_gate_hidden_weights.h b/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch_s16/output_gate_hidden_weights.h new file mode 100644 index 00000000..1df57bff --- /dev/null +++ b/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch_s16/output_gate_hidden_weights.h @@ -0,0 +1,9 @@ +// Generated by generate_test_data.py using flatc version 25.9.23 +// Interpreter from tflite_micro runtime version 0.dev20260203175027-g88f8587. +#pragma once +#include + +const int8_t lstm_stateful_batch_major_multibatch_s16_output_gate_hidden_weights[49] = { + -44, -33, -65, 37, -118, 53, 61, -2, -106, -110, 106, 90, -56, -19, 44, 6, 71, + -43, -12, -47, 75, -25, -62, 18, -104, -29, -33, -83, 30, -91, 6, -2, -120, 38, + -96, 67, 87, -49, -58, -48, 29, -100, -71, -85, -14, 62, -24, 118, -38}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch_s16/output_gate_input_weights.h b/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch_s16/output_gate_input_weights.h new file mode 100644 index 00000000..cb5ee062 --- /dev/null +++ b/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch_s16/output_gate_input_weights.h @@ -0,0 +1,8 @@ +// Generated by generate_test_data.py using flatc version 25.9.23 +// Interpreter from tflite_micro runtime version 0.dev20260203175027-g88f8587. +#pragma once +#include + +const int8_t lstm_stateful_batch_major_multibatch_s16_output_gate_input_weights[42] = { + -74, 30, -26, -97, 44, -26, -1, -54, 15, 91, 99, -32, -80, -28, -127, -22, 105, 55, 110, 113, 126, + 4, 8, 107, 113, 82, -5, 21, 15, 46, 64, 13, -94, -23, 4, -102, 9, -88, -30, 69, -79, 71}; diff --git a/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch_s16/test_data.h b/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch_s16/test_data.h new file mode 100644 index 00000000..4c489a85 --- /dev/null +++ b/Tests/UnitTest/TestCases/TestData/lstm_stateful_batch_major_multibatch_s16/test_data.h @@ -0,0 +1,17 @@ +// Generated by generate_test_data.py using flatc version 25.9.23 +// Interpreter from tflite_micro runtime version 0.dev20260203175027-g88f8587. +#include "cell_gate_bias.h" +#include "cell_gate_hidden_weights.h" +#include "cell_gate_input_weights.h" +#include "config_data.h" +#include "forget_gate_bias.h" +#include "forget_gate_hidden_weights.h" +#include "forget_gate_input_weights.h" +#include "input_gate_bias.h" +#include "input_gate_hidden_weights.h" +#include "input_gate_input_weights.h" +#include "input_tensor.h" +#include "output.h" +#include "output_gate_bias.h" +#include "output_gate_hidden_weights.h" +#include "output_gate_input_weights.h" diff --git a/Tests/UnitTest/TestCases/test_arm_lstm_unidirectional_s16/Unity/unity_test_arm_lstm_unidirectional_s16.c b/Tests/UnitTest/TestCases/test_arm_lstm_unidirectional_s16/Unity/unity_test_arm_lstm_unidirectional_s16.c index b9ef170e..462677eb 100644 --- a/Tests/UnitTest/TestCases/test_arm_lstm_unidirectional_s16/Unity/unity_test_arm_lstm_unidirectional_s16.c +++ b/Tests/UnitTest/TestCases/test_arm_lstm_unidirectional_s16/Unity/unity_test_arm_lstm_unidirectional_s16.c @@ -46,3 +46,4 @@ void tearDown(void) {} void test_lstm_1_s16(void) { lstm_1_s16(); } void test_lstm_2_s16(void) { lstm_2_s16(); } void test_lstm_one_time_step_s16(void) { lstm_one_time_step_s16(); } +void test_lstm_stateful_batch_major_multibatch_s16(void) { lstm_stateful_batch_major_multibatch_s16(); } diff --git a/Tests/UnitTest/TestCases/test_arm_lstm_unidirectional_s16/test_arm_lstm_unidirectional_s16.c b/Tests/UnitTest/TestCases/test_arm_lstm_unidirectional_s16/test_arm_lstm_unidirectional_s16.c index 462ed55e..2bda5148 100644 --- a/Tests/UnitTest/TestCases/test_arm_lstm_unidirectional_s16/test_arm_lstm_unidirectional_s16.c +++ b/Tests/UnitTest/TestCases/test_arm_lstm_unidirectional_s16/test_arm_lstm_unidirectional_s16.c @@ -19,6 +19,7 @@ #include "../TestData/lstm_1_s16/test_data.h" #include "../TestData/lstm_2_s16/test_data.h" #include "../TestData/lstm_one_time_step_s16/test_data.h" +#include "../TestData/lstm_stateful_batch_major_multibatch_s16/test_data.h" #include "../Utils/validate.h" #include #include @@ -173,6 +174,7 @@ void lstm_1_s16(void) buffers.temp1 = buffer1; buffers.temp2 = buffer2; buffers.cell_state = buffer3; + buffers.hidden_state = NULL; arm_cmsis_nn_status result = arm_lstm_unidirectional_s16(lstm_1_s16_input_tensor, output, ¶ms, &buffers); @@ -319,6 +321,7 @@ void lstm_2_s16(void) buffers.temp1 = buffer1; buffers.temp2 = buffer2; buffers.cell_state = buffer3; + buffers.hidden_state = NULL; arm_cmsis_nn_status result = arm_lstm_unidirectional_s16(lstm_2_s16_input_tensor, output, ¶ms, &buffers); @@ -467,10 +470,165 @@ void lstm_one_time_step_s16(void) buffers.temp1 = buffer1; buffers.temp2 = buffer2; buffers.cell_state = buffer3; + buffers.hidden_state = NULL; arm_cmsis_nn_status result = arm_lstm_unidirectional_s16(lstm_one_time_step_s16_input_tensor, output, ¶ms, &buffers); + TEST_ASSERT_EQUAL(expected, result); + TEST_ASSERT_TRUE(validate_s16(output, output_ref, output_ref_size)); +} + +void lstm_stateful_batch_major_multibatch_s16(void) +{ + int16_t output[LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_BATCH_SIZE * LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_TIME_STEPS * LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_HIDDEN_SIZE] = {0}; + const arm_cmsis_nn_status expected = ARM_CMSIS_NN_SUCCESS; + const int16_t *output_ref = &lstm_stateful_batch_major_multibatch_s16_output[0]; + const int32_t output_ref_size = LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_BATCH_SIZE * LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_TIME_STEPS * LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_HIDDEN_SIZE; + + int64_t input_data_kernel_sum[LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_HIDDEN_SIZE]; + int64_t forget_data_kernel_sum[LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_HIDDEN_SIZE]; + int64_t cell_data_kernel_sum[LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_HIDDEN_SIZE]; + int64_t output_data_kernel_sum[LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_HIDDEN_SIZE]; + + int64_t input_hidden_kernel_sum[LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_HIDDEN_SIZE]; + int64_t forget_hidden_kernel_sum[LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_HIDDEN_SIZE]; + int64_t cell_hidden_kernel_sum[LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_HIDDEN_SIZE]; + int64_t output_hidden_kernel_sum[LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_HIDDEN_SIZE]; + + arm_vector_sum_s8_s64(&input_data_kernel_sum[0], + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_INPUT_SIZE, + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_HIDDEN_SIZE, + &lstm_stateful_batch_major_multibatch_s16_input_gate_input_weights[0], + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_INPUT_ZERO_POINT, + &lstm_stateful_batch_major_multibatch_s16_input_gate_bias[0]); + arm_vector_sum_s8_s64(&forget_data_kernel_sum[0], + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_INPUT_SIZE, + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_HIDDEN_SIZE, + &lstm_stateful_batch_major_multibatch_s16_forget_gate_input_weights[0], + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_INPUT_ZERO_POINT, + &lstm_stateful_batch_major_multibatch_s16_forget_gate_bias[0]); + arm_vector_sum_s8_s64(&cell_data_kernel_sum[0], + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_INPUT_SIZE, + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_HIDDEN_SIZE, + &lstm_stateful_batch_major_multibatch_s16_cell_gate_input_weights[0], + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_INPUT_ZERO_POINT, + &lstm_stateful_batch_major_multibatch_s16_cell_gate_bias[0]); + arm_vector_sum_s8_s64(&output_data_kernel_sum[0], + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_INPUT_SIZE, + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_HIDDEN_SIZE, + &lstm_stateful_batch_major_multibatch_s16_output_gate_input_weights[0], + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_INPUT_ZERO_POINT, + &lstm_stateful_batch_major_multibatch_s16_output_gate_bias[0]); + + arm_vector_sum_s8_s64(&input_hidden_kernel_sum[0], + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_HIDDEN_SIZE, + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_HIDDEN_SIZE, + &lstm_stateful_batch_major_multibatch_s16_input_gate_hidden_weights[0], + -LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_OUTPUT_ZERO_POINT, + NULL); + arm_vector_sum_s8_s64(&forget_hidden_kernel_sum[0], + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_HIDDEN_SIZE, + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_HIDDEN_SIZE, + &lstm_stateful_batch_major_multibatch_s16_forget_gate_hidden_weights[0], + -LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_OUTPUT_ZERO_POINT, + NULL); + arm_vector_sum_s8_s64(&cell_hidden_kernel_sum[0], + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_HIDDEN_SIZE, + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_HIDDEN_SIZE, + &lstm_stateful_batch_major_multibatch_s16_cell_gate_hidden_weights[0], + -LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_OUTPUT_ZERO_POINT, + NULL); + arm_vector_sum_s8_s64(&output_hidden_kernel_sum[0], + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_HIDDEN_SIZE, + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_HIDDEN_SIZE, + &lstm_stateful_batch_major_multibatch_s16_output_gate_hidden_weights[0], + -LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_OUTPUT_ZERO_POINT, + NULL); + + // INPUT GATE + const cmsis_nn_lstm_gate gate_input = {LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_INPUT_GATE_INPUT_MULTIPLIER, + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_INPUT_GATE_INPUT_SHIFT, + &lstm_stateful_batch_major_multibatch_s16_input_gate_input_weights[0], + &input_data_kernel_sum[0], + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_INPUT_GATE_HIDDEN_MULTIPLIER, + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_INPUT_GATE_HIDDEN_SHIFT, + &lstm_stateful_batch_major_multibatch_s16_input_gate_hidden_weights[0], + &input_hidden_kernel_sum[0], + &lstm_stateful_batch_major_multibatch_s16_input_gate_bias[0], + ARM_SIGMOID}; + + // FORGET GATE + const cmsis_nn_lstm_gate gate_forget = {LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_FORGET_GATE_INPUT_MULTIPLIER, + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_FORGET_GATE_INPUT_SHIFT, + &lstm_stateful_batch_major_multibatch_s16_forget_gate_input_weights[0], + &forget_data_kernel_sum[0], + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_FORGET_GATE_HIDDEN_MULTIPLIER, + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_FORGET_GATE_HIDDEN_SHIFT, + &lstm_stateful_batch_major_multibatch_s16_forget_gate_hidden_weights[0], + &forget_hidden_kernel_sum[0], + &lstm_stateful_batch_major_multibatch_s16_forget_gate_bias[0], + ARM_SIGMOID}; + + // CELL GATE + const cmsis_nn_lstm_gate gate_cell = {LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_CELL_GATE_INPUT_MULTIPLIER, + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_CELL_GATE_INPUT_SHIFT, + &lstm_stateful_batch_major_multibatch_s16_cell_gate_input_weights[0], + &cell_data_kernel_sum[0], + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_CELL_GATE_HIDDEN_MULTIPLIER, + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_CELL_GATE_HIDDEN_SHIFT, + &lstm_stateful_batch_major_multibatch_s16_cell_gate_hidden_weights[0], + &cell_hidden_kernel_sum[0], + &lstm_stateful_batch_major_multibatch_s16_cell_gate_bias[0], + ARM_TANH}; + + // OUTPUT GATE + const cmsis_nn_lstm_gate gate_output = {LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_OUTPUT_GATE_INPUT_MULTIPLIER, + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_OUTPUT_GATE_INPUT_SHIFT, + &lstm_stateful_batch_major_multibatch_s16_output_gate_input_weights[0], + &output_data_kernel_sum[0], + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_OUTPUT_GATE_HIDDEN_MULTIPLIER, + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_OUTPUT_GATE_HIDDEN_SHIFT, + &lstm_stateful_batch_major_multibatch_s16_output_gate_hidden_weights[0], + &output_hidden_kernel_sum[0], + &lstm_stateful_batch_major_multibatch_s16_output_gate_bias[0], + ARM_SIGMOID}; + + // LSTM DATA + const cmsis_nn_lstm_params params = {LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_TIME_MAJOR, + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_BATCH_SIZE, + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_TIME_STEPS, + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_INPUT_SIZE, + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_HIDDEN_SIZE, + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_INPUT_ZERO_POINT, + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_FORGET_TO_CELL_MULTIPLIER, + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_FORGET_TO_CELL_SHIFT, + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_INPUT_TO_CELL_MULTIPLIER, + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_INPUT_TO_CELL_SHIFT, + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_CELL_CLIP, + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_CELL_SCALE_POWER, + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_OUTPUT_MULTIPLIER, + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_OUTPUT_SHIFT, + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_OUTPUT_ZERO_POINT, + gate_forget, + gate_input, + gate_cell, + gate_output}; + + // Allocate zero-initialized hidden state buffer to test the non-null path! + int16_t hidden_state[LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_BATCH_SIZE * LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_S16_HIDDEN_SIZE] = {0}; + + memset(buffer3, 0, sizeof(buffer3)); + + cmsis_nn_lstm_context buffers; + buffers.temp1 = buffer1; + buffers.temp2 = buffer2; + buffers.cell_state = buffer3; + buffers.hidden_state = hidden_state; + + arm_cmsis_nn_status result = + arm_lstm_unidirectional_s16(lstm_stateful_batch_major_multibatch_s16_input_tensor, output, ¶ms, &buffers); + TEST_ASSERT_EQUAL(expected, result); TEST_ASSERT_TRUE(validate_s16(output, output_ref, output_ref_size)); } \ No newline at end of file diff --git a/Tests/UnitTest/TestCases/test_arm_lstm_unidirectional_s8/Unity/unity_test_arm_lstm_unidirectional_s8.c b/Tests/UnitTest/TestCases/test_arm_lstm_unidirectional_s8/Unity/unity_test_arm_lstm_unidirectional_s8.c index 157c3b96..18bca23f 100644 --- a/Tests/UnitTest/TestCases/test_arm_lstm_unidirectional_s8/Unity/unity_test_arm_lstm_unidirectional_s8.c +++ b/Tests/UnitTest/TestCases/test_arm_lstm_unidirectional_s8/Unity/unity_test_arm_lstm_unidirectional_s8.c @@ -47,3 +47,4 @@ void tearDown(void) {} void test_lstm_1_arm_lstm_unidirectional_s8(void) { lstm_1(); } void test_lstm_2_arm_lstm_unidirectional_s8(void) { lstm_2(); } void test_lstm_one_time_step_arm_lstm_unidirectional_s8(void) { lstm_one_time_step(); } +void test_lstm_stateful_batch_major_multibatch_arm_lstm_unidirectional_s8(void) { lstm_stateful_batch_major_multibatch(); } diff --git a/Tests/UnitTest/TestCases/test_arm_lstm_unidirectional_s8/test_arm_lstm_unidirectional_s8.c b/Tests/UnitTest/TestCases/test_arm_lstm_unidirectional_s8/test_arm_lstm_unidirectional_s8.c index 821279af..1dbd728a 100644 --- a/Tests/UnitTest/TestCases/test_arm_lstm_unidirectional_s8/test_arm_lstm_unidirectional_s8.c +++ b/Tests/UnitTest/TestCases/test_arm_lstm_unidirectional_s8/test_arm_lstm_unidirectional_s8.c @@ -25,6 +25,7 @@ #include "../TestData/lstm_1/test_data.h" #include "../TestData/lstm_2/test_data.h" #include "../TestData/lstm_one_time_step/test_data.h" +#include "../TestData/lstm_stateful_batch_major_multibatch/test_data.h" #include "../Utils/validate.h" // update the buffer size if adding a unit test with larger buffer. @@ -182,6 +183,7 @@ void lstm_1(void) buffers.temp1 = buffer1; buffers.temp2 = buffer2; buffers.cell_state = buffer3; + buffers.hidden_state = NULL; arm_cmsis_nn_status result = arm_lstm_unidirectional_s8(lstm_1_input_tensor, output, ¶ms, &buffers); @@ -336,6 +338,7 @@ void lstm_2(void) buffers.temp1 = buffer1; buffers.temp2 = buffer2; buffers.cell_state = buffer3; + buffers.hidden_state = NULL; arm_cmsis_nn_status result = arm_lstm_unidirectional_s8(lstm_2_input_tensor, output, ¶ms, &buffers); @@ -491,9 +494,172 @@ void lstm_one_time_step(void) buffers.temp1 = buffer1; buffers.temp2 = buffer2; buffers.cell_state = buffer3; + buffers.hidden_state = NULL; arm_cmsis_nn_status result = arm_lstm_unidirectional_s8(lstm_one_time_step_input_tensor, output, ¶ms, &buffers); TEST_ASSERT_EQUAL(expected, result); TEST_ASSERT_TRUE(validate(output, output_ref, output_ref_size)); } + +void lstm_stateful_batch_major_multibatch(void) +{ + int8_t output[LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_BATCH_SIZE * LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_TIME_STEPS * LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_HIDDEN_SIZE] = {0}; + const arm_cmsis_nn_status expected = ARM_CMSIS_NN_SUCCESS; + const int8_t *output_ref = &lstm_stateful_batch_major_multibatch_output[0]; + const int32_t output_ref_size = LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_BATCH_SIZE * LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_TIME_STEPS * LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_HIDDEN_SIZE; + + int32_t input_data_kernel_sum[LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_HIDDEN_SIZE]; + int32_t forget_data_kernel_sum[LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_HIDDEN_SIZE]; + int32_t cell_data_kernel_sum[LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_HIDDEN_SIZE]; + int32_t output_data_kernel_sum[LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_HIDDEN_SIZE]; + + int32_t input_hidden_kernel_sum[LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_HIDDEN_SIZE]; + int32_t forget_hidden_kernel_sum[LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_HIDDEN_SIZE]; + int32_t cell_hidden_kernel_sum[LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_HIDDEN_SIZE]; + int32_t output_hidden_kernel_sum[LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_HIDDEN_SIZE]; + + arm_vector_sum_s8(&input_data_kernel_sum[0], + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_INPUT_SIZE, + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_HIDDEN_SIZE, + &lstm_stateful_batch_major_multibatch_input_gate_input_weights[0], + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_INPUT_ZERO_POINT, + 0, + &lstm_stateful_batch_major_multibatch_input_gate_bias[0]); + arm_vector_sum_s8(&forget_data_kernel_sum[0], + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_INPUT_SIZE, + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_HIDDEN_SIZE, + &lstm_stateful_batch_major_multibatch_forget_gate_input_weights[0], + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_INPUT_ZERO_POINT, + 0, + &lstm_stateful_batch_major_multibatch_forget_gate_bias[0]); + arm_vector_sum_s8(&cell_data_kernel_sum[0], + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_INPUT_SIZE, + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_HIDDEN_SIZE, + &lstm_stateful_batch_major_multibatch_cell_gate_input_weights[0], + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_INPUT_ZERO_POINT, + 0, + &lstm_stateful_batch_major_multibatch_cell_gate_bias[0]); + arm_vector_sum_s8(&output_data_kernel_sum[0], + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_INPUT_SIZE, + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_HIDDEN_SIZE, + &lstm_stateful_batch_major_multibatch_output_gate_input_weights[0], + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_INPUT_ZERO_POINT, + 0, + &lstm_stateful_batch_major_multibatch_output_gate_bias[0]); + + arm_vector_sum_s8(&input_hidden_kernel_sum[0], + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_HIDDEN_SIZE, + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_HIDDEN_SIZE, + &lstm_stateful_batch_major_multibatch_input_gate_hidden_weights[0], + -LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_OUTPUT_ZERO_POINT, + 0, + NULL); + arm_vector_sum_s8(&forget_hidden_kernel_sum[0], + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_HIDDEN_SIZE, + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_HIDDEN_SIZE, + &lstm_stateful_batch_major_multibatch_forget_gate_hidden_weights[0], + -LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_OUTPUT_ZERO_POINT, + 0, + NULL); + arm_vector_sum_s8(&cell_hidden_kernel_sum[0], + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_HIDDEN_SIZE, + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_HIDDEN_SIZE, + &lstm_stateful_batch_major_multibatch_cell_gate_hidden_weights[0], + -LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_OUTPUT_ZERO_POINT, + 0, + NULL); + arm_vector_sum_s8(&output_hidden_kernel_sum[0], + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_HIDDEN_SIZE, + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_HIDDEN_SIZE, + &lstm_stateful_batch_major_multibatch_output_gate_hidden_weights[0], + -LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_OUTPUT_ZERO_POINT, + 0, + NULL); + + // INPUT GATE + const cmsis_nn_lstm_gate gate_input = {LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_INPUT_GATE_INPUT_MULTIPLIER, + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_INPUT_GATE_INPUT_SHIFT, + &lstm_stateful_batch_major_multibatch_input_gate_input_weights[0], + &input_data_kernel_sum[0], + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_INPUT_GATE_HIDDEN_MULTIPLIER, + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_INPUT_GATE_HIDDEN_SHIFT, + &lstm_stateful_batch_major_multibatch_input_gate_hidden_weights[0], + &input_hidden_kernel_sum[0], + &lstm_stateful_batch_major_multibatch_input_gate_bias[0], + ARM_SIGMOID}; + + // FORGET GATE + const cmsis_nn_lstm_gate gate_forget = {LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_FORGET_GATE_INPUT_MULTIPLIER, + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_FORGET_GATE_INPUT_SHIFT, + &lstm_stateful_batch_major_multibatch_forget_gate_input_weights[0], + &forget_data_kernel_sum[0], + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_FORGET_GATE_HIDDEN_MULTIPLIER, + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_FORGET_GATE_HIDDEN_SHIFT, + &lstm_stateful_batch_major_multibatch_forget_gate_hidden_weights[0], + &forget_hidden_kernel_sum[0], + &lstm_stateful_batch_major_multibatch_forget_gate_bias[0], + ARM_SIGMOID}; + + // CELL GATE + const cmsis_nn_lstm_gate gate_cell = {LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_CELL_GATE_INPUT_MULTIPLIER, + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_CELL_GATE_INPUT_SHIFT, + &lstm_stateful_batch_major_multibatch_cell_gate_input_weights[0], + &cell_data_kernel_sum[0], + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_CELL_GATE_HIDDEN_MULTIPLIER, + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_CELL_GATE_HIDDEN_SHIFT, + &lstm_stateful_batch_major_multibatch_cell_gate_hidden_weights[0], + &cell_hidden_kernel_sum[0], + &lstm_stateful_batch_major_multibatch_cell_gate_bias[0], + ARM_TANH}; + + // OUTPUT GATE + const cmsis_nn_lstm_gate gate_output = {LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_OUTPUT_GATE_INPUT_MULTIPLIER, + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_OUTPUT_GATE_INPUT_SHIFT, + &lstm_stateful_batch_major_multibatch_output_gate_input_weights[0], + &output_data_kernel_sum[0], + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_OUTPUT_GATE_HIDDEN_MULTIPLIER, + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_OUTPUT_GATE_HIDDEN_SHIFT, + &lstm_stateful_batch_major_multibatch_output_gate_hidden_weights[0], + &output_hidden_kernel_sum[0], + &lstm_stateful_batch_major_multibatch_output_gate_bias[0], + ARM_SIGMOID}; + + // LSTM DATA + const cmsis_nn_lstm_params params = {LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_TIME_MAJOR, + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_BATCH_SIZE, + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_TIME_STEPS, + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_INPUT_SIZE, + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_HIDDEN_SIZE, + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_INPUT_ZERO_POINT, + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_FORGET_TO_CELL_MULTIPLIER, + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_FORGET_TO_CELL_SHIFT, + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_INPUT_TO_CELL_MULTIPLIER, + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_INPUT_TO_CELL_SHIFT, + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_CELL_CLIP, + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_CELL_SCALE_POWER, + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_OUTPUT_MULTIPLIER, + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_OUTPUT_SHIFT, + LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_OUTPUT_ZERO_POINT, + gate_forget, + gate_input, + gate_cell, + gate_output}; + + // Allocate hidden state buffer and initialize to zero point (representing real 0.0) + int8_t hidden_state[LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_BATCH_SIZE * LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_HIDDEN_SIZE]; + memset(hidden_state, LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_OUTPUT_ZERO_POINT, sizeof(hidden_state)); + + memset(buffer3, 0, sizeof(buffer3)); + + cmsis_nn_lstm_context buffers; + buffers.temp1 = buffer1; + buffers.temp2 = buffer2; + buffers.cell_state = buffer3; + buffers.hidden_state = hidden_state; + + arm_cmsis_nn_status result = arm_lstm_unidirectional_s8(lstm_stateful_batch_major_multibatch_input_tensor, output, ¶ms, &buffers); + + TEST_ASSERT_EQUAL(expected, result); + TEST_ASSERT_TRUE(validate(output, output_ref, output_ref_size)); +} diff --git a/Tests/UnitTest/build_and_run_tests.sh b/Tests/UnitTest/build_and_run_tests.sh index 623970a9..8a1c0e97 100755 --- a/Tests/UnitTest/build_and_run_tests.sh +++ b/Tests/UnitTest/build_and_run_tests.sh @@ -76,6 +76,8 @@ do esac done +SCRIPT_DIR=$(cd "$(dirname "$0")" && pwd -P) + Setup_Environment() { set -e echo "++ Downloading Corstone300" @@ -125,7 +127,7 @@ Setup_Environment() { if [[ -d ${WORKING_DIR}/ethos-u-core-platform ]]; then echo "Ethos-U core platform already installed. If you wish to install a new version, please delete the old folder." else - git clone https://review.mlplatform.org/ml/ethos-u/ethos-u-core-platform + git clone https://gitlab.arm.com/artificial-intelligence/ethos-u/ethos-u-core-platform fi echo "++ Setting up python environment" @@ -134,7 +136,7 @@ Setup_Environment() { else python3 -m venv cmsis_nn_venv source cmsis_nn_venv/bin/activate - pip3 install -r ../requirements.txt + pip3 install -r ${SCRIPT_DIR}/requirements.txt deactivate fi } @@ -192,6 +194,7 @@ fi mkdir -p downloads pushd downloads +cd $(pwd -P) WORKING_DIR=$(pwd) if [[ ${SETUP_ENVIRONMENT} -eq 1 ]]; then