diff --git a/tensorflow/lite/micro/kernels/cmsis_nn/maximum_minimum.cc b/tensorflow/lite/micro/kernels/cmsis_nn/maximum_minimum.cc index a6affaa11bb..105836cb92b 100644 --- a/tensorflow/lite/micro/kernels/cmsis_nn/maximum_minimum.cc +++ b/tensorflow/lite/micro/kernels/cmsis_nn/maximum_minimum.cc @@ -30,16 +30,10 @@ namespace tflite { namespace { -cmsis_nn_dims FillVariableShape(int32_t rank, int32_t* tensor_dims) { - if (rank == 4) { - return {tensor_dims[0], tensor_dims[1], tensor_dims[2], tensor_dims[3]}; - } else if (rank == 3) { - return {1, tensor_dims[0], tensor_dims[1], tensor_dims[2]}; - } else if (rank == 2) { - return {1, 1, tensor_dims[0], tensor_dims[1]}; - } else { - return {1, 1, 1, 1}; - } +cmsis_nn_dims FillVariableShape(const RuntimeShape& shape) { + RuntimeShape extended_shape = RuntimeShape::ExtendedShape(4, shape); + return {extended_shape.Dims(0), extended_shape.Dims(1), + extended_shape.Dims(2), extended_shape.Dims(3)}; } TfLiteStatus EvalMaximum(TfLiteContext* context, TfLiteNode* node) { @@ -55,12 +49,9 @@ TfLiteStatus EvalMaximum(TfLiteContext* context, TfLiteNode* node) { RuntimeShape input_2_shape = tflite::micro::GetTensorShape(input2); RuntimeShape output_shape = tflite::micro::GetTensorShape(output); - cmsis_nn_dims input_1_dims = FillVariableShape( - input_1_shape.DimensionsCount(), input_1_shape.DimsData()); - cmsis_nn_dims input_2_dims = FillVariableShape( - input_2_shape.DimensionsCount(), input_2_shape.DimsData()); - cmsis_nn_dims output_dims = FillVariableShape(output_shape.DimensionsCount(), - output_shape.DimsData()); + cmsis_nn_dims input_1_dims = FillVariableShape(input_1_shape); + cmsis_nn_dims input_2_dims = FillVariableShape(input_2_shape); + cmsis_nn_dims output_dims = FillVariableShape(output_shape); switch (op_context.output->type) { case kTfLiteInt8: @@ -107,12 +98,9 @@ TfLiteStatus EvalMaximumInt8(TfLiteContext* context, TfLiteNode* node) { RuntimeShape input_2_shape = tflite::micro::GetTensorShape(input2); RuntimeShape output_shape = tflite::micro::GetTensorShape(output); - cmsis_nn_dims input_1_dims = FillVariableShape( - input_1_shape.DimensionsCount(), input_1_shape.DimsData()); - cmsis_nn_dims input_2_dims = FillVariableShape( - input_2_shape.DimensionsCount(), input_2_shape.DimsData()); - cmsis_nn_dims output_dims = FillVariableShape(output_shape.DimensionsCount(), - output_shape.DimsData()); + cmsis_nn_dims input_1_dims = FillVariableShape(input_1_shape); + cmsis_nn_dims input_2_dims = FillVariableShape(input_2_shape); + cmsis_nn_dims output_dims = FillVariableShape(output_shape); switch (op_context.output->type) { case kTfLiteInt8: @@ -147,12 +135,9 @@ TfLiteStatus EvalMinimum(TfLiteContext* context, TfLiteNode* node) { RuntimeShape input_2_shape = tflite::micro::GetTensorShape(input2); RuntimeShape output_shape = tflite::micro::GetTensorShape(output); - cmsis_nn_dims input_1_dims = FillVariableShape( - input_1_shape.DimensionsCount(), input_1_shape.DimsData()); - cmsis_nn_dims input_2_dims = FillVariableShape( - input_2_shape.DimensionsCount(), input_2_shape.DimsData()); - cmsis_nn_dims output_dims = FillVariableShape(output_shape.DimensionsCount(), - output_shape.DimsData()); + cmsis_nn_dims input_1_dims = FillVariableShape(input_1_shape); + cmsis_nn_dims input_2_dims = FillVariableShape(input_2_shape); + cmsis_nn_dims output_dims = FillVariableShape(output_shape); switch (op_context.output->type) { case kTfLiteInt8: @@ -199,12 +184,9 @@ TfLiteStatus EvalMinimumInt8(TfLiteContext* context, TfLiteNode* node) { RuntimeShape input_2_shape = tflite::micro::GetTensorShape(input2); RuntimeShape output_shape = tflite::micro::GetTensorShape(output); - cmsis_nn_dims input_1_dims = FillVariableShape( - input_1_shape.DimensionsCount(), input_1_shape.DimsData()); - cmsis_nn_dims input_2_dims = FillVariableShape( - input_2_shape.DimensionsCount(), input_2_shape.DimsData()); - cmsis_nn_dims output_dims = FillVariableShape(output_shape.DimensionsCount(), - output_shape.DimsData()); + cmsis_nn_dims input_1_dims = FillVariableShape(input_1_shape); + cmsis_nn_dims input_2_dims = FillVariableShape(input_2_shape); + cmsis_nn_dims output_dims = FillVariableShape(output_shape); switch (op_context.output->type) { case kTfLiteInt8: diff --git a/tensorflow/lite/micro/kernels/cmsis_nn/pad.cc b/tensorflow/lite/micro/kernels/cmsis_nn/pad.cc index bacba0b325b..81ffa4e07c2 100644 --- a/tensorflow/lite/micro/kernels/cmsis_nn/pad.cc +++ b/tensorflow/lite/micro/kernels/cmsis_nn/pad.cc @@ -49,13 +49,21 @@ TfLiteStatus PadEvalInt8(TfLiteContext* context, TfLiteNode* node) { int8_t* output_ptr = tflite::micro::GetTensorData(output); const RuntimeShape d = tflite::micro::GetTensorShape(input); - const cmsis_nn_dims input_size = {d.Dims(0), d.Dims(1), d.Dims(2), d.Dims(3)}; + const int rank = d.DimensionsCount(); + + cmsis_nn_dims input_size = { + rank >= 4 ? d.Dims(rank - 4) : 1, rank >= 3 ? d.Dims(rank - 3) : 1, + rank >= 2 ? d.Dims(rank - 2) : 1, rank >= 1 ? d.Dims(rank - 1) : 1}; const PadParams p = data->params; - const cmsis_nn_dims pre_pad = {p.left_padding[0], p.left_padding[1], - p.left_padding[2], p.left_padding[3]}; - const cmsis_nn_dims post_pad = {p.right_padding[0], p.right_padding[1], - p.right_padding[2], p.right_padding[3]}; + cmsis_nn_dims pre_pad = {rank >= 4 ? p.left_padding[rank - 4] : 0, + rank >= 3 ? p.left_padding[rank - 3] : 0, + rank >= 2 ? p.left_padding[rank - 2] : 0, + rank >= 1 ? p.left_padding[rank - 1] : 0}; + cmsis_nn_dims post_pad = {rank >= 4 ? p.right_padding[rank - 4] : 0, + rank >= 3 ? p.right_padding[rank - 3] : 0, + rank >= 2 ? p.right_padding[rank - 2] : 0, + rank >= 1 ? p.right_padding[rank - 1] : 0}; arm_pad_s8(input_ptr, output_ptr, pad_value, &input_size, &pre_pad, &post_pad);