Skip to content

Commit acec774

Browse files
zhang-hui-yulozhang hui
andauthored
HIP: Refactor mma for RDNA and CDNA (ggml-org#17990)
* mma.cuh for rdna4 * mma for rdna3 * mmq for rdna4 * mmq for rdna3 * align i-major and j-major * cdna * fix cuda error * add missing tile of mfma * fix j-major wrong ne on CDNA * fix gramma and empty spaces --------- Co-authored-by: zhang hui <you@example.com>
1 parent 5c0d188 commit acec774

3 files changed

Lines changed: 179 additions & 150 deletions

File tree

ggml/src/ggml-cuda/mma.cuh

Lines changed: 129 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -76,15 +76,31 @@ namespace ggml_cuda_mma {
7676
// For the A/C matrices this means I major == row major, J major == column major.
7777
// For the B matrix this means I major == column major, J major == row major.
7878
// MIRRORED == Each data value is held exactly once per thread subgroup.
79-
DATA_LAYOUT_I_MAJOR = 0, // Always used for Turing, Ampere, Ada Lovelace, consumer Blackwell.
80-
DATA_LAYOUT_I_MAJOR_MIRRORED = 10,
81-
DATA_LAYOUT_J_MAJOR_MIRRORED = 20,
79+
DATA_LAYOUT_I_MAJOR = 0, // Always used for Turing, Ampere, Ada Lovelace, consumer Blackwell, matrix A&B for RDNA4 and CDNA.
80+
DATA_LAYOUT_J_MAJOR = 10, // Matrix C for CDNA and RDNA4, int and float matrix C for RDNA3.
81+
DATA_LAYOUT_I_MAJOR_MIRRORED = 20,
82+
DATA_LAYOUT_J_MAJOR_MIRRORED = 30,
83+
DATA_LAYOUT_I_MAJOR_DUAL = 40, // Matrix A&B for RDNA3.
8284
};
8385
// Implemented mma combinations are:
8486
// - (I_MAJOR, I_MAJOR) -> I_MAJOR
8587
// - (I_MAJOR, I_MAJOR_MIRRORED) -> I_MAJOR
8688
// - (I_MAJOR, J_MAJOR_MIRRORED) -> I_MAJOR
8789

90+
constexpr bool is_i_major(const data_layout dl) {
91+
return dl == DATA_LAYOUT_I_MAJOR ||
92+
dl == DATA_LAYOUT_I_MAJOR_MIRRORED ||
93+
dl == DATA_LAYOUT_I_MAJOR_DUAL;
94+
}
95+
96+
constexpr data_layout get_input_data_layout() {
97+
#if defined(RDNA3)
98+
return DATA_LAYOUT_I_MAJOR_DUAL;
99+
#else
100+
return DATA_LAYOUT_I_MAJOR;
101+
#endif // defined(RDNA3)
102+
}
103+
88104
template <int I_, int J_, typename T, data_layout ds_=DATA_LAYOUT_I_MAJOR>
89105
struct tile {};
90106

@@ -115,9 +131,9 @@ namespace ggml_cuda_mma {
115131
} else if constexpr (I == 32 && J == 4) {
116132
return threadIdx.x % 32;
117133
} else if constexpr (I == 16 && J == 16) {
118-
return 4 * (threadIdx.x / 16) + l;
134+
return threadIdx.x % 16;
119135
} else if constexpr (I == 32 && J == 32) {
120-
return 4 * (threadIdx.x / 32) + 8 * (l / 4) + (l % 4);
136+
return threadIdx.x % 32;
121137
} else {
122138
NO_DEVICE_CODE;
123139
return -1;
@@ -132,9 +148,9 @@ namespace ggml_cuda_mma {
132148
} else if constexpr (I == 32 && J == 4) {
133149
return 2 * (threadIdx.x / 32) + l;
134150
} else if constexpr (I == 16 && J == 16) {
135-
return threadIdx.x % 16;
151+
return 4 * (threadIdx.x / 16) + l;
136152
} else if constexpr (I == 32 && J == 32) {
137-
return threadIdx.x % 32;
153+
return 4 * (threadIdx.x / 32) + 8 * (l / 4) + (l % 4);
138154
} else {
139155
NO_DEVICE_CODE;
140156
return -1;
@@ -171,28 +187,19 @@ namespace ggml_cuda_mma {
171187
}
172188
}
173189
#elif defined(AMD_WMMA_AVAILABLE)
174-
#if defined(RDNA4)
175190
static constexpr int ne = I * J / 32;
176-
#elif defined(RDNA3)
177-
static constexpr int ne = (I == 16 && J == 16) ? I * J / 32 : I * J / 16;
178-
#endif // defined(RDNA4)
179191
T x[ne] = {0};
180192

181193
static constexpr __device__ bool supported() {
182194
if (I == 16 && J == 16) return true;
195+
if (I == 16 && J == 8) return true;
196+
if (I == 16 && J == 4) return true;
183197
return false;
184198
}
185199

186200
static __device__ __forceinline__ int get_i(const int l) {
187-
if constexpr (I == 16 && J == 16) {
188-
#if defined(RDNA4)
189-
return 8 * (threadIdx.x / 16) + l;
190-
#elif defined(RDNA3)
191-
return 2 * l + (threadIdx.x / 16);
192-
#else
193-
NO_DEVICE_CODE;
194-
return -1;
195-
#endif // defined(RDNA4)
201+
if constexpr (supported()) {
202+
return threadIdx.x % 16;
196203
} else {
197204
NO_DEVICE_CODE;
198205
return -1;
@@ -201,7 +208,17 @@ namespace ggml_cuda_mma {
201208

202209
static __device__ __forceinline__ int get_j(const int l) {
203210
if constexpr (I == 16 && J == 16) {
204-
return threadIdx.x % 16;
211+
// matrix C
212+
#if defined(RDNA3)
213+
return 2 * l + (threadIdx.x / 16);
214+
#else
215+
return ne * (threadIdx.x / 16) + l;
216+
#endif // defined(RDNA3)
217+
} else if constexpr (I == 16 && J == 8) {
218+
// mmq input for RDNA4
219+
return ne * (threadIdx.x / 16) + l;
220+
} else if constexpr (I == 16 && J == 4) {
221+
return ne * (threadIdx.x / 16) + l;
205222
} else {
206223
NO_DEVICE_CODE;
207224
return -1;
@@ -293,12 +310,7 @@ namespace ggml_cuda_mma {
293310
}
294311
}
295312
#elif defined(AMD_WMMA_AVAILABLE)
296-
#if defined(RDNA3)
297-
// RDNA3 has duplicated data as input.
298-
static constexpr int ne = I * J / 32 * 2;
299-
#else
300313
static constexpr int ne = I * J / 32;
301-
#endif // defined(RDNA3)
302314
half2 x[ne] = {{0.0f, 0.0f}};
303315

304316
static constexpr __device__ bool supported() {
@@ -317,14 +329,7 @@ namespace ggml_cuda_mma {
317329

318330
static __device__ __forceinline__ int get_j(const int l) {
319331
if constexpr (I == 16 && J == 8) {
320-
#if defined(RDNA4)
321332
return 4 * (threadIdx.x / 16) + l;
322-
#elif defined(RDNA3)
323-
return l;
324-
#else
325-
NO_DEVICE_CODE;
326-
return -1;
327-
#endif // defined(RDNA4)
328333
} else {
329334
NO_DEVICE_CODE;
330335
return -1;
@@ -382,42 +387,19 @@ namespace ggml_cuda_mma {
382387
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;
383388

384389
#if defined(AMD_WMMA_AVAILABLE)
385-
#if defined(RDNA3)
386-
// RDNA3 has duplicated data as input.
387-
static constexpr int ne = I * J / 32 * 2;
388-
#else
389390
static constexpr int ne = I * J / 32;
390-
#endif // defined(RDNA3)
391391
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
392392

393393
static constexpr __device__ bool supported() {
394-
if (I == 16 && J == 8) return true;
395-
return false;
394+
return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::supported();
396395
}
397396

398397
static __device__ __forceinline__ int get_i(const int l) {
399-
if constexpr (I == 16 && J == 8) {
400-
return threadIdx.x % 16;
401-
} else {
402-
NO_DEVICE_CODE;
403-
return -1;
404-
}
398+
return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::get_i(l);
405399
}
406400

407401
static __device__ __forceinline__ int get_j(const int l) {
408-
if constexpr (I == 16 && J == 8) {
409-
#if defined(RDNA4)
410-
return 4 * (threadIdx.x / 16) + l;
411-
#elif defined(RDNA3)
412-
return l;
413-
#else
414-
NO_DEVICE_CODE;
415-
return -1;
416-
#endif // defined(RDNA4)
417-
} else {
418-
NO_DEVICE_CODE;
419-
return -1;
420-
}
402+
return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::get_j(l);
421403
}
422404
#else
423405
static constexpr int ne = I * J / WARP_SIZE;
@@ -458,6 +440,28 @@ namespace ggml_cuda_mma {
458440
#endif // defined(AMD_WMMA_AVAILABLE)
459441
};
460442

443+
template <int I_, int J_, typename T>
444+
struct tile<I_, J_, T, DATA_LAYOUT_J_MAJOR> {
445+
static constexpr int I = I_;
446+
static constexpr int J = J_;
447+
static constexpr data_layout dl = DATA_LAYOUT_J_MAJOR;
448+
449+
static constexpr int ne = tile<I_, J_, T, DATA_LAYOUT_I_MAJOR>::ne;
450+
T x[ne] = {0};
451+
452+
static constexpr __device__ bool supported() {
453+
return tile<I_, J_, T, DATA_LAYOUT_I_MAJOR>::supported();
454+
}
455+
456+
static __device__ __forceinline__ int get_i(const int l) {
457+
return tile<I_, J_, T, DATA_LAYOUT_I_MAJOR>::get_j(l);
458+
}
459+
460+
static __device__ __forceinline__ int get_j(const int l) {
461+
return tile<I_, J_, T, DATA_LAYOUT_I_MAJOR>::get_i(l);
462+
}
463+
};
464+
461465
template <int I_, int J_>
462466
struct tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> {
463467
static constexpr int I = I_;
@@ -524,6 +528,42 @@ namespace ggml_cuda_mma {
524528
}
525529
};
526530

531+
template <int I_, int J_, typename T>
532+
struct tile<I_, J_, T, DATA_LAYOUT_I_MAJOR_DUAL> {
533+
static constexpr int I = I_;
534+
static constexpr int J = J_;
535+
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_DUAL;
536+
537+
static constexpr int ne = I * J / 32 * 2;
538+
539+
T x[ne] = {0};
540+
541+
static constexpr __device__ bool supported() {
542+
if (I == 16 && J == 16) return true;
543+
if (I == 16 && J == 8) return true;
544+
if (I == 16 && J == 4) return true;
545+
return false;
546+
}
547+
548+
static __device__ __forceinline__ int get_i(const int l) {
549+
if constexpr (supported()) {
550+
return threadIdx.x % 16;
551+
} else {
552+
NO_DEVICE_CODE;
553+
return -1;
554+
}
555+
}
556+
557+
static __device__ __forceinline__ int get_j(const int l) {
558+
if constexpr (supported()) {
559+
return l;
560+
} else {
561+
NO_DEVICE_CODE;
562+
return -1;
563+
}
564+
}
565+
};
566+
527567
#if defined(TURING_MMA_AVAILABLE)
528568
template <int I, int J>
529569
static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
@@ -569,55 +609,28 @@ namespace ggml_cuda_mma {
569609
t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
570610
}
571611
} else {
572-
int64_t * xi = (int64_t *) t.x;
573-
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
574-
xi[0] = xs[0];
612+
ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
575613
}
576614
#elif defined(AMD_WMMA_AVAILABLE)
577-
if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
578-
#if defined(RDNA4)
579-
ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
580-
#elif defined(RDNA3)
581-
ggml_cuda_memcpy_1<sizeof(t.x)/2>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
582-
ggml_cuda_memcpy_1<sizeof(t.x)/2>(t.x + t.ne/2, xs0 + t.get_i(0) * stride + t.get_j(t.ne/2));
583-
#else
584-
NO_DEVICE_CODE;
585-
#endif // defined(RDNA4)
586-
} else if constexpr (std::is_same_v<T, int>) {
587-
if constexpr (I == 16 && J == 4) {
588-
int64_t * xi = (int64_t *) t.x;
589-
#if defined(RDNA4)
590-
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
591-
xi[0] = xs[0];
592-
#elif defined(RDNA3)
593-
static_assert(tile<I,J,T>::ne >= 4, "fragment too small");
594-
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride);
595-
xi[0] = xs[0];
596-
xi[1] = xs[1];
597-
#endif // defined(RDNA4)
598-
} else if constexpr (I == 16 && J == 8) {
599-
int64_t * xi = (int64_t *) t.x;
600-
#if defined(RDNA4)
601-
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I));
602-
xi[0] = xs[0];
603-
604-
const int64_t * xs1 = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I) + 2);
605-
xi[1] = xs1[0];
606-
#elif defined(RDNA3)
607-
static_assert(tile<I,J,T>::ne >= 8, "fragment too small");
608-
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride);
609-
// contiguous four 64-bit chunks per lane for the wider RDNA3 fragment
610-
xi[0] = xs[0];
611-
xi[1] = xs[1];
612-
const int64_t * xs1 = xs + 2;
613-
xi[2] = xs1[0];
614-
xi[3] = xs1[1];
615-
#endif // defined(RDNA4)
615+
// All wmma layout has contiguous data when i-major.
616+
if constexpr (is_i_major(dl)) {
617+
// the data must be aligned to 16 bytes when bigger than ggml_cuda_get_max_cpy_bytes()
618+
constexpr int aligned_copy_bytes = ggml_cuda_get_max_cpy_bytes();
619+
if constexpr (sizeof(t.x) > aligned_copy_bytes) {
620+
static_assert(sizeof(t.x) % aligned_copy_bytes == 0, "bad type size");
621+
constexpr int aligned_copy_count = sizeof(t.x)/aligned_copy_bytes;
622+
#pragma unroll
623+
for (int i = 0; i < aligned_copy_count; ++i) {
624+
ggml_cuda_memcpy_1<aligned_copy_bytes>(t.x + t.ne/aligned_copy_count*i, xs0 + t.get_i(0) * stride + t.get_j(t.ne/aligned_copy_count*i));
625+
}
616626
} else {
617-
NO_DEVICE_CODE;
627+
ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
618628
}
619629
} else {
620-
NO_DEVICE_CODE;
630+
#pragma unroll
631+
for (int l = 0; l < t.ne; ++l) {
632+
t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
633+
}
621634
}
622635
#else
623636
#pragma unroll
@@ -660,9 +673,9 @@ namespace ggml_cuda_mma {
660673
#endif // TURING_MMA_AVAILABLE
661674
}
662675

663-
template <typename T>
676+
template <typename T, data_layout dl>
664677
static __device__ __forceinline__ void load_ldmatrix(
665-
tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
678+
tile<16, 8, T, dl> & t, const T * __restrict__ xs0, const int stride) {
666679
#if defined(TURING_MMA_AVAILABLE)
667680
int * xi = (int * ) t.x;
668681
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
@@ -832,8 +845,9 @@ namespace ggml_cuda_mma {
832845
#endif // TURING_MMA_AVAILABLE
833846
}
834847

848+
template <data_layout dl_ab, data_layout dl_d>
835849
static __device__ __forceinline__ void mma(
836-
tile<16, 8, float> & D, const tile<16, 8, float> & A, const tile<8, 8, float> & B) {
850+
tile<16, 8, float, dl_d> & D, const tile<16, 8, float, dl_ab> & A, const tile<8, 8, float, dl_ab> & B) {
837851
#ifdef AMPERE_MMA_AVAILABLE
838852
const int * Axi = (const int *) A.x;
839853
const int * Bxi = (const int *) B.x;
@@ -887,8 +901,9 @@ namespace ggml_cuda_mma {
887901
#endif // AMPERE_MMA_AVAILABLE
888902
}
889903

904+
template <data_layout dl_ab, data_layout dl_d>
890905
static __device__ __forceinline__ void mma(
891-
tile<16, 16, float> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) {
906+
tile<16, 16, float, dl_d> & D, const tile<16, 8, half2, dl_ab> & A, const tile<16, 8, half2, dl_ab> & B) {
892907
#ifdef TURING_MMA_AVAILABLE
893908
const int * Axi = (const int *) A.x;
894909
const int * Bxi = (const int *) B.x;
@@ -940,8 +955,9 @@ namespace ggml_cuda_mma {
940955
#endif // TURING_MMA_AVAILABLE
941956
}
942957

958+
template <data_layout dl_ab, data_layout dl_d>
943959
static __device__ __forceinline__ void mma(
944-
tile<16, 16, float> & D, const tile<16, 8, nv_bfloat162> & A, const tile<16, 8, nv_bfloat162> & B) {
960+
tile<16, 16, float, dl_d> & D, const tile<16, 8, nv_bfloat162, dl_ab> & A, const tile<16, 8, nv_bfloat162, dl_ab> & B) {
945961
#if defined(AMD_WMMA_AVAILABLE)
946962
#if defined(RDNA4)
947963
using bf16x8_t = __attribute__((ext_vector_type(8))) __bf16;
@@ -967,8 +983,9 @@ namespace ggml_cuda_mma {
967983
#endif // AMPERE_MMA_AVAILABLE
968984
}
969985

986+
template <data_layout dl_d, data_layout dl_ab>
970987
static __device__ __forceinline__ void mma(
971-
tile<16, 16, int> & D, const tile<16, 8, int> & A, const tile<16, 8, int> & B) {
988+
tile<16, 16, int, dl_d> & D, const tile<16, 8, int, dl_ab> & A, const tile<16, 8, int, dl_ab> & B) {
972989
#if defined(AMD_MFMA_AVAILABLE)
973990
using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;
974991
int32x4_t * acc = (int32x4_t *) D.x;
@@ -1122,8 +1139,9 @@ namespace ggml_cuda_mma {
11221139
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
11231140
}
11241141

1125-
static __device__ __forceinline__ void mma(
1126-
tile<16, 16, int> & D, const tile<16, 4, int> & A, const tile<16, 4, int> & B) {
1142+
template <data_layout dl_d, data_layout dl_ab>
1143+
static __device__ __forceinline__ void mma(
1144+
tile<16, 16, int, dl_d> & D, const tile<16, 4, int, dl_ab> & A, const tile<16, 4, int, dl_ab> & B) {
11271145
#if defined(AMD_WMMA_AVAILABLE)
11281146
using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int;
11291147
int32x8_t * acc = (int32x8_t *) D.x;

0 commit comments

Comments
 (0)