diff --git a/tensorflow/lite/micro/kernels/batch_matmul.cc b/tensorflow/lite/micro/kernels/batch_matmul.cc index bbb1c0b0a7e..59eca5de6a4 100644 --- a/tensorflow/lite/micro/kernels/batch_matmul.cc +++ b/tensorflow/lite/micro/kernels/batch_matmul.cc @@ -340,6 +340,46 @@ TfLiteStatus EvalInt16(TfLiteContext* context, const OpDataBatchMatmul& data, return kTfLiteOk; } +void ReshapeToFlattenBatchDimsIfPossible(bool adj_x, RuntimeShape* lhs_shape, + RuntimeShape* rhs_shape) { + // Compress BatchMatMul when third from last RHS dimension is one. + int32_t rhs_dims_count = rhs_shape->DimensionsCount(); + int32_t lhs_dims_count = lhs_shape->DimensionsCount(); + + // Compress ops where rhs shape is [..., 1, X, Y] and lhs shape is + // [..., Q, R, S] which is equivalent to rhs: [..., X, Y] and + // lhs: [..., Q * R, S]. + // + // We can only flatten the dimensions if the physical layout in memory + // allows us to treat [Batch, Row, Col] as [Batch * Row, Col]. + // This requires the 'Row' dimension to be contiguous with the 'Batch' + // dimension. + // + // If adj_x is true, the logical operation is on transposed matrices. + // The physical layout is [Batch, Row, Col] but logically we access [Batch, + // Col, Row]. Flattening Batch and Row dimensions (physically) results in + // [Batch*Row, Col]. This does not match the logical expectation of [Batch, + // Col, Row] because the columns of the second batch are not contiguous with + // the columns of the first batch in the logical transposed view. Therefore, + // we disable this optimization when adj_x is true. + if (!adj_x && rhs_dims_count > 2 && lhs_dims_count > 2) { + int rhs_one = rhs_shape->DimsData()[rhs_dims_count - 3]; + if (rhs_one == 1) { + int32_t* lhs_dims = lhs_shape->DimsData(); + int32_t* rhs_dims = rhs_shape->DimsData(); + RuntimeShape tmp_l(lhs_dims_count - 1, lhs_dims); + tmp_l.SetDim(lhs_dims_count - 3, + lhs_dims[lhs_dims_count - 3] * lhs_dims[lhs_dims_count - 2]); + tmp_l.SetDim(lhs_dims_count - 2, lhs_dims[lhs_dims_count - 1]); + lhs_shape->ReplaceWith(tmp_l.DimensionsCount(), tmp_l.DimsData()); + RuntimeShape tmp_r(rhs_dims_count - 1, rhs_shape->DimsData()); + tmp_r.SetDim(rhs_dims_count - 3, rhs_dims[rhs_dims_count - 2]); + tmp_r.SetDim(rhs_dims_count - 2, rhs_dims[rhs_dims_count - 1]); + rhs_shape->ReplaceWith(tmp_r.DimensionsCount(), tmp_r.DimsData()); + } + } +} + // Perform a batch matrix multiply on // LHS <..., A, B> X RHS<..., B, C> // where the leading dimensions of LHS and RHS obey broadcasting rules @@ -363,30 +403,7 @@ TfLiteStatus BatchMatMulEval(TfLiteContext* context, TfLiteNode* node) { bool adj_y = op_context.params->adj_y; bool adj_x = op_context.params->adj_x; - // Compress BatchMatMul when third from last RHS dimension is one. - int32_t rhs_dims_count = orig_rhs_shape.DimensionsCount(); - int32_t lhs_dims_count = orig_lhs_shape.DimensionsCount(); - // Compress ops where rhs shape is [..., 1, X, Y] and lhs shape is - // [..., Q, R, S] which is equivalent to rhs: [..., X, Y] and - // lhs: [..., Q * R, S]. - if (rhs_dims_count > 2 && lhs_dims_count > 2) { - int rhs_one = orig_rhs_shape.DimsData()[rhs_dims_count - 3]; - if (rhs_one == 1) { - int32_t* lhs_dims = orig_lhs_shape.DimsData(); - int32_t* rhs_dims = orig_rhs_shape.DimsData(); - RuntimeShape tmp_l(lhs_dims_count - 1, lhs_dims); - tmp_l.SetDim(lhs_dims_count - 3, - lhs_dims[lhs_dims_count - 3] * lhs_dims[lhs_dims_count - 2]); - tmp_l.SetDim(lhs_dims_count - 2, lhs_dims[lhs_dims_count - 1]); - orig_lhs_shape.ReplaceWith(tmp_l.DimensionsCount(), tmp_l.DimsData()); - RuntimeShape tmp_r(rhs_dims_count - 1, orig_rhs_shape.DimsData()); - tmp_r.SetDim(rhs_dims_count - 3, rhs_dims[rhs_dims_count - 2]); - tmp_r.SetDim(rhs_dims_count - 2, rhs_dims[rhs_dims_count - 1]); - orig_rhs_shape.ReplaceWith(tmp_r.DimensionsCount(), tmp_r.DimsData()); - rhs_dims_count = orig_rhs_shape.DimensionsCount(); - lhs_dims_count = orig_lhs_shape.DimensionsCount(); - } - } + ReshapeToFlattenBatchDimsIfPossible(adj_x, &orig_lhs_shape, &orig_rhs_shape); TfLiteEvalTensor* rhs_tensor = adj_y ? const_cast(rhs) : op_data->rhs_transposed_tensor; diff --git a/tensorflow/lite/micro/kernels/batch_matmul_test.cc b/tensorflow/lite/micro/kernels/batch_matmul_test.cc index 712984f6054..8806ea1d16a 100644 --- a/tensorflow/lite/micro/kernels/batch_matmul_test.cc +++ b/tensorflow/lite/micro/kernels/batch_matmul_test.cc @@ -391,6 +391,66 @@ TEST(BatchMatmulTest, BatchMatMulOpTestFloat32Test_Broadcast) { output_data); } +TEST(BatchMatmulTest, + BatchMatMulOpTestFloat32Test_BatchSizeTwo_Broadcast_LHSAdjoint) { + constexpr int kLhsInputDims[] = {3, 2, 3, 2}; + constexpr int kRhsInputDims[] = {3, 1, 3, 4}; + const int* kInputDims[tflite::testing::kNumInputs] = {kLhsInputDims, + kRhsInputDims}; + + constexpr float kLhsInput[] = {1, 4, 2, 5, 3, 6, 7, 10, 8, 11, 9, 12}; + + constexpr size_t kRhsInputSize = 12; + float rhs_input[kRhsInputSize]; + std::iota(std::begin(rhs_input), std::end(rhs_input), 7); + + constexpr float kExpect[] = {74., 80., 86., 92., 173., 188., 203., 218., + 272., 296., 320., 344., 371., 404., 437., 470.}; + constexpr int kOutputDims[] = {3, 2, 2, 4}; + constexpr int kOutputCount = std::extent::value; + float output_data[kOutputCount]; + constexpr TfLiteBatchMatMulParams params = { + true, // adj_x + false, // adj_y + false // asymmetric_quantize_inputs + }; + + tflite::testing::TestBatchMatMulFloat(params, kInputDims, kLhsInput, + rhs_input, kOutputDims, kExpect, + output_data); +} + +TEST(BatchMatmulTest, + BatchMatMulOpTestFloat32Test_BatchSizeTwo_Broadcast_RHSAdjoint) { + constexpr int kLhsInputDims[] = {3, 2, 2, 3}; + constexpr int kRhsInputDims[] = {3, 1, 4, 3}; + const int* kInputDims[tflite::testing::kNumInputs] = {kLhsInputDims, + kRhsInputDims}; + + constexpr size_t kLhsInputSize = 12; + float lhs_input[kLhsInputSize]; + std::iota(std::begin(lhs_input), std::end(lhs_input), 1); + + constexpr float kRhsInput[] = {7, 11, 15, 8, 12, 16, 9, 13, 17, 10, 14, 18}; + + constexpr float kExpect[] = {74., 80., 86., 92., 173., 188., 203., 218., + 272., 296., 320., 344., 371., 404., 437., 470.}; + + constexpr int kOutputDims[] = {3, 2, 2, 4}; + constexpr int kOutputCount = std::extent::value; + + float output_data[kOutputCount]; + constexpr TfLiteBatchMatMulParams params = { + false, // adj_x + true, // adj_y + false // asymmetric_quantize_inputs + }; + + tflite::testing::TestBatchMatMulFloat(params, kInputDims, lhs_input, + kRhsInput, kOutputDims, kExpect, + output_data); +} + TEST(BatchMatmulTest, BatchMatMulOpTestFloat32Test_BroadcastLHSAdjoint) { constexpr int kLhsInputDims[] = {3, 2, 3, 2}; constexpr int kRhsInputDims[] = {2, 3, 4};