Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions Include/arm_nn_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
* Description: Public header file to contain the CMSIS-NN structs for the
* TensorFlowLite micro compliant functions
*
* $Date: 21 Oct 2024
* $Revision: V.3.5.0
* $Date: 21 May 2026
* $Revision: V.3.5.1
*
* Target : Arm(R) M-Profile Architecture
* -------------------------------------------------------------------- */
Expand Down Expand Up @@ -272,6 +272,7 @@ typedef struct
void *temp1;
void *temp2;
void *cell_state;
void *hidden_state;
} cmsis_nn_lstm_context;

/**
Expand Down
62 changes: 48 additions & 14 deletions Source/LSTMFunctions/arm_lstm_unidirectional_s16.c
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
* Title: arm_lstm_unidirectional_s16.c
* Description: S16 LSTM function with S16 gate output
*
* $Date: 26 March 2024
* $Revision: V.1.0.0
* $Date: 21 May 2026
* $Revision: V.1.0.1
*
* Target Processor: Cortex-M processors
*
Expand Down Expand Up @@ -52,8 +52,13 @@ arm_cmsis_nn_status arm_lstm_unidirectional_s16(const int16_t *input,
cmsis_nn_lstm_context *buffers)
{

int16_t *hidden_in = NULL;
memset(buffers->cell_state, 0, params->batch_size * params->hidden_size * sizeof(int16_t));
int16_t *hidden_in = (int16_t *)buffers->hidden_state;

if (buffers->hidden_state == NULL)
{
memset(buffers->cell_state, 0, params->batch_size * params->hidden_size * sizeof(int16_t));
}

if (params->time_major)
{
// First dimension is time, input/output for each time step is stored continously in memory
Expand All @@ -69,22 +74,51 @@ arm_cmsis_nn_status arm_lstm_unidirectional_s16(const int16_t *input,
// Output is used as recurrent input/hidden state for the next timestep.
hidden_in = &hidden_out[0];
}

if (buffers->hidden_state != NULL && params->time_steps > 0)
{
memcpy(buffers->hidden_state, hidden_in, params->batch_size * params->hidden_size * sizeof(int16_t));
}
}
else
{
// First dimension is time, add batch_offset to jump in memory for each batch
for (int t = 0; t < params->time_steps; t++)
// Batch major: [batch, time, size]
// arm_nn_lstm_step_s16 expects data_in and hidden_in to have the same batch_offset.
// Since the initial hidden_state is contiguous, we must process one batch at a time.
cmsis_nn_lstm_params step_params = *params;
step_params.batch_size = 1;

for (int b = 0; b < params->batch_size; b++)
{
const int16_t *data_in = input + (t * params->input_size);
int16_t *hidden_out = output + (t * params->hidden_size);
arm_cmsis_nn_status status =
arm_nn_lstm_step_s16(data_in, hidden_in, hidden_out, params, buffers, params->time_steps);
if (status != ARM_CMSIS_NN_SUCCESS)
int16_t *step_hidden_in = (buffers->hidden_state != NULL)
? ((int16_t *)buffers->hidden_state + b * params->hidden_size)
: NULL;

cmsis_nn_lstm_context step_buffers = *buffers;
step_buffers.cell_state = (int16_t *)buffers->cell_state + b * params->hidden_size;

for (int t = 0; t < params->time_steps; t++)
{
return status;
const int16_t *data_in = input + (b * params->time_steps + t) * params->input_size;
int16_t *hidden_out = output + (b * params->time_steps + t) * params->hidden_size;

arm_cmsis_nn_status status = arm_nn_lstm_step_s16(
data_in, step_hidden_in, hidden_out, &step_params, &step_buffers, 1);

if (status != ARM_CMSIS_NN_SUCCESS)
{
return status;
}

step_hidden_in = hidden_out;
}

if (buffers->hidden_state != NULL && params->time_steps > 0)
{
memcpy((int16_t *)buffers->hidden_state + b * params->hidden_size,
step_hidden_in,
params->hidden_size * sizeof(int16_t));
}
// Output is used as recurrent input/hidden state for the next timestep.
hidden_in = &hidden_out[0];
}
}
return ARM_CMSIS_NN_SUCCESS;
Expand Down
62 changes: 48 additions & 14 deletions Source/LSTMFunctions/arm_lstm_unidirectional_s8.c
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
* Title: arm_lstm_unidirectional_s8.c
* Description: S8 LSTM function with S16 gate output
*
* $Date: 08 February 2024
* $Revision: V.1.1.0
* $Date: 21 May 2026
* $Revision: V.1.1.1
*
* Target Processor: Cortex-M processors
*
Expand Down Expand Up @@ -52,8 +52,13 @@ arm_cmsis_nn_status arm_lstm_unidirectional_s8(const int8_t *input,
cmsis_nn_lstm_context *buffers)
{

int8_t *hidden_in = NULL;
memset(buffers->cell_state, 0, params->batch_size * params->hidden_size * sizeof(int16_t));
int8_t *hidden_in = (int8_t *)buffers->hidden_state;

if (buffers->hidden_state == NULL)
{
memset(buffers->cell_state, 0, params->batch_size * params->hidden_size * sizeof(int16_t));
}

if (params->time_major)
{
// First dimension is time, input/output for each time step is stored continously in memory
Expand All @@ -69,22 +74,51 @@ arm_cmsis_nn_status arm_lstm_unidirectional_s8(const int8_t *input,
// Output is used as recurrent input/hidden state for the next timestep.
hidden_in = &hidden_out[0];
}

if (buffers->hidden_state != NULL && params->time_steps > 0)
{
memcpy(buffers->hidden_state, hidden_in, params->batch_size * params->hidden_size * sizeof(int8_t));
}
}
else
{
// First dimension is time, add batch_offset to jump in memory for each batch
for (int t = 0; t < params->time_steps; t++)
// Batch major: [batch, time, size]
// arm_nn_lstm_step_s8 expects data_in and hidden_in to have the same batch_offset.
// Since the initial hidden_state is contiguous, we must process one batch at a time.
cmsis_nn_lstm_params step_params = *params;
step_params.batch_size = 1;

for (int b = 0; b < params->batch_size; b++)
{
const int8_t *data_in = input + (t * params->input_size);
int8_t *hidden_out = output + (t * params->hidden_size);
arm_cmsis_nn_status status =
arm_nn_lstm_step_s8(data_in, hidden_in, hidden_out, params, buffers, params->time_steps);
if (status != ARM_CMSIS_NN_SUCCESS)
int8_t *step_hidden_in = (buffers->hidden_state != NULL)
? ((int8_t *)buffers->hidden_state + b * params->hidden_size)
: NULL;

cmsis_nn_lstm_context step_buffers = *buffers;
step_buffers.cell_state = (int16_t *)buffers->cell_state + b * params->hidden_size;

for (int t = 0; t < params->time_steps; t++)
{
return status;
const int8_t *data_in = input + (b * params->time_steps + t) * params->input_size;
int8_t *hidden_out = output + (b * params->time_steps + t) * params->hidden_size;

arm_cmsis_nn_status status = arm_nn_lstm_step_s8(
data_in, step_hidden_in, hidden_out, &step_params, &step_buffers, 1);

if (status != ARM_CMSIS_NN_SUCCESS)
{
return status;
}

step_hidden_in = hidden_out;
}

if (buffers->hidden_state != NULL && params->time_steps > 0)
{
memcpy((int8_t *)buffers->hidden_state + b * params->hidden_size,
step_hidden_in,
params->hidden_size * sizeof(int8_t));
}
// Output is used as recurrent input/hidden state for the next timestep.
hidden_in = &hidden_out[0];
}
}
return ARM_CMSIS_NN_SUCCESS;
Expand Down
15 changes: 15 additions & 0 deletions Tests/UnitTest/RefactoredTestGen/test_plan.json
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,14 @@
"input_size" : 22,
"hidden_size" : 3,
"json_template": "lstm_s16.json"
},
{"name" : "lstm_stateful_batch_major_multibatch_s16",
"time_major" : false,
"batch_size" : 2,
"time_steps" : 2,
"input_size" : 6,
"hidden_size" : 7,
"json_template": "lstm_s16.json"
}
]
},
Expand Down Expand Up @@ -574,6 +582,13 @@
"time_steps" : 1,
"input_size" : 22,
"hidden_size" : 3
},
{"name" : "lstm_stateful_batch_major_multibatch",
"time_major" : false,
"batch_size" : 2,
"time_steps" : 2,
"input_size" : 6,
"hidden_size" : 7
}
]
},
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
// Generated by generate_test_data.py using tensorflow version 2.18.0 (Keras version 3.14.1).
// Interpreter from tflite_micro runtime version 0.dev20260203175027-g88f8587.
#pragma once
#include <stdint.h>

const int32_t lstm_stateful_batch_major_multibatch_cell_gate_bias[7] = {0, 0, 0, 0, 0, 0, 0};
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// Generated by generate_test_data.py using tensorflow version 2.18.0 (Keras version 3.14.1).
// Interpreter from tflite_micro runtime version 0.dev20260203175027-g88f8587.
#pragma once
#include <stdint.h>

const int8_t lstm_stateful_batch_major_multibatch_cell_gate_hidden_weights[49] = {
-20, -88, 87, 87, 109, -54, -12, 21, 2, -112, 44, -79, -97, -15, 123, 105, -122,
-29, -83, 36, 58, 33, 59, -115, 127, -106, 101, -57, -97, -64, -39, 71, 4, -114,
-94, 74, 34, -12, -118, -64, 104, 102, -36, -114, 117, 95, -1, 67, 81};
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
// Generated by generate_test_data.py using tensorflow version 2.18.0 (Keras version 3.14.1).
// Interpreter from tflite_micro runtime version 0.dev20260203175027-g88f8587.
#pragma once
#include <stdint.h>

const int8_t lstm_stateful_batch_major_multibatch_cell_gate_input_weights[42] = {
92, 100, -29, 127, 76, -74, 115, -81, 63, 4, 69, -81, 8, -25, 42, 99, 44, 101, 12, -25, 99,
-70, -88, -41, 107, 65, 67, -31, 87, -54, -104, 95, 35, 21, 125, -87, 27, 78, -113, -114, 61, -101};
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Generated by generate_test_data.py using tensorflow version 2.18.0 (Keras version 3.14.1).
// Interpreter from tflite_micro runtime version 0.dev20260203175027-g88f8587.
#pragma once

#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_TIME_MAJOR false
#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_BATCH_SIZE 2
#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_TIME_STEPS 2
#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_INPUT_SIZE 6
#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_HIDDEN_SIZE 7
#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_CELL_SCALE_POWER -15
#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_INPUT_ZERO_POINT 128
#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_OUTPUT_ZERO_POINT 4
#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_CELL_CLIP 32767
#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_FORGET_TO_CELL_MULTIPLIER 1073741824
#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_FORGET_TO_CELL_SHIFT -14
#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_INPUT_TO_CELL_MULTIPLIER 1073741824
#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_INPUT_TO_CELL_SHIFT -14
#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_OUTPUT_MULTIPLIER 1993694592
#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_OUTPUT_SHIFT -21
#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_OUTPUT_GATE_HIDDEN_MULTIPLIER 1143723136
#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_OUTPUT_GATE_HIDDEN_SHIFT -3
#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_CELL_GATE_HIDDEN_MULTIPLIER 1164696448
#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_CELL_GATE_HIDDEN_SHIFT -3
#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_FORGET_GATE_HIDDEN_MULTIPLIER 1134438656
#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_FORGET_GATE_HIDDEN_SHIFT -3
#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_INPUT_GATE_HIDDEN_MULTIPLIER 2044599040
#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_INPUT_GATE_HIDDEN_SHIFT -4
#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_OUTPUT_GATE_INPUT_MULTIPLIER 2130456576
#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_OUTPUT_GATE_INPUT_SHIFT -3
#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_CELL_GATE_INPUT_MULTIPLIER 2096726656
#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_CELL_GATE_INPUT_SHIFT -3
#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_FORGET_GATE_INPUT_MULTIPLIER 2025295488
#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_FORGET_GATE_INPUT_SHIFT -3
#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_INPUT_GATE_INPUT_MULTIPLIER 2146124928
#define LSTM_STATEFUL_BATCH_MAJOR_MULTIBATCH_INPUT_GATE_INPUT_SHIFT -3
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
// Generated by generate_test_data.py using tensorflow version 2.18.0 (Keras version 3.14.1).
// Interpreter from tflite_micro runtime version 0.dev20260203175027-g88f8587.
#pragma once
#include <stdint.h>

const int32_t lstm_stateful_batch_major_multibatch_forget_gate_bias[7] = {0, 0, 0, 0, 0, 0, 0};
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// Generated by generate_test_data.py using tensorflow version 2.18.0 (Keras version 3.14.1).
// Interpreter from tflite_micro runtime version 0.dev20260203175027-g88f8587.
#pragma once
#include <stdint.h>

const int8_t lstm_stateful_batch_major_multibatch_forget_gate_hidden_weights[49] = {
8, -96, 106, 66, -101, 22, -70, 86, -37, -1, -127, 52, 9, 79, 111, -94, 126,
-21, -25, -79, 42, -57, -42, -3, 126, 51, -49, -28, -10, 50, -104, -48, -11, -78,
36, 121, 6, -56, -24, -75, -104, -103, -119, -63, -69, -51, 11, -43, 19};
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
// Generated by generate_test_data.py using tensorflow version 2.18.0 (Keras version 3.14.1).
// Interpreter from tflite_micro runtime version 0.dev20260203175027-g88f8587.
#pragma once
#include <stdint.h>

const int8_t lstm_stateful_batch_major_multibatch_forget_gate_input_weights[42] = {
58, -87, 92, 94, -41, -61, 31, -86, 72, 3, 116, 68, -81, -55, -44, -122, -105, -87, -23, 94, -113,
-68, 23, 45, -96, 13, 103, -116, 87, 88, -2, 6, 58, -126, 33, 28, -113, 68, -30, -84, 127, -33};
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
// Generated by generate_test_data.py using tensorflow version 2.18.0 (Keras version 3.14.1).
// Interpreter from tflite_micro runtime version 0.dev20260203175027-g88f8587.
#pragma once
#include <stdint.h>

const int32_t lstm_stateful_batch_major_multibatch_input_gate_bias[7] = {0, 0, 0, 0, 0, 0, 0};
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// Generated by generate_test_data.py using tensorflow version 2.18.0 (Keras version 3.14.1).
// Interpreter from tflite_micro runtime version 0.dev20260203175027-g88f8587.
#pragma once
#include <stdint.h>

const int8_t lstm_stateful_batch_major_multibatch_input_gate_hidden_weights[49] = {
-120, -121, -18, 3, 74, 27, -107, 84, 43, 127, 91, 116, 77, -17, -8, 32, 89,
111, 85, -8, 9, -115, -36, 71, 70, -10, 70, -106, -9, -71, -83, -79, -73, -88,
39, -124, -1, -2, 108, 117, -27, 61, -90, -32, -105, 97, 41, -98, 57};
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
// Generated by generate_test_data.py using tensorflow version 2.18.0 (Keras version 3.14.1).
// Interpreter from tflite_micro runtime version 0.dev20260203175027-g88f8587.
#pragma once
#include <stdint.h>

const int8_t lstm_stateful_batch_major_multibatch_input_gate_input_weights[42] = {
-21, -53, 56, 127, -88, 66, -111, 94, -19, 90, 124, -67, 15, 3, 85, -34, 37, 65, 93, -86, -72,
-30, -121, -122, 5, -54, -83, -111, -6, 49, 99, -67, 0, 45, 85, -16, 54, 120, 36, 79, -119, -4};
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
// Generated by generate_test_data.py using tensorflow version 2.18.0 (Keras version 3.14.1).
// Interpreter from tflite_micro runtime version 0.dev20260203175027-g88f8587.
#pragma once
#include <stdint.h>

const int8_t lstm_stateful_batch_major_multibatch_input_tensor[24] = {
-62, -21, 64, 39, 55, -93, -122, -119, 14, -86, -125, 79, -67, -83, 65, -111, 59, -111, 66, -115, -86, 3, 118, 75};
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
// Generated by generate_test_data.py using tensorflow version 2.18.0 (Keras version 3.14.1).
// Interpreter from tflite_micro runtime version 0.dev20260203175027-g88f8587.
#pragma once
#include <stdint.h>

const int8_t lstm_stateful_batch_major_multibatch_output[28] = {72, 68, 89, -16, 29, 94, -68, 7, 15, 102,
7, 2, 30, -105, 39, 90, 65, 5, 52, 115,
-27, 54, 105, 98, -48, 32, 78, -23};
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
// Generated by generate_test_data.py using tensorflow version 2.18.0 (Keras version 3.14.1).
// Interpreter from tflite_micro runtime version 0.dev20260203175027-g88f8587.
#pragma once
#include <stdint.h>

const int32_t lstm_stateful_batch_major_multibatch_output_gate_bias[7] = {0, 0, 0, 0, 0, 0, 0};
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// Generated by generate_test_data.py using tensorflow version 2.18.0 (Keras version 3.14.1).
// Interpreter from tflite_micro runtime version 0.dev20260203175027-g88f8587.
#pragma once
#include <stdint.h>

const int8_t lstm_stateful_batch_major_multibatch_output_gate_hidden_weights[49] = {
29, -109, 63, -104, -1, -60, 94, -95, -65, 46, 52, -10, 50, -59, -29, 30, -106,
108, 7, 88, -76, 5, 85, -121, 127, 35, 63, -97, -12, 84, -112, 109, -89, 19,
-41, 74, 83, 56, 53, 93, -127, -25, -104, 37, -75, 7, -10, 72, -111};
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
// Generated by generate_test_data.py using tensorflow version 2.18.0 (Keras version 3.14.1).
// Interpreter from tflite_micro runtime version 0.dev20260203175027-g88f8587.
#pragma once
#include <stdint.h>

const int8_t lstm_stateful_batch_major_multibatch_output_gate_input_weights[42] = {
-54, 67, -56, -37, -72, -6, -36, 1, 30, -117, -51, 94, -11, -95, -16, 52, -37, -99, -1, 66, -31,
73, 52, 11, -93, 34, 9, -84, -77, -1, 94, -34, -61, -122, 115, -107, 2, -35, 127, -76, -75, -79};
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#include "cell_gate_bias.h"
#include "cell_gate_hidden_weights.h"
#include "cell_gate_input_weights.h"
#include "config_data.h"
#include "forget_gate_bias.h"
#include "forget_gate_hidden_weights.h"
#include "forget_gate_input_weights.h"
#include "input_gate_bias.h"
#include "input_gate_hidden_weights.h"
#include "input_gate_input_weights.h"
#include "input_tensor.h"
#include "output.h"
#include "output_gate_bias.h"
#include "output_gate_hidden_weights.h"
#include "output_gate_input_weights.h"
Loading