Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ limitations under the License.
#include <limits>

#include "Include/arm_nnfunctions.h"
#include "Include/arm_nnsupportfunctions.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/micro/kernels/fully_connected.h"
Expand Down Expand Up @@ -270,7 +271,7 @@ TfLiteStatus CMSIS_NN_PortOpData(TfLiteContext* context, OpDataLSTM* params_ref,
}

TfLiteStatus CMSIS_NN_EvalInteger8x8_16Lstm(
const OpData& op_data, const LSTMKernelContents& kernel_content,
const OpData& op_data, LSTMKernelContents& kernel_content,
const LSTMBuffers<int16_t>& buffers) {
TFLITE_DCHECK(
kernel_content.GetInternalTensor(tflite::kLstmInputTensor)->dims->size >=
Expand All @@ -282,21 +283,74 @@ TfLiteStatus CMSIS_NN_EvalInteger8x8_16Lstm(
kernel_content.GetInternalTensor(tflite::kLstmInputTensor));
int8_t* output =
tflite::micro::GetTensorData<int8_t>(kernel_content.output_tensor);
int8_t* hidden_state =
tflite::micro::GetTensorData<int8_t>(kernel_content.HiddenStateTensor());
int16_t* cell_state =
tflite::micro::GetTensorData<int16_t>(kernel_content.CellStateTensor());

// Create lstm buffer struct
cmsis_nn_lstm_context cmsis_buffers;
cmsis_buffers.temp1 = reinterpret_cast<int16_t*>(buffers.buffer0);
cmsis_buffers.temp2 = reinterpret_cast<int16_t*>(buffers.buffer1);
cmsis_buffers.cell_state = reinterpret_cast<int16_t*>(buffers.buffer2);

arm_lstm_unidirectional_s8(input, output, &op_data.params_cmsis_nn,
&cmsis_buffers);
cmsis_buffers.cell_state = cell_state;

const auto& params = op_data.params_cmsis_nn;

#ifdef CMSIS_NN_STATEFUL_LSTM
cmsis_buffers.hidden_state = hidden_state;
arm_cmsis_nn_status status =
arm_lstm_unidirectional_s8(input, output, &params, &cmsis_buffers);
if (status != ARM_CMSIS_NN_SUCCESS) return kTfLiteError;
#else
if (params.time_major) {
int8_t* step_hidden_in = hidden_state;
for (int t = 0; t < params.time_steps; t++) {
const int8_t* data_in =
input + (t * params.batch_size * params.input_size);
int8_t* hidden_out =
output + (t * params.batch_size * params.hidden_size);

arm_cmsis_nn_status status = arm_nn_lstm_step_s8(
data_in, step_hidden_in, hidden_out, &params, &cmsis_buffers, 1);
if (status != ARM_CMSIS_NN_SUCCESS) return kTfLiteError;
step_hidden_in = hidden_out;
}
if (params.time_steps > 0) {
std::copy_n(step_hidden_in, params.batch_size * params.hidden_size,
hidden_state);
}
} else {
cmsis_nn_lstm_params step_params = params;
step_params.batch_size = 1;
for (int b = 0; b < params.batch_size; b++) {
int8_t* step_hidden_in = hidden_state + b * params.hidden_size;
cmsis_buffers.cell_state = cell_state + b * params.hidden_size;

for (int t = 0; t < params.time_steps; t++) {
const int8_t* data_in =
input + (b * params.time_steps + t) * params.input_size;
int8_t* hidden_out =
output + (b * params.time_steps + t) * params.hidden_size;

arm_cmsis_nn_status status =
arm_nn_lstm_step_s8(data_in, step_hidden_in, hidden_out,
&step_params, &cmsis_buffers, 1);
if (status != ARM_CMSIS_NN_SUCCESS) return kTfLiteError;
step_hidden_in = hidden_out;
}
if (params.time_steps > 0) {
std::copy_n(step_hidden_in, params.hidden_size,
hidden_state + b * params.hidden_size);
}
}
}
#endif

return kTfLiteOk;
}

TfLiteStatus CMSIS_NN_EvalInteger16x8_16Lstm(
const OpData& op_data, const LSTMKernelContents& kernel_content,
const OpData& op_data, LSTMKernelContents& kernel_content,
const LSTMBuffers<int16_t>& buffers) {
TFLITE_DCHECK(
kernel_content.GetInternalTensor(tflite::kLstmInputTensor)->dims->size >=
Expand All @@ -308,15 +362,63 @@ TfLiteStatus CMSIS_NN_EvalInteger16x8_16Lstm(
kernel_content.GetInternalTensor(tflite::kLstmInputTensor));
int16_t* output =
tflite::micro::GetTensorData<int16_t>(kernel_content.output_tensor);
int16_t* hidden_state =
tflite::micro::GetTensorData<int16_t>(kernel_content.HiddenStateTensor());
int16_t* cell_state =
tflite::micro::GetTensorData<int16_t>(kernel_content.CellStateTensor());

// Create lstm buffer struct
cmsis_nn_lstm_context cmsis_buffers;
cmsis_buffers.temp1 = reinterpret_cast<int16_t*>(buffers.buffer0);
cmsis_buffers.temp2 = reinterpret_cast<int16_t*>(buffers.buffer1);
cmsis_buffers.cell_state = reinterpret_cast<int16_t*>(buffers.buffer2);

arm_lstm_unidirectional_s16(input, output, &op_data.params_cmsis_nn,
&cmsis_buffers);
cmsis_buffers.cell_state = cell_state;

const auto& params = op_data.params_cmsis_nn;

#ifdef CMSIS_NN_STATEFUL_LSTM
cmsis_buffers.hidden_state = hidden_state;
arm_cmsis_nn_status status =
arm_lstm_unidirectional_s16(input, output, &params, &cmsis_buffers);
if (status != ARM_CMSIS_NN_SUCCESS) return kTfLiteError;
#else
if (params.time_major) {
for (int t = 0; t < params.time_steps; t++) {
const int16_t* data_in =
input + (t * params.batch_size * params.input_size);
int16_t* hidden_out =
output + (t * params.batch_size * params.hidden_size);

arm_cmsis_nn_status status = arm_nn_lstm_step_s16(
data_in, hidden_state, hidden_out, &params, &cmsis_buffers, 1);
if (status != ARM_CMSIS_NN_SUCCESS) return kTfLiteError;

// Update hidden state for next step
std::copy_n(hidden_out, params.batch_size * params.hidden_size,
hidden_state);
}
} else {
cmsis_nn_lstm_params step_params = params;
step_params.batch_size = 1;
for (int b = 0; b < params.batch_size; b++) {
for (int t = 0; t < params.time_steps; t++) {
const int16_t* data_in =
input + (b * params.time_steps + t) * params.input_size;
int16_t* hidden_out =
output + (b * params.time_steps + t) * params.hidden_size;
int16_t* current_hidden = hidden_state + b * params.hidden_size;
cmsis_buffers.cell_state = cell_state + b * params.hidden_size;

arm_cmsis_nn_status status =
arm_nn_lstm_step_s16(data_in, current_hidden, hidden_out,
&step_params, &cmsis_buffers, 1);
if (status != ARM_CMSIS_NN_SUCCESS) return kTfLiteError;

// Update hidden state for next step
std::copy_n(hidden_out, params.hidden_size, current_hidden);
}
}
}
#endif

return kTfLiteOk;
}
Expand Down
Loading