tensor: Add buffer to hide 2cyc commit latency

Since operand and commit throughput are the same (2 cycles), it is
unnecessary to stall the dpu during the multi-cycle commit.
This enables the dpu to operate at full throughput of 1 operand every 2
cycles.
This commit is contained in:
Hansung Kim
2024-05-16 20:07:30 -07:00
parent 317695a8d0
commit 5034d8d14b
3 changed files with 25 additions and 4 deletions

View File

@@ -391,7 +391,7 @@
// Tensor Core Latency
`ifndef LATENCY_HMMA
`define LATENCY_HMMA 2
`define LATENCY_HMMA 8
`endif
// Icache Configurable Knobs //////////////////////////////////////////////////

View File

@@ -326,8 +326,10 @@ module VX_tensor_octet #(
end
wire hmma_ready;
wire stall = result_valid && ~result_ready;
wire outbuf_ready_in;
// wire stall = result_valid && ~result_ready;
// backpressure from commit
wire stall = ~outbuf_ready_in;
assign operands_ready = ~stall;
// TODO: Below line is to only allow 1 warp to occupy the octet at a time;
// currently, dpu is fully-pipelined and allows concurrency between
@@ -349,6 +351,7 @@ module VX_tensor_octet #(
};
// C is 4x4 fp32 matrix
logic [3:0][3:0][31:0] C_tile;
logic [3:0][3:0][31:0] D_tile;
always @(*) begin
C_tile = {
@@ -360,6 +363,7 @@ module VX_tensor_octet #(
end
wire do_hmma = (substep == 1'b1 && operands_valid && operands_ready);
wire dpu_valid;
// this does (m,n,k)=(4,4,2) matmul, modeling compute of a single octet
VX_tensor_dpu #(
@@ -377,8 +381,24 @@ module VX_tensor_octet #(
.B_tile(B_tile),
.C_tile(C_tile),
.valid_out(result_valid),
.D_tile(D_out)
.valid_out(dpu_valid),
.D_tile(D_tile)
);
// buffer to stage the result tile for 2 cycles until commit/writeback is
// complete
VX_stream_buffer #(
.DATAW ($bits(D_out)),
.OUT_REG (1) // not sure this is necessary
) output_buffer (
.clk (clk),
.reset (reset),
.valid_in (dpu_valid),
.ready_in (outbuf_ready_in),
.data_in (D_tile),
.data_out (D_out),
.ready_out (result_ready),
.valid_out (result_valid)
);
endmodule
`endif

View File

@@ -40,6 +40,7 @@ module VX_tensor_dpu #(
// ready as soon as valid_out
assign ready_in = ready_reg || valid_out;
// fixed-latency model
VX_shift_register #(
.DATAW (1 + $bits(D_tile)),
.DEPTH (`LATENCY_HMMA),