From d4d18c28231ee43e118e1bb5a361938b53a2cfcb Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Sat, 27 Jul 2024 20:53:56 -0700 Subject: [PATCH] tensor: spurious assert, doc, remove unused param --- hw/rtl/core/VX_tensor_core.sv | 25 +++++++++++++++---------- hw/rtl/fpu/VX_tensor_dpu.sv | 11 +++-------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/hw/rtl/core/VX_tensor_core.sv b/hw/rtl/core/VX_tensor_core.sv index a1f4c937..d52150ac 100644 --- a/hw/rtl/core/VX_tensor_core.sv +++ b/hw/rtl/core/VX_tensor_core.sv @@ -324,9 +324,6 @@ endmodule module VX_tensor_octet #( parameter ISW, parameter OCTET, - // RESULT_BUFFER_DEPTH = 2 gives good performance by absorbing commit - // backpressure (result_ready), although the value is arbitrary. - // RESULT_BUFFER_DEPTH = 0 eliminates result buffering. parameter RESULT_BUFFER_DEPTH = 2 ) ( input clk, @@ -385,8 +382,11 @@ module VX_tensor_octet #( assign operands_ready = operands_ready_buf; typedef struct { + // single column of A logic [3:0][31:0] A_half; + // single row of B logic [3:0][31:0] B_half; + // interleaved elements of C logic [7:0][31:0] C_half; } half_t; @@ -477,18 +477,20 @@ module VX_tensor_octet #( wire hmma_ready; assign operands_ready_buf = hmma_ready; - // A is 4x2 fp32 matrix + // all *_tiles below are row-major + // A is a 4x2 fp32 matrix wire [3:0][1:0][31:0] A_tile = { { halves_buf.A_half[3], A_buffer[operands_wid_buf][3] }, { halves_buf.A_half[2], A_buffer[operands_wid_buf][2] }, { halves_buf.A_half[1], A_buffer[operands_wid_buf][1] }, { halves_buf.A_half[0], A_buffer[operands_wid_buf][0] } }; - // B is 2x4 fp32 matrix + // B is a 2x4 fp32 matrix wire [1:0][3:0][31:0] B_tile = { - halves_buf.B_half, B_buffer[operands_wid_buf] + halves_buf.B_half, + B_buffer[operands_wid_buf] }; - // C is 4x4 fp32 matrix + // C is a 4x4 fp32 matrix logic [3:0][3:0][31:0] C_tile; wire [3:0][3:0][31:0] D_tile; wire [`NW_WIDTH-1:0] D_wid_dpu; @@ -538,7 +540,10 @@ module VX_tensor_octet #( // commit/writeback is complete. This decouples the irregular dpu // output traffic from the regular, every-2-cycle commit traffic to // ensure the commit pipeline is used more efficiently. - // FIXME: unnecessary? + // + // @perf: RESULT_BUFFER_DEPTH == 2 gives good performance by + // completely dampening commit backpressure (result_ready). + // RESULT_BUFFER_DEPTH = 0 removes the fifo queue altogether. VX_fifo_queue #( .DATAW ($bits(D_wid) + $bits(D_out)), .DEPTH (RESULT_BUFFER_DEPTH) // 2 works good @@ -556,8 +561,8 @@ module VX_tensor_octet #( `UNUSED_PIN(size) ); - // FIXME: overly strict; this firing doesn't mean a bug - `RUNTIME_ASSERT(reset || !outbuf_full, ("dpu result queue is full!")) + // for perf debug + // `RUNTIME_ASSERT(reset || !outbuf_full, ("dpu result queue is full!")) end else begin // XXX: this depends on the assumption that commit stage only asserts // result_ready when result_valid is true diff --git a/hw/rtl/fpu/VX_tensor_dpu.sv b/hw/rtl/fpu/VX_tensor_dpu.sv index aabe0105..c196ffec 100644 --- a/hw/rtl/fpu/VX_tensor_dpu.sv +++ b/hw/rtl/fpu/VX_tensor_dpu.sv @@ -60,12 +60,11 @@ module VX_tensor_dpu #( wire empty; wire full; - // sync between operand buffer and wid buffer + // sync operand buffer and wid buffer assign ready_in = !full && !wid_full; wire [1:0] threadgroup_valids_out; wire [1:0] threadgroup_readys_in; - // sync operand queue and wid queue wire threadgroup_valid_in = !empty; wire threadgroup_fire_in = threadgroup_valid_in && &(threadgroup_readys_in); @@ -98,11 +97,9 @@ module VX_tensor_dpu #( ); // Split A_tile and C_tile by rows (0-1, 2-3) and parallelize in two - // threadgroups - // - // B_tile is shared across the two threadgroups; see Figure 13 + // threadgroups; B_tile is shared across the two threadgroups. See Figure + // 13 in paper VX_tensor_threadgroup #( - .OPERAND_BUFFER_DEPTH(OPERAND_BUFFER_DEPTH) ) threadgroup_0 ( .clk (clk), .reset (reset), @@ -116,7 +113,6 @@ module VX_tensor_dpu #( .D_frag (D_tile[1:0]) ); VX_tensor_threadgroup #( - .OPERAND_BUFFER_DEPTH(OPERAND_BUFFER_DEPTH) ) threadgroup_1 ( .clk (clk), .reset (reset), @@ -167,7 +163,6 @@ endmodule // does (m,n,k) = (2,4,2) matmul compute over 2 cycles. // see Figure 10(b) of the paper. module VX_tensor_threadgroup #( - parameter OPERAND_BUFFER_DEPTH ) ( input clk, input reset,