diff --git a/tensorflow/lite/micro/kernels/cmsis_nn/unidirectional_sequence_lstm.cc b/tensorflow/lite/micro/kernels/cmsis_nn/unidirectional_sequence_lstm.cc index 9c27cc45ffd..76a7fc53223 100644 --- a/tensorflow/lite/micro/kernels/cmsis_nn/unidirectional_sequence_lstm.cc +++ b/tensorflow/lite/micro/kernels/cmsis_nn/unidirectional_sequence_lstm.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include "Include/arm_nnfunctions.h" +#include "Include/arm_nnsupportfunctions.h" #include "tensorflow/lite/kernels/internal/quantization_util.h" #include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/micro/kernels/fully_connected.h" @@ -270,7 +271,7 @@ TfLiteStatus CMSIS_NN_PortOpData(TfLiteContext* context, OpDataLSTM* params_ref, } TfLiteStatus CMSIS_NN_EvalInteger8x8_16Lstm( - const OpData& op_data, const LSTMKernelContents& kernel_content, + const OpData& op_data, LSTMKernelContents& kernel_content, const LSTMBuffers& buffers) { TFLITE_DCHECK( kernel_content.GetInternalTensor(tflite::kLstmInputTensor)->dims->size >= @@ -282,21 +283,74 @@ TfLiteStatus CMSIS_NN_EvalInteger8x8_16Lstm( kernel_content.GetInternalTensor(tflite::kLstmInputTensor)); int8_t* output = tflite::micro::GetTensorData(kernel_content.output_tensor); + int8_t* hidden_state = + tflite::micro::GetTensorData(kernel_content.HiddenStateTensor()); + int16_t* cell_state = + tflite::micro::GetTensorData(kernel_content.CellStateTensor()); // Create lstm buffer struct cmsis_nn_lstm_context cmsis_buffers; cmsis_buffers.temp1 = reinterpret_cast(buffers.buffer0); cmsis_buffers.temp2 = reinterpret_cast(buffers.buffer1); - cmsis_buffers.cell_state = reinterpret_cast(buffers.buffer2); - - arm_lstm_unidirectional_s8(input, output, &op_data.params_cmsis_nn, - &cmsis_buffers); + cmsis_buffers.cell_state = cell_state; + + const auto& params = op_data.params_cmsis_nn; + +#ifdef CMSIS_NN_STATEFUL_LSTM + cmsis_buffers.hidden_state = hidden_state; + arm_cmsis_nn_status status = + arm_lstm_unidirectional_s8(input, output, ¶ms, &cmsis_buffers); + if (status != ARM_CMSIS_NN_SUCCESS) return kTfLiteError; +#else + if (params.time_major) { + int8_t* step_hidden_in = hidden_state; + for (int t = 0; t < params.time_steps; t++) { + const int8_t* data_in = + input + (t * params.batch_size * params.input_size); + int8_t* hidden_out = + output + (t * params.batch_size * params.hidden_size); + + arm_cmsis_nn_status status = arm_nn_lstm_step_s8( + data_in, step_hidden_in, hidden_out, ¶ms, &cmsis_buffers, 1); + if (status != ARM_CMSIS_NN_SUCCESS) return kTfLiteError; + step_hidden_in = hidden_out; + } + if (params.time_steps > 0) { + std::copy_n(step_hidden_in, params.batch_size * params.hidden_size, + hidden_state); + } + } else { + cmsis_nn_lstm_params step_params = params; + step_params.batch_size = 1; + for (int b = 0; b < params.batch_size; b++) { + int8_t* step_hidden_in = hidden_state + b * params.hidden_size; + cmsis_buffers.cell_state = cell_state + b * params.hidden_size; + + for (int t = 0; t < params.time_steps; t++) { + const int8_t* data_in = + input + (b * params.time_steps + t) * params.input_size; + int8_t* hidden_out = + output + (b * params.time_steps + t) * params.hidden_size; + + arm_cmsis_nn_status status = + arm_nn_lstm_step_s8(data_in, step_hidden_in, hidden_out, + &step_params, &cmsis_buffers, 1); + if (status != ARM_CMSIS_NN_SUCCESS) return kTfLiteError; + step_hidden_in = hidden_out; + } + if (params.time_steps > 0) { + std::copy_n(step_hidden_in, params.hidden_size, + hidden_state + b * params.hidden_size); + } + } + } +#endif return kTfLiteOk; } TfLiteStatus CMSIS_NN_EvalInteger16x8_16Lstm( - const OpData& op_data, const LSTMKernelContents& kernel_content, + const OpData& op_data, LSTMKernelContents& kernel_content, const LSTMBuffers& buffers) { TFLITE_DCHECK( kernel_content.GetInternalTensor(tflite::kLstmInputTensor)->dims->size >= @@ -308,15 +362,63 @@ TfLiteStatus CMSIS_NN_EvalInteger16x8_16Lstm( kernel_content.GetInternalTensor(tflite::kLstmInputTensor)); int16_t* output = tflite::micro::GetTensorData(kernel_content.output_tensor); + int16_t* hidden_state = + tflite::micro::GetTensorData(kernel_content.HiddenStateTensor()); + int16_t* cell_state = + tflite::micro::GetTensorData(kernel_content.CellStateTensor()); // Create lstm buffer struct cmsis_nn_lstm_context cmsis_buffers; cmsis_buffers.temp1 = reinterpret_cast(buffers.buffer0); cmsis_buffers.temp2 = reinterpret_cast(buffers.buffer1); - cmsis_buffers.cell_state = reinterpret_cast(buffers.buffer2); - - arm_lstm_unidirectional_s16(input, output, &op_data.params_cmsis_nn, - &cmsis_buffers); + cmsis_buffers.cell_state = cell_state; + + const auto& params = op_data.params_cmsis_nn; + +#ifdef CMSIS_NN_STATEFUL_LSTM + cmsis_buffers.hidden_state = hidden_state; + arm_cmsis_nn_status status = + arm_lstm_unidirectional_s16(input, output, ¶ms, &cmsis_buffers); + if (status != ARM_CMSIS_NN_SUCCESS) return kTfLiteError; +#else + if (params.time_major) { + for (int t = 0; t < params.time_steps; t++) { + const int16_t* data_in = + input + (t * params.batch_size * params.input_size); + int16_t* hidden_out = + output + (t * params.batch_size * params.hidden_size); + + arm_cmsis_nn_status status = arm_nn_lstm_step_s16( + data_in, hidden_state, hidden_out, ¶ms, &cmsis_buffers, 1); + if (status != ARM_CMSIS_NN_SUCCESS) return kTfLiteError; + + // Update hidden state for next step + std::copy_n(hidden_out, params.batch_size * params.hidden_size, + hidden_state); + } + } else { + cmsis_nn_lstm_params step_params = params; + step_params.batch_size = 1; + for (int b = 0; b < params.batch_size; b++) { + for (int t = 0; t < params.time_steps; t++) { + const int16_t* data_in = + input + (b * params.time_steps + t) * params.input_size; + int16_t* hidden_out = + output + (b * params.time_steps + t) * params.hidden_size; + int16_t* current_hidden = hidden_state + b * params.hidden_size; + cmsis_buffers.cell_state = cell_state + b * params.hidden_size; + + arm_cmsis_nn_status status = + arm_nn_lstm_step_s16(data_in, current_hidden, hidden_out, + &step_params, &cmsis_buffers, 1); + if (status != ARM_CMSIS_NN_SUCCESS) return kTfLiteError; + + // Update hidden state for next step + std::copy_n(hidden_out, params.hidden_size, current_hidden); + } + } + } +#endif return kTfLiteOk; }