Handle BLOCK_SIZE != 1 in dispatch_unit

+ change ALU and FPU unit to use it as well
This commit is contained in:
Hansung Kim
2024-05-30 23:20:21 -07:00
parent a02773eb92
commit 52bb827a46
3 changed files with 128 additions and 17 deletions

View File

@@ -42,7 +42,7 @@ module VX_alu_unit #(
`RESET_RELAY (dispatch_reset, reset);
VX_dispatch_unit #(
VX_dispatch_unit_sane #(
.BLOCK_SIZE (BLOCK_SIZE),
.NUM_LANES (NUM_LANES),
.OUT_REG (PARTIAL_BW ? 1 : 0)

View File

@@ -78,20 +78,25 @@ module VX_dispatch_unit_sane import VX_gpu_pkg::*; #(
// assign batch_idx = 0;
// `UNUSED_VAR(batch_done)
// end
// group dispatch_valid by blocks
wire [BATCH_COUNT-1:0] batch_valids;
for (genvar i = 0; i < BATCH_COUNT; ++i) begin
assign batch_valids[i] = |(dispatch_valid[(BLOCK_SIZE * i) +: BLOCK_SIZE]);
end
// elect the leftmost-valid batch for the dispatch
wire dispatch_any_valid;
VX_lzc_rr #(
.N (`ISSUE_WIDTH)
.N (BATCH_COUNT)
) batch_select (
.clk (clk),
.reset (reset),
.data_in (dispatch_valid),
.data_in (batch_valids),
.data_out (batch_idx),
.valid_out (dispatch_any_valid)
);
`STATIC_ASSERT ((BLOCK_SIZE == 1), ("dispatch_unit_sane only supports BLOCK_SIZE == 1 for now"))
for (genvar block_idx = 0; block_idx < BLOCK_SIZE; ++block_idx) begin
wire [ISSUE_W-1:0] issue_idx = ISSUE_W'(batch_idx * BLOCK_SIZE) + ISSUE_W'(block_idx);
@@ -99,16 +104,122 @@ module VX_dispatch_unit_sane import VX_gpu_pkg::*; #(
wire valid_p, ready_p;
assign valid_p = dispatch_valid[issue_idx];
assign block_tmask[block_idx] = dispatch_data[issue_idx][DATA_TMASK_OFF +: `NUM_THREADS];
assign block_regs[block_idx][0] = dispatch_data[issue_idx][DATA_REGS_OFF + 2 * `NUM_THREADS * `XLEN +: `NUM_THREADS * `XLEN];
assign block_regs[block_idx][1] = dispatch_data[issue_idx][DATA_REGS_OFF + 1 * `NUM_THREADS * `XLEN +: `NUM_THREADS * `XLEN];
assign block_regs[block_idx][2] = dispatch_data[issue_idx][DATA_REGS_OFF + 0 * `NUM_THREADS * `XLEN +: `NUM_THREADS * `XLEN];
assign block_pid[block_idx] = '0;
assign block_sop[block_idx] = 1'b1;
assign block_eop[block_idx] = 1'b1;
assign block_ready[block_idx] = ready_p;
assign block_done[block_idx] = ~valid_p || ready_p;
if (`NUM_THREADS != NUM_LANES) begin
reg [NUM_PACKETS-1:0] sent_mask_p;
wire [PID_WIDTH-1:0] start_p_n, start_p, end_p;
wire dispatch_valid_r;
reg is_first_p;
wire fire_p = valid_p && ready_p;
wire is_last_p = (start_p == end_p);
wire fire_eop = fire_p && is_last_p;
always @(posedge clk) begin
if (reset) begin
sent_mask_p <= '0;
is_first_p <= 1;
end else begin
if ((BATCH_COUNT != 1) ? batch_done : fire_eop) begin
sent_mask_p <= '0;
is_first_p <= 1;
end else if (fire_p) begin
sent_mask_p[start_p] <= 1;
is_first_p <= 0;
end
end
end
wire [NUM_PACKETS-1:0][NUM_LANES-1:0] per_packet_tmask;
wire [NUM_PACKETS-1:0][2:0][NUM_LANES-1:0][`XLEN-1:0] per_packet_regs;
wire [`NUM_THREADS-1:0] dispatch_tmask = dispatch_data[issue_idx][DATA_TMASK_OFF +: `NUM_THREADS];
wire [`NUM_THREADS-1:0][`XLEN-1:0] dispatch_rs1_data = dispatch_data[issue_idx][DATA_REGS_OFF + 2 * `NUM_THREADS * `XLEN +: `NUM_THREADS * `XLEN];
wire [`NUM_THREADS-1:0][`XLEN-1:0] dispatch_rs2_data = dispatch_data[issue_idx][DATA_REGS_OFF + 1 * `NUM_THREADS * `XLEN +: `NUM_THREADS * `XLEN];
wire [`NUM_THREADS-1:0][`XLEN-1:0] dispatch_rs3_data = dispatch_data[issue_idx][DATA_REGS_OFF + 0 * `NUM_THREADS * `XLEN +: `NUM_THREADS * `XLEN];
for (genvar i = 0; i < NUM_PACKETS; ++i) begin
for (genvar j = 0; j < NUM_LANES; ++j) begin
localparam k = i * NUM_LANES + j;
assign per_packet_tmask[i][j] = dispatch_tmask[k];
assign per_packet_regs[i][0][j] = dispatch_rs1_data[k];
assign per_packet_regs[i][1][j] = dispatch_rs2_data[k];
assign per_packet_regs[i][2][j] = dispatch_rs3_data[k];
end
end
wire [NUM_PACKETS-1:0] packet_valids;
wire [NUM_PACKETS-1:0][PID_WIDTH-1:0] packet_ids;
for (genvar i = 0; i < NUM_PACKETS; ++i) begin
assign packet_valids[i] = (| per_packet_tmask[i]);
assign packet_ids[i] = PID_WIDTH'(i);
end
VX_find_first #(
.N (NUM_PACKETS),
.DATAW (PID_WIDTH),
.REVERSE (0)
) find_first (
.valid_in (packet_valids & ~sent_mask_p),
.data_in (packet_ids),
.data_out (start_p_n),
`UNUSED_PIN (valid_out)
);
VX_find_first #(
.N (NUM_PACKETS),
.DATAW (PID_WIDTH),
.REVERSE (1)
) find_last (
.valid_in (packet_valids),
.data_in (packet_ids),
.data_out (end_p),
`UNUSED_PIN (valid_out)
);
VX_pipe_register #(
.DATAW (1 + PID_WIDTH),
.RESETW (1),
.DEPTH (FANOUT_ENABLE ? 1 : 0)
) pipe_reg (
.clk (clk),
.reset (reset || fire_p), // should flush on fire
.enable (1'b1),
.data_in ({dispatch_valid[issue_idx], start_p_n}),
.data_out ({dispatch_valid_r, start_p})
);
wire [NUM_LANES-1:0] tmask_p = per_packet_tmask[start_p];
wire [2:0][NUM_LANES-1:0][`XLEN-1:0] regs_p = per_packet_regs[start_p];
wire block_enable = (BATCH_COUNT == 1 || ~(& sent_mask_p));
assign valid_p = dispatch_valid_r && block_enable;
assign block_tmask[block_idx] = tmask_p;
assign block_regs[block_idx] = regs_p;
assign block_pid[block_idx] = start_p;
assign block_sop[block_idx] = is_first_p;
assign block_eop[block_idx] = is_last_p;
if (FANOUT_ENABLE) begin
assign block_ready[block_idx] = dispatch_valid_r && ready_p && block_enable;
end else begin
assign block_ready[block_idx] = ready_p && block_enable;
end
assign block_done[block_idx] = ~dispatch_valid[issue_idx] || fire_eop;
end else begin
assign valid_p = dispatch_valid[issue_idx];
assign block_tmask[block_idx] = dispatch_data[issue_idx][DATA_TMASK_OFF +: `NUM_THREADS];
assign block_regs[block_idx][0] = dispatch_data[issue_idx][DATA_REGS_OFF + 2 * `NUM_THREADS * `XLEN +: `NUM_THREADS * `XLEN];
assign block_regs[block_idx][1] = dispatch_data[issue_idx][DATA_REGS_OFF + 1 * `NUM_THREADS * `XLEN +: `NUM_THREADS * `XLEN];
assign block_regs[block_idx][2] = dispatch_data[issue_idx][DATA_REGS_OFF + 0 * `NUM_THREADS * `XLEN +: `NUM_THREADS * `XLEN];
assign block_pid[block_idx] = '0;
assign block_sop[block_idx] = 1'b1;
assign block_eop[block_idx] = 1'b1;
assign block_ready[block_idx] = ready_p;
assign block_done[block_idx] = ~valid_p || ready_p;
end
wire [ISSUE_ISW_W-1:0] isw;
if (BATCH_COUNT != 1) begin

View File

@@ -39,7 +39,7 @@ module VX_fpu_unit import VX_fpu_pkg::*; #(
`RESET_RELAY (dispatch_reset, reset);
VX_dispatch_unit #(
VX_dispatch_unit_sane #(
.BLOCK_SIZE (BLOCK_SIZE),
.NUM_LANES (NUM_LANES),
.OUT_REG (PARTIAL_BW ? 1 : 0)