Skip to content

Commit b3e89a6

Browse files
committed
fix(kernel): 为 TransposeInfo 应对更多 coner case 并增加一个单元测试
Signed-off-by: YdrMaster <ydrml@hotmail.com>
1 parent 7db54c1 commit b3e89a6

File tree

2 files changed

+45
-1
lines changed

2 files changed

+45
-1
lines changed

src/04kernel/src/attributes/transpose_info.cc

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,12 @@ namespace refactor::kernel {
7373
}
7474
perm.resize(rank);
7575
}
76+
if (rank <= 1) {
77+
dims = {{1, 1}};
78+
blockSize *= blockCount;
79+
blockCount = 1;
80+
return;
81+
}
7682
// 合并末尾连续访存
7783
if (perm.back() == rank - 1) {
7884
blockSize *= shape.back();
@@ -81,7 +87,6 @@ namespace refactor::kernel {
8187
perm.pop_back();
8288
--rank;
8389
}
84-
8590
// 计算 stride
8691
struct StrideI {
8792
dim_t strideI;
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
#include "kernel/attributes/transpose_info.h"
2+
#include <gtest/gtest.h>
3+
4+
using namespace refactor;
5+
using namespace kernel;
6+
7+
TEST(kernel, TransposeInfo) {
8+
{
9+
TransposeInfo info(
10+
DataType::F32,
11+
{1, 2, 3, 2, 1},
12+
{1, 2, 3, 0, 4});
13+
EXPECT_EQ(info.blockSize, 48);
14+
EXPECT_EQ(info.blockCount, 1);
15+
EXPECT_EQ(info.dims.size(), 1);
16+
}
17+
{
18+
TransposeInfo info(
19+
DataType::F32,
20+
{1, 1, 2, 1, 1},
21+
{1, 2, 3, 0, 4});
22+
EXPECT_EQ(info.blockSize, 8);
23+
EXPECT_EQ(info.blockCount, 1);
24+
EXPECT_EQ(info.dims.size(), 1);
25+
}
26+
{
27+
TransposeInfo info(
28+
DataType::F32,
29+
{1, 2, 3, 4, 5},
30+
{2, 3, 1, 0, 4});
31+
EXPECT_EQ(info.blockSize, 20);
32+
EXPECT_EQ(info.blockCount, 24);
33+
EXPECT_EQ(info.dims.size(), 2);
34+
EXPECT_EQ(info.dims[1].strideI, 12);
35+
EXPECT_EQ(info.dims[1].strideO, 1);
36+
EXPECT_EQ(info.dims[0].strideI, 1);
37+
EXPECT_EQ(info.dims[0].strideO, 2);
38+
}
39+
}

0 commit comments

Comments
 (0)