From 288f0c976be58cd38c6b0e851fee0fbdf5cd9ce8 Mon Sep 17 00:00:00 2001 From: Blaise Tine Date: Tue, 15 Jun 2021 16:06:54 -0400 Subject: [PATCH] fix split/join hardware --- hw/rtl/VX_warp_sched.v | 43 ++++++++++++++++++++---------------------- 1 file changed, 20 insertions(+), 23 deletions(-) diff --git a/hw/rtl/VX_warp_sched.v b/hw/rtl/VX_warp_sched.v index 438ddb0b..03b8a9ce 100644 --- a/hw/rtl/VX_warp_sched.v +++ b/hw/rtl/VX_warp_sched.v @@ -21,7 +21,7 @@ module VX_warp_sched #( `UNUSED_PARAM (CORE_ID) - wire join_fall; + wire join_else; wire [31:0] join_pc; wire [`NUM_THREADS-1:0] join_tm; @@ -45,8 +45,6 @@ module VX_warp_sched #( reg [`NW_BITS-1:0] scheduled_warp; wire warp_scheduled; - reg didnt_split; - wire ifetch_rsp_fire = ifetch_rsp_if.valid && ifetch_rsp_if.ready; always @(*) begin @@ -82,7 +80,6 @@ module VX_warp_sched #( schedule_table[0] <= 1; // set first warp as ready thread_masks[0] <= 1; // Activating first thread in first warp stalled_warps <= 0; - didnt_split <= 0; fetch_lock <= 0; for (integer i = 1; i < `NUM_WARPS; i++) begin @@ -107,19 +104,10 @@ module VX_warp_sched #( end else if (warp_ctl_if.valid && warp_ctl_if.tmc.valid) begin thread_masks[warp_ctl_if.wid] <= warp_ctl_if.tmc.tmask; stalled_warps[warp_ctl_if.wid] <= 0; - end else if (join_if.valid && !didnt_split) begin - if (!join_fall) begin - warp_pcs[join_if.wid] <= join_pc; - end - thread_masks[join_if.wid] <= join_tm; - didnt_split <= 0; end else if (warp_ctl_if.valid && warp_ctl_if.split.valid) begin stalled_warps[warp_ctl_if.wid] <= 0; if (warp_ctl_if.split.diverged) begin thread_masks[warp_ctl_if.wid] <= warp_ctl_if.split.then_mask; - didnt_split <= 0; - end else begin - didnt_split <= 1; end end @@ -150,6 +138,14 @@ module VX_warp_sched #( warp_pcs[ifetch_rsp_if.wid] <= ifetch_rsp_if.PC + 4; end + // join handling + if (join_if.valid) begin + if (join_else) begin + warp_pcs[join_if.wid] <= join_pc; + end + thread_masks[join_if.wid] <= join_tm; + end + active_warps <= active_warps_n; // reset 'schedule_table' when it goes to zero @@ -174,22 +170,21 @@ module VX_warp_sched #( end end - // split/join stack management + // split/join stack management wire [(1+32+`NUM_THREADS-1):0] ipdom [`NUM_WARPS-1:0]; - wire [(1+32+`NUM_THREADS-1):0] q1 = {1'b1, 32'b0, thread_masks[warp_ctl_if.wid]}; - wire [(1+32+`NUM_THREADS-1):0] q2 = {1'b0, warp_ctl_if.split.pc, warp_ctl_if.split.else_mask}; - - assign {join_fall, join_pc, join_tm} = ipdom [join_if.wid]; - + for (genvar i = 0; i < `NUM_WARPS; i++) begin wire push = warp_ctl_if.valid - && warp_ctl_if.split.valid - && warp_ctl_if.split.diverged + && warp_ctl_if.split.valid && (i == warp_ctl_if.wid); wire pop = join_if.valid && (i == join_if.wid); + wire [`NUM_THREADS-1:0] else_mask = warp_ctl_if.split.diverged ? warp_ctl_if.split.else_mask : thread_masks[warp_ctl_if.wid]; + wire [(1+32+`NUM_THREADS-1):0] q_end = {1'b0, 32'b0, thread_masks[warp_ctl_if.wid]}; + wire [(1+32+`NUM_THREADS-1):0] q_else = {1'b1, warp_ctl_if.split.pc, else_mask}; + VX_ipdom_stack #( .WIDTH (1+32+`NUM_THREADS), .DEPTH (2 ** (`NT_BITS+1)) @@ -198,14 +193,16 @@ module VX_warp_sched #( .reset (reset), .push (push), .pop (pop), - .q1 (q1), - .q2 (q2), + .q1 (q_end), + .q2 (q_else), .d (ipdom[i]), `UNUSED_PIN (empty), `UNUSED_PIN (full) ); end + assign {join_else, join_pc, join_tm} = ipdom [join_if.wid]; + // calculate next warp schedule reg [`NUM_THREADS-1:0] thread_mask;