Skip to content

Commit 8edf8fb

Browse files
committed
Merge branch 'kmu_rebase' of https://github.com/randyliu4345/vortex into bug_fixes
1 parent 4a4bc35 commit 8edf8fb

48 files changed

Lines changed: 1441 additions & 27 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

hw/VX_config.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ VM_ENABLED = "expr: 1 if $VM_ENABLE else 0"
3232
EXT_D_ENABLE = "expr: $XLEN_64"
3333
FLEN = "expr: 64 if $EXT_D_ENABLE else 32"
3434

35+
KMU_ENABLE = false
36+
KMU_ENABLED = "expr: 1 if $KMU_ENABLE else 0"
37+
3538
# extensions
3639
EXT_M_ENABLE = true
3740
EXT_F_ENABLE = true

hw/VX_types.toml

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,17 @@ VX_DCR_BASE_STARTUP_ARG1 = 0x004
1515
VX_DCR_BASE_MPM_CLASS = 0x005
1616
VX_DCR_BASE_STATE_END = 0x006
1717

18+
[dxr_kmu]
19+
VX_DCR_BASE_GRID_DIM0 = 0x006
20+
VX_DCR_BASE_GRID_DIM1 = 0x007
21+
VX_DCR_BASE_GRID_DIM2 = 0x008
22+
VX_DCR_BASE_BLOCK_DIM0 = 0x009
23+
VX_DCR_BASE_BLOCK_DIM1 = 0x00A
24+
VX_DCR_BASE_BLOCK_DIM2 = 0x00B
25+
VX_DCR_BASE_LMEM_SIZE = 0x00C
26+
VX_DCR_BASE_START_EXE = 0x00D
27+
VX_DCR_BASE_STATE_END = 0x00E
28+
1829
# DXA descriptor DCR mapping
1930
# access pattern:
2031
# VX_DCR_DXA_DESC_BASE + slot * VX_DCR_DXA_DESC_STRIDE + field_off
@@ -114,6 +125,20 @@ VX_CSR_NUM_CORES = 0xFC2
114125
VX_CSR_LOCAL_MEM_BASE = 0xFC3
115126
VX_CSR_NUM_BARRIERS = 0xFC4
116127

128+
[csr_kmu]
129+
VX_CSR_CTA_X = 0xCC6
130+
VX_CSR_CTA_Y = 0xCC7
131+
VX_CSR_CTA_Z = 0xCC8
132+
VX_CSR_CTA_ID = 0xCC9
133+
134+
# cta block dimension CSRs
135+
VX_CSR_BLOCK_DIM_X = 0xCCA
136+
VX_CSR_BLOCK_DIM_Y = 0xCCB
137+
VX_CSR_BLOCK_DIM_Z = 0xCCC
138+
139+
# warp local ID within CTA
140+
VX_CSR_CTA_WARP_ID = 0xCCD
141+
117142
[dcr_mpm_class]
118143
VX_DCR_MPM_CLASS_BASE = 0
119144
VX_DCR_MPM_CLASS_CORE = 1

hw/rtl/VX_cluster.sv

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ module VX_cluster import VX_gpu_pkg::*; #(
3333
// Memory
3434
VX_mem_bus_if.master mem_bus_if [`L2_MEM_PORTS],
3535

36+
// KMU bus
37+
VX_kmu_bus_if.slave kmu_bus_if[1],
38+
3639
// Status
3740
output wire busy
3841
);
@@ -51,6 +54,18 @@ module VX_cluster import VX_gpu_pkg::*; #(
5154
end
5255
`endif
5356

57+
VX_kmu_bus_if per_socket_kmu_bus_if[NUM_SOCKETS]();
58+
59+
VX_kmu_arb #(
60+
.NUM_INPUTS (1),
61+
.NUM_OUTPUTS (NUM_SOCKETS)
62+
) kmu_arb (
63+
.clk (clk),
64+
.reset (reset),
65+
.bus_in_if (kmu_bus_if),
66+
.bus_out_if (per_socket_kmu_bus_if)
67+
);
68+
5469
VX_gbar_bus_if per_socket_gbar_bus_if[NUM_SOCKETS]();
5570
VX_gbar_bus_if gbar_bus_if();
5671

@@ -251,6 +266,8 @@ module VX_cluster import VX_gpu_pkg::*; #(
251266
.per_core_bank_wr_if(per_core_bank_wr_if[socket_id * `SOCKET_SIZE +: `SOCKET_SIZE]),
252267
`endif
253268

269+
.kmu_bus_if (per_socket_kmu_bus_if[socket_id +: 1]),
270+
254271
.gbar_bus_if (per_socket_gbar_bus_if[socket_id]),
255272

256273
.busy (per_socket_busy[socket_id])

hw/rtl/VX_gpu_pkg.sv

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -533,6 +533,21 @@ package VX_gpu_pkg;
533533
logic [7:0] mpm_class;
534534
} base_dcrs_t;
535535

536+
typedef struct packed {
537+
logic [`XLEN-1:0] pc;
538+
logic [2:0][31:0] grid_dim;
539+
logic [2:0][31:0] block_dim;
540+
logic [`XLEN-1:0] param;
541+
} kmu_data_t;
542+
543+
typedef struct packed {
544+
logic [31:0] cta_x;
545+
logic [31:0] cta_y;
546+
logic [31:0] cta_z;
547+
logic [31:0] cta_id;
548+
logic [31:0] local_warp_id;
549+
} cta_csr_data_t;
550+
536551
//////////////////////// instruction arguments ////////////////////////////
537552

538553
localparam INST_ARGS_BITS = 3 + ALU_TYPE_BITS + 20;

hw/rtl/VX_socket.sv

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,10 @@ module VX_socket import VX_gpu_pkg::*; #(
4040
VX_dxa_bank_wr_if.slave per_core_bank_wr_if [`SOCKET_SIZE],
4141
`endif
4242

43-
// Barrier
43+
// KMU bus
44+
VX_kmu_bus_if.slave kmu_bus_if[1],
45+
46+
// Global barrier
4447
VX_gbar_bus_if.master gbar_bus_if,
4548

4649
// Status
@@ -52,6 +55,18 @@ module VX_socket import VX_gpu_pkg::*; #(
5255
`SCOPE_IO_SWITCH (`SOCKET_SIZE);
5356
`endif
5457

58+
VX_kmu_bus_if per_core_kmu_bus_if[`SOCKET_SIZE]();
59+
60+
VX_kmu_arb #(
61+
.NUM_INPUTS (1),
62+
.NUM_OUTPUTS (`SOCKET_SIZE)
63+
) kmu_arb (
64+
.clk (clk),
65+
.reset (reset),
66+
.bus_in_if (kmu_bus_if),
67+
.bus_out_if (per_core_kmu_bus_if[`SOCKET_SIZE-1:0])
68+
);
69+
5570
VX_gbar_bus_if per_core_gbar_bus_if[`SOCKET_SIZE]();
5671

5772
VX_gbar_arb #(
@@ -303,6 +318,8 @@ module VX_socket import VX_gpu_pkg::*; #(
303318
.dxa_bank_wr_if (per_core_bank_wr_if[core_id]),
304319
`endif
305320

321+
.kmu_bus_if (per_core_kmu_bus_if[core_id]),
322+
306323
.gbar_bus_if (per_core_gbar_bus_if[core_id]),
307324

308325
.busy (per_core_busy[core_id])

hw/rtl/Vortex.sv

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,18 @@ module Vortex import VX_gpu_pkg::*, VX_trace_pkg::*; (
4848
`STATIC_ASSERT(`IS_POW2(`NUM_CORES), ("NUM_CORES must be a power of 2"));
4949
`STATIC_ASSERT(`IS_POW2(`SOCKET_SIZE), ("SOCKET_SIZE must be a power of 2"));
5050

51+
VX_kmu_bus_if kmu_bus_in[1]();
52+
VX_kmu_bus_if per_cluster_kmu_bus_if[`NUM_CLUSTERS]();
53+
54+
VX_kmu kmu(
55+
.clk (clk),
56+
.reset (reset),
57+
.dcr_wr_valid (dcr_wr_valid),
58+
.dcr_wr_addr (dcr_wr_addr),
59+
.dcr_wr_data (dcr_wr_data),
60+
.kmu_bus_if (kmu_bus_in[0])
61+
);
62+
5163
`ifdef SCOPE
5264
localparam scope_cluster = 0;
5365
`SCOPE_IO_SWITCH (`NUM_CLUSTERS);
@@ -135,6 +147,16 @@ module Vortex import VX_gpu_pkg::*, VX_trace_pkg::*; (
135147

136148
wire [`NUM_CLUSTERS-1:0] per_cluster_busy;
137149

150+
VX_kmu_arb #(
151+
.NUM_INPUTS (1),
152+
.NUM_OUTPUTS (`NUM_CLUSTERS)
153+
) kmu_arb (
154+
.clk (clk),
155+
.reset (reset),
156+
.bus_in_if (kmu_bus_in),
157+
.bus_out_if (per_cluster_kmu_bus_if[`NUM_CLUSTERS-1:0])
158+
);
159+
138160
// Generate all clusters
139161
for (genvar cluster_id = 0; cluster_id < `NUM_CLUSTERS; ++cluster_id) begin : g_clusters
140162

@@ -160,6 +182,8 @@ module Vortex import VX_gpu_pkg::*, VX_trace_pkg::*; (
160182

161183
.mem_bus_if (per_cluster_mem_bus_if[cluster_id * `L2_MEM_PORTS +: `L2_MEM_PORTS]),
162184

185+
.kmu_bus_if (per_cluster_kmu_bus_if[cluster_id +: 1]),
186+
163187
.busy (per_cluster_busy[cluster_id])
164188
);
165189
end

hw/rtl/core/VX_core.sv

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ module VX_core import VX_gpu_pkg::*; #(
4242
VX_dxa_bank_wr_if.slave dxa_bank_wr_if,
4343
`endif
4444

45+
// KMU bus
46+
VX_kmu_bus_if.slave kmu_bus_if,
47+
48+
// Global barrier
4549
VX_gbar_bus_if.master gbar_bus_if,
4650

4751
// Status
@@ -114,6 +118,8 @@ module VX_core import VX_gpu_pkg::*; #(
114118
.issue_sched_if (issue_sched_if),
115119
.commit_sched_if(commit_sched_if),
116120

121+
.kmu_bus_if (kmu_bus_if),
122+
117123
.schedule_if (schedule_if),
118124
.sched_csr_if (sched_csr_if),
119125
.gbar_bus_if (gbar_bus_if),

hw/rtl/core/VX_csr_data.sv

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ import VX_fpu_pkg::*;
5454
input wire [`NUM_WARPS-1:0] active_warps,
5555
input wire [`NUM_WARPS-1:0][`NUM_THREADS-1:0] thread_masks,
5656

57+
input wire cta_csr_valid,
58+
input wire [NW_WIDTH-1:0] cta_csr_wid,
59+
input cta_csr_data_t cta_csr_data,
60+
5761
input wire read_enable,
5862
input wire [UUID_WIDTH-1:0] read_uuid,
5963
input wire [NW_WIDTH-1:0] read_wid,
@@ -78,6 +82,7 @@ import VX_fpu_pkg::*;
7882
// CSRs Write /////////////////////////////////////////////////////////////
7983

8084
reg [`XLEN-1:0] mscratch;
85+
cta_csr_data_t [`NUM_WARPS-1:0] cta_csrs;
8186

8287
`ifdef EXT_F_ENABLE
8388
reg [`NUM_WARPS-1:0][INST_FRM_BITS+`FP_FLAGS_BITS-1:0] fcsr, fcsr_n;
@@ -153,6 +158,9 @@ import VX_fpu_pkg::*;
153158
end
154159
endcase
155160
end
161+
if (cta_csr_valid) begin
162+
cta_csrs[cta_csr_wid] <= cta_csr_data;
163+
end
156164
end
157165

158166
// CSRs read //////////////////////////////////////////////////////////////
@@ -177,6 +185,11 @@ import VX_fpu_pkg::*;
177185
`endif
178186
`VX_CSR_MSCRATCH : read_data_rw_w = mscratch;
179187

188+
`VX_CSR_CTA_X : read_data_rw_w = cta_csrs[read_wid].cta_x;
189+
`VX_CSR_CTA_Y : read_data_rw_w = cta_csrs[read_wid].cta_y;
190+
`VX_CSR_CTA_Z : read_data_rw_w = cta_csrs[read_wid].cta_z;
191+
`VX_CSR_CTA_ID : read_data_rw_w = cta_csrs[read_wid].cta_id;
192+
180193
`VX_CSR_WARP_ID : read_data_ro_w = `XLEN'(read_wid);
181194
`VX_CSR_CORE_ID : read_data_ro_w = `XLEN'(CORE_ID);
182195
`VX_CSR_ACTIVE_THREADS: read_data_ro_w = `XLEN'(thread_masks[read_wid]);

hw/rtl/core/VX_csr_unit.sv

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,10 @@ module VX_csr_unit import VX_gpu_pkg::*; #(
8282
.active_warps (sched_csr_if.active_warps),
8383
.thread_masks (sched_csr_if.thread_masks),
8484

85+
.cta_csr_valid (sched_csr_if.cta_csr_valid),
86+
.cta_csr_wid (sched_csr_if.cta_csr_wid),
87+
.cta_csr_data (sched_csr_if.cta_csr_data),
88+
8589
`ifdef EXT_F_ENABLE
8690
.fpu_csr_if (fpu_csr_if),
8791
`endif

hw/rtl/core/VX_cta_dispatch.sv

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
module VX_cta_dispatch import VX_gpu_pkg::*;
2+
(
3+
input wire clk,
4+
input wire reset,
5+
6+
// from KMU
7+
VX_kmu_bus_if.slave kmu_bus_if,
8+
9+
// from scheduler
10+
input wire[`NUM_WARPS-1:0] active_warps,
11+
12+
// to scheduler
13+
output wire dispatch_fire,
14+
output wire[`CLOG2(`NUM_WARPS)-1:0] cta_dispatch_wid,
15+
output reg[`XLEN-1:0] pc,
16+
output reg[`XLEN-1:0] param,
17+
output wire[`NUM_THREADS-1:0] tmask,
18+
19+
// to CSR unit
20+
output wire cta_csr_valid,
21+
output wire [NW_WIDTH-1:0] cta_csr_wid,
22+
output wire cta_csr_data_t cta_csr_data
23+
);
24+
// Define states for the FSM
25+
typedef enum logic [0:0] { IDLE = 1'b0, DISPATCH = 1'b1 } state_t;
26+
state_t state;
27+
28+
// Only handshake with KMU when the machine is in non-reset IDLE state
29+
assign kmu_bus_if.req_ready = (state == IDLE && ~reset);
30+
31+
// store some kernel info
32+
logic [31:0] num_warps;
33+
reg [`NUM_THREADS-1:0] cur_remain_mask;
34+
35+
36+
37+
// Generate indices for data_in
38+
logic [`NUM_WARPS-1:0][`CLOG2(`NUM_WARPS)-1:0] warp_indices;
39+
generate
40+
for (genvar i = 0; i < `NUM_WARPS; ++i) begin : gen_warp_index
41+
assign warp_indices[i] = i[`CLOG2(`NUM_WARPS)-1:0];
42+
end
43+
endgenerate
44+
45+
logic cta_dispatch_valid;
46+
47+
// Find the first idle warp (active_warps[i] == 0)
48+
VX_find_first #(
49+
.N (`NUM_WARPS),
50+
.DATAW(`CLOG2 (`NUM_WARPS)),
51+
.REVERSE(0)
52+
) find_first_idle (
53+
.data_in (warp_indices),
54+
.valid_in (~active_warps), // invert to find first '0'
55+
.data_out (cta_dispatch_wid),
56+
.valid_out (cta_dispatch_valid)
57+
);
58+
59+
int warp_counter;
60+
// Only dispatch when 1. in DISPATCH state and 2. there are slot in activate_warps and 3. there
61+
// are some warp to be disptached
62+
assign dispatch_fire = (state == DISPATCH) && (cta_dispatch_valid) && (warp_counter < num_warps);
63+
64+
// combinational logic to handle tmask, in this way, it is available when needed
65+
logic[`NUM_THREADS-1:0] tmask_n;
66+
always_comb begin
67+
if (warp_counter == num_warps - 1)
68+
tmask_n = cur_remain_mask;
69+
else
70+
tmask_n = {`NUM_THREADS{1'b1}};
71+
end
72+
73+
assign tmask = tmask_n;
74+
75+
// FSM and dispatch logic
76+
always_ff @(posedge clk) begin
77+
if (reset) begin
78+
state <= IDLE;
79+
warp_counter <= 0;
80+
pc <= '0;
81+
param <= '0;
82+
cur_remain_mask <= '0;
83+
num_warps <= 0;
84+
cta_csr_valid <= 0;
85+
end else begin
86+
case (state)
87+
IDLE: begin
88+
warp_counter <= 0;
89+
if (kmu_bus_if.req_valid && kmu_bus_if.req_ready) begin
90+
/* When there is a handshake, store the kernel info from kmu
91+
then make the FSM do a transition ot DISPATCH state
92+
*/
93+
pc <= kmu_bus_if.req_data.start_pc;
94+
param <= kmu_bus_if.req_data.param;
95+
num_warps <= kmu_bus_if.req_data.num_warps;
96+
cur_remain_mask <= kmu_bus_if.req_data.remain_mask;
97+
cta_csr_data.cta_x <= kmu_bus_if.req_data.cta_x;
98+
cta_csr_data.cta_y <= kmu_bus_if.req_data.cta_y;
99+
cta_csr_data.cta_z <= kmu_bus_if.req_data.cta_z;
100+
cta_csr_data.cta_id <= kmu_bus_if.req_data.cta_id;
101+
state <= DISPATCH;
102+
end
103+
end
104+
105+
DISPATCH: begin
106+
if (dispatch_fire) begin
107+
/* update counter, write the cta csr message.
108+
The warp is dispatched in the VX_schedule module
109+
at the same time
110+
*/
111+
warp_counter <= warp_counter + 1;
112+
cta_csr_data.local_warp_id <= warp_counter;
113+
cta_csr_wid <= cta_dispatch_wid;
114+
cta_csr_valid <= 1;
115+
end else begin
116+
/* If not dispatch_fire, no warp is dispatched in the coming cycle
117+
So set the cta csr valid to 0 for the coming cycle
118+
*/
119+
cta_csr_valid <= 0;
120+
end
121+
122+
// Exit DISPATCH after all warps are dispatched
123+
if (warp_counter >= num_warps) begin
124+
state <= IDLE;
125+
end
126+
end
127+
128+
default: state <= IDLE;
129+
endcase
130+
end
131+
end
132+
133+
endmodule

0 commit comments

Comments
 (0)