tensor: Split flops into structural module

to get separate area/power numbers in hierarchical
This commit is contained in:
Hansung Kim
2024-07-26 16:26:48 -07:00
parent 7f43bab0aa
commit 01f6024a76
2 changed files with 103 additions and 56 deletions

View File

@@ -51,7 +51,7 @@ module VX_tensor_core import VX_gpu_pkg::*; #(
); );
for (genvar block_idx = 0; block_idx < BLOCK_SIZE; ++block_idx) begin for (genvar block_idx = 0; block_idx < BLOCK_SIZE; ++block_idx) begin
VX_tensor_core_warp #( VX_tensor_core_block #(
.ISW(1) // FIXME: not block_idx .ISW(1) // FIXME: not block_idx
) tensor_core ( ) tensor_core (
.clk(clk), .clk(clk),
@@ -64,7 +64,7 @@ module VX_tensor_core import VX_gpu_pkg::*; #(
endmodule endmodule
module VX_tensor_core_warp import VX_gpu_pkg::*; #( module VX_tensor_core_block import VX_gpu_pkg::*; #(
parameter ISW parameter ISW
) ( ) (
input clk, input clk,
@@ -82,15 +82,16 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
localparam METADATA_QUEUE_DEPTH = 2 * `LATENCY_HMMA; localparam METADATA_QUEUE_DEPTH = 2 * `LATENCY_HMMA;
wire [1:0] step = 2'(execute_if.data.op_type); wire [1:0] step = 2'(execute_if.data.op_type);
// op_mod is reused to indicate instruction's id in pair // op_mod is reused to indicate if instruction is the last substep inside
// a step (pair of substeps)
wire last_in_pair = (execute_if.data.op_mod == `INST_MOD_BITS'(1)); wire last_in_pair = (execute_if.data.op_mod == `INST_MOD_BITS'(1));
logic [NUM_OCTETS-1:0] octet_results_valid; wire [NUM_OCTETS-1:0] octet_results_valid;
logic [NUM_OCTETS-1:0] octet_results_ready; logic [NUM_OCTETS-1:0] octet_results_ready;
logic [NUM_OCTETS-1:0] octet_operands_ready; wire [NUM_OCTETS-1:0] octet_operands_ready;
// FIXME: should be NUM_LANES? // FIXME: should be NUM_LANES?
logic [`NUM_THREADS-1:0][`XLEN-1:0] wb_data_0; wire [`NUM_THREADS-1:0][`XLEN-1:0] wb_data_0;
logic [`NUM_THREADS-1:0][`XLEN-1:0] wb_data_1; wire [`NUM_THREADS-1:0][`XLEN-1:0] wb_data_1;
wire [`NW_WIDTH-1:0] wb_wid; wire [`NW_WIDTH-1:0] wb_wid;
// valid signal synced between the functional units (octet) and the // valid signal synced between the functional units (octet) and the
@@ -113,9 +114,9 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
execute_if.data.rs3_data[LANE_OFFSET_THREADGROUP + 4*i +: 4], execute_if.data.rs3_data[4*i +: 4] execute_if.data.rs3_data[LANE_OFFSET_THREADGROUP + 4*i +: 4], execute_if.data.rs3_data[4*i +: 4]
}; };
logic [3:0][3:0][31:0] octet_D; wire [3:0][3:0][31:0] octet_D;
logic result_valid; wire result_valid;
logic result_ready; wire result_ready;
VX_tensor_octet #( VX_tensor_octet #(
.ISW(ISW), .ISW(ISW),
@@ -285,15 +286,39 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
end end
end end
VX_tensor_reg #(
.N(1)
) staging_subcommit (
.clk(clk),
.reset(reset),
.d(subcommit_n),
.en(1'b1),
.q(subcommit)
);
endmodule
module VX_tensor_reg #(
parameter N
) (
input clk,
input reset,
input [N-1:0] d,
input en,
output [N-1:0] q
);
logic [N-1:0] data;
always @(posedge clk) begin always @(posedge clk) begin
if (reset) begin if (reset) begin
subcommit <= '0; data <= '0;
end end else begin
else begin if (en) begin
subcommit <= subcommit_n; data <= d;
end
end end
end end
assign q = data;
endmodule endmodule
module VX_tensor_octet #( module VX_tensor_octet #(
@@ -337,7 +362,6 @@ module VX_tensor_octet #(
logic [3:0][31:0] B_half_buf; logic [3:0][31:0] B_half_buf;
logic [7:0][31:0] C_half_buf; logic [7:0][31:0] C_half_buf;
logic [`NUM_WARPS-1:0] substeps; logic [`NUM_WARPS-1:0] substeps;
logic [`NUM_WARPS-1:0] substeps_n; logic [`NUM_WARPS-1:0] substeps_n;
@@ -353,6 +377,7 @@ module VX_tensor_octet #(
assign A_in_buf = A_in; assign A_in_buf = A_in;
assign B_in_buf = B_in; assign B_in_buf = B_in;
assign C_in_buf = C_in; assign C_in_buf = C_in;
// TODO: merge *_buf/*
assign operands_step_buf = operands_step; assign operands_step_buf = operands_step;
assign operands_wid_buf = operands_wid; assign operands_wid_buf = operands_wid;
assign operands_last_in_pair_buf = operands_last_in_pair; assign operands_last_in_pair_buf = operands_last_in_pair;
@@ -408,6 +433,18 @@ module VX_tensor_octet #(
// wire operands_first_in_pair_fire = operands_ready && operands_valid && (!operands_last_in_pair); // wire operands_first_in_pair_fire = operands_ready && operands_valid && (!operands_last_in_pair);
wire operands_first_in_pair_fire = operands_ready_buf && operands_valid_buf && (!operands_last_in_pair_buf); wire operands_first_in_pair_fire = operands_ready_buf && operands_valid_buf && (!operands_last_in_pair_buf);
// Staging buffer for the A/B/C half-tiles that will later be assembled
// with the other half tiles coming in on the input ports.
VX_tensor_reg #(
.N($bits(A_buffer) + $bits(B_buffer) + $bits(C_buffer))
) staging_abc (
.clk(clk),
.reset(reset),
.d({A_buffer_n, B_buffer_n, C_buffer_n}),
.en(1'b1),
.q({A_buffer, B_buffer, C_buffer})
);
always @(*) begin always @(*) begin
A_buffer_n = A_buffer; A_buffer_n = A_buffer;
B_buffer_n = B_buffer; B_buffer_n = B_buffer;
@@ -426,20 +463,15 @@ module VX_tensor_octet #(
end end
end end
always @(posedge clk) begin VX_tensor_reg #(
if (reset) begin .N($bits(substeps))
A_buffer <= '0; ) staging_substeps (
B_buffer <= '0; .clk(clk),
C_buffer <= '0; .reset(reset),
substeps <= '0; .d(substeps_n),
end .en(1'b1),
else begin .q(substeps)
A_buffer <= A_buffer_n; );
B_buffer <= B_buffer_n;
C_buffer <= C_buffer_n;
substeps <= substeps_n;
end
end
wire outbuf_ready_in; wire outbuf_ready_in;
wire hmma_ready; wire hmma_ready;
@@ -458,8 +490,8 @@ module VX_tensor_octet #(
}; };
// C is 4x4 fp32 matrix // C is 4x4 fp32 matrix
logic [3:0][3:0][31:0] C_tile; logic [3:0][3:0][31:0] C_tile;
logic [3:0][3:0][31:0] D_tile; wire [3:0][3:0][31:0] D_tile;
logic [`NW_WIDTH-1:0] D_wid_dpu; wire [`NW_WIDTH-1:0] D_wid_dpu;
always @(*) begin always @(*) begin
C_tile[3] = { halves_buf.C_half[7], C_buffer[operands_wid_buf][7], halves_buf.C_half[5], C_buffer[operands_wid_buf][5] }; C_tile[3] = { halves_buf.C_half[7], C_buffer[operands_wid_buf][7], halves_buf.C_half[5], C_buffer[operands_wid_buf][5] };

View File

@@ -62,11 +62,11 @@ module VX_tensor_dpu #(
// stalling // stalling
// assign ready_in = ready_out; // assign ready_in = ready_out;
logic synced_fire; wire synced_fire;
assign synced_fire = valid_in && ready_in; assign synced_fire = valid_in && ready_in;
logic [1:0] threadgroup_valids; wire [1:0] threadgroup_valids;
logic [1:0] threadgroup_readys; wire [1:0] threadgroup_readys;
// B_tile is shared across the two threadgroups; see Figure 13 // B_tile is shared across the two threadgroups; see Figure 13
VX_tensor_threadgroup #( VX_tensor_threadgroup #(
.ISSUE_QUEUE_DEPTH(ISSUE_QUEUE_DEPTH) .ISSUE_QUEUE_DEPTH(ISSUE_QUEUE_DEPTH)
@@ -187,7 +187,7 @@ module VX_tensor_threadgroup #(
`UNUSED_PIN(size) `UNUSED_PIN(size)
); );
logic [3:0] fedp_valids; wire [3:0] fedp_valids;
wire fedp_valid_out = &(fedp_valids); wire fedp_valid_out = &(fedp_valids);
wire fedp_ready_out = !stall; wire fedp_ready_out = !stall;
wire fedp_fire_out = fedp_valid_out && fedp_ready_out; wire fedp_fire_out = fedp_valid_out && fedp_ready_out;
@@ -198,14 +198,27 @@ module VX_tensor_threadgroup #(
// 0: FEDP uses first half from input_buffer // 0: FEDP uses first half from input_buffer
// 1: FEDP uses last half and pops input_buffer // 1: FEDP uses last half and pops input_buffer
logic step_in; wire step_in;
// 0: FEDP produces first half of D_frag // 0: FEDP produces first half of D_frag
// 1: FEDP produces last half of D_frag and asserts valid_out // 1: FEDP produces last half of D_frag and asserts valid_out
logic step_out; wire step_out;
assign ready_buf = fedp_fire_in && (step_in == 1'b1); assign ready_buf = fedp_fire_in && (step_in == 1'b1);
wire [3:0][31:0] D_reg;
logic [3:0][31:0] D_reg_n;
// Staging buffer that latches the D half-tile.
VX_tensor_reg #(
.N($bits(D_reg))
) staging_d (
.clk(clk),
.reset(reset),
.d(D_reg_n),
.en(1'b1),
.q(D_reg)
);
// latch the first-half result of D_frag // latch the first-half result of D_frag
logic [3:0][31:0] D_reg, D_reg_n;
wire [3:0][31:0] D_half; wire [3:0][31:0] D_half;
always @(*) begin always @(*) begin
D_reg_n = D_reg; D_reg_n = D_reg;
@@ -216,23 +229,25 @@ module VX_tensor_threadgroup #(
end end
end end
always @(posedge clk) begin // flip step_in/step_out on FEDP in/out fire, respectively
if (reset) begin VX_tensor_reg #(
step_in <= '0; .N(1)
step_out <= '0; ) staging_step_in (
.clk(clk),
D_reg <= '0; .reset(reset),
end else begin .d(~step_in),
if (fedp_fire_in) begin .en(fedp_fire_in),
step_in <= ~step_in; .q(step_in)
end );
if (fedp_fire_out) begin VX_tensor_reg #(
step_out <= ~step_out; .N(1)
end ) staging_step_out (
.clk(clk),
D_reg <= D_reg_n; .reset(reset),
end .d(~step_out),
end .en(fedp_fire_out),
.q(step_out)
);
// TODO: Instead of latching half-result and constructing a full D tile, // TODO: Instead of latching half-result and constructing a full D tile,
// we should be able to send these half fragments down to commit stage // we should be able to send these half fragments down to commit stage