@@ -608,56 +608,28 @@ transpose16x16b32(T &w1, T &w2, T &w3, T &w4, T &w5, T &w6, T &w7, T &w8, T &w9,
608608template <typename T>
609609__device__ __forceinline__ void transpose4x4b32 (T &w1, T &w2, T &w3, T &w4,
610610 T v1, T v2, T v3, T v4) {
611- const uint64_t mask1a = 0x5555555555555555ULL ;
612- const uint64_t mask1b = 0xaaaaaaaaaaaaaaaaULL ;
613- const uint64_t mask2a = 0x3333333333333333ULL ;
614- const uint64_t mask2b = 0xccccccccccccccccULL ;
615-
616- T u1, u2, u3, u4;
617-
618- // a bit suboptimal: we four more s_movs than specified
619- // (otherwise there was the danger of a register override)
620-
621- __asm (" s_mov_b64 vcc, %[mask] \n\t " CM4STR (0 , 0 , 2 , 2 , " %[u1]" , " %[v2]" ,
622- " %[v1]" )
623- : [u1] " =v" (u1)
624- : [mask] " s" (mask1a), [v1] " v" (v1), [v2] " v" (v2)
625- : " vcc" );
626- __asm (" s_mov_b64 vcc, %[mask] \n\t " CM4STR (0 , 0 , 2 , 2 , " %[u3]" , " %[v4]" ,
627- " %[v3]" )
628- : [u3] " =v" (u3)
629- : [mask] " s" (mask1a), [v3] " v" (v3), [v4] " v" (v4)
630- : " vcc" );
631- __asm (" s_mov_b64 vcc, %[mask] \n\t " CM4STR (1 , 1 , 3 , 3 , " %[u2]" , " %[v1]" ,
632- " %[v2]" )
633- : [u2] " =v" (u2)
634- : [mask] " s" (mask1b), [v1] " v" (v1), [v2] " v" (v2)
635- : " vcc" );
636- __asm (" s_mov_b64 vcc, %[mask] \n\t " CM4STR (1 , 1 , 3 , 3 , " %[u4]" , " %[v3]" ,
637- " %[v4]" )
638- : [u4] " =v" (u4)
639- : [mask] " s" (mask1b), [v3] " v" (v3), [v4] " v" (v4)
640- : " vcc" );
641- __asm (" s_mov_b64 vcc, %[mask] \n\t " CM4STR (0 , 1 , 0 , 1 , " %[w1]" , " %[u3]" ,
642- " %[u1]" )
643- : [w1] " =v" (w1)
644- : [mask] " s" (mask2a), [u1] " v" (u1), [u3] " v" (u3)
645- : " vcc" );
646- __asm (" s_mov_b64 vcc, %[mask] \n\t " CM4STR (0 , 1 , 0 , 1 , " %[w2]" , " %[u4]" ,
647- " %[u2]" )
648- : [w2] " =v" (w2)
649- : [mask] " s" (mask2a), [u2] " v" (u2), [u4] " v" (u4)
650- : " vcc" );
651- __asm (" s_mov_b64 vcc, %[mask] \n\t " CM4STR (2 , 3 , 2 , 3 , " %[w3]" , " %[u1]" ,
652- " %[u3]" )
653- : [w3] " =v" (w3)
654- : [mask] " s" (mask2b), [u1] " v" (u1), [u3] " v" (u3)
655- : " vcc" );
656- __asm (" s_mov_b64 vcc, %[mask] \n\t " CM4STR (2 , 3 , 2 , 3 , " %[w4]" , " %[u2]" ,
657- " %[u4]" )
658- : [w4] " =v" (w4)
659- : [mask] " s" (mask2b), [u2] " v" (u2), [u4] " v" (u4)
660- : " vcc" );
611+ // REMARK: we could combine DPP with cndmask (I tried it with inline assembly;
612+ // but it didn't work w/o errors alas)
613+
614+ const auto vv2 = dpp<0xa0 , 0xf , 0xf , true >(v2);
615+ const auto vv4 = dpp<0xa0 , 0xf , 0xf , true >(v4);
616+ const auto vv1 = dpp<0xf5 , 0xf , 0xf , true >(v1);
617+ const auto vv3 = dpp<0xf5 , 0xf , 0xf , true >(v3);
618+
619+ const auto u1 = __lane_id () % 2 == 0 ? v1 : vv2;
620+ const auto u2 = __lane_id () % 2 == 1 ? v2 : vv1;
621+ const auto u3 = __lane_id () % 2 == 0 ? v3 : vv4;
622+ const auto u4 = __lane_id () % 2 == 1 ? v4 : vv3;
623+
624+ const auto uu1 = dpp<0xee , 0xf , 0xf , true >(u1);
625+ const auto uu2 = dpp<0xee , 0xf , 0xf , true >(u2);
626+ const auto uu3 = dpp<0x44 , 0xf , 0xf , true >(u3);
627+ const auto uu4 = dpp<0x44 , 0xf , 0xf , true >(u4);
628+
629+ w1 = __lane_id () % 4 < 2 ? u1 : uu3;
630+ w2 = __lane_id () % 4 < 2 ? u2 : uu4;
631+ w3 = __lane_id () % 4 >= 2 ? u3 : uu1;
632+ w4 = __lane_id () % 4 >= 2 ? u4 : uu2;
661633}
662634
663635/*
0 commit comments