diff --git a/hw/dpi/float_dpi.cpp b/hw/dpi/float_dpi.cpp index 29ca22df..6a810555 100644 --- a/hw/dpi/float_dpi.cpp +++ b/hw/dpi/float_dpi.cpp @@ -347,7 +347,7 @@ void dpi_fmax(bool enable, int dst_fmt, int64_t a, int64_t b, int64_t* result, s // A is M * K, B is K * M, C is M * M, D is M * M #define M 4 -#define K 2 +#define K 2 // FIXME: 4x4x1 / cycle / octet! // all row major float c_A_tile[M][K]; @@ -551,7 +551,7 @@ void dpi_print_results(int wid, int octet, const svBitVecVal* A_tile, const svBi } steps[wid] += 1; - if (steps[wid] % 64 == 0) { + if (steps[wid] % 32 == 0) { steps[wid] = 0; std::cout << "warp " << wid << " finished wmma\n"; std::cout << "A tile" << "\n"; diff --git a/hw/rtl/VX_config.vh b/hw/rtl/VX_config.vh index 65d56e8a..5ef71794 100644 --- a/hw/rtl/VX_config.vh +++ b/hw/rtl/VX_config.vh @@ -391,7 +391,7 @@ // Tensor Core Latency `ifndef LATENCY_HMMA -`define LATENCY_HMMA 8 +`define LATENCY_HMMA 2 `endif // Icache Configurable Knobs ////////////////////////////////////////////////// diff --git a/hw/rtl/core/VX_tensor_core.sv b/hw/rtl/core/VX_tensor_core.sv index 14d8175b..185218fc 100644 --- a/hw/rtl/core/VX_tensor_core.sv +++ b/hw/rtl/core/VX_tensor_core.sv @@ -32,6 +32,10 @@ module VX_tensor_core import VX_gpu_pkg::*; #( .execute_if (execute_if) ); + // FIXME: when multiple warps are running, step0_0 from multiple warps can + // get interleaved before the first warp advances to step0_1, fucking + // everything up + VX_commit_if #( .NUM_LANES (NUM_LANES) ) commit_block_if[BLOCK_SIZE](); @@ -175,6 +179,7 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #( execute_if.data.PC, execute_if.data.wb, execute_if.data.rd + // pid/sop/eop set later }; wire [DATAW-1:0] execute_if_data_deq; @@ -320,8 +325,16 @@ module VX_tensor_octet #( end end + wire hmma_ready; wire stall = result_valid && ~result_ready; + // backpressure from commit assign operands_ready = ~stall; + // TODO: Below line is to only allow 1 warp to occupy the octet at a time; + // currently, dpu is fully-pipelined and allows concurrency between + // multiple warps. This seems to be not a problem though given that the + // RF operand read takes >=2 cycles, which should be the end-to-end + // latency of the DPU anyways + // assign operands_ready = hmma_ready && ~stall; // A is 4x2 fp32 matrix wire [3:0][1:0][31:0] A_tile = { @@ -359,6 +372,7 @@ module VX_tensor_octet #( .stall(stall), .valid_in(do_hmma), + .ready_in(hmma_ready), .A_tile(A_tile), .B_tile(B_tile), .C_tile(C_tile), diff --git a/hw/rtl/fpu/VX_tensor_dpu.sv b/hw/rtl/fpu/VX_tensor_dpu.sv index cfc5f507..63d35ae7 100644 --- a/hw/rtl/fpu/VX_tensor_dpu.sv +++ b/hw/rtl/fpu/VX_tensor_dpu.sv @@ -11,6 +11,7 @@ module VX_tensor_dpu #( input stall, input valid_in, + output ready_in, input [3:0][1:0][31:0] A_tile, input [1:0][3:0][31:0] B_tile, input [3:0][3:0][31:0] C_tile, @@ -24,12 +25,20 @@ module VX_tensor_dpu #( dpi_hmma(valid_in, A_tile, B_tile, C_tile, result_hmma); end + logic ready_reg; always @(posedge clk) begin - if (~reset && valid_in) begin + if (reset) begin + ready_reg <= '1; + end else if (valid_in) begin + ready_reg <= '0; dpi_print_results(int'(ISW), int'(OCTET), A_tile, B_tile, C_tile, result_hmma); + end else if (valid_out) begin + ready_reg <= '1; end end - + + // ready as soon as valid_out + assign ready_in = ready_reg || valid_out; VX_shift_register #( .DATAW (1 + $bits(D_tile)),