diff --git a/hw/rtl/VX_decode.v b/hw/rtl/VX_decode.v index 54257679..e1d0d221 100644 --- a/hw/rtl/VX_decode.v +++ b/hw/rtl/VX_decode.v @@ -221,6 +221,7 @@ module VX_decode #( use_rd = 1; use_imm = 1; use_PC = 1; + is_wstall = 1; imm = 32'd4; `USED_IREG (rd); end @@ -406,8 +407,9 @@ module VX_decode #( assign join_if.valid = ifetch_rsp_fire && is_join; assign join_if.wid = ifetch_rsp_if.wid; - assign wstall_if.valid = ifetch_rsp_fire && is_wstall; + assign wstall_if.valid = ifetch_rsp_fire; assign wstall_if.wid = ifetch_rsp_if.wid; + assign wstall_if.stalled = is_wstall; assign ifetch_rsp_if.ready = decode_if.ready; diff --git a/hw/rtl/VX_fetch.v b/hw/rtl/VX_fetch.v index 30f786e3..765d9b72 100644 --- a/hw/rtl/VX_fetch.v +++ b/hw/rtl/VX_fetch.v @@ -44,7 +44,6 @@ module VX_fetch #( .branch_ctl_if (branch_ctl_if), .ifetch_req_if (ifetch_req_if), - .ifetch_rsp_if (ifetch_rsp_if), .fetch_to_csr_if (fetch_to_csr_if), diff --git a/hw/rtl/VX_print_instr.vh b/hw/rtl/VX_print_instr.vh index 4559c199..691c4ac8 100644 --- a/hw/rtl/VX_print_instr.vh +++ b/hw/rtl/VX_print_instr.vh @@ -38,7 +38,7 @@ task print_ex_op ( `BR_MRET: $write("MRET"); `BR_SRET: $write("SRET"); `BR_DRET: $write("DRET"); - default: $write("?"); + default: $write("?"); endcase end else if (`ALU_IS_MUL(op_mod)) begin case (`MUL_BITS'(op_type)) diff --git a/hw/rtl/VX_warp_sched.v b/hw/rtl/VX_warp_sched.v index 2e708a70..125f1685 100644 --- a/hw/rtl/VX_warp_sched.v +++ b/hw/rtl/VX_warp_sched.v @@ -13,7 +13,6 @@ module VX_warp_sched #( VX_join_if join_if, VX_branch_ctl_if branch_ctl_if, - VX_ifetch_rsp_if ifetch_rsp_if, VX_ifetch_req_if ifetch_req_if, VX_fetch_to_csr_if fetch_to_csr_if, @@ -30,26 +29,25 @@ module VX_warp_sched #( reg [`NUM_WARPS-1:0] active_warps, active_warps_n; // real active warps (updated when a warp is activated or disabled) reg [`NUM_WARPS-1:0] stalled_warps; // asserted when a branch/gpgpu instructions are issued - // Lock warp until instruction decode to resolve branches - reg [`NUM_WARPS-1:0] fetch_lock; reg [`NUM_WARPS-1:0][`NUM_THREADS-1:0] thread_masks; reg [`NUM_WARPS-1:0][31:0] warp_pcs, warp_next_pcs; // barriers - reg [`NUM_BARRIERS-1:0][`NUM_WARPS-1:0] barrier_stall_mask; // warps waiting on barrier + reg [`NUM_BARRIERS-1:0][`NUM_WARPS-1:0] barrier_masks; // warps waiting on barrier wire reached_barrier_limit; // the expected number of warps reached the barrier // wspawn - reg [31:0] use_wspawn_pc; + reg [31:0] wspawn_pc; reg [`NUM_WARPS-1:0] use_wspawn; - wire [`NW_BITS-1:0] schedule_warp; + wire [`NW_BITS-1:0] schedule_wid; + wire [`NUM_THREADS-1:0] schedule_tmask; + wire [31:0] schedule_pc; + wire schedule_valid; wire warp_scheduled; wire ifetch_req_fire = ifetch_req_if.valid && ifetch_req_if.ready; - wire ifetch_rsp_fire = ifetch_rsp_if.valid && ifetch_rsp_if.ready; - wire tmc_active = (warp_ctl_if.tmc.tmask != 0); always @(*) begin @@ -64,55 +62,44 @@ module VX_warp_sched #( always @(posedge clk) begin if (reset) begin - for (integer i = 0; i < `NUM_BARRIERS; i++) begin - barrier_stall_mask[i] <= 0; - end + barrier_masks <= 0; + use_wspawn <= 0; + stalled_warps <= 0; + warp_pcs <= '0; + active_warps <= '0; + thread_masks <= '0; - use_wspawn_pc <= 0; - use_wspawn <= 0; - warp_pcs[0] <= `STARTUP_ADDR; - active_warps[0] <= 1; // Activating first warp - thread_masks[0] <= 1; // Activating first thread in first warp - stalled_warps <= 0; - fetch_lock <= 0; - - for (integer i = 1; i < `NUM_WARPS; i++) begin - warp_pcs[i] <= 0; - active_warps[i] <= 0; - thread_masks[i] <= 0; - end + // activate first warp + warp_pcs[0] <= `STARTUP_ADDR; + active_warps[0] <= '1; + thread_masks[0] <= '1; end else begin if (warp_ctl_if.valid && warp_ctl_if.wspawn.valid) begin - use_wspawn <= warp_ctl_if.wspawn.wmask & (~`NUM_WARPS'(1)); - use_wspawn_pc <= warp_ctl_if.wspawn.pc; + use_wspawn <= warp_ctl_if.wspawn.wmask & (~`NUM_WARPS'(1)); + wspawn_pc <= warp_ctl_if.wspawn.pc; end if (warp_ctl_if.valid && warp_ctl_if.barrier.valid) begin stalled_warps[warp_ctl_if.wid] <= 0; if (reached_barrier_limit) begin - barrier_stall_mask[warp_ctl_if.barrier.id] <= 0; + barrier_masks[warp_ctl_if.barrier.id] <= 0; end else begin - barrier_stall_mask[warp_ctl_if.barrier.id][warp_ctl_if.wid] <= 1; + barrier_masks[warp_ctl_if.barrier.id][warp_ctl_if.wid] <= 1; end - end else if (warp_ctl_if.valid && warp_ctl_if.tmc.valid) begin + end + + if (warp_ctl_if.valid && warp_ctl_if.tmc.valid) begin thread_masks[warp_ctl_if.wid] <= warp_ctl_if.tmc.tmask; stalled_warps[warp_ctl_if.wid] <= 0; - end else if (warp_ctl_if.valid && warp_ctl_if.split.valid) begin + end + + if (warp_ctl_if.valid && warp_ctl_if.split.valid) begin stalled_warps[warp_ctl_if.wid] <= 0; if (warp_ctl_if.split.diverged) begin thread_masks[warp_ctl_if.wid] <= warp_ctl_if.split.then_tmask; end - end - - if (use_wspawn[schedule_warp] && warp_scheduled) begin - use_wspawn[schedule_warp] <= 0; - thread_masks[schedule_warp] <= 1; end - // Stalling the scheduling of warps - if (wstall_if.valid) begin - stalled_warps[wstall_if.wid] <= 1; - end // Branch if (branch_ctl_if.valid) begin @@ -122,18 +109,24 @@ module VX_warp_sched #( stalled_warps[branch_ctl_if.wid] <= 0; end - // Lock warp until instruction decode to resolve branches if (warp_scheduled) begin - fetch_lock[schedule_warp] <= 1; + // stall the warp until decode stage + stalled_warps[schedule_wid] <= 1; + + // release wspawn + use_wspawn[schedule_wid] <= 0; + if (use_wspawn[schedule_wid]) begin + thread_masks[schedule_wid] <= 1; + end end if (ifetch_req_fire) begin warp_next_pcs[ifetch_req_if.wid] <= ifetch_req_if.PC + 4; end - - if (ifetch_rsp_fire) begin - fetch_lock[ifetch_rsp_if.wid] <= 0; - warp_pcs[ifetch_rsp_if.wid] <= warp_next_pcs[ifetch_rsp_if.wid]; + + if (wstall_if.valid) begin + stalled_warps[wstall_if.wid] <= wstall_if.stalled; + warp_pcs[wstall_if.wid] <= warp_next_pcs[wstall_if.wid]; end // join handling @@ -156,15 +149,15 @@ module VX_warp_sched #( `IGNORE_UNUSED_BEGIN wire [`NW_BITS:0] active_barrier_count; `IGNORE_UNUSED_END - assign active_barrier_count = $countones(barrier_stall_mask[warp_ctl_if.barrier.id]); + assign active_barrier_count = $countones(barrier_masks[warp_ctl_if.barrier.id]); assign reached_barrier_limit = (active_barrier_count[`NW_BITS-1:0] == warp_ctl_if.barrier.size_m1); - reg [`NUM_WARPS-1:0] total_barrier_stall; + reg [`NUM_WARPS-1:0] barrier_stalls; always @(*) begin - total_barrier_stall = barrier_stall_mask[0]; + barrier_stalls = barrier_masks[0]; for (integer i = 1; i < `NUM_BARRIERS; ++i) begin - total_barrier_stall |= barrier_stall_mask[i]; + barrier_stalls |= barrier_masks[i]; end end @@ -205,22 +198,27 @@ module VX_warp_sched #( // round-robin warp scheduling - wire schedule_valid; + wire [`NUM_WARPS-1:0] ready_warps = active_warps & ~(stalled_warps | barrier_stalls); VX_rr_arbiter #( .NUM_REQS (`NUM_WARPS) ) rr_arbiter ( .clk (clk), .reset (reset), - .requests (active_warps & ~(stalled_warps | total_barrier_stall | fetch_lock)), - .grant_index (schedule_warp), + .requests (ready_warps), + .grant_index (schedule_wid), .grant_valid (schedule_valid), `UNUSED_PIN (grant_onehot), `UNUSED_PIN (enable) ); - wire [`NUM_THREADS-1:0] thread_mask = use_wspawn[schedule_warp] ? `NUM_THREADS'(1) : thread_masks[schedule_warp]; - wire [31:0] warp_pc = use_wspawn[schedule_warp] ? use_wspawn_pc : warp_pcs[schedule_warp]; + wire [`NUM_WARPS-1:0][(`NUM_THREADS + 32)-1:0] schedule_data; + for (genvar i = 0; i < `NUM_WARPS; ++i) begin + assign schedule_data[i] = {(use_wspawn[i] ? `NUM_THREADS'(1) : thread_masks[i]), + (use_wspawn[i] ? wspawn_pc : warp_pcs[i])}; + end + + assign {schedule_tmask, schedule_pc} = schedule_data[schedule_wid]; wire stall_out = ~ifetch_req_if.ready && ifetch_req_if.valid; @@ -233,17 +231,17 @@ module VX_warp_sched #( .clk (clk), .reset (reset), .enable (!stall_out), - .data_in ({schedule_valid, thread_mask, warp_pc, schedule_warp}), + .data_in ({schedule_valid, schedule_tmask, schedule_pc, schedule_wid}), .data_out ({ifetch_req_if.valid, ifetch_req_if.tmask, ifetch_req_if.PC, ifetch_req_if.wid}) ); assign busy = (active_warps != 0); - `SCOPE_ASSIGN (wsched_scheduled_warp, warp_scheduled); + `SCOPE_ASSIGN (wsched_scheduled, warp_scheduled); `SCOPE_ASSIGN (wsched_active_warps, active_warps); - `SCOPE_ASSIGN (wsched_schedule_table, schedule_table); - `SCOPE_ASSIGN (wsched_schedule_ready, schedule_ready); - `SCOPE_ASSIGN (wsched_warp_to_schedule, schedule_warp); - `SCOPE_ASSIGN (wsched_warp_pc, warp_pc); + `SCOPE_ASSIGN (wsched_stalled_warps, stalled_warps); + `SCOPE_ASSIGN (wsched_schedule_wid, schedule_wid); + `SCOPE_ASSIGN (wsched_schedule_tmask, schedule_tmask); + `SCOPE_ASSIGN (wsched_schedule_pc, schedule_pc); endmodule \ No newline at end of file diff --git a/hw/rtl/interfaces/VX_wstall_if.v b/hw/rtl/interfaces/VX_wstall_if.v index b50d0711..e8e0e249 100644 --- a/hw/rtl/interfaces/VX_wstall_if.v +++ b/hw/rtl/interfaces/VX_wstall_if.v @@ -7,6 +7,7 @@ interface VX_wstall_if(); wire valid; wire [`NW_BITS-1:0] wid; + wire stalled; endinterface diff --git a/hw/scripts/scope.json b/hw/scripts/scope.json index 96db6498..4c411131 100644 --- a/hw/scripts/scope.json +++ b/hw/scripts/scope.json @@ -140,9 +140,9 @@ "afu/vortex/cluster/core/pipeline/fetch/warp_sched": { "?wsched_scheduled_warp": 1, "wsched_active_warps": "`NUM_WARPS", - "wsched_schedule_table": "`NUM_WARPS", - "wsched_schedule_ready": "`NUM_WARPS", - "wsched_warp_to_schedule": "`NW_BITS", + "wsched_stalled_warps": "`NUM_WARPS", + "wsched_schedule_tmask": "`NUM_THREADS", + "wsched_schedule_wid": "`NW_BITS", "wsched_warp_pc": "32" }, "afu/vortex/cluster/core/pipeline/execute/gpu_unit": {