diff --git a/hw/rtl/VX_config.vh b/hw/rtl/VX_config.vh index 5ef71794..65d56e8a 100644 --- a/hw/rtl/VX_config.vh +++ b/hw/rtl/VX_config.vh @@ -391,7 +391,7 @@ // Tensor Core Latency `ifndef LATENCY_HMMA -`define LATENCY_HMMA 2 +`define LATENCY_HMMA 8 `endif // Icache Configurable Knobs ////////////////////////////////////////////////// diff --git a/hw/rtl/core/VX_tensor_core.sv b/hw/rtl/core/VX_tensor_core.sv index 185218fc..29bfb98c 100644 --- a/hw/rtl/core/VX_tensor_core.sv +++ b/hw/rtl/core/VX_tensor_core.sv @@ -326,8 +326,10 @@ module VX_tensor_octet #( end wire hmma_ready; - wire stall = result_valid && ~result_ready; + wire outbuf_ready_in; + // wire stall = result_valid && ~result_ready; // backpressure from commit + wire stall = ~outbuf_ready_in; assign operands_ready = ~stall; // TODO: Below line is to only allow 1 warp to occupy the octet at a time; // currently, dpu is fully-pipelined and allows concurrency between @@ -349,6 +351,7 @@ module VX_tensor_octet #( }; // C is 4x4 fp32 matrix logic [3:0][3:0][31:0] C_tile; + logic [3:0][3:0][31:0] D_tile; always @(*) begin C_tile = { @@ -360,6 +363,7 @@ module VX_tensor_octet #( end wire do_hmma = (substep == 1'b1 && operands_valid && operands_ready); + wire dpu_valid; // this does (m,n,k)=(4,4,2) matmul, modeling compute of a single octet VX_tensor_dpu #( @@ -377,8 +381,24 @@ module VX_tensor_octet #( .B_tile(B_tile), .C_tile(C_tile), - .valid_out(result_valid), - .D_tile(D_out) + .valid_out(dpu_valid), + .D_tile(D_tile) + ); + + // buffer to stage the result tile for 2 cycles until commit/writeback is + // complete + VX_stream_buffer #( + .DATAW ($bits(D_out)), + .OUT_REG (1) // not sure this is necessary + ) output_buffer ( + .clk (clk), + .reset (reset), + .valid_in (dpu_valid), + .ready_in (outbuf_ready_in), + .data_in (D_tile), + .data_out (D_out), + .ready_out (result_ready), + .valid_out (result_valid) ); endmodule `endif diff --git a/hw/rtl/fpu/VX_tensor_dpu.sv b/hw/rtl/fpu/VX_tensor_dpu.sv index 63d35ae7..4130fb98 100644 --- a/hw/rtl/fpu/VX_tensor_dpu.sv +++ b/hw/rtl/fpu/VX_tensor_dpu.sv @@ -40,6 +40,7 @@ module VX_tensor_dpu #( // ready as soon as valid_out assign ready_in = ready_reg || valid_out; + // fixed-latency model VX_shift_register #( .DATAW (1 + $bits(D_tile)), .DEPTH (`LATENCY_HMMA),