bleh still not work
This commit is contained in:
@@ -52,6 +52,7 @@ module VX_commit import VX_gpu_pkg::*; #(
|
|||||||
wire [`ISSUE_WIDTH-1:0][`NW_WIDTH-1:0] commit_wid;
|
wire [`ISSUE_WIDTH-1:0][`NW_WIDTH-1:0] commit_wid;
|
||||||
wire [`ISSUE_WIDTH-1:0][`NUM_THREADS-1:0] commit_tmask;
|
wire [`ISSUE_WIDTH-1:0][`NUM_THREADS-1:0] commit_tmask;
|
||||||
wire [`ISSUE_WIDTH-1:0] commit_eop;
|
wire [`ISSUE_WIDTH-1:0] commit_eop;
|
||||||
|
wire [`ISSUE_WIDTH-1:0][`EX_BITS-1:0] commit_sel;
|
||||||
|
|
||||||
for (genvar i = 0; i < `ISSUE_WIDTH; ++i) begin
|
for (genvar i = 0; i < `ISSUE_WIDTH; ++i) begin
|
||||||
|
|
||||||
@@ -101,7 +102,7 @@ module VX_commit import VX_gpu_pkg::*; #(
|
|||||||
.data_out (commit_if[i].data),
|
.data_out (commit_if[i].data),
|
||||||
.valid_out (commit_if[i].valid),
|
.valid_out (commit_if[i].valid),
|
||||||
.ready_out (commit_if[i].ready),
|
.ready_out (commit_if[i].ready),
|
||||||
`UNUSED_PIN (sel_out)
|
.sel_out (commit_sel[i])
|
||||||
);
|
);
|
||||||
|
|
||||||
assign commit_fire[i] = commit_if[i].valid && commit_if[i].ready;
|
assign commit_fire[i] = commit_if[i].valid && commit_if[i].ready;
|
||||||
@@ -171,10 +172,24 @@ module VX_commit import VX_gpu_pkg::*; #(
|
|||||||
// Committed instructions
|
// Committed instructions
|
||||||
|
|
||||||
// temporary hack to not underflow the pending instructions buffer
|
// temporary hack to not underflow the pending instructions buffer
|
||||||
|
// relies on 1 cycle delay of arbiter and continuous issuing of tensor instructions,
|
||||||
|
// so probably want to change this at some point
|
||||||
|
// (i.e. pass a "don't count this towards pending instructions" signal down the pipeline)
|
||||||
|
logic [`ISSUE_WIDTH-1:0][4:0] hmma_ctr, hmma_ctr_n;
|
||||||
wire [`ISSUE_WIDTH-1:0] final_hmma;
|
wire [`ISSUE_WIDTH-1:0] final_hmma;
|
||||||
`ifdef EXT_T_ENABLE
|
`ifdef EXT_T_ENABLE
|
||||||
for (genvar i = 0; i < `ISSUE_WIDTH; ++i) begin
|
for (genvar i = 0; i < `ISSUE_WIDTH; ++i) begin
|
||||||
assign final_hmma[i] = ~(tensor_commit_if[i].ready && tensor_commit_if[i].valid) || (tensor_commit_if[i].data.rd == `NR_BITS'(32 + 23));
|
assign hmma_ctr_n[i] = (tensor_commit_if[i].valid && tensor_commit_if[i].ready) ? hmma_ctr[i] + 5'b1 : hmma_ctr[i];
|
||||||
|
assign final_hmma[i] = (commit_sel[i] != `EX_BITS'(2) || hmma_ctr == '0);
|
||||||
|
end
|
||||||
|
|
||||||
|
always @(posedge clk) begin
|
||||||
|
if (reset) begin
|
||||||
|
hmma_ctr <= '0;
|
||||||
|
end
|
||||||
|
else begin
|
||||||
|
hmma_ctr <= hmma_ctr_n;
|
||||||
|
end
|
||||||
end
|
end
|
||||||
`else
|
`else
|
||||||
assign final_hmma = '1;
|
assign final_hmma = '1;
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ module VX_ibuffer import VX_gpu_pkg::*; #(
|
|||||||
|
|
||||||
assign decode_if.ready = ibuf_ready_in[decode_isw];
|
assign decode_if.ready = ibuf_ready_in[decode_isw];
|
||||||
|
|
||||||
VX_ibuffer_if uop_sequencer_if [`ISSUE_WIDTH];
|
VX_ibuffer_if uop_sequencer_if [`ISSUE_WIDTH]();
|
||||||
|
|
||||||
for (genvar i = 0; i < `ISSUE_WIDTH; ++i) begin
|
for (genvar i = 0; i < `ISSUE_WIDTH; ++i) begin
|
||||||
VX_elastic_buffer #(
|
VX_elastic_buffer #(
|
||||||
|
|||||||
@@ -288,6 +288,27 @@ module VX_operands import VX_gpu_pkg::*; #(
|
|||||||
.raddr (gpr_rd_addr),
|
.raddr (gpr_rd_addr),
|
||||||
.rdata (gpr_rd_data[j])
|
.rdata (gpr_rd_data[j])
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// blast read register file because printf is slowge
|
||||||
|
logic [31:0] cycle, cycle_n;
|
||||||
|
assign cycle_n = cycle + 32'd1;
|
||||||
|
always @(posedge clk) begin
|
||||||
|
if (reset) begin
|
||||||
|
cycle <= '0;
|
||||||
|
end
|
||||||
|
else begin
|
||||||
|
cycle <= cycle_n;
|
||||||
|
end
|
||||||
|
|
||||||
|
if (cycle == 32'd25000) begin
|
||||||
|
for (integer k = 0; k < `NUM_REGS * ISSUE_RATIO; ++k) begin
|
||||||
|
integer warp = i * ISSUE_RATIO + (k / `NUM_REGS);
|
||||||
|
integer thread = j;
|
||||||
|
integer register = k % `NUM_REGS;
|
||||||
|
$display("warp %0d, thread %0d, register %0d: %0x", warp, thread, register, gpr_ram.ram[k]);
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ module VX_tensor_core #(
|
|||||||
VX_dispatch_if.slave dispatch_if [`ISSUE_WIDTH],
|
VX_dispatch_if.slave dispatch_if [`ISSUE_WIDTH],
|
||||||
VX_commit_if.master commit_if [`ISSUE_WIDTH]
|
VX_commit_if.master commit_if [`ISSUE_WIDTH]
|
||||||
);
|
);
|
||||||
`STATIC_ASSERT(`NUM_THREADS == 32, ("tensor core requires # of threads in a warp to be 32"));
|
`STATIC_ASSERT(`NUM_THREADS == 32, ("tensor core requires # of threads in a warp to be 32 (try running w/ CONFIGS=\"-DNUM_THREADS=32\")"));
|
||||||
|
|
||||||
for (genvar i = 0; i < `ISSUE_WIDTH; ++i) begin
|
for (genvar i = 0; i < `ISSUE_WIDTH; ++i) begin
|
||||||
VX_tensor_core_warp #(
|
VX_tensor_core_warp #(
|
||||||
@@ -34,20 +34,20 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
|
|||||||
VX_dispatch_if.slave dispatch_if,
|
VX_dispatch_if.slave dispatch_if,
|
||||||
VX_commit_if.master commit_if
|
VX_commit_if.master commit_if
|
||||||
);
|
);
|
||||||
logic [1:0] step = 2'(dispatch_if.data.op_type);
|
wire [1:0] step = 2'(dispatch_if.data.op_type);
|
||||||
logic [3:0] octet_results_valid;
|
logic [3:0] octet_results_valid;
|
||||||
logic [3:0] octet_results_ready;
|
logic [3:0] octet_results_ready;
|
||||||
logic [`NUM_THREADS-1:0][`XLEN-1:0] wb_data_0;
|
logic [`NUM_THREADS-1:0][`XLEN-1:0] wb_data_0;
|
||||||
logic [`NUM_THREADS-1:0][`XLEN-1:0] wb_data_1;
|
logic [`NUM_THREADS-1:0][`XLEN-1:0] wb_data_1;
|
||||||
|
|
||||||
for (genvar i = 0; i < 4; ++i) begin
|
for (genvar i = 0; i < 4; ++i) begin
|
||||||
logic [7:0][31:0] octet_A = {
|
wire [7:0][31:0] octet_A = {
|
||||||
dispatch_if.data.rs1_data[4*i +: 4], dispatch_if.data.rs1_data[16+4*i +: 4]
|
dispatch_if.data.rs1_data[4*i +: 4], dispatch_if.data.rs1_data[16+4*i +: 4]
|
||||||
};
|
};
|
||||||
logic [7:0][31:0] octet_B = {
|
wire [7:0][31:0] octet_B = {
|
||||||
dispatch_if.data.rs2_data[4*i +: 4], dispatch_if.data.rs2_data[16+4*i +: 4]
|
dispatch_if.data.rs2_data[4*i +: 4], dispatch_if.data.rs2_data[16+4*i +: 4]
|
||||||
};
|
};
|
||||||
logic [7:0][31:0] octet_C = {
|
wire [7:0][31:0] octet_C = {
|
||||||
dispatch_if.data.rs3_data[4*i +: 4], dispatch_if.data.rs3_data[16+4*i +: 4]
|
dispatch_if.data.rs3_data[4*i +: 4], dispatch_if.data.rs3_data[16+4*i +: 4]
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -141,11 +141,11 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
|
|||||||
);
|
);
|
||||||
|
|
||||||
logic subcommit, subcommit_n;
|
logic subcommit, subcommit_n;
|
||||||
logic all_valid = (& octet_results_valid);
|
wire all_valid = (& octet_results_valid);
|
||||||
assign commit_if.valid = all_valid;
|
assign commit_if.valid = all_valid;
|
||||||
|
|
||||||
localparam COMMIT_DATAW = `UUID_WIDTH + `NW_WIDTH + `NUM_THREADS + `XLEN + 1 + `NR_BITS + (`NUM_THREADS * `XLEN) + 1 + 1 + 1;
|
localparam COMMIT_DATAW = `UUID_WIDTH + `NW_WIDTH + `NUM_THREADS + `XLEN + 1 + `NR_BITS + (`NUM_THREADS * `XLEN) + 1 + 1 + 1;
|
||||||
logic [COMMIT_DATAW-1:0] commit_if_data = {
|
wire [COMMIT_DATAW-1:0] commit_if_data = {
|
||||||
dispatch_if_data_deq,
|
dispatch_if_data_deq,
|
||||||
subcommit == 1'b0 ? wb_data_0 : wb_data_1,
|
subcommit == 1'b0 ? wb_data_0 : wb_data_1,
|
||||||
1'b0,
|
1'b0,
|
||||||
@@ -227,7 +227,7 @@ module VX_tensor_octet #(
|
|||||||
end
|
end
|
||||||
|
|
||||||
logic substep;
|
logic substep;
|
||||||
logic substep_n = (operands_ready && operands_valid) ? ~substep : substep;
|
wire substep_n = (operands_ready && operands_valid) ? ~substep : substep;
|
||||||
|
|
||||||
always @(*) begin
|
always @(*) begin
|
||||||
A_buffer_n = A_buffer;
|
A_buffer_n = A_buffer;
|
||||||
@@ -260,13 +260,13 @@ module VX_tensor_octet #(
|
|||||||
wire stall = result_valid && ~result_ready;
|
wire stall = result_valid && ~result_ready;
|
||||||
assign operands_ready = ~stall;
|
assign operands_ready = ~stall;
|
||||||
|
|
||||||
logic [3:0][1:0][31:0] A_tile = {
|
wire [3:0][1:0][31:0] A_tile = {
|
||||||
{ A_buffer[0], A_half[0] },
|
{ A_buffer[0], A_half[0] },
|
||||||
{ A_buffer[1], A_half[1] },
|
{ A_buffer[1], A_half[1] },
|
||||||
{ A_buffer[2], A_half[2] },
|
{ A_buffer[2], A_half[2] },
|
||||||
{ A_buffer[3], A_half[3] }
|
{ A_buffer[3], A_half[3] }
|
||||||
};
|
};
|
||||||
logic [1:0][3:0][31:0] B_tile = {
|
wire [1:0][3:0][31:0] B_tile = {
|
||||||
B_buffer, B_half
|
B_buffer, B_half
|
||||||
};
|
};
|
||||||
logic [3:0][3:0][31:0] C_tile;
|
logic [3:0][3:0][31:0] C_tile;
|
||||||
|
|||||||
96
hw/rtl/core/VX_tensor_ucode.vh
Normal file
96
hw/rtl/core/VX_tensor_ucode.vh
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
HMMA_SET0_STEP0_0: begin
|
||||||
|
uop = {NEXT, HMMA_SET0_STEP0_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(0), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(16), `FREG(0), `FREG(8), `FREG(16)};
|
||||||
|
end
|
||||||
|
HMMA_SET0_STEP0_1: begin
|
||||||
|
uop = {NEXT, HMMA_SET0_STEP1_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(0), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(17), `FREG(1), `FREG(9), `FREG(17)};
|
||||||
|
end
|
||||||
|
HMMA_SET0_STEP1_0: begin
|
||||||
|
uop = {NEXT, HMMA_SET0_STEP1_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(1), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(18), `FREG(0), `FREG(8), `FREG(18)};
|
||||||
|
end
|
||||||
|
HMMA_SET0_STEP1_1: begin
|
||||||
|
uop = {NEXT, HMMA_SET0_STEP2_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(1), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(19), `FREG(1), `FREG(9), `FREG(19)};
|
||||||
|
end
|
||||||
|
HMMA_SET0_STEP2_0: begin
|
||||||
|
uop = {NEXT, HMMA_SET0_STEP2_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(2), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(20), `FREG(0), `FREG(8), `FREG(20)};
|
||||||
|
end
|
||||||
|
HMMA_SET0_STEP2_1: begin
|
||||||
|
uop = {NEXT, HMMA_SET0_STEP3_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(2), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(21), `FREG(1), `FREG(9), `FREG(21)};
|
||||||
|
end
|
||||||
|
HMMA_SET0_STEP3_0: begin
|
||||||
|
uop = {NEXT, HMMA_SET0_STEP3_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(3), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(22), `FREG(0), `FREG(8), `FREG(22)};
|
||||||
|
end
|
||||||
|
HMMA_SET0_STEP3_1: begin
|
||||||
|
uop = {NEXT, HMMA_SET1_STEP0_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(3), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(23), `FREG(1), `FREG(9), `FREG(23)};
|
||||||
|
end
|
||||||
|
HMMA_SET1_STEP0_0: begin
|
||||||
|
uop = {NEXT, HMMA_SET1_STEP0_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(0), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(16), `FREG(2), `FREG(10), `FREG(16)};
|
||||||
|
end
|
||||||
|
HMMA_SET1_STEP0_1: begin
|
||||||
|
uop = {NEXT, HMMA_SET1_STEP1_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(0), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(17), `FREG(3), `FREG(11), `FREG(17)};
|
||||||
|
end
|
||||||
|
HMMA_SET1_STEP1_0: begin
|
||||||
|
uop = {NEXT, HMMA_SET1_STEP1_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(1), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(18), `FREG(2), `FREG(10), `FREG(18)};
|
||||||
|
end
|
||||||
|
HMMA_SET1_STEP1_1: begin
|
||||||
|
uop = {NEXT, HMMA_SET1_STEP2_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(1), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(19), `FREG(3), `FREG(11), `FREG(19)};
|
||||||
|
end
|
||||||
|
HMMA_SET1_STEP2_0: begin
|
||||||
|
uop = {NEXT, HMMA_SET1_STEP2_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(2), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(20), `FREG(2), `FREG(10), `FREG(20)};
|
||||||
|
end
|
||||||
|
HMMA_SET1_STEP2_1: begin
|
||||||
|
uop = {NEXT, HMMA_SET1_STEP3_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(2), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(21), `FREG(3), `FREG(11), `FREG(21)};
|
||||||
|
end
|
||||||
|
HMMA_SET1_STEP3_0: begin
|
||||||
|
uop = {NEXT, HMMA_SET1_STEP3_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(3), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(22), `FREG(2), `FREG(10), `FREG(22)};
|
||||||
|
end
|
||||||
|
HMMA_SET1_STEP3_1: begin
|
||||||
|
uop = {NEXT, HMMA_SET2_STEP0_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(3), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(23), `FREG(3), `FREG(11), `FREG(23)};
|
||||||
|
end
|
||||||
|
HMMA_SET2_STEP0_0: begin
|
||||||
|
uop = {NEXT, HMMA_SET2_STEP0_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(0), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(16), `FREG(4), `FREG(12), `FREG(16)};
|
||||||
|
end
|
||||||
|
HMMA_SET2_STEP0_1: begin
|
||||||
|
uop = {NEXT, HMMA_SET2_STEP1_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(0), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(17), `FREG(5), `FREG(13), `FREG(17)};
|
||||||
|
end
|
||||||
|
HMMA_SET2_STEP1_0: begin
|
||||||
|
uop = {NEXT, HMMA_SET2_STEP1_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(1), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(18), `FREG(4), `FREG(12), `FREG(18)};
|
||||||
|
end
|
||||||
|
HMMA_SET2_STEP1_1: begin
|
||||||
|
uop = {NEXT, HMMA_SET2_STEP2_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(1), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(19), `FREG(5), `FREG(13), `FREG(19)};
|
||||||
|
end
|
||||||
|
HMMA_SET2_STEP2_0: begin
|
||||||
|
uop = {NEXT, HMMA_SET2_STEP2_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(2), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(20), `FREG(4), `FREG(12), `FREG(20)};
|
||||||
|
end
|
||||||
|
HMMA_SET2_STEP2_1: begin
|
||||||
|
uop = {NEXT, HMMA_SET2_STEP3_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(2), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(21), `FREG(5), `FREG(13), `FREG(21)};
|
||||||
|
end
|
||||||
|
HMMA_SET2_STEP3_0: begin
|
||||||
|
uop = {NEXT, HMMA_SET2_STEP3_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(3), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(22), `FREG(4), `FREG(12), `FREG(22)};
|
||||||
|
end
|
||||||
|
HMMA_SET2_STEP3_1: begin
|
||||||
|
uop = {NEXT, HMMA_SET3_STEP0_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(3), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(23), `FREG(5), `FREG(13), `FREG(23)};
|
||||||
|
end
|
||||||
|
HMMA_SET3_STEP0_0: begin
|
||||||
|
uop = {NEXT, HMMA_SET3_STEP0_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(0), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(16), `FREG(6), `FREG(14), `FREG(16)};
|
||||||
|
end
|
||||||
|
HMMA_SET3_STEP0_1: begin
|
||||||
|
uop = {NEXT, HMMA_SET3_STEP1_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(0), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(17), `FREG(7), `FREG(15), `FREG(17)};
|
||||||
|
end
|
||||||
|
HMMA_SET3_STEP1_0: begin
|
||||||
|
uop = {NEXT, HMMA_SET3_STEP1_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(1), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(18), `FREG(6), `FREG(14), `FREG(18)};
|
||||||
|
end
|
||||||
|
HMMA_SET3_STEP1_1: begin
|
||||||
|
uop = {NEXT, HMMA_SET3_STEP2_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(1), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(19), `FREG(7), `FREG(15), `FREG(19)};
|
||||||
|
end
|
||||||
|
HMMA_SET3_STEP2_0: begin
|
||||||
|
uop = {NEXT, HMMA_SET3_STEP2_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(2), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(20), `FREG(6), `FREG(14), `FREG(20)};
|
||||||
|
end
|
||||||
|
HMMA_SET3_STEP2_1: begin
|
||||||
|
uop = {NEXT, HMMA_SET3_STEP3_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(2), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(21), `FREG(7), `FREG(15), `FREG(21)};
|
||||||
|
end
|
||||||
|
HMMA_SET3_STEP3_0: begin
|
||||||
|
uop = {NEXT, HMMA_SET3_STEP3_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(3), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(22), `FREG(6), `FREG(14), `FREG(22)};
|
||||||
|
end
|
||||||
|
HMMA_SET3_STEP3_1: begin
|
||||||
|
uop = {FINISH, HMMA_SET0_STEP0_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(3), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(23), `FREG(7), `FREG(15), `FREG(23)};
|
||||||
|
end
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
`include "VX_define.vh"
|
`include "VX_define.vh"
|
||||||
|
|
||||||
`define FREG(x) {1'b1, `NRI_BITS'(`CLOG2(x))}
|
`define FREG(x) {1'b1, `NRI_BITS'(x)}
|
||||||
|
|
||||||
module VX_uop_sequencer import VX_gpu_pkg::*; (
|
module VX_uop_sequencer import VX_gpu_pkg::*; (
|
||||||
input clk,
|
input clk,
|
||||||
@@ -28,7 +28,6 @@ module VX_uop_sequencer import VX_gpu_pkg::*; (
|
|||||||
// reserve space at start of table for more uop sequences
|
// reserve space at start of table for more uop sequences
|
||||||
localparam HMMA_SET0_STEP0_0 = UPC_BITS'(0);
|
localparam HMMA_SET0_STEP0_0 = UPC_BITS'(0);
|
||||||
localparam HMMA_SET0_STEP0_1 = UPC_BITS'(8);
|
localparam HMMA_SET0_STEP0_1 = UPC_BITS'(8);
|
||||||
/*
|
|
||||||
localparam HMMA_SET0_STEP1_0 = UPC_BITS'(9);
|
localparam HMMA_SET0_STEP1_0 = UPC_BITS'(9);
|
||||||
localparam HMMA_SET0_STEP1_1 = UPC_BITS'(10);
|
localparam HMMA_SET0_STEP1_1 = UPC_BITS'(10);
|
||||||
localparam HMMA_SET0_STEP2_0 = UPC_BITS'(11);
|
localparam HMMA_SET0_STEP2_0 = UPC_BITS'(11);
|
||||||
@@ -62,49 +61,11 @@ module VX_uop_sequencer import VX_gpu_pkg::*; (
|
|||||||
localparam HMMA_SET3_STEP2_1 = UPC_BITS'(36);
|
localparam HMMA_SET3_STEP2_1 = UPC_BITS'(36);
|
||||||
localparam HMMA_SET3_STEP3_0 = UPC_BITS'(37);
|
localparam HMMA_SET3_STEP3_0 = UPC_BITS'(37);
|
||||||
localparam HMMA_SET3_STEP3_1 = UPC_BITS'(38);
|
localparam HMMA_SET3_STEP3_1 = UPC_BITS'(38);
|
||||||
*/
|
|
||||||
// register layout: f0-f7 used for A, f8-f15 used for B, f16-f23 used for C
|
// register layout: f0-f7 used for A, f8-f15 used for B, f16-f23 used for C
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
always @(*) begin
|
always @(*) begin
|
||||||
case (upc)
|
case (upc)
|
||||||
HMMA_SET0_STEP0_0: begin
|
`include "VX_tensor_ucode.vh"
|
||||||
uop = {
|
|
||||||
NEXT,
|
|
||||||
HMMA_SET0_STEP0_1,
|
|
||||||
`EX_BITS'(`EX_TENSOR),
|
|
||||||
`INST_OP_BITS'(0), // denotes that the first step is being computed
|
|
||||||
`INST_MOD_BITS'(0), // denotes that this is first substep (tensor core also tracks this)
|
|
||||||
1'b1, // write back
|
|
||||||
1'b0, // don't use PC
|
|
||||||
1'b0, // don't use immediate
|
|
||||||
32'b0, // PC is unused - TODO: don't send a bogus PC down the pipeline as it is very confusing in trace
|
|
||||||
32'b0, // immediate is unused
|
|
||||||
`FREG(16), // rd=f16
|
|
||||||
`FREG(0), // rs1=f0,
|
|
||||||
`FREG(8), // rs2=f8
|
|
||||||
`FREG(16) // rs3=f16
|
|
||||||
};
|
|
||||||
end
|
|
||||||
HMMA_SET0_STEP0_1: begin
|
|
||||||
uop = {
|
|
||||||
FINISH,
|
|
||||||
HMMA_SET0_STEP0_0,
|
|
||||||
`EX_BITS'(`EX_TENSOR),
|
|
||||||
`INST_OP_BITS'(0), // denotes that the first step is being computed
|
|
||||||
`INST_MOD_BITS'(1), // denotes that this is first substep (tensor core also tracks this)
|
|
||||||
1'b1, // write back
|
|
||||||
1'b0, // don't use PC
|
|
||||||
1'b0, // don't use immediate
|
|
||||||
32'b0, // PC is unused - TODO: don't send a bogus PC down the pipeline as it is very confusing in trace
|
|
||||||
32'b0, // immediate is unused
|
|
||||||
`FREG(17), // rd=f17
|
|
||||||
`FREG(1), // rs1=f1,
|
|
||||||
`FREG(9), // rs2=f9
|
|
||||||
`FREG(17) // rs3=f17
|
|
||||||
};
|
|
||||||
end
|
|
||||||
default: begin
|
default: begin
|
||||||
uop = '0;
|
uop = '0;
|
||||||
end
|
end
|
||||||
@@ -113,13 +74,15 @@ module VX_uop_sequencer import VX_gpu_pkg::*; (
|
|||||||
|
|
||||||
logic [UPC_BITS-1:0] upc, upc_r, upc_n;
|
logic [UPC_BITS-1:0] upc, upc_r, upc_n;
|
||||||
|
|
||||||
logic [UBR_BITS-1:0] ubr = uop[UOP_TABLE_WIDTH-1:UOP_TABLE_WIDTH-UBR_BITS];
|
wire [UBR_BITS-1:0] ubr = uop[UOP_TABLE_WIDTH-1:UOP_TABLE_WIDTH-UBR_BITS];
|
||||||
logic [UPC_BITS-1:0] next_upc = uop[UOP_TABLE_WIDTH-UBR_BITS-1:UOP_TABLE_WIDTH-UBR_BITS-UPC_BITS];
|
wire [UPC_BITS-1:0] next_upc = uop[UOP_TABLE_WIDTH-UBR_BITS-1:UOP_TABLE_WIDTH-UBR_BITS-UPC_BITS];
|
||||||
|
|
||||||
logic uop_fire = use_uop && ibuffer_if.valid && ibuffer_if.ready;
|
|
||||||
logic uop_start = ~use_uop_1d && use_uop;
|
|
||||||
logic uop_finish = use_uop && uop_sequencer_if.valid && uop_sequencer_if.ready;
|
|
||||||
logic use_uop, use_uop_1d;
|
logic use_uop, use_uop_1d;
|
||||||
|
wire uop_fire = use_uop && ibuffer_if.valid && ibuffer_if.ready;
|
||||||
|
|
||||||
|
wire uop_start = ~use_uop_1d && use_uop;
|
||||||
|
wire uop_finish = use_uop && uop_sequencer_if.valid && uop_sequencer_if.ready;
|
||||||
|
|
||||||
|
|
||||||
// merging the 2 always blocks leads to spurious UNOPTFLAT verilator lint, but conceptually they should be linked
|
// merging the 2 always blocks leads to spurious UNOPTFLAT verilator lint, but conceptually they should be linked
|
||||||
always @(*) begin
|
always @(*) begin
|
||||||
@@ -149,7 +112,7 @@ module VX_uop_sequencer import VX_gpu_pkg::*; (
|
|||||||
end
|
end
|
||||||
|
|
||||||
// copy UUID, wis, tmask from microcoded instruction
|
// copy UUID, wis, tmask from microcoded instruction
|
||||||
logic [IBUFFER_IF_DATAW-1:0] ibuffer_output = {
|
wire [IBUFFER_IF_DATAW-1:0] ibuffer_output = {
|
||||||
uop_sequencer_if.data.uuid,
|
uop_sequencer_if.data.uuid,
|
||||||
uop_sequencer_if.data.wis,
|
uop_sequencer_if.data.wis,
|
||||||
uop_sequencer_if.data.tmask,
|
uop_sequencer_if.data.tmask,
|
||||||
@@ -161,11 +124,18 @@ module VX_uop_sequencer import VX_gpu_pkg::*; (
|
|||||||
assign ibuffer_if.data = use_uop ? ibuffer_output : uop_sequencer_if.data;
|
assign ibuffer_if.data = use_uop ? ibuffer_output : uop_sequencer_if.data;
|
||||||
|
|
||||||
always @(posedge clk) begin
|
always @(posedge clk) begin
|
||||||
|
if (uop_start) begin
|
||||||
if (use_uop) begin
|
$display("UOP start @ %t", $time);
|
||||||
$display("unexpectedly used uop at %d", $time);
|
$display("use_uop=%0d, use_uop_1d=%0d, uop_start=%0d, ibuffer_if.valid=%0d, ibuffer_if.ready=%0d", use_uop, use_uop_1d, uop_start, ibuffer_if.valid, ibuffer_if.ready);
|
||||||
end
|
end
|
||||||
|
|
||||||
|
if (uop_fire) begin
|
||||||
|
$display("UOP fire @ %t", $time);
|
||||||
|
end
|
||||||
|
|
||||||
|
if (uop_finish) begin
|
||||||
|
$display("UOP finish @ %t", $time);
|
||||||
|
end
|
||||||
|
|
||||||
if (reset) begin
|
if (reset) begin
|
||||||
upc_r <= '0;
|
upc_r <= '0;
|
||||||
|
|||||||
81
hw/rtl/core/generate_ucode.py
Normal file
81
hw/rtl/core/generate_ucode.py
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
num_sets = 4
|
||||||
|
num_steps = 4
|
||||||
|
num_substeps = 2
|
||||||
|
|
||||||
|
|
||||||
|
def set_step_substep(sequence_number):
|
||||||
|
set_num, step = divmod(sequence_number, num_steps * num_substeps)
|
||||||
|
step //= num_substeps
|
||||||
|
substep = sequence_number % 2
|
||||||
|
|
||||||
|
return set_num, step, substep
|
||||||
|
|
||||||
|
# set + substep -> rs1, rs2
|
||||||
|
rs1 = {
|
||||||
|
(0, 0): 0,
|
||||||
|
(0, 1): 1,
|
||||||
|
(1, 0): 2,
|
||||||
|
(1, 1): 3,
|
||||||
|
(2, 0): 4,
|
||||||
|
(2, 1): 5,
|
||||||
|
(3, 0): 6,
|
||||||
|
(3, 1): 7
|
||||||
|
}
|
||||||
|
|
||||||
|
rs2 = {
|
||||||
|
(0, 0): 8,
|
||||||
|
(0, 1): 9,
|
||||||
|
(1, 0): 10,
|
||||||
|
(1, 1): 11,
|
||||||
|
(2, 0): 12,
|
||||||
|
(2, 1): 13,
|
||||||
|
(3, 0): 14,
|
||||||
|
(3, 1): 15
|
||||||
|
}
|
||||||
|
|
||||||
|
# step + substep -> rs3, rd
|
||||||
|
rs3_rd = {
|
||||||
|
(0, 0): 16,
|
||||||
|
(0, 1): 17,
|
||||||
|
(1, 0): 18,
|
||||||
|
(1, 1): 19,
|
||||||
|
(2, 0): 20,
|
||||||
|
(2, 1): 21,
|
||||||
|
(3, 0): 22,
|
||||||
|
(3, 1): 23
|
||||||
|
}
|
||||||
|
|
||||||
|
with open('VX_tensor_ucode.vh', 'w') as f:
|
||||||
|
for sequence_number in range(num_sets * num_steps * num_substeps):
|
||||||
|
set_num, step, substep = set_step_substep(sequence_number)
|
||||||
|
|
||||||
|
|
||||||
|
next_sequence_num = (sequence_number + 1) % (num_sets * num_steps * num_substeps)
|
||||||
|
next_set_num, next_step, next_substep = set_step_substep(next_sequence_num)
|
||||||
|
finish = (next_sequence_num == 0)
|
||||||
|
|
||||||
|
name = "HMMA_SET{}_STEP{}_{}"
|
||||||
|
ucode = "{}, HMMA_SET{}_STEP{}_{}, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'({}), `INST_MOD_BITS'({}), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG({}), `FREG({}), `FREG({}), `FREG({})"
|
||||||
|
|
||||||
|
name = name.format(
|
||||||
|
set_num, step, substep,
|
||||||
|
)
|
||||||
|
|
||||||
|
ucode = ucode.format(
|
||||||
|
"FINISH" if finish else "NEXT",
|
||||||
|
next_set_num, next_step, next_substep,
|
||||||
|
step,
|
||||||
|
substep,
|
||||||
|
rs3_rd[(step, substep)],
|
||||||
|
rs1[(set_num, substep)],
|
||||||
|
rs2[(set_num, substep)],
|
||||||
|
rs3_rd[(step, substep)],
|
||||||
|
)
|
||||||
|
|
||||||
|
entry = f"{name}: begin \n"
|
||||||
|
entry += "\tuop = {"
|
||||||
|
entry += ucode
|
||||||
|
entry += "}; \n"
|
||||||
|
entry += "end \n"
|
||||||
|
|
||||||
|
f.write(entry)
|
||||||
8
tests/kernel/tensor/Makefile
Normal file
8
tests/kernel/tensor/Makefile
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
PROJECT = tensor
|
||||||
|
|
||||||
|
SRCS = main.cpp
|
||||||
|
DEPS = a_matrix.h
|
||||||
|
DEPS += b_matrix.h
|
||||||
|
DEPS += c_matrix.h
|
||||||
|
|
||||||
|
include ../common.mk
|
||||||
94
tests/kernel/tensor/check_correctness.py
Normal file
94
tests/kernel/tensor/check_correctness.py
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
import numpy as np
|
||||||
|
import struct
|
||||||
|
|
||||||
|
A_array = np.zeros((16, 8))
|
||||||
|
B_array = np.zeros((8, 16))
|
||||||
|
C_array = np.zeros((16, 16))
|
||||||
|
|
||||||
|
file = input("simulator output filename: ")
|
||||||
|
|
||||||
|
def hex2float(float_hex_str):
|
||||||
|
# print(float_hex_str.strip())
|
||||||
|
return struct.unpack(">f",struct.pack(">i",int(float_hex_str,16)))[0]
|
||||||
|
|
||||||
|
def C_index(threadgroup, thread, register):
|
||||||
|
"""
|
||||||
|
col = ((tg % 4) / 2) * 8;
|
||||||
|
row = (tg * 8) % 16;
|
||||||
|
row += (tg / 4) * 4;
|
||||||
|
|
||||||
|
asm volatile ("flw f16, %0" :: "m"(C[row+0][col+0]));
|
||||||
|
asm volatile ("flw f17, %0" :: "m"(C[row+0][col+1]));
|
||||||
|
asm volatile ("flw f18, %0" :: "m"(C[row+2][col+0]));
|
||||||
|
asm volatile ("flw f19, %0" :: "m"(C[row+2][col+1]));
|
||||||
|
asm volatile ("flw f20, %0" :: "m"(C[row+0][col+4]));
|
||||||
|
asm volatile ("flw f21, %0" :: "m"(C[row+0][col+5]));
|
||||||
|
asm volatile ("flw f22, %0" :: "m"(C[row+2][col+4]));
|
||||||
|
asm volatile ("flw f23, %0" :: "m"(C[row+2][col+5]));
|
||||||
|
"""
|
||||||
|
|
||||||
|
col = ((threadgroup % 4) // 2) * 8
|
||||||
|
row = (threadgroup * 8) % 16
|
||||||
|
row += (threadgroup // 4) * 4
|
||||||
|
offsets = [(0, 0), (0, 1), (2, 0), (2, 1), (0, 4), (0, 5), (2, 4), (2, 5)]
|
||||||
|
offset = offsets[register-16]
|
||||||
|
row += offset[0]
|
||||||
|
col += offset[1]
|
||||||
|
thread_offsets = [(0, 0), (1, 0), (0, 2), (1, 2)]
|
||||||
|
thread_offset = thread_offsets[thread % 4]
|
||||||
|
row += thread_offset[0]
|
||||||
|
col += thread_offset[1]
|
||||||
|
if C_array[row, col] != 0:
|
||||||
|
print("bad")
|
||||||
|
return (row, col)
|
||||||
|
|
||||||
|
|
||||||
|
with open(file) as f:
|
||||||
|
for line in f.readlines():
|
||||||
|
line = line.strip()
|
||||||
|
if "warp" in line:
|
||||||
|
a, b, c = line.split(',')
|
||||||
|
_, a = a.split(' ')
|
||||||
|
_, b = b.strip().split(' ')
|
||||||
|
c, d = c.strip().split(':')
|
||||||
|
_, c = c.split(' ')
|
||||||
|
warp = int(a)
|
||||||
|
thread = int(b)
|
||||||
|
register = int(c)
|
||||||
|
value = d.strip()
|
||||||
|
|
||||||
|
if warp != 0:
|
||||||
|
continue
|
||||||
|
if not (32 <= register < 32+24):
|
||||||
|
continue
|
||||||
|
|
||||||
|
register = register - 32
|
||||||
|
|
||||||
|
# threadgroups 0, 4, 1, 5 have all elements of A
|
||||||
|
threadgroup = thread // 4
|
||||||
|
if threadgroup in [0, 4, 1, 5]:
|
||||||
|
row = [0, 4, 1, 5].index(threadgroup) * 4 + thread % 4
|
||||||
|
if 0 <= register < 8:
|
||||||
|
A_array[row, register] = hex2float(value)
|
||||||
|
|
||||||
|
if threadgroup in [0, 4, 2, 6]:
|
||||||
|
col = [0, 4, 2, 6].index(threadgroup) * 4 + thread % 4
|
||||||
|
if 8 <= register < 16:
|
||||||
|
B_array[register-8, col] = hex2float(value)
|
||||||
|
|
||||||
|
if 16 <= register < 24:
|
||||||
|
# print(value)
|
||||||
|
C_array[C_index(threadgroup, thread, register)] = hex2float(value)
|
||||||
|
|
||||||
|
|
||||||
|
expected = np.load("abc.npz")
|
||||||
|
expected_A = expected['A_array']
|
||||||
|
expected_B = expected['B_array']
|
||||||
|
expected_C = expected['C_array']
|
||||||
|
expected_C = expected_C + expected_A @ expected_B
|
||||||
|
|
||||||
|
print(expected_C - C_array)
|
||||||
|
|
||||||
|
assert np.allclose(expected_A, A_array)
|
||||||
|
assert np.allclose(expected_B, B_array)
|
||||||
|
assert np.allclose(expected_C, C_array)
|
||||||
29
tests/kernel/tensor/create_test_case.py
Normal file
29
tests/kernel/tensor/create_test_case.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
import numpy as np
|
||||||
|
# A_array = np.random.rand(16, 8)
|
||||||
|
# B_array = np.random.rand(8, 16)
|
||||||
|
A_array = np.zeros((16, 8))
|
||||||
|
B_array = np.zeros((8, 16))
|
||||||
|
A_array[0,:] = 1.0
|
||||||
|
B_array[:,0] = 1.0
|
||||||
|
C_array = np.random.rand(16, 16)
|
||||||
|
|
||||||
|
|
||||||
|
with open('a_matrix.h', 'w') as f:
|
||||||
|
for i in range(A_array.shape[0]):
|
||||||
|
for j in range(A_array.shape[1]):
|
||||||
|
f.write(f'{A_array[i,j]}f, ')
|
||||||
|
f.write('\n')
|
||||||
|
|
||||||
|
with open('b_matrix.h', 'w') as f:
|
||||||
|
for i in range(B_array.shape[0]):
|
||||||
|
for j in range(B_array.shape[1]):
|
||||||
|
f.write(f'{B_array[i,j]}f, ')
|
||||||
|
f.write('\n')
|
||||||
|
|
||||||
|
with open('c_matrix.h', 'w') as f:
|
||||||
|
for i in range(C_array.shape[0]):
|
||||||
|
for j in range(C_array.shape[1]):
|
||||||
|
f.write(f'{C_array[i,j]}f, ')
|
||||||
|
f.write('\n')
|
||||||
|
|
||||||
|
np.savez("abc", A_array=A_array, B_array=B_array, C_array=C_array)
|
||||||
96
tests/kernel/tensor/main.cpp
Normal file
96
tests/kernel/tensor/main.cpp
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
#define RISCV_CUSTOM3 0x7B
|
||||||
|
|
||||||
|
#include <vx_intrinsics.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <vx_print.h>
|
||||||
|
|
||||||
|
inline void vx_wmma() {
|
||||||
|
asm volatile (".insn r %0, 0, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3));
|
||||||
|
}
|
||||||
|
|
||||||
|
#include "test_data.h"
|
||||||
|
|
||||||
|
void vx_wmma_load() {
|
||||||
|
int tid = vx_thread_id();
|
||||||
|
int tg = tid / 4;
|
||||||
|
|
||||||
|
// load A
|
||||||
|
int row = tid % 4;
|
||||||
|
row += (tg * 8) % 16;
|
||||||
|
row += (tg / 4) * 4;
|
||||||
|
|
||||||
|
asm volatile ("flw f0, %0" :: "m"(A[row][0]));
|
||||||
|
asm volatile ("flw f1, %0" :: "m"(A[row][1]));
|
||||||
|
asm volatile ("flw f2, %0" :: "m"(A[row][2]));
|
||||||
|
asm volatile ("flw f3, %0" :: "m"(A[row][3]));
|
||||||
|
asm volatile ("flw f4, %0" :: "m"(A[row][4]));
|
||||||
|
asm volatile ("flw f5, %0" :: "m"(A[row][5]));
|
||||||
|
asm volatile ("flw f6, %0" :: "m"(A[row][6]));
|
||||||
|
asm volatile ("flw f7, %0" :: "m"(A[row][7]));
|
||||||
|
|
||||||
|
// load B
|
||||||
|
int col = tid % 4;
|
||||||
|
col += ((tg % 4) / 2) * 8;
|
||||||
|
col += (tg / 4) * 4;
|
||||||
|
|
||||||
|
asm volatile ("flw f8 , %0" :: "m"(B[0][col]));
|
||||||
|
asm volatile ("flw f9 , %0" :: "m"(B[1][col]));
|
||||||
|
asm volatile ("flw f10, %0" :: "m"(B[2][col]));
|
||||||
|
asm volatile ("flw f11, %0" :: "m"(B[3][col]));
|
||||||
|
asm volatile ("flw f12, %0" :: "m"(B[4][col]));
|
||||||
|
asm volatile ("flw f13, %0" :: "m"(B[5][col]));
|
||||||
|
asm volatile ("flw f14, %0" :: "m"(B[6][col]));
|
||||||
|
asm volatile ("flw f15, %0" :: "m"(B[7][col]));
|
||||||
|
|
||||||
|
// load C
|
||||||
|
col = ((tg % 4) / 2) * 8;
|
||||||
|
row = (tg * 8) % 16;
|
||||||
|
row += (tg / 4) * 4;
|
||||||
|
|
||||||
|
row += (tid % 4) % 2;
|
||||||
|
col += ((tid % 4) / 2) * 2;
|
||||||
|
|
||||||
|
asm volatile ("flw f16, %0" :: "m"(C[row+0][col+0]));
|
||||||
|
asm volatile ("flw f17, %0" :: "m"(C[row+0][col+1]));
|
||||||
|
asm volatile ("flw f18, %0" :: "m"(C[row+2][col+0]));
|
||||||
|
asm volatile ("flw f19, %0" :: "m"(C[row+2][col+1]));
|
||||||
|
asm volatile ("flw f20, %0" :: "m"(C[row+0][col+4]));
|
||||||
|
asm volatile ("flw f21, %0" :: "m"(C[row+0][col+5]));
|
||||||
|
asm volatile ("flw f22, %0" :: "m"(C[row+2][col+4]));
|
||||||
|
asm volatile ("flw f23, %0" :: "m"(C[row+2][col+5]));
|
||||||
|
}
|
||||||
|
|
||||||
|
float results[32*8];
|
||||||
|
|
||||||
|
void store_wmma_result() {
|
||||||
|
int tid = vx_thread_id();
|
||||||
|
|
||||||
|
asm volatile ("fsw f16, %0" :: "m"(results[tid*8+0]));
|
||||||
|
asm volatile ("fsw f17, %0" :: "m"(results[tid*8+1]));
|
||||||
|
asm volatile ("fsw f18, %0" :: "m"(results[tid*8+2]));
|
||||||
|
asm volatile ("fsw f19, %0" :: "m"(results[tid*8+3]));
|
||||||
|
asm volatile ("fsw f20, %0" :: "m"(results[tid*8+4]));
|
||||||
|
asm volatile ("fsw f21, %0" :: "m"(results[tid*8+5]));
|
||||||
|
asm volatile ("fsw f22, %0" :: "m"(results[tid*8+6]));
|
||||||
|
asm volatile ("fsw f23, %0" :: "m"(results[tid*8+7]));
|
||||||
|
}
|
||||||
|
|
||||||
|
void print_wmma_result() {
|
||||||
|
for (int tid = 0; tid < 32; tid += 1) {
|
||||||
|
for (int reg = 0; reg < 8; reg += 1) {
|
||||||
|
vx_printf("thread %d, f%d: %x\n", tid, 16+reg, *((int*) &results[tid*8+reg]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int main()
|
||||||
|
{
|
||||||
|
vx_tmc(-1);
|
||||||
|
vx_wmma_load();
|
||||||
|
vx_wmma();
|
||||||
|
store_wmma_result();
|
||||||
|
vx_tmc(1);
|
||||||
|
print_wmma_result();
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
11
tests/kernel/tensor/test_data.h
Normal file
11
tests/kernel/tensor/test_data.h
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
float A[16][8] = {
|
||||||
|
#include "a_matrix.h"
|
||||||
|
};
|
||||||
|
|
||||||
|
float B[8][16] = {
|
||||||
|
#include "b_matrix.h"
|
||||||
|
};
|
||||||
|
|
||||||
|
float C[16][16] = {
|
||||||
|
#include "c_matrix.h"
|
||||||
|
};
|
||||||
Reference in New Issue
Block a user