@@ -76,15 +76,31 @@ namespace ggml_cuda_mma {
7676 // For the A/C matrices this means I major == row major, J major == column major.
7777 // For the B matrix this means I major == column major, J major == row major.
7878 // MIRRORED == Each data value is held exactly once per thread subgroup.
79- DATA_LAYOUT_I_MAJOR = 0 , // Always used for Turing, Ampere, Ada Lovelace, consumer Blackwell.
80- DATA_LAYOUT_I_MAJOR_MIRRORED = 10 ,
81- DATA_LAYOUT_J_MAJOR_MIRRORED = 20 ,
79+ DATA_LAYOUT_I_MAJOR = 0 , // Always used for Turing, Ampere, Ada Lovelace, consumer Blackwell, matrix A&B for RDNA4 and CDNA.
80+ DATA_LAYOUT_J_MAJOR = 10 , // Matrix C for CDNA and RDNA4, int and float matrix C for RDNA3.
81+ DATA_LAYOUT_I_MAJOR_MIRRORED = 20 ,
82+ DATA_LAYOUT_J_MAJOR_MIRRORED = 30 ,
83+ DATA_LAYOUT_I_MAJOR_DUAL = 40 , // Matrix A&B for RDNA3.
8284 };
8385 // Implemented mma combinations are:
8486 // - (I_MAJOR, I_MAJOR) -> I_MAJOR
8587 // - (I_MAJOR, I_MAJOR_MIRRORED) -> I_MAJOR
8688 // - (I_MAJOR, J_MAJOR_MIRRORED) -> I_MAJOR
8789
90+ constexpr bool is_i_major (const data_layout dl) {
91+ return dl == DATA_LAYOUT_I_MAJOR ||
92+ dl == DATA_LAYOUT_I_MAJOR_MIRRORED ||
93+ dl == DATA_LAYOUT_I_MAJOR_DUAL;
94+ }
95+
96+ constexpr data_layout get_input_data_layout () {
97+ #if defined(RDNA3)
98+ return DATA_LAYOUT_I_MAJOR_DUAL;
99+ #else
100+ return DATA_LAYOUT_I_MAJOR;
101+ #endif // defined(RDNA3)
102+ }
103+
88104 template <int I_, int J_, typename T, data_layout ds_=DATA_LAYOUT_I_MAJOR>
89105 struct tile {};
90106
@@ -115,9 +131,9 @@ namespace ggml_cuda_mma {
115131 } else if constexpr (I == 32 && J == 4 ) {
116132 return threadIdx .x % 32 ;
117133 } else if constexpr (I == 16 && J == 16 ) {
118- return 4 * ( threadIdx .x / 16 ) + l ;
134+ return threadIdx .x % 16 ;
119135 } else if constexpr (I == 32 && J == 32 ) {
120- return 4 * ( threadIdx .x / 32 ) + 8 * (l / 4 ) + (l % 4 ) ;
136+ return threadIdx .x % 32 ;
121137 } else {
122138 NO_DEVICE_CODE;
123139 return -1 ;
@@ -132,9 +148,9 @@ namespace ggml_cuda_mma {
132148 } else if constexpr (I == 32 && J == 4 ) {
133149 return 2 * (threadIdx .x / 32 ) + l;
134150 } else if constexpr (I == 16 && J == 16 ) {
135- return threadIdx .x % 16 ;
151+ return 4 * ( threadIdx .x / 16 ) + l ;
136152 } else if constexpr (I == 32 && J == 32 ) {
137- return threadIdx .x % 32 ;
153+ return 4 * ( threadIdx .x / 32 ) + 8 * (l / 4 ) + (l % 4 ) ;
138154 } else {
139155 NO_DEVICE_CODE;
140156 return -1 ;
@@ -171,28 +187,19 @@ namespace ggml_cuda_mma {
171187 }
172188 }
173189#elif defined(AMD_WMMA_AVAILABLE)
174- #if defined(RDNA4)
175190 static constexpr int ne = I * J / 32 ;
176- #elif defined(RDNA3)
177- static constexpr int ne = (I == 16 && J == 16 ) ? I * J / 32 : I * J / 16 ;
178- #endif // defined(RDNA4)
179191 T x[ne] = {0 };
180192
181193 static constexpr __device__ bool supported () {
182194 if (I == 16 && J == 16 ) return true ;
195+ if (I == 16 && J == 8 ) return true ;
196+ if (I == 16 && J == 4 ) return true ;
183197 return false ;
184198 }
185199
186200 static __device__ __forceinline__ int get_i (const int l) {
187- if constexpr (I == 16 && J == 16 ) {
188- #if defined(RDNA4)
189- return 8 * (threadIdx .x / 16 ) + l;
190- #elif defined(RDNA3)
191- return 2 * l + (threadIdx .x / 16 );
192- #else
193- NO_DEVICE_CODE;
194- return -1 ;
195- #endif // defined(RDNA4)
201+ if constexpr (supported ()) {
202+ return threadIdx .x % 16 ;
196203 } else {
197204 NO_DEVICE_CODE;
198205 return -1 ;
@@ -201,7 +208,17 @@ namespace ggml_cuda_mma {
201208
202209 static __device__ __forceinline__ int get_j (const int l) {
203210 if constexpr (I == 16 && J == 16 ) {
204- return threadIdx .x % 16 ;
211+ // matrix C
212+ #if defined(RDNA3)
213+ return 2 * l + (threadIdx .x / 16 );
214+ #else
215+ return ne * (threadIdx .x / 16 ) + l;
216+ #endif // defined(RDNA3)
217+ } else if constexpr (I == 16 && J == 8 ) {
218+ // mmq input for RDNA4
219+ return ne * (threadIdx .x / 16 ) + l;
220+ } else if constexpr (I == 16 && J == 4 ) {
221+ return ne * (threadIdx .x / 16 ) + l;
205222 } else {
206223 NO_DEVICE_CODE;
207224 return -1 ;
@@ -293,12 +310,7 @@ namespace ggml_cuda_mma {
293310 }
294311 }
295312#elif defined(AMD_WMMA_AVAILABLE)
296- #if defined(RDNA3)
297- // RDNA3 has duplicated data as input.
298- static constexpr int ne = I * J / 32 * 2 ;
299- #else
300313 static constexpr int ne = I * J / 32 ;
301- #endif // defined(RDNA3)
302314 half2 x[ne] = {{0 .0f , 0 .0f }};
303315
304316 static constexpr __device__ bool supported () {
@@ -317,14 +329,7 @@ namespace ggml_cuda_mma {
317329
318330 static __device__ __forceinline__ int get_j (const int l) {
319331 if constexpr (I == 16 && J == 8 ) {
320- #if defined(RDNA4)
321332 return 4 * (threadIdx .x / 16 ) + l;
322- #elif defined(RDNA3)
323- return l;
324- #else
325- NO_DEVICE_CODE;
326- return -1 ;
327- #endif // defined(RDNA4)
328333 } else {
329334 NO_DEVICE_CODE;
330335 return -1 ;
@@ -382,42 +387,19 @@ namespace ggml_cuda_mma {
382387 static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;
383388
384389#if defined(AMD_WMMA_AVAILABLE)
385- #if defined(RDNA3)
386- // RDNA3 has duplicated data as input.
387- static constexpr int ne = I * J / 32 * 2 ;
388- #else
389390 static constexpr int ne = I * J / 32 ;
390- #endif // defined(RDNA3)
391391 nv_bfloat162 x[ne] = {{0 .0f , 0 .0f }};
392392
393393 static constexpr __device__ bool supported () {
394- if (I == 16 && J == 8 ) return true ;
395- return false ;
394+ return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::supported ();
396395 }
397396
398397 static __device__ __forceinline__ int get_i (const int l) {
399- if constexpr (I == 16 && J == 8 ) {
400- return threadIdx .x % 16 ;
401- } else {
402- NO_DEVICE_CODE;
403- return -1 ;
404- }
398+ return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::get_i (l);
405399 }
406400
407401 static __device__ __forceinline__ int get_j (const int l) {
408- if constexpr (I == 16 && J == 8 ) {
409- #if defined(RDNA4)
410- return 4 * (threadIdx .x / 16 ) + l;
411- #elif defined(RDNA3)
412- return l;
413- #else
414- NO_DEVICE_CODE;
415- return -1 ;
416- #endif // defined(RDNA4)
417- } else {
418- NO_DEVICE_CODE;
419- return -1 ;
420- }
402+ return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::get_j (l);
421403 }
422404#else
423405 static constexpr int ne = I * J / WARP_SIZE;
@@ -458,6 +440,28 @@ namespace ggml_cuda_mma {
458440#endif // defined(AMD_WMMA_AVAILABLE)
459441 };
460442
443+ template <int I_, int J_, typename T>
444+ struct tile <I_, J_, T, DATA_LAYOUT_J_MAJOR> {
445+ static constexpr int I = I_;
446+ static constexpr int J = J_;
447+ static constexpr data_layout dl = DATA_LAYOUT_J_MAJOR;
448+
449+ static constexpr int ne = tile<I_, J_, T, DATA_LAYOUT_I_MAJOR>::ne;
450+ T x[ne] = {0 };
451+
452+ static constexpr __device__ bool supported () {
453+ return tile<I_, J_, T, DATA_LAYOUT_I_MAJOR>::supported ();
454+ }
455+
456+ static __device__ __forceinline__ int get_i (const int l) {
457+ return tile<I_, J_, T, DATA_LAYOUT_I_MAJOR>::get_j (l);
458+ }
459+
460+ static __device__ __forceinline__ int get_j (const int l) {
461+ return tile<I_, J_, T, DATA_LAYOUT_I_MAJOR>::get_i (l);
462+ }
463+ };
464+
461465 template <int I_, int J_>
462466 struct tile <I_, J_, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> {
463467 static constexpr int I = I_;
@@ -524,6 +528,42 @@ namespace ggml_cuda_mma {
524528 }
525529 };
526530
531+ template <int I_, int J_, typename T>
532+ struct tile <I_, J_, T, DATA_LAYOUT_I_MAJOR_DUAL> {
533+ static constexpr int I = I_;
534+ static constexpr int J = J_;
535+ static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_DUAL;
536+
537+ static constexpr int ne = I * J / 32 * 2 ;
538+
539+ T x[ne] = {0 };
540+
541+ static constexpr __device__ bool supported () {
542+ if (I == 16 && J == 16 ) return true ;
543+ if (I == 16 && J == 8 ) return true ;
544+ if (I == 16 && J == 4 ) return true ;
545+ return false ;
546+ }
547+
548+ static __device__ __forceinline__ int get_i (const int l) {
549+ if constexpr (supported ()) {
550+ return threadIdx .x % 16 ;
551+ } else {
552+ NO_DEVICE_CODE;
553+ return -1 ;
554+ }
555+ }
556+
557+ static __device__ __forceinline__ int get_j (const int l) {
558+ if constexpr (supported ()) {
559+ return l;
560+ } else {
561+ NO_DEVICE_CODE;
562+ return -1 ;
563+ }
564+ }
565+ };
566+
527567#if defined(TURING_MMA_AVAILABLE)
528568 template <int I, int J>
529569 static __device__ __forceinline__ tile<I, J/2 , half2> get_half2 (const tile<I, J, float > & tile_float) {
@@ -569,55 +609,28 @@ namespace ggml_cuda_mma {
569609 t.x [l] = xs0[t.get_i (l)*stride + t.get_j (l)];
570610 }
571611 } else {
572- int64_t * xi = (int64_t *) t.x ;
573- const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx .x % t.I ) * stride + 2 * (threadIdx .x / t.I ));
574- xi[0 ] = xs[0 ];
612+ ggml_cuda_memcpy_1<sizeof (t.x )>(t.x , xs0 + t.get_i (0 ) * stride + t.get_j (0 ));
575613 }
576614#elif defined(AMD_WMMA_AVAILABLE)
577- if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
578- #if defined(RDNA4)
579- ggml_cuda_memcpy_1<sizeof (t.x )>(t.x , xs0 + t.get_i (0 ) * stride + t.get_j (0 ));
580- #elif defined(RDNA3)
581- ggml_cuda_memcpy_1<sizeof (t.x )/2 >(t.x , xs0 + t.get_i (0 ) * stride + t.get_j (0 ));
582- ggml_cuda_memcpy_1<sizeof (t.x )/2 >(t.x + t.ne /2 , xs0 + t.get_i (0 ) * stride + t.get_j (t.ne /2 ));
583- #else
584- NO_DEVICE_CODE;
585- #endif // defined(RDNA4)
586- } else if constexpr (std::is_same_v<T, int >) {
587- if constexpr (I == 16 && J == 4 ) {
588- int64_t * xi = (int64_t *) t.x ;
589- #if defined(RDNA4)
590- const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx .x % t.I ) * stride + 2 * (threadIdx .x / t.I ));
591- xi[0 ] = xs[0 ];
592- #elif defined(RDNA3)
593- static_assert (tile<I,J,T>::ne >= 4 , " fragment too small" );
594- const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx .x % t.I ) * stride);
595- xi[0 ] = xs[0 ];
596- xi[1 ] = xs[1 ];
597- #endif // defined(RDNA4)
598- } else if constexpr (I == 16 && J == 8 ) {
599- int64_t * xi = (int64_t *) t.x ;
600- #if defined(RDNA4)
601- const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx .x % t.I ) * stride + 4 * (threadIdx .x / t.I ));
602- xi[0 ] = xs[0 ];
603-
604- const int64_t * xs1 = (int64_t *) ((const int *) xs0 + (threadIdx .x % t.I ) * stride + 4 * (threadIdx .x / t.I ) + 2 );
605- xi[1 ] = xs1[0 ];
606- #elif defined(RDNA3)
607- static_assert (tile<I,J,T>::ne >= 8 , " fragment too small" );
608- const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx .x % t.I ) * stride);
609- // contiguous four 64-bit chunks per lane for the wider RDNA3 fragment
610- xi[0 ] = xs[0 ];
611- xi[1 ] = xs[1 ];
612- const int64_t * xs1 = xs + 2 ;
613- xi[2 ] = xs1[0 ];
614- xi[3 ] = xs1[1 ];
615- #endif // defined(RDNA4)
615+ // All wmma layout has contiguous data when i-major.
616+ if constexpr (is_i_major (dl)) {
617+ // the data must be aligned to 16 bytes when bigger than ggml_cuda_get_max_cpy_bytes()
618+ constexpr int aligned_copy_bytes = ggml_cuda_get_max_cpy_bytes ();
619+ if constexpr (sizeof (t.x ) > aligned_copy_bytes) {
620+ static_assert (sizeof (t.x ) % aligned_copy_bytes == 0 , " bad type size" );
621+ constexpr int aligned_copy_count = sizeof (t.x )/aligned_copy_bytes;
622+ #pragma unroll
623+ for (int i = 0 ; i < aligned_copy_count; ++i) {
624+ ggml_cuda_memcpy_1<aligned_copy_bytes>(t.x + t.ne /aligned_copy_count*i, xs0 + t.get_i (0 ) * stride + t.get_j (t.ne /aligned_copy_count*i));
625+ }
616626 } else {
617- NO_DEVICE_CODE ;
627+ ggml_cuda_memcpy_1< sizeof (t. x )>(t. x , xs0 + t. get_i ( 0 ) * stride + t. get_j ( 0 )) ;
618628 }
619629 } else {
620- NO_DEVICE_CODE;
630+ #pragma unroll
631+ for (int l = 0 ; l < t.ne ; ++l) {
632+ t.x [l] = xs0[t.get_i (l)*stride + t.get_j (l)];
633+ }
621634 }
622635#else
623636#pragma unroll
@@ -660,9 +673,9 @@ namespace ggml_cuda_mma {
660673#endif // TURING_MMA_AVAILABLE
661674 }
662675
663- template <typename T>
676+ template <typename T, data_layout dl >
664677 static __device__ __forceinline__ void load_ldmatrix (
665- tile<16 , 8 , T> & t, const T * __restrict__ xs0, const int stride) {
678+ tile<16 , 8 , T, dl > & t, const T * __restrict__ xs0, const int stride) {
666679#if defined(TURING_MMA_AVAILABLE)
667680 int * xi = (int * ) t.x ;
668681 const int * xs = (const int *) xs0 + (threadIdx .x % t.I ) * stride + (threadIdx .x / t.I ) * (t.J / 2 );
@@ -832,8 +845,9 @@ namespace ggml_cuda_mma {
832845#endif // TURING_MMA_AVAILABLE
833846 }
834847
848+ template <data_layout dl_ab, data_layout dl_d>
835849 static __device__ __forceinline__ void mma (
836- tile<16 , 8 , float > & D, const tile<16 , 8 , float > & A, const tile<8 , 8 , float > & B) {
850+ tile<16 , 8 , float , dl_d > & D, const tile<16 , 8 , float , dl_ab > & A, const tile<8 , 8 , float , dl_ab > & B) {
837851#ifdef AMPERE_MMA_AVAILABLE
838852 const int * Axi = (const int *) A.x ;
839853 const int * Bxi = (const int *) B.x ;
@@ -887,8 +901,9 @@ namespace ggml_cuda_mma {
887901#endif // AMPERE_MMA_AVAILABLE
888902 }
889903
904+ template <data_layout dl_ab, data_layout dl_d>
890905 static __device__ __forceinline__ void mma (
891- tile<16 , 16 , float > & D, const tile<16 , 8 , half2> & A, const tile<16 , 8 , half2> & B) {
906+ tile<16 , 16 , float , dl_d > & D, const tile<16 , 8 , half2, dl_ab > & A, const tile<16 , 8 , half2, dl_ab > & B) {
892907#ifdef TURING_MMA_AVAILABLE
893908 const int * Axi = (const int *) A.x ;
894909 const int * Bxi = (const int *) B.x ;
@@ -940,8 +955,9 @@ namespace ggml_cuda_mma {
940955#endif // TURING_MMA_AVAILABLE
941956 }
942957
958+ template <data_layout dl_ab, data_layout dl_d>
943959 static __device__ __forceinline__ void mma (
944- tile<16 , 16 , float > & D, const tile<16 , 8 , nv_bfloat162> & A, const tile<16 , 8 , nv_bfloat162> & B) {
960+ tile<16 , 16 , float , dl_d > & D, const tile<16 , 8 , nv_bfloat162, dl_ab > & A, const tile<16 , 8 , nv_bfloat162, dl_ab > & B) {
945961#if defined(AMD_WMMA_AVAILABLE)
946962#if defined(RDNA4)
947963 using bf16x8_t = __attribute__ ((ext_vector_type (8 ))) __bf16;
@@ -967,8 +983,9 @@ namespace ggml_cuda_mma {
967983#endif // AMPERE_MMA_AVAILABLE
968984 }
969985
986+ template <data_layout dl_d, data_layout dl_ab>
970987 static __device__ __forceinline__ void mma (
971- tile<16 , 16 , int > & D, const tile<16 , 8 , int > & A, const tile<16 , 8 , int > & B) {
988+ tile<16 , 16 , int , dl_d > & D, const tile<16 , 8 , int , dl_ab > & A, const tile<16 , 8 , int , dl_ab > & B) {
972989#if defined(AMD_MFMA_AVAILABLE)
973990 using int32x4_t = __attribute__ ((__vector_size__ (4 * sizeof (int )))) int ;
974991 int32x4_t * acc = (int32x4_t *) D.x ;
@@ -1122,8 +1139,9 @@ namespace ggml_cuda_mma {
11221139#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
11231140 }
11241141
1125- static __device__ __forceinline__ void mma (
1126- tile<16 , 16 , int > & D, const tile<16 , 4 , int > & A, const tile<16 , 4 , int > & B) {
1142+ template <data_layout dl_d, data_layout dl_ab>
1143+ static __device__ __forceinline__ void mma (
1144+ tile<16 , 16 , int , dl_d> & D, const tile<16 , 4 , int , dl_ab> & A, const tile<16 , 4 , int , dl_ab> & B) {
11271145#if defined(AMD_WMMA_AVAILABLE)
11281146 using int32x8_t = __attribute__ ((__vector_size__ (8 * sizeof (int )))) int ;
11291147 int32x8_t * acc = (int32x8_t *) D.x ;
0 commit comments