Skip to content

Commit ac6ce45

Browse files
author
Randy L
committed
add blockidx, blockdim, threadidx intrinsics and change vecadd kernel.
1 parent 8edf8fb commit ac6ce45

5 files changed

Lines changed: 184 additions & 7 deletions

File tree

hw/rtl/VX_gpu_pkg.sv

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -531,6 +531,7 @@ package VX_gpu_pkg;
531531
logic [`XLEN-1:0] startup_addr;
532532
logic [`XLEN-1:0] startup_arg;
533533
logic [7:0] mpm_class;
534+
logic [2:0][31:0] block_dim;
534535
} base_dcrs_t;
535536

536537
typedef struct packed {

hw/rtl/core/VX_csr_data.sv

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,11 @@ import VX_fpu_pkg::*;
189189
`VX_CSR_CTA_Y : read_data_rw_w = cta_csrs[read_wid].cta_y;
190190
`VX_CSR_CTA_Z : read_data_rw_w = cta_csrs[read_wid].cta_z;
191191
`VX_CSR_CTA_ID : read_data_rw_w = cta_csrs[read_wid].cta_id;
192+
`VX_CSR_CTA_WARP_ID: read_data_ro_w = cta_csrs[read_wid].local_warp_id;
193+
194+
`VX_CSR_BLOCK_DIM_X: read_data_ro_w = `XLEN'(base_dcrs.block_dim[0]);
195+
`VX_CSR_BLOCK_DIM_Y: read_data_ro_w = `XLEN'(base_dcrs.block_dim[1]);
196+
`VX_CSR_BLOCK_DIM_Z: read_data_ro_w = `XLEN'(base_dcrs.block_dim[2]);
192197

193198
`VX_CSR_WARP_ID : read_data_ro_w = `XLEN'(read_wid);
194199
`VX_CSR_CORE_ID : read_data_ro_w = `XLEN'(CORE_ID);

hw/rtl/core/VX_dcr_data.sv

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ module VX_dcr_data import VX_gpu_pkg::*; (
4040
`VX_DCR_BASE_STARTUP_ARG1 : dcrs.startup_arg[63:32] <= dcr_bus_if.write_data;
4141
`endif
4242
`VX_DCR_BASE_MPM_CLASS : dcrs.mpm_class <= dcr_bus_if.write_data[7:0];
43+
`VX_DCR_BASE_BLOCK_DIM0 : dcrs.block_dim[0] <= dcr_bus_if.write_data;
44+
`VX_DCR_BASE_BLOCK_DIM1 : dcrs.block_dim[1] <= dcr_bus_if.write_data;
45+
`VX_DCR_BASE_BLOCK_DIM2 : dcrs.block_dim[2] <= dcr_bus_if.write_data;
4346
default:;
4447
endcase
4548
end

kernel/include/vx_intrinsics.h

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -469,6 +469,132 @@ inline void vx_barrier_wait(int barrier_id, int phase) {
469469

470470
#ifdef __cplusplus
471471
}
472+
473+
// CTA Block Index Proxy Structures
474+
// These allow blockIdx.x, blockIdx.y, blockIdx.z to be used directly
475+
// without function call syntax, reading from RISC-V CSRs automatically
476+
477+
#ifndef VX_CSR_CTA_X
478+
#define VX_CSR_CTA_X 0xCC6
479+
#endif
480+
481+
#ifndef VX_CSR_CTA_Y
482+
#define VX_CSR_CTA_Y 0xCC7
483+
#endif
484+
485+
#ifndef VX_CSR_CTA_Z
486+
#define VX_CSR_CTA_Z 0xCC8
487+
#endif
488+
489+
#ifndef VX_CSR_BLOCK_DIM_X
490+
#define VX_CSR_BLOCK_DIM_X 0xCCA
472491
#endif
473492

493+
#ifndef VX_CSR_BLOCK_DIM_Y
494+
#define VX_CSR_BLOCK_DIM_Y 0xCCB
495+
#endif
496+
497+
#ifndef VX_CSR_BLOCK_DIM_Z
498+
#define VX_CSR_BLOCK_DIM_Z 0xCCC
499+
#endif
500+
501+
#ifndef VX_CSR_CTA_WARP_ID
502+
#define VX_CSR_CTA_WARP_ID 0xCCD
503+
#endif
504+
505+
// Proxy structure for blockIdx with x, y, z members
506+
struct BlockIdx {
507+
struct X {
508+
// Implicit conversion to unsigned int triggers the CSR read
509+
inline operator unsigned int() const {
510+
unsigned int val;
511+
__asm__ __volatile__ ("csrr %0, %1" : "=r"(val) : "i"(VX_CSR_CTA_X));
512+
return val;
513+
}
514+
} x;
515+
516+
struct Y {
517+
inline operator unsigned int() const {
518+
unsigned int val;
519+
__asm__ __volatile__ ("csrr %0, %1" : "=r"(val) : "i"(VX_CSR_CTA_Y));
520+
return val;
521+
}
522+
} y;
523+
524+
struct Z {
525+
inline operator unsigned int() const {
526+
unsigned int val;
527+
__asm__ __volatile__ ("csrr %0, %1" : "=r"(val) : "i"(VX_CSR_CTA_Z));
528+
return val;
529+
}
530+
} z;
531+
};
532+
533+
// Create a global instance of blockIdx
534+
// Marking it static ensures no linker errors if included in multiple files.
535+
// The struct holds no actual data, so the compiler will optimize it away.
536+
static const BlockIdx blockIdx;
537+
538+
// Proxy structure for blockDim with x, y, z members
539+
struct BlockDim {
540+
struct X {
541+
// Implicit conversion to unsigned int triggers the CSR read
542+
inline operator unsigned int() const {
543+
unsigned int val;
544+
__asm__ __volatile__ ("csrr %0, %1" : "=r"(val) : "i"(VX_CSR_BLOCK_DIM_X));
545+
return val;
546+
}
547+
} x;
548+
549+
struct Y {
550+
inline operator unsigned int() const {
551+
unsigned int val;
552+
__asm__ __volatile__ ("csrr %0, %1" : "=r"(val) : "i"(VX_CSR_BLOCK_DIM_Y));
553+
return val;
554+
}
555+
} y;
556+
557+
struct Z {
558+
inline operator unsigned int() const {
559+
unsigned int val;
560+
__asm__ __volatile__ ("csrr %0, %1" : "=r"(val) : "i"(VX_CSR_BLOCK_DIM_Z));
561+
return val;
562+
}
563+
} z;
564+
};
565+
566+
// Create a global instance of blockDim
567+
// Marking it static ensures no linker errors if included in multiple files.
568+
// The struct holds no actual data, so the compiler will optimize it away.
569+
static const BlockDim blockDim;
570+
571+
// Proxy structure for threadIdx with x, y, z members
572+
// threadIdx.x gives the flat thread index within the CTA:
573+
// warp_local_id * NUM_THREADS + thread_id_within_warp
574+
struct ThreadIdx {
575+
struct X {
576+
inline operator unsigned int() const {
577+
unsigned int warp_local_id;
578+
__asm__ __volatile__ ("csrr %0, %1" : "=r"(warp_local_id) : "i"(VX_CSR_CTA_WARP_ID));
579+
return warp_local_id * vx_num_threads() + vx_thread_id();
580+
}
581+
} x;
582+
583+
struct Y {
584+
inline operator unsigned int() const {
585+
return 0;
586+
}
587+
} y;
588+
589+
struct Z {
590+
inline operator unsigned int() const {
591+
return 0;
592+
}
593+
} z;
594+
};
595+
596+
static const ThreadIdx threadIdx;
597+
598+
#endif // __cplusplus
599+
474600
#endif // __VX_INTRINSICS_H__

tests/regression/vecadd/kernel.cpp

Lines changed: 49 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,57 @@
11
#include <vx_spawn.h>
22
#include "common.h"
33

4-
void kernel_body(kernel_arg_t* __UNIFORM__ arg) {
5-
auto src0_ptr = reinterpret_cast<TYPE*>(arg->src0_addr);
6-
auto src1_ptr = reinterpret_cast<TYPE*>(arg->src1_addr);
7-
auto dst_ptr = reinterpret_cast<TYPE*>(arg->dst_addr);
4+
// void kernel_body(kernel_arg_t* __UNIFORM__ arg) {
5+
// auto src0_ptr = reinterpret_cast<TYPE*>(arg->src0_addr);
6+
// auto src1_ptr = reinterpret_cast<TYPE*>(arg->src1_addr);
7+
// auto dst_ptr = reinterpret_cast<TYPE*>(arg->dst_addr);
88

9-
dst_ptr[blockIdx.x] = src0_ptr[blockIdx.x] + src1_ptr[blockIdx.x];
10-
}
9+
// dst_ptr[blockIdx.x] = src0_ptr[blockIdx.x] + src1_ptr[blockIdx.x];
10+
// }
1111

1212
int main() {
1313
kernel_arg_t* arg = (kernel_arg_t*)csr_read(VX_CSR_MSCRATCH);
14-
return vx_spawn_threads(1, &arg->num_points, nullptr, (vx_kernel_func_cb)kernel_body, arg);
14+
// return vx_spawn_threads(1, &arg->num_points, nullptr, (vx_kernel_func_cb)kernel_body, arg);
15+
//int warpId = static_cast<int>(csr_read(VX_CSR_CTA_ID));
16+
//int warpSize = vx_num_threads();
17+
//int threadId = vx_thread_id();
18+
19+
int core_id = static_cast<int>(csr_read(VX_CSR_CORE_ID));
20+
21+
// int cond = (warpId == 0) && (threadId == 0);
22+
// int sp = vx_split(cond);
23+
24+
// if (cond) {
25+
// asm volatile(
26+
// "la a0, _edata\n\t"
27+
// "la a2, _end\n\t"
28+
// "sub a2, a2, a0\n\t"
29+
// "li a1, 0\n\t"
30+
// "call memset\n\t"
31+
// :
32+
// :
33+
// : "memory", "ra",
34+
// "a0","a1","a2","a3","a4","a5","a6","a7",
35+
// "t0","t1","t2","t3","t4","t5","t6");
36+
// }
37+
38+
// vx_join(sp);
39+
40+
41+
42+
// Calculate global thread ID
43+
// threadIdx.x gives the flat thread index within the CTA (warp_local_id * NUM_THREADS + thread_id)
44+
// globalId = blockIdx.x * blockDim.x + threadIdx.x
45+
uint32_t globalId = blockIdx.x * blockDim.x + threadIdx.x;
46+
47+
vx_printf("block id x: %d, threadIdx.x: %d, global id: %d\n",
48+
blockIdx.x, threadIdx.x, globalId);
49+
50+
auto src0_ptr = reinterpret_cast<TYPE*>(arg->src0_addr);
51+
auto src1_ptr = reinterpret_cast<TYPE*>(arg->src1_addr);
52+
auto dst_ptr = reinterpret_cast<TYPE*>(arg->dst_addr);
53+
54+
dst_ptr[globalId] = src0_ptr[globalId] + src1_ptr[globalId];
55+
56+
return 0;
1557
}

0 commit comments

Comments
 (0)