Files
kernels/hw/rtl/VX_warp_sched.v
2020-08-22 00:22:04 -07:00

261 lines
9.0 KiB
Verilog

`include "VX_define.vh"
module VX_warp_sched #(
parameter CORE_ID = 0
) (
input wire clk,
input wire reset,
VX_warp_ctl_if warp_ctl_if,
VX_wstall_if wstall_if,
VX_join_if join_if,
VX_branch_ctl_if branch_ctl_if,
VX_ifetch_rsp_if ifetch_rsp_if,
VX_ifetch_req_if ifetch_req_if,
output wire busy
);
wire join_fall;
wire [31:0] join_pc;
wire [`NUM_THREADS-1:0] join_tm;
reg [`NUM_WARPS-1:0] warp_active; // real active warps (updated when a warp is activated or disabled)
reg [`NUM_WARPS-1:0] warp_stalled; // asserted when a branch/gpgpu instructions are issued
reg [`NUM_WARPS-1:0] warp_ready, warp_ready_n; // enforces round-robin, barrier, and non-speculating branches
// Lock warp until instruction decode to resolve branches
reg [`NUM_WARPS-1:0] fetch_lock;
reg [`NUM_THREADS-1:0] thread_masks[`NUM_WARPS-1:0];
reg [31:0] warp_pcs[`NUM_WARPS-1:0];
// barriers
reg [`NUM_WARPS-1:0] barrier_stall_mask[`NUM_BARRIERS-1:0]; // warps waiting on barrier
wire reached_barrier_limit; // the expected number of warps reached the barrier
// wspawn
reg [31:0] use_wspawn_pc;
reg [`NUM_WARPS-1:0] use_wspawn;
wire [31:0] warp_pc;
wire [`NW_BITS-1:0] warp_to_schedule;
wire scheduled_warp;
wire [`NUM_WARPS-1:0] total_warp_stalled;
reg didnt_split;
always @(*) begin
warp_ready_n = warp_ready;
if (warp_ctl_if.valid
&& warp_ctl_if.tmc.valid
&& (0 == warp_ctl_if.tmc.thread_mask)) begin
warp_ready_n[warp_ctl_if.wid] = 0;
end
if (wstall_if.wstall) begin
warp_ready_n[wstall_if.wid] = 0;
end
if (scheduled_warp) begin
warp_ready_n[warp_to_schedule] = 0;
end
end
always @(posedge clk) begin
if (reset) begin
for (integer i = 0; i < `NUM_BARRIERS; i++) begin
barrier_stall_mask[i] <= 0;
end
use_wspawn_pc <= 0;
use_wspawn <= 0;
warp_pcs[0] <= `STARTUP_ADDR;
warp_active[0] <= 1; // Activating first warp
warp_ready[0] <= 1; // set first warp as ready
thread_masks[0] <= 1; // Activating first thread in first warp
warp_stalled <= 0;
didnt_split <= 0;
fetch_lock <= 0;
for (integer i = 1; i < `NUM_WARPS; i++) begin
warp_pcs[i] <= 0;
warp_active[i] <= 0;
warp_ready[i] <= 0;
thread_masks[i] <= 0;
end
end else begin
if (warp_ctl_if.valid && warp_ctl_if.wspawn.valid) begin
warp_active <= warp_ctl_if.wspawn.wmask;
use_wspawn <= warp_ctl_if.wspawn.wmask & (~`NUM_WARPS'(1));
use_wspawn_pc <= warp_ctl_if.wspawn.pc;
end
if (warp_ctl_if.valid && warp_ctl_if.barrier.valid) begin
warp_stalled[warp_ctl_if.wid] <= 0;
if (reached_barrier_limit) begin
barrier_stall_mask[warp_ctl_if.barrier.id] <= 0;
end else begin
barrier_stall_mask[warp_ctl_if.barrier.id][warp_ctl_if.wid] <= 1;
end
end else if (warp_ctl_if.valid && warp_ctl_if.tmc.valid) begin
thread_masks[warp_ctl_if.wid] <= warp_ctl_if.tmc.thread_mask;
warp_stalled[warp_ctl_if.wid] <= 0;
if (0 == warp_ctl_if.tmc.thread_mask) begin
warp_active[warp_ctl_if.wid] <= 0;
end
end else if (join_if.is_join && !didnt_split) begin
if (!join_fall) begin
warp_pcs[join_if.wid] <= join_pc;
end
thread_masks[join_if.wid] <= join_tm;
didnt_split <= 0;
end else if (warp_ctl_if.valid && warp_ctl_if.split.valid) begin
warp_stalled[warp_ctl_if.wid] <= 0;
if (warp_ctl_if.split.diverged) begin
thread_masks[warp_ctl_if.wid] <= warp_ctl_if.split.then_mask;
didnt_split <= 0;
end else begin
didnt_split <= 1;
end
end
if (use_wspawn[warp_to_schedule] && scheduled_warp) begin
use_wspawn[warp_to_schedule] <= 0;
thread_masks[warp_to_schedule] <= 1;
end
// Stalling the scheduling of warps
if (wstall_if.wstall) begin
warp_stalled[wstall_if.wid] <= 1;
end
// update 'warp_ready' when a warp is scheduled (update round-robin warp schedule)
if (scheduled_warp) begin
warp_pcs[warp_to_schedule] <= warp_pc + 4;
end
// Branch
if (branch_ctl_if.valid) begin
if (branch_ctl_if.taken) begin
warp_pcs[branch_ctl_if.wid] <= branch_ctl_if.dest;
end
warp_stalled[branch_ctl_if.wid] <= 0;
end
// Lock warp until instruction decode to resolve branches
if (scheduled_warp) begin
fetch_lock[warp_to_schedule] <= 1;
end
if (ifetch_rsp_if.valid && ifetch_rsp_if.ready) begin
fetch_lock[ifetch_rsp_if.wid] <= 0;
end
// reset 'warp_ready' when it goes to zero (reset round-robin warp schedule)
warp_ready <= (| warp_ready_n) ? warp_ready_n : (warp_active & ~total_warp_stalled);
end
end
// calculate active barrier status
`IGNORE_WARNINGS_BEGIN
wire [`NW_BITS:0] active_barrier_count;
`IGNORE_WARNINGS_END
VX_countones #(
.N(`NUM_WARPS)
) barrier_count (
.valids(barrier_stall_mask[warp_ctl_if.barrier.id]),
.count (active_barrier_count)
);
wire reached_barrier_limit = (active_barrier_count[`NW_BITS-1:0] == warp_ctl_if.barrier.size_m1);
reg [`NUM_WARPS-1:0] total_barrier_stall;
always @(*) begin
total_barrier_stall = barrier_stall_mask[0];
for (integer i = 1; i < `NUM_BARRIERS; ++i) begin
total_barrier_stall |= barrier_stall_mask[i];
end
end
// split/join stack management
wire [(1+32+`NUM_THREADS-1):0] ipdom[`NUM_WARPS-1:0];
wire [(1+32+`NUM_THREADS-1):0] q1 = {1'b1, 32'b0, thread_masks[warp_ctl_if.wid]};
wire [(1+32+`NUM_THREADS-1):0] q2 = {1'b0, warp_ctl_if.split.pc, warp_ctl_if.split.else_mask};
assign {join_fall, join_pc, join_tm} = ipdom[join_if.wid];
for (genvar i = 0; i < `NUM_WARPS; i++) begin
wire push = warp_ctl_if.valid
&& warp_ctl_if.split.valid
&& warp_ctl_if.split.diverged
&& (i == warp_ctl_if.wid);
wire pop = join_if.is_join
&& (i == join_if.wid);
VX_ipdom_stack #(
.WIDTH(1+32+`NUM_THREADS),
.DEPTH(`NT_BITS+1)
) ipdom_stack (
.clk (clk),
.reset(reset),
.push (push),
.pop (pop),
.q1 (q1),
.q2 (q2),
.d (ipdom[i]),
`UNUSED_PIN (empty),
`UNUSED_PIN (full)
);
end
// calculate next warp schedule
wire schedule;
assign total_warp_stalled = warp_stalled | total_barrier_stall | fetch_lock;
wire [`NUM_WARPS-1:0] use_ready = warp_ready & ~total_warp_stalled;
VX_fixed_arbiter #(
.N(`NUM_WARPS)
) choose_schedule (
.clk (clk),
.reset (reset),
.requests (use_ready),
.grant_index (warp_to_schedule),
.grant_valid (schedule),
`UNUSED_PIN (grant_onehot)
);
wire stall_out = ~ifetch_req_if.ready && ifetch_req_if.valid;
wire branch_hazard = branch_ctl_if.valid
&& branch_ctl_if.taken
&& (branch_ctl_if.wid == warp_to_schedule);
wire wstall_this_cycle = wstall_if.wstall && (wstall_if.wid == warp_to_schedule);
wire stall = stall_out || wstall_this_cycle || branch_hazard || join_if.is_join;
assign scheduled_warp = schedule && ~stall;
wire [`NUM_THREADS-1:0] thread_mask = use_wspawn[warp_to_schedule] ? `NUM_THREADS'(1) : thread_masks[warp_to_schedule];
assign warp_pc = use_wspawn[warp_to_schedule] ? use_wspawn_pc : warp_pcs[warp_to_schedule];
VX_generic_register #(
.N(1 + `NUM_THREADS + 32 + `NW_BITS)
) fetch_reg (
.clk (clk),
.reset (reset),
.stall (stall_out),
.flush (0),
.in ({scheduled_warp, thread_mask, warp_pc, warp_to_schedule}),
.out ({ifetch_req_if.valid, ifetch_req_if.thread_mask, ifetch_req_if.curr_PC, ifetch_req_if.wid})
);
assign busy = (warp_active != 0);
endmodule