diff --git a/hw/rtl/core/VX_tensor_core.sv b/hw/rtl/core/VX_tensor_core.sv index e37f5016..b6b11754 100644 --- a/hw/rtl/core/VX_tensor_core.sv +++ b/hw/rtl/core/VX_tensor_core.sv @@ -83,6 +83,8 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #( localparam LANE_OFFSET_THREADGROUP = (4 * NUM_OCTETS); wire [1:0] step = 2'(execute_if.data.op_type); + wire operands_last_in_pair = (execute_if.data.op_mod == `INST_MOD_BITS'(1)); + logic [NUM_OCTETS-1:0] octet_results_valid; logic [NUM_OCTETS-1:0] octet_results_ready; logic [NUM_OCTETS-1:0] octet_operands_ready; @@ -111,6 +113,8 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #( logic [3:0][3:0][31:0] octet_D; logic result_valid; logic result_ready; + + // op_mod is reused to indicate instruction's id in pair VX_tensor_octet #( .ISW(ISW), .OCTET(i) @@ -122,6 +126,8 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #( .B_in(octet_B), .C_in(octet_C), .operands_valid(execute_if.valid), + .operands_wid(execute_if.data.wid), + .operands_last_in_pair(operands_last_in_pair), .operands_ready(octet_operands_ready[i]), .step(step), @@ -245,11 +251,14 @@ module VX_tensor_octet #( input clk, input reset, - input [7:0][31:0] A_in, - input [7:0][31:0] B_in, - input [7:0][31:0] C_in, - input operands_valid, // we have to backpressure due to there potentially being contention over commit - output operands_ready, + input [7:0][31:0] A_in, + input [7:0][31:0] B_in, + input [7:0][31:0] C_in, + input operands_valid, + input [`NW_WIDTH-1:0] operands_wid, + input operands_last_in_pair, + // we have to backpressure due to there potentially being contention over commit + output operands_ready, input [1:0] step, @@ -258,9 +267,9 @@ module VX_tensor_octet #( input result_ready ); // 512 bits/octet * 4 octets per warp - logic [3:0][31:0] A_buffer, A_buffer_n; - logic [3:0][31:0] B_buffer, B_buffer_n; - logic [7:0][31:0] C_buffer, C_buffer_n; + logic [`NUM_WARPS-1:0][3:0][31:0] A_buffer, A_buffer_n; + logic [`NUM_WARPS-1:0][3:0][31:0] B_buffer, B_buffer_n; + logic [`NUM_WARPS-1:0][7:0][31:0] C_buffer, C_buffer_n; // half the inputs are buffered, half are not (instead coming straight // from operand bus) unlike the real tensor core. @@ -268,6 +277,10 @@ module VX_tensor_octet #( logic [3:0][31:0] A_half; logic [3:0][31:0] B_half; logic [7:0][31:0] C_half; + + logic [`NUM_WARPS-1:0] substeps; + logic [`NUM_WARPS-1:0] substeps_n; + always @(*) begin // note that not all lanes participate at every step case (step) @@ -296,18 +309,29 @@ module VX_tensor_octet #( end logic substep; - wire substep_n = (operands_ready && operands_valid) ? ~substep : substep; + wire operands_fire = operands_ready && operands_valid; + wire substep_n = operands_fire && operands_last_in_pair; always @(*) begin A_buffer_n = A_buffer; B_buffer_n = B_buffer; C_buffer_n = C_buffer; + substeps_n = substeps; - if (substep == 1'b0) begin - A_buffer_n = A_half; - B_buffer_n = B_half; - C_buffer_n = C_half; + if (operands_fire) begin + substeps_n[operands_wid] = ~substeps[operands_wid]; + if (!operands_last_in_pair) begin + A_buffer_n[operands_wid] = A_half; + B_buffer_n[operands_wid] = B_half; + C_buffer_n[operands_wid] = C_half; + end end + + // if (operands_fire && (substep == 1'b0)) begin + // A_buffer_n[operands_wid] = A_half; + // B_buffer_n[operands_wid] = B_half; + // C_buffer_n[operands_wid] = C_half; + // end end always @(posedge clk) begin @@ -315,13 +339,17 @@ module VX_tensor_octet #( A_buffer <= '0; B_buffer <= '0; C_buffer <= '0; + substep <= '0; + substeps <= '0; end else begin A_buffer <= A_buffer_n; B_buffer <= B_buffer_n; C_buffer <= C_buffer_n; + substep <= substep_n; + substeps <= substeps_n; end end @@ -330,39 +358,38 @@ module VX_tensor_octet #( // wire stall = result_valid && ~result_ready; // backpressure from commit wire stall = ~outbuf_ready_in; - assign operands_ready = ~stall; + // assign operands_ready = ~stall; // TODO: Below line is to only allow 1 warp to occupy the octet at a time; // currently, dpu is fully-pipelined and allows concurrency between // multiple warps. This seems to be not a problem though given that the // RF operand read takes >=2 cycles, which should be the end-to-end // latency of the DPU anyways - // assign operands_ready = hmma_ready && ~stall; + assign operands_ready = hmma_ready && ~stall; // A is 4x2 fp32 matrix wire [3:0][1:0][31:0] A_tile = { - { A_half[3], A_buffer[3] }, - { A_half[2], A_buffer[2] }, - { A_half[1], A_buffer[1] }, - { A_half[0], A_buffer[0] } + { A_half[3], A_buffer[operands_wid][3] }, + { A_half[2], A_buffer[operands_wid][2] }, + { A_half[1], A_buffer[operands_wid][1] }, + { A_half[0], A_buffer[operands_wid][0] } }; // B is 2x4 fp32 matrix wire [1:0][3:0][31:0] B_tile = { - B_half, B_buffer + B_half, B_buffer[operands_wid] }; // C is 4x4 fp32 matrix logic [3:0][3:0][31:0] C_tile; logic [3:0][3:0][31:0] D_tile; always @(*) begin - C_tile = { - C_half[7], C_buffer[7], C_half[5], C_buffer[5], - C_half[6], C_buffer[6], C_half[4], C_buffer[4], - C_half[3], C_buffer[3], C_half[1], C_buffer[1], - C_half[2], C_buffer[2], C_half[0], C_buffer[0] - }; + C_tile[3] = { C_half[7], C_buffer[operands_wid][7], C_half[5], C_buffer[operands_wid][5] }; + C_tile[2] = { C_half[6], C_buffer[operands_wid][6], C_half[4], C_buffer[operands_wid][4] }; + C_tile[1] = { C_half[3], C_buffer[operands_wid][3], C_half[1], C_buffer[operands_wid][1] }; + C_tile[0] = { C_half[2], C_buffer[operands_wid][2], C_half[0], C_buffer[operands_wid][0] }; end - wire do_hmma = (substep == 1'b1 && operands_valid && operands_ready); + // wire do_hmma = operands_fire && (substeps[operands_wid] == 1'b1); + wire do_hmma = operands_fire && operands_last_in_pair; wire dpu_valid; // this does (m,n,k)=(4,4,2) matmul, modeling compute of a single octet