From 1224fc2921679ada4b97fc6f0beaf3408e2c9034 Mon Sep 17 00:00:00 2001 From: zhangzefeng Date: Fri, 31 Mar 2023 11:48:00 +0800 Subject: [PATCH] fix bug of matmul op's test --- python/conformance/diopi_functions.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/python/conformance/diopi_functions.py b/python/conformance/diopi_functions.py index f59487c..cfba621 100644 --- a/python/conformance/diopi_functions.py +++ b/python/conformance/diopi_functions.py @@ -407,19 +407,28 @@ def matmul(input, other) -> Tensor: out = Tensor((), input.get_dtype()) # (batched) matrix x vector elif len(sizeO) == 1: - sizeI[-1] = 1 + sizeI.pop() out = Tensor(sizeI, input.get_dtype()) # pretended matrix x (batched) matrix elif len(sizeI) == 1: - sizeO[-2] = 1 + sizeO.pop(-2) out = Tensor(sizeO, input.get_dtype()) # (batched) matrix x (batched) matrix else: - sizeI[-1] = sizeO[-1] - if len(sizeI) > 3 and len(sizeO) > 2: - assert sizeI[-3] == sizeO[-3] or sizeI[-3] == 1 or sizeO[-3] == 1,\ - 'input and other should be broadcastable' - sizeI[-3] = sizeI[-3] if sizeI[-3] == 1 else sizeO[-3] + if len(sizeI) < len(sizeO): + for i in range(len(sizeO) - len(sizeI)): + sizeI.insert(0, 1) + elif len(sizeI) > len(sizeO): + for i in range(len(sizeI) - len(sizeO)): + sizeO.insert(0, 1) + + assert sizeI[-1] == sizeO[-2], 'can not execute matmul because shape not match' + for i in range(len(sizeI) - 2): + assert sizeI[i] == sizeO[i] or sizeI[i] == 1 or sizeO[i] == 1, 'input and other should be broadcastable' + if sizeI[i] == 1: + sizeI[i] = sizeO[i] + sizeI.pop() + sizeI.append(sizeO[-1]) out = Tensor(sizeI, input.get_dtype()) func = check_function("diopiMatmul")