diff --git a/hw/rtl/VX_ipdom_stack.sv b/hw/rtl/VX_ipdom_stack.sv index 4cdb1317..2c0cc322 100644 --- a/hw/rtl/VX_ipdom_stack.sv +++ b/hw/rtl/VX_ipdom_stack.sv @@ -6,11 +6,13 @@ module VX_ipdom_stack #( ) ( input wire clk, input wire reset, + input wire pair, input wire [WIDTH - 1:0] q1, input wire [WIDTH - 1:0] q2, output wire [WIDTH - 1:0] d, input wire push, input wire pop, + output wire index, output wire empty, output wire full ); @@ -52,15 +54,15 @@ module VX_ipdom_stack #( always @(posedge clk) begin if (push) begin - is_part[wr_ptr] <= 0; + is_part[wr_ptr] <= ~pair; end else if (pop) begin is_part[rd_ptr] <= 1; end end - wire p = is_part[rd_ptr]; - assign d = p ? d1 : d2; - assign empty = ~(| wr_ptr); + assign index = is_part[rd_ptr]; + assign d = index ? d1 : d2; + assign empty = (ADDRW'(0) == wr_ptr); assign full = (ADDRW'(DEPTH-1) == wr_ptr); endmodule \ No newline at end of file diff --git a/hw/rtl/VX_warp_sched.sv b/hw/rtl/VX_warp_sched.sv index e206c09e..9495c001 100644 --- a/hw/rtl/VX_warp_sched.sv +++ b/hw/rtl/VX_warp_sched.sv @@ -24,7 +24,7 @@ module VX_warp_sched #( wire join_else; wire [31:0] join_pc; - wire [`NUM_THREADS-1:0] join_tm; + wire [`NUM_THREADS-1:0] join_tmask; reg [`NUM_WARPS-1:0] active_warps, active_warps_n; // real active warps (updated when a warp is activated or disabled) reg [`NUM_WARPS-1:0] stalled_warps; // asserted when a branch/gpgpu instructions are issued @@ -132,7 +132,7 @@ module VX_warp_sched #( if (join_else) begin warp_pcs[join_if.wid] <= join_pc; end - thread_masks[join_if.wid] <= join_tm; + thread_masks[join_if.wid] <= join_tmask; end active_warps <= active_warps_n; @@ -162,9 +162,8 @@ module VX_warp_sched #( // split/join stack management - wire [(1+32+`NUM_THREADS)-1:0] ipdom [`NUM_WARPS-1:0]; - - wire [`NUM_THREADS-1:0] curr_tmask = thread_masks[warp_ctl_if.wid]; + wire [(32+`NUM_THREADS)-1:0] ipdom_data [`NUM_WARPS-1:0]; + wire ipdom_index [`NUM_WARPS-1:0]; for (genvar i = 0; i < `NUM_WARPS; i++) begin wire push = warp_ctl_if.valid @@ -173,27 +172,32 @@ module VX_warp_sched #( wire pop = join_if.valid && (i == join_if.wid); - wire [`NUM_THREADS-1:0] else_tmask = warp_ctl_if.split.diverged ? warp_ctl_if.split.else_tmask : curr_tmask; - wire [(1+32+`NUM_THREADS)-1:0] q_end = {1'b0, 32'b0, curr_tmask}; - wire [(1+32+`NUM_THREADS)-1:0] q_else = {1'b1, warp_ctl_if.split.pc, else_tmask}; + wire [`NUM_THREADS-1:0] else_tmask = warp_ctl_if.split.else_tmask; + wire [`NUM_THREADS-1:0] orig_tmask = thread_masks[warp_ctl_if.wid]; + + wire [(32+`NUM_THREADS)-1:0] q_else = {warp_ctl_if.split.pc, else_tmask}; + wire [(32+`NUM_THREADS)-1:0] q_end = {32'b0, orig_tmask}; VX_ipdom_stack #( - .WIDTH (1+32+`NUM_THREADS), + .WIDTH (32+`NUM_THREADS), .DEPTH (2 ** (`NT_BITS+1)) ) ipdom_stack ( .clk (clk), .reset (reset), .push (push), .pop (pop), + .pair (warp_ctl_if.split.diverged), .q1 (q_end), .q2 (q_else), - .d (ipdom[i]), + .d (ipdom_data[i]), + .index (ipdom_index[i]), `UNUSED_PIN (empty), `UNUSED_PIN (full) ); end - assign {join_else, join_pc, join_tm} = ipdom [join_if.wid]; + assign {join_pc, join_tmask} = ipdom_data[join_if.wid]; + assign join_else = ~ipdom_index[join_if.wid]; // schedule the next ready warp