@@ -340,6 +340,46 @@ TfLiteStatus EvalInt16(TfLiteContext* context, const OpDataBatchMatmul& data,
340340 return kTfLiteOk ;
341341}
342342
343+ void ReshapeToFlattenBatchDimsIfPossible (bool adj_x, RuntimeShape* lhs_shape,
344+ RuntimeShape* rhs_shape) {
345+ // Compress BatchMatMul when third from last RHS dimension is one.
346+ int32_t rhs_dims_count = rhs_shape->DimensionsCount ();
347+ int32_t lhs_dims_count = lhs_shape->DimensionsCount ();
348+
349+ // Compress ops where rhs shape is [..., 1, X, Y] and lhs shape is
350+ // [..., Q, R, S] which is equivalent to rhs: [..., X, Y] and
351+ // lhs: [..., Q * R, S].
352+ //
353+ // We can only flatten the dimensions if the physical layout in memory
354+ // allows us to treat [Batch, Row, Col] as [Batch * Row, Col].
355+ // This requires the 'Row' dimension to be contiguous with the 'Batch'
356+ // dimension.
357+ //
358+ // If adj_x is true, the logical operation is on transposed matrices.
359+ // The physical layout is [Batch, Row, Col] but logically we access [Batch,
360+ // Col, Row]. Flattening Batch and Row dimensions (physically) results in
361+ // [Batch*Row, Col]. This does not match the logical expectation of [Batch,
362+ // Col, Row] because the columns of the second batch are not contiguous with
363+ // the columns of the first batch in the logical transposed view. Therefore,
364+ // we disable this optimization when adj_x is true.
365+ if (!adj_x && rhs_dims_count > 2 && lhs_dims_count > 2 ) {
366+ int rhs_one = rhs_shape->DimsData ()[rhs_dims_count - 3 ];
367+ if (rhs_one == 1 ) {
368+ int32_t * lhs_dims = lhs_shape->DimsData ();
369+ int32_t * rhs_dims = rhs_shape->DimsData ();
370+ RuntimeShape tmp_l (lhs_dims_count - 1 , lhs_dims);
371+ tmp_l.SetDim (lhs_dims_count - 3 ,
372+ lhs_dims[lhs_dims_count - 3 ] * lhs_dims[lhs_dims_count - 2 ]);
373+ tmp_l.SetDim (lhs_dims_count - 2 , lhs_dims[lhs_dims_count - 1 ]);
374+ lhs_shape->ReplaceWith (tmp_l.DimensionsCount (), tmp_l.DimsData ());
375+ RuntimeShape tmp_r (rhs_dims_count - 1 , rhs_shape->DimsData ());
376+ tmp_r.SetDim (rhs_dims_count - 3 , rhs_dims[rhs_dims_count - 2 ]);
377+ tmp_r.SetDim (rhs_dims_count - 2 , rhs_dims[rhs_dims_count - 1 ]);
378+ rhs_shape->ReplaceWith (tmp_r.DimensionsCount (), tmp_r.DimsData ());
379+ }
380+ }
381+ }
382+
343383// Perform a batch matrix multiply on
344384// LHS <..., A, B> X RHS<..., B, C>
345385// where the leading dimensions of LHS and RHS obey broadcasting rules
@@ -363,30 +403,7 @@ TfLiteStatus BatchMatMulEval(TfLiteContext* context, TfLiteNode* node) {
363403 bool adj_y = op_context.params ->adj_y ;
364404 bool adj_x = op_context.params ->adj_x ;
365405
366- // Compress BatchMatMul when third from last RHS dimension is one.
367- int32_t rhs_dims_count = orig_rhs_shape.DimensionsCount ();
368- int32_t lhs_dims_count = orig_lhs_shape.DimensionsCount ();
369- // Compress ops where rhs shape is [..., 1, X, Y] and lhs shape is
370- // [..., Q, R, S] which is equivalent to rhs: [..., X, Y] and
371- // lhs: [..., Q * R, S].
372- if (rhs_dims_count > 2 && lhs_dims_count > 2 ) {
373- int rhs_one = orig_rhs_shape.DimsData ()[rhs_dims_count - 3 ];
374- if (rhs_one == 1 ) {
375- int32_t * lhs_dims = orig_lhs_shape.DimsData ();
376- int32_t * rhs_dims = orig_rhs_shape.DimsData ();
377- RuntimeShape tmp_l (lhs_dims_count - 1 , lhs_dims);
378- tmp_l.SetDim (lhs_dims_count - 3 ,
379- lhs_dims[lhs_dims_count - 3 ] * lhs_dims[lhs_dims_count - 2 ]);
380- tmp_l.SetDim (lhs_dims_count - 2 , lhs_dims[lhs_dims_count - 1 ]);
381- orig_lhs_shape.ReplaceWith (tmp_l.DimensionsCount (), tmp_l.DimsData ());
382- RuntimeShape tmp_r (rhs_dims_count - 1 , orig_rhs_shape.DimsData ());
383- tmp_r.SetDim (rhs_dims_count - 3 , rhs_dims[rhs_dims_count - 2 ]);
384- tmp_r.SetDim (rhs_dims_count - 2 , rhs_dims[rhs_dims_count - 1 ]);
385- orig_rhs_shape.ReplaceWith (tmp_r.DimensionsCount (), tmp_r.DimsData ());
386- rhs_dims_count = orig_rhs_shape.DimensionsCount ();
387- lhs_dims_count = orig_lhs_shape.DimensionsCount ();
388- }
389- }
406+ ReshapeToFlattenBatchDimsIfPossible (adj_x, &orig_lhs_shape, &orig_rhs_shape);
390407
391408 TfLiteEvalTensor* rhs_tensor = adj_y ? const_cast <TfLiteEvalTensor*>(rhs)
392409 : op_data->rhs_transposed_tensor ;
0 commit comments