tensor: Share B operand buffer between threadgroups
The two threadgroups use the same B fragment, so no need to duplicately store them in the operand buffer. To do this, pull the operand buffer out of the threadgroups to the octet-level.
This commit is contained in:
@@ -287,7 +287,7 @@ module VX_tensor_core_block import VX_gpu_pkg::*; #(
|
|||||||
end
|
end
|
||||||
|
|
||||||
VX_tensor_reg #(
|
VX_tensor_reg #(
|
||||||
.N(1)
|
.DATAW(1)
|
||||||
) staging_subcommit (
|
) staging_subcommit (
|
||||||
.clk(clk),
|
.clk(clk),
|
||||||
.reset(reset),
|
.reset(reset),
|
||||||
@@ -298,15 +298,15 @@ module VX_tensor_core_block import VX_gpu_pkg::*; #(
|
|||||||
endmodule
|
endmodule
|
||||||
|
|
||||||
module VX_tensor_reg #(
|
module VX_tensor_reg #(
|
||||||
parameter N
|
parameter DATAW
|
||||||
) (
|
) (
|
||||||
input clk,
|
input clk,
|
||||||
input reset,
|
input reset,
|
||||||
input [N-1:0] d,
|
input [DATAW-1:0] d,
|
||||||
input en,
|
input en,
|
||||||
output [N-1:0] q
|
output [DATAW-1:0] q
|
||||||
);
|
);
|
||||||
logic [N-1:0] data;
|
logic [DATAW-1:0] data;
|
||||||
|
|
||||||
always @(posedge clk) begin
|
always @(posedge clk) begin
|
||||||
if (reset) begin
|
if (reset) begin
|
||||||
@@ -436,7 +436,7 @@ module VX_tensor_octet #(
|
|||||||
// Staging buffer for the A/B/C half-tiles that will later be assembled
|
// Staging buffer for the A/B/C half-tiles that will later be assembled
|
||||||
// with the other half tiles coming in on the input ports.
|
// with the other half tiles coming in on the input ports.
|
||||||
VX_tensor_reg #(
|
VX_tensor_reg #(
|
||||||
.N($bits(A_buffer) + $bits(B_buffer) + $bits(C_buffer))
|
.DATAW($bits(A_buffer) + $bits(B_buffer) + $bits(C_buffer))
|
||||||
) staging_abc (
|
) staging_abc (
|
||||||
.clk(clk),
|
.clk(clk),
|
||||||
.reset(reset),
|
.reset(reset),
|
||||||
@@ -464,7 +464,7 @@ module VX_tensor_octet #(
|
|||||||
end
|
end
|
||||||
|
|
||||||
VX_tensor_reg #(
|
VX_tensor_reg #(
|
||||||
.N($bits(substeps))
|
.DATAW($bits(substeps))
|
||||||
) staging_substeps (
|
) staging_substeps (
|
||||||
.clk(clk),
|
.clk(clk),
|
||||||
.reset(reset),
|
.reset(reset),
|
||||||
@@ -506,7 +506,7 @@ module VX_tensor_octet #(
|
|||||||
VX_tensor_dpu #(
|
VX_tensor_dpu #(
|
||||||
.ISW(ISW),
|
.ISW(ISW),
|
||||||
.OCTET(OCTET),
|
.OCTET(OCTET),
|
||||||
.ISSUE_QUEUE_DEPTH(4 /*@perf: arbtirary*/)
|
.OPERAND_BUFFER_DEPTH(4 /*@perf: arbtirary*/)
|
||||||
) dpu (
|
) dpu (
|
||||||
.clk(clk),
|
.clk(clk),
|
||||||
.reset(reset),
|
.reset(reset),
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ module VX_tensor_dpu #(
|
|||||||
// @perf: has big impact on throughput. A rule of thumb is to set it to
|
// @perf: has big impact on throughput. A rule of thumb is to set it to
|
||||||
// the pipeline length of FEDPs in order to make sure there are enough
|
// the pipeline length of FEDPs in order to make sure there are enough
|
||||||
// entries to fully saturate the pipeline, but this is still rough
|
// entries to fully saturate the pipeline, but this is still rough
|
||||||
parameter ISSUE_QUEUE_DEPTH = `LATENCY_HMMA
|
parameter OPERAND_BUFFER_DEPTH = `LATENCY_HMMA
|
||||||
) (
|
) (
|
||||||
input clk,
|
input clk,
|
||||||
input reset,
|
input reset,
|
||||||
@@ -51,65 +51,111 @@ module VX_tensor_dpu #(
|
|||||||
// stalling
|
// stalling
|
||||||
// assign ready_in = ready_out;
|
// assign ready_in = ready_out;
|
||||||
|
|
||||||
wire synced_fire;
|
wire [3:0][1:0][31:0] A_tile_buf;
|
||||||
assign synced_fire = valid_in && ready_in;
|
wire [1:0][3:0][31:0] B_tile_buf;
|
||||||
|
wire [3:0][3:0][31:0] C_tile_buf;
|
||||||
|
|
||||||
wire [1:0] threadgroup_valids;
|
wire wid_empty;
|
||||||
wire [1:0] threadgroup_readys;
|
wire wid_full;
|
||||||
// B_tile is shared across the two threadgroups; see Figure 13
|
|
||||||
VX_tensor_threadgroup #(
|
|
||||||
.ISSUE_QUEUE_DEPTH(ISSUE_QUEUE_DEPTH)
|
|
||||||
) threadgroup_0 (
|
|
||||||
.clk (clk),
|
|
||||||
.reset (reset),
|
|
||||||
.valid_in (synced_fire),
|
|
||||||
.ready_in (threadgroup_readys[0]),
|
|
||||||
.stall (!ready_out),
|
|
||||||
.A_frag (A_tile[1:0]),
|
|
||||||
.B_frag (B_tile),
|
|
||||||
.C_frag (C_tile[1:0]),
|
|
||||||
.valid_out (threadgroup_valids[0]),
|
|
||||||
.D_frag (D_tile[1:0])
|
|
||||||
);
|
|
||||||
VX_tensor_threadgroup #(
|
|
||||||
.ISSUE_QUEUE_DEPTH(ISSUE_QUEUE_DEPTH)
|
|
||||||
) threadgroup_1 (
|
|
||||||
.clk (clk),
|
|
||||||
.reset (reset),
|
|
||||||
.valid_in (synced_fire),
|
|
||||||
.ready_in (threadgroup_readys[1]),
|
|
||||||
.stall (!ready_out),
|
|
||||||
.A_frag (A_tile[3:2]),
|
|
||||||
.B_frag (B_tile),
|
|
||||||
.C_frag (C_tile[3:2]),
|
|
||||||
.valid_out (threadgroup_valids[1]),
|
|
||||||
.D_frag (D_tile[3:2])
|
|
||||||
);
|
|
||||||
|
|
||||||
wire empty;
|
wire empty;
|
||||||
wire full;
|
wire full;
|
||||||
wire enq = valid_in && ready_in;
|
// sync between operand buffer and wid buffer
|
||||||
wire deq = valid_out && ready_out;
|
assign ready_in = !full && !wid_full;
|
||||||
|
|
||||||
assign ready_in = &(threadgroup_readys) && !full;
|
wire [1:0] threadgroup_valids_out;
|
||||||
assign valid_out = &(threadgroup_valids);
|
wire [1:0] threadgroup_readys_in;
|
||||||
|
// sync operand queue and wid queue
|
||||||
|
wire threadgroup_valid_in = !empty;
|
||||||
|
wire threadgroup_fire_in = threadgroup_valid_in && &(threadgroup_readys_in);
|
||||||
|
|
||||||
|
wire enq = valid_in && ready_in;
|
||||||
|
wire deq = threadgroup_fire_in;
|
||||||
|
|
||||||
|
// Operand buffer for the dot product units.
|
||||||
|
//
|
||||||
|
// This exists to decouple the execution of the dot-product unit from
|
||||||
|
// the operand arrival. Operands from the upstream execute_if can arrive
|
||||||
|
// intermittently depending on the frontend's behavior, whereas downstream
|
||||||
|
// writeback happens at a regular cadence. Therefore to achieve full
|
||||||
|
// throughput of the dpu, we need to decouple the operand arrival from the
|
||||||
|
// direct input to the dpu.
|
||||||
|
VX_fifo_queue #(
|
||||||
|
.DATAW ($bits(A_tile) + $bits(B_tile) + $bits(C_tile)),
|
||||||
|
.DEPTH (OPERAND_BUFFER_DEPTH)
|
||||||
|
) operand_buffer (
|
||||||
|
.clk (clk),
|
||||||
|
.reset (reset),
|
||||||
|
.push (enq),
|
||||||
|
.pop (deq),
|
||||||
|
.data_in ({A_tile, B_tile, C_tile}),
|
||||||
|
.data_out ({A_tile_buf, B_tile_buf, C_tile_buf}),
|
||||||
|
.empty (empty),
|
||||||
|
`UNUSED_PIN(alm_empty),
|
||||||
|
.full (full),
|
||||||
|
`UNUSED_PIN(alm_full),
|
||||||
|
`UNUSED_PIN(size)
|
||||||
|
);
|
||||||
|
|
||||||
|
// Split A_tile and C_tile by rows (0-1, 2-3) and parallelize in two
|
||||||
|
// threadgroups
|
||||||
|
//
|
||||||
|
// B_tile is shared across the two threadgroups; see Figure 13
|
||||||
|
VX_tensor_threadgroup #(
|
||||||
|
.OPERAND_BUFFER_DEPTH(OPERAND_BUFFER_DEPTH)
|
||||||
|
) threadgroup_0 (
|
||||||
|
.clk (clk),
|
||||||
|
.reset (reset),
|
||||||
|
.valid_in (threadgroup_valid_in),
|
||||||
|
.ready_in (threadgroup_readys_in[0]),
|
||||||
|
.stall (!ready_out),
|
||||||
|
.A_frag (A_tile_buf[1:0]),
|
||||||
|
.B_frag (B_tile_buf),
|
||||||
|
.C_frag (C_tile_buf[1:0]),
|
||||||
|
.valid_out (threadgroup_valids_out[0]),
|
||||||
|
.D_frag (D_tile[1:0])
|
||||||
|
);
|
||||||
|
VX_tensor_threadgroup #(
|
||||||
|
.OPERAND_BUFFER_DEPTH(OPERAND_BUFFER_DEPTH)
|
||||||
|
) threadgroup_1 (
|
||||||
|
.clk (clk),
|
||||||
|
.reset (reset),
|
||||||
|
.valid_in (threadgroup_valid_in),
|
||||||
|
.ready_in (threadgroup_readys_in[1]),
|
||||||
|
.stall (!ready_out),
|
||||||
|
.A_frag (A_tile_buf[3:2]),
|
||||||
|
.B_frag (B_tile_buf),
|
||||||
|
.C_frag (C_tile_buf[3:2]),
|
||||||
|
.valid_out (threadgroup_valids_out[1]),
|
||||||
|
.D_frag (D_tile[3:2])
|
||||||
|
);
|
||||||
|
|
||||||
|
`RUNTIME_ASSERT(&(threadgroup_valids_out) == |(threadgroup_valids_out),
|
||||||
|
("threadgroups went out of sync!"))
|
||||||
|
`RUNTIME_ASSERT(&(threadgroup_readys_in) == |(threadgroup_readys_in),
|
||||||
|
("threadgroups went out of sync!"))
|
||||||
|
|
||||||
|
wire wid_enq = valid_in && ready_in;
|
||||||
|
wire wid_deq = valid_out && ready_out;
|
||||||
|
|
||||||
|
assign valid_out = &(threadgroup_valids_out);
|
||||||
|
|
||||||
// need to pass along warp id's to do multithreading
|
// need to pass along warp id's to do multithreading
|
||||||
VX_fifo_queue #(
|
VX_fifo_queue #(
|
||||||
.DATAW ($bits(wid)),
|
.DATAW ($bits(wid)),
|
||||||
// @perf: seems to require deeper depth than the FEDP issue queues to
|
// @perf: seems to require deeper depth than the FEDP issue queues to
|
||||||
// not cause stalls.
|
// not cause stalls.
|
||||||
.DEPTH (2 * ISSUE_QUEUE_DEPTH)
|
.DEPTH (2 * OPERAND_BUFFER_DEPTH)
|
||||||
) wid_queue (
|
) wid_queue (
|
||||||
.clk (clk),
|
.clk (clk),
|
||||||
.reset (reset),
|
.reset (reset),
|
||||||
.push (enq),
|
.push (wid_enq),
|
||||||
.pop (deq),
|
.pop (wid_deq),
|
||||||
.data_in (wid),
|
.data_in (wid),
|
||||||
.data_out (D_wid),
|
.data_out (D_wid),
|
||||||
.empty (empty),
|
.empty (wid_empty),
|
||||||
`UNUSED_PIN(alm_empty),
|
`UNUSED_PIN(alm_empty),
|
||||||
.full (full),
|
.full (wid_full),
|
||||||
`UNUSED_PIN(alm_full),
|
`UNUSED_PIN(alm_full),
|
||||||
`UNUSED_PIN(size)
|
`UNUSED_PIN(size)
|
||||||
);
|
);
|
||||||
@@ -119,9 +165,9 @@ module VX_tensor_dpu #(
|
|||||||
endmodule
|
endmodule
|
||||||
|
|
||||||
// does (m,n,k) = (2,4,2) matmul compute over 2 cycles.
|
// does (m,n,k) = (2,4,2) matmul compute over 2 cycles.
|
||||||
// matches Figure 10(b) of the paper.
|
// see Figure 10(b) of the paper.
|
||||||
module VX_tensor_threadgroup #(
|
module VX_tensor_threadgroup #(
|
||||||
parameter ISSUE_QUEUE_DEPTH
|
parameter OPERAND_BUFFER_DEPTH
|
||||||
) (
|
) (
|
||||||
input clk,
|
input clk,
|
||||||
input reset,
|
input reset,
|
||||||
@@ -136,69 +182,37 @@ module VX_tensor_threadgroup #(
|
|||||||
output valid_out,
|
output valid_out,
|
||||||
output [1:0][3:0][31:0] D_frag
|
output [1:0][3:0][31:0] D_frag
|
||||||
);
|
);
|
||||||
wire [1:0][1:0][31:0] A_frag_buf;
|
wire fedp_valid_in;
|
||||||
wire [1:0][3:0][31:0] B_frag_buf;
|
wire fedp_ready_in;
|
||||||
wire [1:0][3:0][31:0] C_frag_buf;
|
wire fedp_fire_in = fedp_valid_in && fedp_ready_in;
|
||||||
|
|
||||||
wire valid_buf;
|
|
||||||
wire ready_buf;
|
|
||||||
|
|
||||||
wire enq = valid_in && ready_in;
|
|
||||||
wire deq = valid_buf && ready_buf;
|
|
||||||
wire empty;
|
|
||||||
wire full;
|
|
||||||
assign ready_in = !full;
|
|
||||||
assign valid_buf = !empty;
|
|
||||||
|
|
||||||
// 'Issue queue' for the FEDP units.
|
|
||||||
// This exists to decouple the execution of the dot-product unit from
|
|
||||||
// the operand arrival. Operands from execute_if can arrive
|
|
||||||
// intermittently according to the frontend's behavior, and since the dpu
|
|
||||||
// can also stall for a fixed initiation latency, we need to decouple the
|
|
||||||
// two to efficiently feed the dpu.
|
|
||||||
//
|
|
||||||
// TODO: better queue design possible; e.g. B_frag is shared by two
|
|
||||||
// threadgroups, so we need only 1 queue per octet for B
|
|
||||||
VX_fifo_queue #(
|
|
||||||
.DATAW ($bits(A_frag) + $bits(B_frag) + $bits(C_frag)),
|
|
||||||
.DEPTH (ISSUE_QUEUE_DEPTH)
|
|
||||||
) input_buffer (
|
|
||||||
.clk (clk),
|
|
||||||
.reset (reset),
|
|
||||||
.push (enq),
|
|
||||||
.pop (deq),
|
|
||||||
.data_in ({A_frag, B_frag, C_frag}),
|
|
||||||
.data_out ({A_frag_buf, B_frag_buf, C_frag_buf}),
|
|
||||||
.empty (empty),
|
|
||||||
`UNUSED_PIN(alm_empty),
|
|
||||||
.full (full),
|
|
||||||
`UNUSED_PIN(alm_full),
|
|
||||||
`UNUSED_PIN(size)
|
|
||||||
);
|
|
||||||
|
|
||||||
wire [3:0] fedp_valids;
|
wire [3:0] fedp_valids;
|
||||||
wire fedp_valid_out = &(fedp_valids);
|
wire fedp_valid_out = &(fedp_valids);
|
||||||
wire fedp_ready_out = !stall;
|
wire fedp_ready_out = !stall;
|
||||||
wire fedp_fire_out = fedp_valid_out && fedp_ready_out;
|
wire fedp_fire_out = fedp_valid_out && fedp_ready_out;
|
||||||
|
|
||||||
wire fedp_valid_in = valid_buf;
|
assign fedp_valid_in = valid_in;
|
||||||
wire fedp_ready_in = fedp_ready_out; // coupled
|
// coupled ready; backpressure immediately reaches input from output
|
||||||
wire fedp_fire_in = fedp_valid_in && fedp_ready_in;
|
assign fedp_ready_in = fedp_ready_out;
|
||||||
|
|
||||||
// 0: FEDP uses first half from input_buffer
|
// The dot product units take 2 cycles to finish computing A_frag * B_frag
|
||||||
// 1: FEDP uses last half and pops input_buffer
|
// + C_frag. step_in and step_out keeps track of which cycle they're at
|
||||||
|
// & when they have to pop from input queue and push to result queue.
|
||||||
|
//
|
||||||
|
// step_in == 0: FEDP uses first half from operand buffer
|
||||||
|
// step_in == 1: FEDP uses last half and pops from operand buffer
|
||||||
wire step_in;
|
wire step_in;
|
||||||
// 0: FEDP produces first half of D_frag
|
// step_out == 0: FEDP produces first half of D_frag
|
||||||
// 1: FEDP produces last half of D_frag and asserts valid_out
|
// step_out == 1: FEDP produces last half of D_frag and asserts valid_out
|
||||||
wire step_out;
|
wire step_out;
|
||||||
assign ready_buf = fedp_fire_in && (step_in == 1'b1);
|
assign ready_in = fedp_fire_in && (step_in == 1'b1);
|
||||||
|
|
||||||
wire [3:0][31:0] D_reg;
|
wire [3:0][31:0] D_reg;
|
||||||
logic [3:0][31:0] D_reg_n;
|
logic [3:0][31:0] D_reg_n;
|
||||||
|
|
||||||
// Staging buffer that latches the D half-tile.
|
// staging buffer that latches the D half-tile
|
||||||
VX_tensor_reg #(
|
VX_tensor_reg #(
|
||||||
.N($bits(D_reg))
|
.DATAW($bits(D_reg))
|
||||||
) staging_d (
|
) staging_d (
|
||||||
.clk(clk),
|
.clk(clk),
|
||||||
.reset(reset),
|
.reset(reset),
|
||||||
@@ -220,7 +234,7 @@ module VX_tensor_threadgroup #(
|
|||||||
|
|
||||||
// flip step_in/step_out on FEDP in/out fire, respectively
|
// flip step_in/step_out on FEDP in/out fire, respectively
|
||||||
VX_tensor_reg #(
|
VX_tensor_reg #(
|
||||||
.N(1)
|
.DATAW(1)
|
||||||
) staging_step_in (
|
) staging_step_in (
|
||||||
.clk(clk),
|
.clk(clk),
|
||||||
.reset(reset),
|
.reset(reset),
|
||||||
@@ -229,7 +243,7 @@ module VX_tensor_threadgroup #(
|
|||||||
.q(step_in)
|
.q(step_in)
|
||||||
);
|
);
|
||||||
VX_tensor_reg #(
|
VX_tensor_reg #(
|
||||||
.N(1)
|
.DATAW(1)
|
||||||
) staging_step_out (
|
) staging_step_out (
|
||||||
.clk(clk),
|
.clk(clk),
|
||||||
.reset(reset),
|
.reset(reset),
|
||||||
@@ -259,15 +273,15 @@ module VX_tensor_threadgroup #(
|
|||||||
.clock (clk),
|
.clock (clk),
|
||||||
.reset (reset),
|
.reset (reset),
|
||||||
.io_in_valid (fedp_fire_in),
|
.io_in_valid (fedp_fire_in),
|
||||||
.io_in_bits_a_0 (A_frag_buf[d_row][0]),
|
.io_in_bits_a_0 (A_frag[d_row][0]),
|
||||||
.io_in_bits_a_1 (A_frag_buf[d_row][1]),
|
.io_in_bits_a_1 (A_frag[d_row][1]),
|
||||||
.io_in_bits_a_2 (32'h0),
|
.io_in_bits_a_2 (32'h0),
|
||||||
.io_in_bits_a_3 (32'h0),
|
.io_in_bits_a_3 (32'h0),
|
||||||
.io_in_bits_b_0 (step_in == 1'b0 ? B_frag_buf[0][d_col] : B_frag_buf[0][d_col + 1]),
|
.io_in_bits_b_0 (step_in == 1'b0 ? B_frag[0][d_col] : B_frag[0][d_col + 1]),
|
||||||
.io_in_bits_b_1 (step_in == 1'b0 ? B_frag_buf[1][d_col] : B_frag_buf[1][d_col + 1]),
|
.io_in_bits_b_1 (step_in == 1'b0 ? B_frag[1][d_col] : B_frag[1][d_col + 1]),
|
||||||
.io_in_bits_b_2 (32'h0),
|
.io_in_bits_b_2 (32'h0),
|
||||||
.io_in_bits_b_3 (32'h0),
|
.io_in_bits_b_3 (32'h0),
|
||||||
.io_in_bits_c (step_in == 1'b0 ? C_frag_buf[d_row][d_col] : C_frag_buf[d_row][d_col + 1]),
|
.io_in_bits_c (step_in == 1'b0 ? C_frag[d_row][d_col] : C_frag[d_row][d_col + 1]),
|
||||||
.io_stall (stall),
|
.io_stall (stall),
|
||||||
.io_out_valid (fedp_valids[i]),
|
.io_out_valid (fedp_valids[i]),
|
||||||
.io_out_bits_data (D_half[i])
|
.io_out_bits_data (D_half[i])
|
||||||
|
|||||||
Reference in New Issue
Block a user