Skip to content

Commit bfb7c4c

Browse files
committed
Replace transpose
1 parent 92a2ddd commit bfb7c4c

1 file changed

Lines changed: 22 additions & 50 deletions

File tree

  • tensorforge/include/tensorforge_device

tensorforge/include/tensorforge_device/hip.h

Lines changed: 22 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -608,56 +608,28 @@ transpose16x16b32(T &w1, T &w2, T &w3, T &w4, T &w5, T &w6, T &w7, T &w8, T &w9,
608608
template <typename T>
609609
__device__ __forceinline__ void transpose4x4b32(T &w1, T &w2, T &w3, T &w4,
610610
T v1, T v2, T v3, T v4) {
611-
const uint64_t mask1a = 0x5555555555555555ULL;
612-
const uint64_t mask1b = 0xaaaaaaaaaaaaaaaaULL;
613-
const uint64_t mask2a = 0x3333333333333333ULL;
614-
const uint64_t mask2b = 0xccccccccccccccccULL;
615-
616-
T u1, u2, u3, u4;
617-
618-
// a bit suboptimal: we four more s_movs than specified
619-
// (otherwise there was the danger of a register override)
620-
621-
__asm("s_mov_b64 vcc, %[mask] \n\t" CM4STR(0, 0, 2, 2, "%[u1]", "%[v2]",
622-
"%[v1]")
623-
: [u1] "=v"(u1)
624-
: [mask] "s"(mask1a), [v1] "v"(v1), [v2] "v"(v2)
625-
: "vcc");
626-
__asm("s_mov_b64 vcc, %[mask] \n\t" CM4STR(0, 0, 2, 2, "%[u3]", "%[v4]",
627-
"%[v3]")
628-
: [u3] "=v"(u3)
629-
: [mask] "s"(mask1a), [v3] "v"(v3), [v4] "v"(v4)
630-
: "vcc");
631-
__asm("s_mov_b64 vcc, %[mask] \n\t" CM4STR(1, 1, 3, 3, "%[u2]", "%[v1]",
632-
"%[v2]")
633-
: [u2] "=v"(u2)
634-
: [mask] "s"(mask1b), [v1] "v"(v1), [v2] "v"(v2)
635-
: "vcc");
636-
__asm("s_mov_b64 vcc, %[mask] \n\t" CM4STR(1, 1, 3, 3, "%[u4]", "%[v3]",
637-
"%[v4]")
638-
: [u4] "=v"(u4)
639-
: [mask] "s"(mask1b), [v3] "v"(v3), [v4] "v"(v4)
640-
: "vcc");
641-
__asm("s_mov_b64 vcc, %[mask] \n\t" CM4STR(0, 1, 0, 1, "%[w1]", "%[u3]",
642-
"%[u1]")
643-
: [w1] "=v"(w1)
644-
: [mask] "s"(mask2a), [u1] "v"(u1), [u3] "v"(u3)
645-
: "vcc");
646-
__asm("s_mov_b64 vcc, %[mask] \n\t" CM4STR(0, 1, 0, 1, "%[w2]", "%[u4]",
647-
"%[u2]")
648-
: [w2] "=v"(w2)
649-
: [mask] "s"(mask2a), [u2] "v"(u2), [u4] "v"(u4)
650-
: "vcc");
651-
__asm("s_mov_b64 vcc, %[mask] \n\t" CM4STR(2, 3, 2, 3, "%[w3]", "%[u1]",
652-
"%[u3]")
653-
: [w3] "=v"(w3)
654-
: [mask] "s"(mask2b), [u1] "v"(u1), [u3] "v"(u3)
655-
: "vcc");
656-
__asm("s_mov_b64 vcc, %[mask] \n\t" CM4STR(2, 3, 2, 3, "%[w4]", "%[u2]",
657-
"%[u4]")
658-
: [w4] "=v"(w4)
659-
: [mask] "s"(mask2b), [u2] "v"(u2), [u4] "v"(u4)
660-
: "vcc");
611+
// REMARK: we could combine DPP with cndmask (I tried it with inline assembly;
612+
// but it didn't work w/o errors alas)
613+
614+
const auto vv2 = dpp<0xa0, 0xf, 0xf, true>(v2);
615+
const auto vv4 = dpp<0xa0, 0xf, 0xf, true>(v4);
616+
const auto vv1 = dpp<0xf5, 0xf, 0xf, true>(v1);
617+
const auto vv3 = dpp<0xf5, 0xf, 0xf, true>(v3);
618+
619+
const auto u1 = __lane_id() % 2 == 0 ? v1 : vv2;
620+
const auto u2 = __lane_id() % 2 == 1 ? v2 : vv1;
621+
const auto u3 = __lane_id() % 2 == 0 ? v3 : vv4;
622+
const auto u4 = __lane_id() % 2 == 1 ? v4 : vv3;
623+
624+
const auto uu1 = dpp<0xee, 0xf, 0xf, true>(u1);
625+
const auto uu2 = dpp<0xee, 0xf, 0xf, true>(u2);
626+
const auto uu3 = dpp<0x44, 0xf, 0xf, true>(u3);
627+
const auto uu4 = dpp<0x44, 0xf, 0xf, true>(u4);
628+
629+
w1 = __lane_id() % 4 < 2 ? u1 : uu3;
630+
w2 = __lane_id() % 4 < 2 ? u2 : uu4;
631+
w3 = __lane_id() % 4 >= 2 ? u3 : uu1;
632+
w4 = __lane_id() % 4 >= 2 ? u4 : uu2;
661633
}
662634

663635
/*

0 commit comments

Comments
 (0)