RTL code refactoring

This commit is contained in:
Blaise Tine
2020-04-21 01:52:12 -04:00
parent 20ae78f434
commit 43a8bf4326
2 changed files with 72 additions and 60 deletions

View File

@@ -1,66 +1,65 @@
`include "VX_define.vh" `include "VX_define.vh"
module VX_warp_sched ( module VX_warp_sched (
input wire clk, // Clock input wire clk, // Clock
input wire reset, input wire reset,
input wire stall, input wire stall,
// Wspawn // Wspawn
input wire wspawn, input wire wspawn,
input wire[31:0] wsapwn_pc, input wire[31:0] wsapwn_pc,
input wire[`NUM_WARPS-1:0] wspawn_new_active, input wire[`NUM_WARPS-1:0] wspawn_new_active,
// CTM // CTM
input wire ctm, input wire ctm,
input wire[`NUM_THREADS-1:0] ctm_mask, input wire[`NUM_THREADS-1:0] ctm_mask,
input wire[`NW_BITS-1:0] ctm_warp_num, input wire[`NW_BITS-1:0] ctm_warp_num,
// WHALT // WHALT
input wire whalt, input wire whalt,
input wire[`NW_BITS-1:0] whalt_warp_num, input wire[`NW_BITS-1:0] whalt_warp_num,
input wire is_barrier, input wire is_barrier,
`DEBUG_BEGIN `DEBUG_BEGIN
input wire[31:0] barrier_id, input wire[31:0] barrier_id,
`DEBUG_END `DEBUG_END
input wire[$clog2(`NUM_WARPS):0] num_warps, input wire[$clog2(`NUM_WARPS):0] num_warps,
input wire[`NW_BITS-1:0] barrier_warp_num, input wire[`NW_BITS-1:0] barrier_warp_num,
// WSTALL // WSTALL
input wire wstall, input wire wstall,
input wire[`NW_BITS-1:0] wstall_warp_num, input wire[`NW_BITS-1:0] wstall_warp_num,
// Split // Split
input wire is_split, input wire is_split,
input wire dont_split, input wire dont_split,
input wire[`NUM_THREADS-1:0] split_new_mask, input wire[`NUM_THREADS-1:0] split_new_mask,
input wire[`NUM_THREADS-1:0] split_later_mask, input wire[`NUM_THREADS-1:0] split_later_mask,
input wire[31:0] split_save_pc, input wire[31:0] split_save_pc,
input wire[`NW_BITS-1:0] split_warp_num, input wire[`NW_BITS-1:0] split_warp_num,
// Join // Join
input wire is_join, input wire is_join,
input wire[`NW_BITS-1:0] join_warp_num, input wire[`NW_BITS-1:0] join_warp_num,
// JAL // JAL
input wire jal, input wire jal,
input wire[31:0] jal_dest, input wire[31:0] jal_dest,
input wire[`NW_BITS-1:0] jal_warp_num, input wire[`NW_BITS-1:0] jal_warp_num,
// Branch // Branch
input wire branch_valid, input wire branch_valid,
input wire branch_dir, input wire branch_dir,
input wire[31:0] branch_dest, input wire[31:0] branch_dest,
input wire[`NW_BITS-1:0] branch_warp_num, input wire[`NW_BITS-1:0] branch_warp_num,
output wire[`NUM_THREADS-1:0] thread_mask, output wire[`NUM_THREADS-1:0] thread_mask,
output wire[`NW_BITS-1:0] warp_num, output wire[`NW_BITS-1:0] warp_num,
output wire[31:0] warp_pc, output wire[31:0] warp_pc,
output wire ebreak, output wire ebreak,
output wire scheduled_warp, output wire scheduled_warp,
input wire[`NW_BITS-1:0] icache_stage_wid,
input wire[`NUM_THREADS-1:0] icache_stage_valids
input wire[`NW_BITS-1:0] icache_stage_wid,
input wire[`NUM_THREADS-1:0] icache_stage_valids
); );
wire update_use_wspawn; wire update_use_wspawn;
wire update_visible_active; wire update_visible_active;
@@ -226,16 +225,21 @@ module VX_warp_sched (
end end
end end
VX_countones #(.N(`NUM_WARPS)) barrier_count( VX_countones #(
.N(`NUM_WARPS)
) barrier_count (
.valids(curr_barrier_mask), .valids(curr_barrier_mask),
.count (curr_barrier_count) .count (curr_barrier_count)
); );
wire[$clog2(`NUM_WARPS):0] count_visible_active; wire [$clog2(`NUM_WARPS):0] count_visible_active;
VX_countones #(.N(`NUM_WARPS)) num_visible(
VX_countones #(
.N(`NUM_WARPS)
) num_visible (
.valids(visible_active), .valids(visible_active),
.count (count_visible_active) .count (count_visible_active)
); );
// assign curr_barrier_count = $countones(curr_barrier_mask); // assign curr_barrier_count = $countones(curr_barrier_mask);
@@ -254,17 +258,13 @@ module VX_warp_sched (
// end // end
// end // end
assign update_visible_active = (count_visible_active < 1) && !(stall || wstall_this_cycle || hazard || is_join); assign update_visible_active = (count_visible_active < 1) && !(stall || wstall_this_cycle || hazard || is_join);
wire[(1+32+`NUM_THREADS-1):0] q1 = {1'b1, 32'b0 , thread_masks[split_warp_num]}; wire[(1+32+`NUM_THREADS-1):0] q1 = {1'b1, 32'b0 , thread_masks[split_warp_num]};
wire[(1+32+`NUM_THREADS-1):0] q2 = {1'b0, split_save_pc , split_later_mask}; wire[(1+32+`NUM_THREADS-1):0] q2 = {1'b0, split_save_pc , split_later_mask};
assign {join_fall, join_pc, join_tm} = d[join_warp_num]; assign {join_fall, join_pc, join_tm} = d[join_warp_num];
genvar curr_warp; genvar curr_warp;
generate generate
for (curr_warp = 0; curr_warp < `NUM_WARPS; curr_warp = curr_warp + 1) begin : stacks for (curr_warp = 0; curr_warp < `NUM_WARPS; curr_warp = curr_warp + 1) begin : stacks
@@ -273,7 +273,11 @@ module VX_warp_sched (
wire push = (is_split && !dont_split) && correct_warp_s; wire push = (is_split && !dont_split) && correct_warp_s;
wire pop = is_join && correct_warp_j; wire pop = is_join && correct_warp_j;
VX_generic_stack #(.WIDTH(1+32+`NUM_THREADS), .DEPTH($clog2(`NUM_THREADS)+1)) ipdom_stack(
VX_generic_stack #(
.WIDTH(1+32+`NUM_THREADS),
.DEPTH($clog2(`NUM_THREADS)+1)
) ipdom_stack(
.clk (clk), .clk (clk),
.reset(reset), .reset(reset),
.push (push), .push (push),
@@ -308,11 +312,12 @@ module VX_warp_sched (
assign new_pc = warp_pc + 4; assign new_pc = warp_pc + 4;
assign use_active = (count_visible_active < 1) ? (warp_active & (~warp_stalled) & (~total_barrier_stall) & (~warp_lock)) : visible_active; assign use_active = (count_visible_active < 1) ? (warp_active & (~warp_stalled) & (~total_barrier_stall) & (~warp_lock)) : visible_active;
// Choosing a warp to schedule // Choosing a warp to schedule
VX_priority_encoder choose_schedule( VX_priority_encoder #(
.N(`NUM_WARPS)
) choose_schedule (
.valids(use_active), .valids(use_active),
.index (warp_to_schedule), .index (warp_to_schedule),
.found (schedule) .found (schedule)

View File

@@ -1,21 +1,28 @@
`include "VX_define.vh" `include "VX_define.vh"
module VX_priority_encoder ( module VX_priority_encoder #(
input wire[`NUM_WARPS-1:0] valids, parameter N
output reg[`NW_BITS-1:0] index, ) (
output reg found input wire [N-1:0] valids,
output wire [`LOG2UP(N)-1:0] index,
output wire found
); );
reg [`LOG2UP(N)-1:0] index_r;
reg found_r;
integer i; integer i;
always @(*) begin always @(*) begin
index = 0; index_r = 0;
found = 0; found_r = 0;
for (i = `NUM_WARPS-1; i >= 0; i = i - 1) begin for (i = `NUM_WARPS-1; i >= 0; i = i - 1) begin
if (valids[i]) begin if (valids[i]) begin
index = i[`NW_BITS-1:0]; index_r = i[`NW_BITS-1:0];
found = 1; found_r = 1;
end end
end end
end end
assign index = index_r;
assign found = found_r;
endmodule endmodule