From d3617b06ea4d67b68b70391e8a574d938828f23b Mon Sep 17 00:00:00 2001 From: Legendx4060 Date: Fri, 13 Feb 2026 02:20:07 +0530 Subject: [PATCH] Add aten::_grouped_mm converter implementation Implements the converter for aten::_grouped_mm.default to address issue #2795. Handles the batch/dense mode where groups are implicit in the batch dimension using MatMul, with optional bias addition and dtype casting. --- .../function_libs/torch_lib/ops/core.py | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 483e0ea46f..f2e8d3b603 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4503,6 +4503,29 @@ def aten_grid_sampler_3d_backward( raise NotImplementedError() +@torch_op("aten::_grouped_mm") +def aten_grouped_mm( + self: TFloat, + mat2: TFloat, + offs: Optional[TInt] = None, + bias: Optional[TFloat] = None, + out_dtype: Optional[int] = None, +) -> TFloat: + """_grouped_mm(Tensor self, Tensor mat2, *, Tensor? offs=None, Tensor? bias=None, int? out_dtype=None) -> Tensor""" + + # If offs is None, it uses the "dense" / "batch" mode where groups are implicit in the batch dimension. + # self: (G, M, K), mat2: (G, K, N) -> (G, M, N) + if offs is None: + res = op.MatMul(self, mat2) + if bias is not None: + res = op.Add(res, bias) + if out_dtype is not None: + res = op.Cast(res, to=out_dtype) + return res + + raise NotImplementedError("aten::_grouped_mm with 'offs' is not supported.") + + def aten_gru_cell( input: TensorType, hx: TensorType,