|
| 1 | +module VX_cta_dispatch import VX_gpu_pkg::*; |
| 2 | +( |
| 3 | + input wire clk, |
| 4 | + input wire reset, |
| 5 | + |
| 6 | + // from KMU |
| 7 | + VX_kmu_bus_if.slave kmu_bus_if, |
| 8 | + |
| 9 | + // from scheduler |
| 10 | + input wire[`NUM_WARPS-1:0] active_warps, |
| 11 | + |
| 12 | + // to scheduler |
| 13 | + output wire dispatch_fire, |
| 14 | + output wire[`CLOG2(`NUM_WARPS)-1:0] cta_dispatch_wid, |
| 15 | + output reg[`XLEN-1:0] pc, |
| 16 | + output reg[`XLEN-1:0] param, |
| 17 | + output wire[`NUM_THREADS-1:0] tmask, |
| 18 | + |
| 19 | + // to CSR unit |
| 20 | + output wire cta_csr_valid, |
| 21 | + output wire [NW_WIDTH-1:0] cta_csr_wid, |
| 22 | + output wire cta_csr_data_t cta_csr_data |
| 23 | +); |
| 24 | + // Define states for the FSM |
| 25 | + typedef enum logic [0:0] { IDLE = 1'b0, DISPATCH = 1'b1 } state_t; |
| 26 | + state_t state; |
| 27 | + |
| 28 | + // Only handshake with KMU when the machine is in non-reset IDLE state |
| 29 | + assign kmu_bus_if.req_ready = (state == IDLE && ~reset); |
| 30 | + |
| 31 | + // store some kernel info |
| 32 | + logic [31:0] num_warps; |
| 33 | + reg [`NUM_THREADS-1:0] cur_remain_mask; |
| 34 | + |
| 35 | + |
| 36 | + |
| 37 | + // Generate indices for data_in |
| 38 | + logic [`NUM_WARPS-1:0][`CLOG2(`NUM_WARPS)-1:0] warp_indices; |
| 39 | + generate |
| 40 | + for (genvar i = 0; i < `NUM_WARPS; ++i) begin : gen_warp_index |
| 41 | + assign warp_indices[i] = i[`CLOG2(`NUM_WARPS)-1:0]; |
| 42 | + end |
| 43 | + endgenerate |
| 44 | + |
| 45 | + logic cta_dispatch_valid; |
| 46 | + |
| 47 | + // Find the first idle warp (active_warps[i] == 0) |
| 48 | + VX_find_first #( |
| 49 | + .N (`NUM_WARPS), |
| 50 | + .DATAW(`CLOG2 (`NUM_WARPS)), |
| 51 | + .REVERSE(0) |
| 52 | + ) find_first_idle ( |
| 53 | + .data_in (warp_indices), |
| 54 | + .valid_in (~active_warps), // invert to find first '0' |
| 55 | + .data_out (cta_dispatch_wid), |
| 56 | + .valid_out (cta_dispatch_valid) |
| 57 | + ); |
| 58 | + |
| 59 | + int warp_counter; |
| 60 | + // Only dispatch when 1. in DISPATCH state and 2. there are slot in activate_warps and 3. there |
| 61 | + // are some warp to be disptached |
| 62 | + assign dispatch_fire = (state == DISPATCH) && (cta_dispatch_valid) && (warp_counter < num_warps); |
| 63 | + |
| 64 | + // combinational logic to handle tmask, in this way, it is available when needed |
| 65 | + logic[`NUM_THREADS-1:0] tmask_n; |
| 66 | + always_comb begin |
| 67 | + if (warp_counter == num_warps - 1) |
| 68 | + tmask_n = cur_remain_mask; |
| 69 | + else |
| 70 | + tmask_n = {`NUM_THREADS{1'b1}}; |
| 71 | + end |
| 72 | + |
| 73 | + assign tmask = tmask_n; |
| 74 | + |
| 75 | + // FSM and dispatch logic |
| 76 | + always_ff @(posedge clk) begin |
| 77 | + if (reset) begin |
| 78 | + state <= IDLE; |
| 79 | + warp_counter <= 0; |
| 80 | + pc <= '0; |
| 81 | + param <= '0; |
| 82 | + cur_remain_mask <= '0; |
| 83 | + num_warps <= 0; |
| 84 | + cta_csr_valid <= 0; |
| 85 | + end else begin |
| 86 | + case (state) |
| 87 | + IDLE: begin |
| 88 | + warp_counter <= 0; |
| 89 | + if (kmu_bus_if.req_valid && kmu_bus_if.req_ready) begin |
| 90 | + /* When there is a handshake, store the kernel info from kmu |
| 91 | + then make the FSM do a transition ot DISPATCH state |
| 92 | + */ |
| 93 | + pc <= kmu_bus_if.req_data.start_pc; |
| 94 | + param <= kmu_bus_if.req_data.param; |
| 95 | + num_warps <= kmu_bus_if.req_data.num_warps; |
| 96 | + cur_remain_mask <= kmu_bus_if.req_data.remain_mask; |
| 97 | + cta_csr_data.cta_x <= kmu_bus_if.req_data.cta_x; |
| 98 | + cta_csr_data.cta_y <= kmu_bus_if.req_data.cta_y; |
| 99 | + cta_csr_data.cta_z <= kmu_bus_if.req_data.cta_z; |
| 100 | + cta_csr_data.cta_id <= kmu_bus_if.req_data.cta_id; |
| 101 | + state <= DISPATCH; |
| 102 | + end |
| 103 | + end |
| 104 | + |
| 105 | + DISPATCH: begin |
| 106 | + if (dispatch_fire) begin |
| 107 | + /* update counter, write the cta csr message. |
| 108 | + The warp is dispatched in the VX_schedule module |
| 109 | + at the same time |
| 110 | + */ |
| 111 | + warp_counter <= warp_counter + 1; |
| 112 | + cta_csr_data.local_warp_id <= warp_counter; |
| 113 | + cta_csr_wid <= cta_dispatch_wid; |
| 114 | + cta_csr_valid <= 1; |
| 115 | + end else begin |
| 116 | + /* If not dispatch_fire, no warp is dispatched in the coming cycle |
| 117 | + So set the cta csr valid to 0 for the coming cycle |
| 118 | + */ |
| 119 | + cta_csr_valid <= 0; |
| 120 | + end |
| 121 | + |
| 122 | + // Exit DISPATCH after all warps are dispatched |
| 123 | + if (warp_counter >= num_warps) begin |
| 124 | + state <= IDLE; |
| 125 | + end |
| 126 | + end |
| 127 | + |
| 128 | + default: state <= IDLE; |
| 129 | + endcase |
| 130 | + end |
| 131 | + end |
| 132 | + |
| 133 | +endmodule |
0 commit comments