Skip to content

Commit b2c604b

Browse files
committed
Added BatchMatMulOpTestFloat32Test_BatchSizeTwo_Broadcast_LHSAdjoint
Added BatchMatMulOpTestFloat32Test_BatchSizeTwo_Broadcast_RHSAdjoint
1 parent a6cd602 commit b2c604b

2 files changed

Lines changed: 101 additions & 24 deletions

File tree

tensorflow/lite/micro/kernels/batch_matmul.cc

Lines changed: 41 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,46 @@ TfLiteStatus EvalInt16(TfLiteContext* context, const OpDataBatchMatmul& data,
340340
return kTfLiteOk;
341341
}
342342

343+
void ReshapeToFlattenBatchDimsIfPossible(bool adj_x, RuntimeShape* lhs_shape,
344+
RuntimeShape* rhs_shape) {
345+
// Compress BatchMatMul when third from last RHS dimension is one.
346+
int32_t rhs_dims_count = rhs_shape->DimensionsCount();
347+
int32_t lhs_dims_count = lhs_shape->DimensionsCount();
348+
349+
// Compress ops where rhs shape is [..., 1, X, Y] and lhs shape is
350+
// [..., Q, R, S] which is equivalent to rhs: [..., X, Y] and
351+
// lhs: [..., Q * R, S].
352+
//
353+
// We can only flatten the dimensions if the physical layout in memory
354+
// allows us to treat [Batch, Row, Col] as [Batch * Row, Col].
355+
// This requires the 'Row' dimension to be contiguous with the 'Batch'
356+
// dimension.
357+
//
358+
// If adj_x is true, the logical operation is on transposed matrices.
359+
// The physical layout is [Batch, Row, Col] but logically we access [Batch,
360+
// Col, Row]. Flattening Batch and Row dimensions (physically) results in
361+
// [Batch*Row, Col]. This does not match the logical expectation of [Batch,
362+
// Col, Row] because the columns of the second batch are not contiguous with
363+
// the columns of the first batch in the logical transposed view. Therefore,
364+
// we disable this optimization when adj_x is true.
365+
if (!adj_x && rhs_dims_count > 2 && lhs_dims_count > 2) {
366+
int rhs_one = rhs_shape->DimsData()[rhs_dims_count - 3];
367+
if (rhs_one == 1) {
368+
int32_t* lhs_dims = lhs_shape->DimsData();
369+
int32_t* rhs_dims = rhs_shape->DimsData();
370+
RuntimeShape tmp_l(lhs_dims_count - 1, lhs_dims);
371+
tmp_l.SetDim(lhs_dims_count - 3,
372+
lhs_dims[lhs_dims_count - 3] * lhs_dims[lhs_dims_count - 2]);
373+
tmp_l.SetDim(lhs_dims_count - 2, lhs_dims[lhs_dims_count - 1]);
374+
lhs_shape->ReplaceWith(tmp_l.DimensionsCount(), tmp_l.DimsData());
375+
RuntimeShape tmp_r(rhs_dims_count - 1, rhs_shape->DimsData());
376+
tmp_r.SetDim(rhs_dims_count - 3, rhs_dims[rhs_dims_count - 2]);
377+
tmp_r.SetDim(rhs_dims_count - 2, rhs_dims[rhs_dims_count - 1]);
378+
rhs_shape->ReplaceWith(tmp_r.DimensionsCount(), tmp_r.DimsData());
379+
}
380+
}
381+
}
382+
343383
// Perform a batch matrix multiply on
344384
// LHS <..., A, B> X RHS<..., B, C>
345385
// where the leading dimensions of LHS and RHS obey broadcasting rules
@@ -363,30 +403,7 @@ TfLiteStatus BatchMatMulEval(TfLiteContext* context, TfLiteNode* node) {
363403
bool adj_y = op_context.params->adj_y;
364404
bool adj_x = op_context.params->adj_x;
365405

366-
// Compress BatchMatMul when third from last RHS dimension is one.
367-
int32_t rhs_dims_count = orig_rhs_shape.DimensionsCount();
368-
int32_t lhs_dims_count = orig_lhs_shape.DimensionsCount();
369-
// Compress ops where rhs shape is [..., 1, X, Y] and lhs shape is
370-
// [..., Q, R, S] which is equivalent to rhs: [..., X, Y] and
371-
// lhs: [..., Q * R, S].
372-
if (rhs_dims_count > 2 && lhs_dims_count > 2) {
373-
int rhs_one = orig_rhs_shape.DimsData()[rhs_dims_count - 3];
374-
if (rhs_one == 1) {
375-
int32_t* lhs_dims = orig_lhs_shape.DimsData();
376-
int32_t* rhs_dims = orig_rhs_shape.DimsData();
377-
RuntimeShape tmp_l(lhs_dims_count - 1, lhs_dims);
378-
tmp_l.SetDim(lhs_dims_count - 3,
379-
lhs_dims[lhs_dims_count - 3] * lhs_dims[lhs_dims_count - 2]);
380-
tmp_l.SetDim(lhs_dims_count - 2, lhs_dims[lhs_dims_count - 1]);
381-
orig_lhs_shape.ReplaceWith(tmp_l.DimensionsCount(), tmp_l.DimsData());
382-
RuntimeShape tmp_r(rhs_dims_count - 1, orig_rhs_shape.DimsData());
383-
tmp_r.SetDim(rhs_dims_count - 3, rhs_dims[rhs_dims_count - 2]);
384-
tmp_r.SetDim(rhs_dims_count - 2, rhs_dims[rhs_dims_count - 1]);
385-
orig_rhs_shape.ReplaceWith(tmp_r.DimensionsCount(), tmp_r.DimsData());
386-
rhs_dims_count = orig_rhs_shape.DimensionsCount();
387-
lhs_dims_count = orig_lhs_shape.DimensionsCount();
388-
}
389-
}
406+
ReshapeToFlattenBatchDimsIfPossible(adj_x, &orig_lhs_shape, &orig_rhs_shape);
390407

391408
TfLiteEvalTensor* rhs_tensor = adj_y ? const_cast<TfLiteEvalTensor*>(rhs)
392409
: op_data->rhs_transposed_tensor;

tensorflow/lite/micro/kernels/batch_matmul_test.cc

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,66 @@ TEST(BatchMatmulTest, BatchMatMulOpTestFloat32Test_Broadcast) {
391391
output_data);
392392
}
393393

394+
TEST(BatchMatmulTest,
395+
BatchMatMulOpTestFloat32Test_BatchSizeTwo_Broadcast_LHSAdjoint) {
396+
constexpr int kLhsInputDims[] = {3, 2, 3, 2};
397+
constexpr int kRhsInputDims[] = {3, 1, 3, 4};
398+
const int* kInputDims[tflite::testing::kNumInputs] = {kLhsInputDims,
399+
kRhsInputDims};
400+
401+
constexpr float kLhsInput[] = {1, 4, 2, 5, 3, 6, 7, 10, 8, 11, 9, 12};
402+
403+
constexpr size_t kRhsInputSize = 12;
404+
float rhs_input[kRhsInputSize];
405+
std::iota(std::begin(rhs_input), std::end(rhs_input), 7);
406+
407+
constexpr float kExpect[] = {74., 80., 86., 92., 173., 188., 203., 218.,
408+
272., 296., 320., 344., 371., 404., 437., 470.};
409+
constexpr int kOutputDims[] = {3, 2, 2, 4};
410+
constexpr int kOutputCount = std::extent<decltype(kExpect)>::value;
411+
float output_data[kOutputCount];
412+
constexpr TfLiteBatchMatMulParams params = {
413+
true, // adj_x
414+
false, // adj_y
415+
false // asymmetric_quantize_inputs
416+
};
417+
418+
tflite::testing::TestBatchMatMulFloat(params, kInputDims, kLhsInput,
419+
rhs_input, kOutputDims, kExpect,
420+
output_data);
421+
}
422+
423+
TEST(BatchMatmulTest,
424+
BatchMatMulOpTestFloat32Test_BatchSizeTwo_Broadcast_RHSAdjoint) {
425+
constexpr int kLhsInputDims[] = {3, 2, 2, 3};
426+
constexpr int kRhsInputDims[] = {3, 1, 4, 3};
427+
const int* kInputDims[tflite::testing::kNumInputs] = {kLhsInputDims,
428+
kRhsInputDims};
429+
430+
constexpr size_t kLhsInputSize = 12;
431+
float lhs_input[kLhsInputSize];
432+
std::iota(std::begin(lhs_input), std::end(lhs_input), 1);
433+
434+
constexpr float kRhsInput[] = {7, 11, 15, 8, 12, 16, 9, 13, 17, 10, 14, 18};
435+
436+
constexpr float kExpect[] = {74., 80., 86., 92., 173., 188., 203., 218.,
437+
272., 296., 320., 344., 371., 404., 437., 470.};
438+
439+
constexpr int kOutputDims[] = {3, 2, 2, 4};
440+
constexpr int kOutputCount = std::extent<decltype(kExpect)>::value;
441+
442+
float output_data[kOutputCount];
443+
constexpr TfLiteBatchMatMulParams params = {
444+
false, // adj_x
445+
true, // adj_y
446+
false // asymmetric_quantize_inputs
447+
};
448+
449+
tflite::testing::TestBatchMatMulFloat(params, kInputDims, lhs_input,
450+
kRhsInput, kOutputDims, kExpect,
451+
output_data);
452+
}
453+
394454
TEST(BatchMatmulTest, BatchMatMulOpTestFloat32Test_BroadcastLHSAdjoint) {
395455
constexpr int kLhsInputDims[] = {3, 2, 3, 2};
396456
constexpr int kRhsInputDims[] = {2, 3, 4};

0 commit comments

Comments
 (0)