Skip to content

Commit de6e0eb

Browse files
committed
Fix build; start RDNA support
1 parent c501e14 commit de6e0eb

5 files changed

Lines changed: 23 additions & 13 deletions

File tree

tensorforge/include/tensorforge_aux.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
#include "tensorforge_aux.h"
2+
13
#include <hip/hip_runtime.h>
24
#include <iostream>
35

tensorforge/include/tensorforge_aux.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
#include "tensorforge_aux.h"
2+
13
#include <cuda_runtime.h>
24
#include <iostream>
35

tensorforge/include/tensorforge_aux_sycl.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
#include "tensorforge_aux.h"
2+
13
#include <iostream>
24
#include <sycl/sycl.hpp>
35

tensorforge/include/tensorforge_aux_target.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
#include "tensorforge_aux.h"
2+
13
#include <iostream>
24
#include <omp.h>
35

tensorforge/include/tensorforge_device/hip.h

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -172,26 +172,34 @@ __device__ __forceinline__ T swap(T value) {
172172
return value;
173173
}
174174

175+
#ifdef __GFX9__
176+
#define CMVCC ", vcc"
177+
#define CMFI ""
178+
#else
179+
#define CMVCC ""
180+
#define CMFI " fi:1 "
181+
#endif
182+
175183
// FMAC_DPP inline assembly
176184
// !! MAY DISREGARD WAIT STATES !!
177185

178186
#define ISTRINGIFY(x) #x
179187
#define STR(x) ISTRINGIFY(x)
180188
#define FMADPP4(pos, c, a, b) \
181189
__asm("v_fmac_f32_dpp %0, %1, %2 quad_perm:[" STR(pos) "," STR(pos) "," STR( \
182-
pos) "," STR(pos) "] row_mask:0xf bank_mask:0xf bound_ctrl:1" \
190+
pos) "," STR(pos) "] row_mask:0xf bank_mask:0xf bound_ctrl:1" CMFI \
183191
: "+v"(c) \
184192
: "v"(a), "v"(b) \
185193
:)
186194
#define FMADPP16(pos, c, a, b) \
187195
__asm("v_fmac_f32_dpp %0, %1, %2 row_newbcast:" STR( \
188-
pos) " row_mask:0xf bank_mask:0xf bound_ctrl:1" \
196+
pos) " row_mask:0xf bank_mask:0xf bound_ctrl:1" CMFI \
189197
: "+v"(c) \
190198
: "v"(a), "v"(b) \
191199
:)
192200
#define DMADPP16(pos, c, a, b) \
193201
__asm("v_fmac_f64_dpp %0, %1, %2 row_newbcast:" STR( \
194-
pos) " row_mask:0xf bank_mask:0xf bound_ctrl:1" \
202+
pos) " row_mask:0xf bank_mask:0xf bound_ctrl:1" CMFI \
195203
: "+v"(c) \
196204
: "v"(a), "v"(b) \
197205
:)
@@ -590,17 +598,13 @@ transpose16x16b32(T &w1, T &w2, T &w3, T &w4, T &w5, T &w6, T &w7, T &w8, T &w9,
590598
// const T u1 = dppUpdate<0x128, 0b1010, 0b1111, true>(v1, v5);
591599
}
592600

593-
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__) || \
594-
defined(__gfx950__)
595-
596-
#define CMVCC ", vcc"
597-
598601
#define CM4STR(p1, p2, p3, p4, c, a, b) \
599-
"v_cndmask_b32_dpp " c ", " a ", " b CMVCC " quad_perm:[" STR(p1) "," STR( \
600-
p2) "," STR(p3) "," STR(p4) "] row_mask:0xf bank_mask:0xf bound_ctrl:1"
602+
"v_cndmask_b32_dpp " c ", " a ", " b CMVCC \
603+
" quad_perm:[" STR(p1) "," STR(p2) "," STR(p3) "," STR( \
604+
p4) "] row_mask:0xf bank_mask:0xf bound_ctrl:1" CMFI
601605
#define CMRSTR(cnt, c, a, b) \
602606
"v_cndmask_b32_dpp " c ", " a ", " b CMVCC \
603-
" row_ror:" STR(cnt) " row_mask:0xf bank_mask:0xf bound_ctrl:1"
607+
" row_ror:" STR(cnt) " row_mask:0xf bank_mask:0xf bound_ctrl:1" CMFI
604608

605609
template <typename T>
606610
__device__ __forceinline__ void transpose4x4b32(T &w1, T &w2, T &w3, T &w4,
@@ -643,8 +647,6 @@ __device__ __forceinline__ void transpose4x4b32(T &w1, T &w2, T &w3, T &w4,
643647
: "vcc");
644648
}
645649

646-
#endif
647-
648650
/*
649651
class Buffer {
650652
public:

0 commit comments

Comments
 (0)