Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 41 additions & 24 deletions tensorflow/lite/micro/kernels/batch_matmul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,46 @@ TfLiteStatus EvalInt16(TfLiteContext* context, const OpDataBatchMatmul& data,
return kTfLiteOk;
}

void ReshapeToFlattenBatchDimsIfPossible(bool adj_x, RuntimeShape* lhs_shape,
RuntimeShape* rhs_shape) {
// Compress BatchMatMul when third from last RHS dimension is one.
int32_t rhs_dims_count = rhs_shape->DimensionsCount();
int32_t lhs_dims_count = lhs_shape->DimensionsCount();

// Compress ops where rhs shape is [..., 1, X, Y] and lhs shape is
// [..., Q, R, S] which is equivalent to rhs: [..., X, Y] and
// lhs: [..., Q * R, S].
//
// We can only flatten the dimensions if the physical layout in memory
// allows us to treat [Batch, Row, Col] as [Batch * Row, Col].
// This requires the 'Row' dimension to be contiguous with the 'Batch'
// dimension.
//
// If adj_x is true, the logical operation is on transposed matrices.
// The physical layout is [Batch, Row, Col] but logically we access [Batch,
// Col, Row]. Flattening Batch and Row dimensions (physically) results in
// [Batch*Row, Col]. This does not match the logical expectation of [Batch,
// Col, Row] because the columns of the second batch are not contiguous with
// the columns of the first batch in the logical transposed view. Therefore,
// we disable this optimization when adj_x is true.
Comment on lines +353 to +364
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm. This means the LiteRT code also needs to be corrected.

if (!adj_x && rhs_dims_count > 2 && lhs_dims_count > 2) {
int rhs_one = rhs_shape->DimsData()[rhs_dims_count - 3];
if (rhs_one == 1) {
int32_t* lhs_dims = lhs_shape->DimsData();
int32_t* rhs_dims = rhs_shape->DimsData();
RuntimeShape tmp_l(lhs_dims_count - 1, lhs_dims);
tmp_l.SetDim(lhs_dims_count - 3,
lhs_dims[lhs_dims_count - 3] * lhs_dims[lhs_dims_count - 2]);
tmp_l.SetDim(lhs_dims_count - 2, lhs_dims[lhs_dims_count - 1]);
lhs_shape->ReplaceWith(tmp_l.DimensionsCount(), tmp_l.DimsData());
RuntimeShape tmp_r(rhs_dims_count - 1, rhs_shape->DimsData());
tmp_r.SetDim(rhs_dims_count - 3, rhs_dims[rhs_dims_count - 2]);
tmp_r.SetDim(rhs_dims_count - 2, rhs_dims[rhs_dims_count - 1]);
rhs_shape->ReplaceWith(tmp_r.DimensionsCount(), tmp_r.DimsData());
}
}
}

// Perform a batch matrix multiply on
// LHS <..., A, B> X RHS<..., B, C>
// where the leading dimensions of LHS and RHS obey broadcasting rules
Expand All @@ -363,30 +403,7 @@ TfLiteStatus BatchMatMulEval(TfLiteContext* context, TfLiteNode* node) {
bool adj_y = op_context.params->adj_y;
bool adj_x = op_context.params->adj_x;

// Compress BatchMatMul when third from last RHS dimension is one.
int32_t rhs_dims_count = orig_rhs_shape.DimensionsCount();
int32_t lhs_dims_count = orig_lhs_shape.DimensionsCount();
// Compress ops where rhs shape is [..., 1, X, Y] and lhs shape is
// [..., Q, R, S] which is equivalent to rhs: [..., X, Y] and
// lhs: [..., Q * R, S].
if (rhs_dims_count > 2 && lhs_dims_count > 2) {
int rhs_one = orig_rhs_shape.DimsData()[rhs_dims_count - 3];
if (rhs_one == 1) {
int32_t* lhs_dims = orig_lhs_shape.DimsData();
int32_t* rhs_dims = orig_rhs_shape.DimsData();
RuntimeShape tmp_l(lhs_dims_count - 1, lhs_dims);
tmp_l.SetDim(lhs_dims_count - 3,
lhs_dims[lhs_dims_count - 3] * lhs_dims[lhs_dims_count - 2]);
tmp_l.SetDim(lhs_dims_count - 2, lhs_dims[lhs_dims_count - 1]);
orig_lhs_shape.ReplaceWith(tmp_l.DimensionsCount(), tmp_l.DimsData());
RuntimeShape tmp_r(rhs_dims_count - 1, orig_rhs_shape.DimsData());
tmp_r.SetDim(rhs_dims_count - 3, rhs_dims[rhs_dims_count - 2]);
tmp_r.SetDim(rhs_dims_count - 2, rhs_dims[rhs_dims_count - 1]);
orig_rhs_shape.ReplaceWith(tmp_r.DimensionsCount(), tmp_r.DimsData());
rhs_dims_count = orig_rhs_shape.DimensionsCount();
lhs_dims_count = orig_lhs_shape.DimensionsCount();
}
}
ReshapeToFlattenBatchDimsIfPossible(adj_x, &orig_lhs_shape, &orig_rhs_shape);

TfLiteEvalTensor* rhs_tensor = adj_y ? const_cast<TfLiteEvalTensor*>(rhs)
: op_data->rhs_transposed_tensor;
Expand Down
60 changes: 60 additions & 0 deletions tensorflow/lite/micro/kernels/batch_matmul_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,66 @@ TEST(BatchMatmulTest, BatchMatMulOpTestFloat32Test_Broadcast) {
output_data);
}

TEST(BatchMatmulTest,
BatchMatMulOpTestFloat32Test_BatchSizeTwo_Broadcast_LHSAdjoint) {
constexpr int kLhsInputDims[] = {3, 2, 3, 2};
constexpr int kRhsInputDims[] = {3, 1, 3, 4};
const int* kInputDims[tflite::testing::kNumInputs] = {kLhsInputDims,
kRhsInputDims};

constexpr float kLhsInput[] = {1, 4, 2, 5, 3, 6, 7, 10, 8, 11, 9, 12};

constexpr size_t kRhsInputSize = 12;
float rhs_input[kRhsInputSize];
std::iota(std::begin(rhs_input), std::end(rhs_input), 7);

constexpr float kExpect[] = {74., 80., 86., 92., 173., 188., 203., 218.,
272., 296., 320., 344., 371., 404., 437., 470.};
constexpr int kOutputDims[] = {3, 2, 2, 4};
constexpr int kOutputCount = std::extent<decltype(kExpect)>::value;
float output_data[kOutputCount];
constexpr TfLiteBatchMatMulParams params = {
true, // adj_x
false, // adj_y
false // asymmetric_quantize_inputs
};

tflite::testing::TestBatchMatMulFloat(params, kInputDims, kLhsInput,
rhs_input, kOutputDims, kExpect,
output_data);
}

TEST(BatchMatmulTest,
BatchMatMulOpTestFloat32Test_BatchSizeTwo_Broadcast_RHSAdjoint) {
constexpr int kLhsInputDims[] = {3, 2, 2, 3};
constexpr int kRhsInputDims[] = {3, 1, 4, 3};
const int* kInputDims[tflite::testing::kNumInputs] = {kLhsInputDims,
kRhsInputDims};

constexpr size_t kLhsInputSize = 12;
float lhs_input[kLhsInputSize];
std::iota(std::begin(lhs_input), std::end(lhs_input), 1);

constexpr float kRhsInput[] = {7, 11, 15, 8, 12, 16, 9, 13, 17, 10, 14, 18};

constexpr float kExpect[] = {74., 80., 86., 92., 173., 188., 203., 218.,
272., 296., 320., 344., 371., 404., 437., 470.};

constexpr int kOutputDims[] = {3, 2, 2, 4};
constexpr int kOutputCount = std::extent<decltype(kExpect)>::value;

float output_data[kOutputCount];
constexpr TfLiteBatchMatMulParams params = {
false, // adj_x
true, // adj_y
false // asymmetric_quantize_inputs
};

tflite::testing::TestBatchMatMulFloat(params, kInputDims, lhs_input,
kRhsInput, kOutputDims, kExpect,
output_data);
}

TEST(BatchMatmulTest, BatchMatMulOpTestFloat32Test_BroadcastLHSAdjoint) {
constexpr int kLhsInputDims[] = {3, 2, 3, 2};
constexpr int kRhsInputDims[] = {2, 3, 4};
Expand Down
Loading