From 9f9ec109604ad6d21c366015d74538cd318c987a Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Wed, 8 May 2024 11:26:09 -0700 Subject: [PATCH] tensor: Enable scaling NUM_THREADS by octets todo: lane-to-octet mapping is arbitrary atm --- hw/rtl/core/VX_tensor_core.sv | 38 +++++++++++---------- hw/rtl/core/VX_tensor_ucode_8lanes.vh | 49 +++++++++++++++++++++++++++ 2 files changed, 70 insertions(+), 17 deletions(-) create mode 100644 hw/rtl/core/VX_tensor_ucode_8lanes.vh diff --git a/hw/rtl/core/VX_tensor_core.sv b/hw/rtl/core/VX_tensor_core.sv index 9971d619..71ed8538 100644 --- a/hw/rtl/core/VX_tensor_core.sv +++ b/hw/rtl/core/VX_tensor_core.sv @@ -10,8 +10,6 @@ module VX_tensor_core #( VX_dispatch_if.slave dispatch_if [`ISSUE_WIDTH], VX_commit_if.master commit_if [`ISSUE_WIDTH] ); - `STATIC_ASSERT(`NUM_THREADS == 32, ("tensor core requires # of threads in a warp to be 32 (try running w/ CONFIGS=\"-DNUM_THREADS=32\")")); - for (genvar i = 0; i < `ISSUE_WIDTH; ++i) begin VX_tensor_core_warp #( .ISW(i) @@ -35,29 +33,35 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #( VX_dispatch_if.slave dispatch_if, VX_commit_if.master commit_if ); + localparam NUM_OCTETS = (`NUM_THREADS / 8); + // offet in the lane numbers that get mapped to the two threadgroups in an + // octet. E.g. two tgs map lane 0-3 and lane 16-19 -> 16 + // FIXME: not sure this is the right logic. just filling in what works + localparam LANE_OFFSET_THREADGROUP = (4 * NUM_OCTETS); + wire [1:0] step = 2'(dispatch_if.data.op_type); - logic [3:0] octet_results_valid; - logic [3:0] octet_results_ready; - logic [3:0] octet_operands_ready; + logic [NUM_OCTETS-1:0] octet_results_valid; + logic [NUM_OCTETS-1:0] octet_results_ready; + logic [NUM_OCTETS-1:0] octet_operands_ready; logic [`NUM_THREADS-1:0][`XLEN-1:0] wb_data_0; logic [`NUM_THREADS-1:0][`XLEN-1:0] wb_data_1; assign dispatch_if.ready = &octet_operands_ready; `ifdef EXT_T_ENABLE - for (genvar i = 0; i < 4/*octets*/; ++i) begin + for (genvar i = 0; i < NUM_OCTETS; ++i) begin `else for (genvar i = 0; i < 0; ++i) begin `endif // lane-to-octet mapping; see figure 13 of the paper wire [7:0][31:0] octet_A = { - dispatch_if.data.rs1_data[16+4*i +: 4], dispatch_if.data.rs1_data[4*i +: 4] + dispatch_if.data.rs1_data[LANE_OFFSET_THREADGROUP + 4*i +: 4], dispatch_if.data.rs1_data[4*i +: 4] }; wire [7:0][31:0] octet_B = { - dispatch_if.data.rs2_data[16+4*i +: 4], dispatch_if.data.rs2_data[4*i +: 4] + dispatch_if.data.rs2_data[LANE_OFFSET_THREADGROUP + 4*i +: 4], dispatch_if.data.rs2_data[4*i +: 4] }; wire [7:0][31:0] octet_C = { - dispatch_if.data.rs3_data[16+4*i +: 4], dispatch_if.data.rs3_data[4*i +: 4] + dispatch_if.data.rs3_data[LANE_OFFSET_THREADGROUP + 4*i +: 4], dispatch_if.data.rs3_data[4*i +: 4] }; logic [3:0][3:0][31:0] octet_D; @@ -100,15 +104,15 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #( assign wb_data_1[4*i+2] = octet_D[0][3]; assign wb_data_1[4*i+3] = octet_D[1][3]; - assign wb_data_0[4*i+16+0] = octet_D[2][0]; - assign wb_data_0[4*i+16+1] = octet_D[3][0]; - assign wb_data_0[4*i+16+2] = octet_D[2][2]; - assign wb_data_0[4*i+16+3] = octet_D[3][2]; + assign wb_data_0[4*i+LANE_OFFSET_THREADGROUP+0] = octet_D[2][0]; + assign wb_data_0[4*i+LANE_OFFSET_THREADGROUP+1] = octet_D[3][0]; + assign wb_data_0[4*i+LANE_OFFSET_THREADGROUP+2] = octet_D[2][2]; + assign wb_data_0[4*i+LANE_OFFSET_THREADGROUP+3] = octet_D[3][2]; - assign wb_data_1[4*i+16+0] = octet_D[2][1]; - assign wb_data_1[4*i+16+1] = octet_D[3][1]; - assign wb_data_1[4*i+16+2] = octet_D[2][3]; - assign wb_data_1[4*i+16+3] = octet_D[3][3]; + assign wb_data_1[4*i+LANE_OFFSET_THREADGROUP+0] = octet_D[2][1]; + assign wb_data_1[4*i+LANE_OFFSET_THREADGROUP+1] = octet_D[3][1]; + assign wb_data_1[4*i+LANE_OFFSET_THREADGROUP+2] = octet_D[2][3]; + assign wb_data_1[4*i+LANE_OFFSET_THREADGROUP+3] = octet_D[3][3]; end /* commit_if.data_t parts that we need to keep around: diff --git a/hw/rtl/core/VX_tensor_ucode_8lanes.vh b/hw/rtl/core/VX_tensor_ucode_8lanes.vh new file mode 100644 index 00000000..41ec857e --- /dev/null +++ b/hw/rtl/core/VX_tensor_ucode_8lanes.vh @@ -0,0 +1,49 @@ +// uop metadata (sequencing, next state), execution metadata (EX_TYPE, OP_TYPE, OP_MOD), wb, use pc, use imm, pc, imm, rd, rs1, rs2, rs3 +HMMA_SET0_STEP0_0: begin + uop = {NEXT, HMMA_SET0_STEP0_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(0), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(16), `FREG(0), `FREG(8), `FREG(16)}; +end +HMMA_SET0_STEP0_1: begin + uop = {NEXT, HMMA_SET0_STEP1_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(0), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(17), `FREG(1), `FREG(9), `FREG(17)}; +end +HMMA_SET0_STEP1_0: begin + uop = {NEXT, HMMA_SET0_STEP1_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(1), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(18), `FREG(0), `FREG(8), `FREG(18)}; +end +HMMA_SET0_STEP1_1: begin + uop = {NEXT, HMMA_SET0_STEP2_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(1), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(19), `FREG(1), `FREG(9), `FREG(19)}; +end +HMMA_SET0_STEP2_0: begin + uop = {NEXT, HMMA_SET0_STEP2_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(2), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(20), `FREG(0), `FREG(8), `FREG(20)}; +end +HMMA_SET0_STEP2_1: begin + uop = {NEXT, HMMA_SET0_STEP3_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(2), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(21), `FREG(1), `FREG(9), `FREG(21)}; +end +HMMA_SET0_STEP3_0: begin + uop = {NEXT, HMMA_SET0_STEP3_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(3), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(22), `FREG(0), `FREG(8), `FREG(22)}; +end +HMMA_SET0_STEP3_1: begin + uop = {NEXT, HMMA_SET1_STEP0_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(3), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(23), `FREG(1), `FREG(9), `FREG(23)}; +end +HMMA_SET1_STEP0_0: begin + uop = {NEXT, HMMA_SET1_STEP0_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(0), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(16), `FREG(2), `FREG(10), `FREG(16)}; +end +HMMA_SET1_STEP0_1: begin + uop = {NEXT, HMMA_SET1_STEP1_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(0), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(17), `FREG(3), `FREG(11), `FREG(17)}; +end +HMMA_SET1_STEP1_0: begin + uop = {NEXT, HMMA_SET1_STEP1_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(1), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(18), `FREG(2), `FREG(10), `FREG(18)}; +end +HMMA_SET1_STEP1_1: begin + uop = {NEXT, HMMA_SET1_STEP2_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(1), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(19), `FREG(3), `FREG(11), `FREG(19)}; +end +HMMA_SET1_STEP2_0: begin + uop = {NEXT, HMMA_SET1_STEP2_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(2), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(20), `FREG(2), `FREG(10), `FREG(20)}; +end +HMMA_SET1_STEP2_1: begin + uop = {NEXT, HMMA_SET1_STEP3_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(2), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(21), `FREG(3), `FREG(11), `FREG(21)}; +end +HMMA_SET1_STEP3_0: begin + uop = {NEXT, HMMA_SET1_STEP3_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(3), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(22), `FREG(2), `FREG(10), `FREG(22)}; +end +HMMA_SET1_STEP3_1: begin + uop = {FINISH, HMMA_SET0_STEP0_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(3), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b1, 32'b1, `FREG(23), `FREG(3), `FREG(11), `FREG(23)}; +end