tensor: Split flops into structural module
to get separate area/power numbers in hierarchical
This commit is contained in:
@@ -51,7 +51,7 @@ module VX_tensor_core import VX_gpu_pkg::*; #(
|
|||||||
);
|
);
|
||||||
|
|
||||||
for (genvar block_idx = 0; block_idx < BLOCK_SIZE; ++block_idx) begin
|
for (genvar block_idx = 0; block_idx < BLOCK_SIZE; ++block_idx) begin
|
||||||
VX_tensor_core_warp #(
|
VX_tensor_core_block #(
|
||||||
.ISW(1) // FIXME: not block_idx
|
.ISW(1) // FIXME: not block_idx
|
||||||
) tensor_core (
|
) tensor_core (
|
||||||
.clk(clk),
|
.clk(clk),
|
||||||
@@ -64,7 +64,7 @@ module VX_tensor_core import VX_gpu_pkg::*; #(
|
|||||||
|
|
||||||
endmodule
|
endmodule
|
||||||
|
|
||||||
module VX_tensor_core_warp import VX_gpu_pkg::*; #(
|
module VX_tensor_core_block import VX_gpu_pkg::*; #(
|
||||||
parameter ISW
|
parameter ISW
|
||||||
) (
|
) (
|
||||||
input clk,
|
input clk,
|
||||||
@@ -82,15 +82,16 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
|
|||||||
localparam METADATA_QUEUE_DEPTH = 2 * `LATENCY_HMMA;
|
localparam METADATA_QUEUE_DEPTH = 2 * `LATENCY_HMMA;
|
||||||
|
|
||||||
wire [1:0] step = 2'(execute_if.data.op_type);
|
wire [1:0] step = 2'(execute_if.data.op_type);
|
||||||
// op_mod is reused to indicate instruction's id in pair
|
// op_mod is reused to indicate if instruction is the last substep inside
|
||||||
|
// a step (pair of substeps)
|
||||||
wire last_in_pair = (execute_if.data.op_mod == `INST_MOD_BITS'(1));
|
wire last_in_pair = (execute_if.data.op_mod == `INST_MOD_BITS'(1));
|
||||||
|
|
||||||
logic [NUM_OCTETS-1:0] octet_results_valid;
|
wire [NUM_OCTETS-1:0] octet_results_valid;
|
||||||
logic [NUM_OCTETS-1:0] octet_results_ready;
|
logic [NUM_OCTETS-1:0] octet_results_ready;
|
||||||
logic [NUM_OCTETS-1:0] octet_operands_ready;
|
wire [NUM_OCTETS-1:0] octet_operands_ready;
|
||||||
// FIXME: should be NUM_LANES?
|
// FIXME: should be NUM_LANES?
|
||||||
logic [`NUM_THREADS-1:0][`XLEN-1:0] wb_data_0;
|
wire [`NUM_THREADS-1:0][`XLEN-1:0] wb_data_0;
|
||||||
logic [`NUM_THREADS-1:0][`XLEN-1:0] wb_data_1;
|
wire [`NUM_THREADS-1:0][`XLEN-1:0] wb_data_1;
|
||||||
wire [`NW_WIDTH-1:0] wb_wid;
|
wire [`NW_WIDTH-1:0] wb_wid;
|
||||||
|
|
||||||
// valid signal synced between the functional units (octet) and the
|
// valid signal synced between the functional units (octet) and the
|
||||||
@@ -113,9 +114,9 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
|
|||||||
execute_if.data.rs3_data[LANE_OFFSET_THREADGROUP + 4*i +: 4], execute_if.data.rs3_data[4*i +: 4]
|
execute_if.data.rs3_data[LANE_OFFSET_THREADGROUP + 4*i +: 4], execute_if.data.rs3_data[4*i +: 4]
|
||||||
};
|
};
|
||||||
|
|
||||||
logic [3:0][3:0][31:0] octet_D;
|
wire [3:0][3:0][31:0] octet_D;
|
||||||
logic result_valid;
|
wire result_valid;
|
||||||
logic result_ready;
|
wire result_ready;
|
||||||
|
|
||||||
VX_tensor_octet #(
|
VX_tensor_octet #(
|
||||||
.ISW(ISW),
|
.ISW(ISW),
|
||||||
@@ -285,15 +286,39 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
VX_tensor_reg #(
|
||||||
|
.N(1)
|
||||||
|
) staging_subcommit (
|
||||||
|
.clk(clk),
|
||||||
|
.reset(reset),
|
||||||
|
.d(subcommit_n),
|
||||||
|
.en(1'b1),
|
||||||
|
.q(subcommit)
|
||||||
|
);
|
||||||
|
endmodule
|
||||||
|
|
||||||
|
module VX_tensor_reg #(
|
||||||
|
parameter N
|
||||||
|
) (
|
||||||
|
input clk,
|
||||||
|
input reset,
|
||||||
|
input [N-1:0] d,
|
||||||
|
input en,
|
||||||
|
output [N-1:0] q
|
||||||
|
);
|
||||||
|
logic [N-1:0] data;
|
||||||
|
|
||||||
always @(posedge clk) begin
|
always @(posedge clk) begin
|
||||||
if (reset) begin
|
if (reset) begin
|
||||||
subcommit <= '0;
|
data <= '0;
|
||||||
end
|
end else begin
|
||||||
else begin
|
if (en) begin
|
||||||
subcommit <= subcommit_n;
|
data <= d;
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
assign q = data;
|
||||||
endmodule
|
endmodule
|
||||||
|
|
||||||
module VX_tensor_octet #(
|
module VX_tensor_octet #(
|
||||||
@@ -337,7 +362,6 @@ module VX_tensor_octet #(
|
|||||||
logic [3:0][31:0] B_half_buf;
|
logic [3:0][31:0] B_half_buf;
|
||||||
logic [7:0][31:0] C_half_buf;
|
logic [7:0][31:0] C_half_buf;
|
||||||
|
|
||||||
|
|
||||||
logic [`NUM_WARPS-1:0] substeps;
|
logic [`NUM_WARPS-1:0] substeps;
|
||||||
logic [`NUM_WARPS-1:0] substeps_n;
|
logic [`NUM_WARPS-1:0] substeps_n;
|
||||||
|
|
||||||
@@ -353,6 +377,7 @@ module VX_tensor_octet #(
|
|||||||
assign A_in_buf = A_in;
|
assign A_in_buf = A_in;
|
||||||
assign B_in_buf = B_in;
|
assign B_in_buf = B_in;
|
||||||
assign C_in_buf = C_in;
|
assign C_in_buf = C_in;
|
||||||
|
// TODO: merge *_buf/*
|
||||||
assign operands_step_buf = operands_step;
|
assign operands_step_buf = operands_step;
|
||||||
assign operands_wid_buf = operands_wid;
|
assign operands_wid_buf = operands_wid;
|
||||||
assign operands_last_in_pair_buf = operands_last_in_pair;
|
assign operands_last_in_pair_buf = operands_last_in_pair;
|
||||||
@@ -408,6 +433,18 @@ module VX_tensor_octet #(
|
|||||||
// wire operands_first_in_pair_fire = operands_ready && operands_valid && (!operands_last_in_pair);
|
// wire operands_first_in_pair_fire = operands_ready && operands_valid && (!operands_last_in_pair);
|
||||||
wire operands_first_in_pair_fire = operands_ready_buf && operands_valid_buf && (!operands_last_in_pair_buf);
|
wire operands_first_in_pair_fire = operands_ready_buf && operands_valid_buf && (!operands_last_in_pair_buf);
|
||||||
|
|
||||||
|
// Staging buffer for the A/B/C half-tiles that will later be assembled
|
||||||
|
// with the other half tiles coming in on the input ports.
|
||||||
|
VX_tensor_reg #(
|
||||||
|
.N($bits(A_buffer) + $bits(B_buffer) + $bits(C_buffer))
|
||||||
|
) staging_abc (
|
||||||
|
.clk(clk),
|
||||||
|
.reset(reset),
|
||||||
|
.d({A_buffer_n, B_buffer_n, C_buffer_n}),
|
||||||
|
.en(1'b1),
|
||||||
|
.q({A_buffer, B_buffer, C_buffer})
|
||||||
|
);
|
||||||
|
|
||||||
always @(*) begin
|
always @(*) begin
|
||||||
A_buffer_n = A_buffer;
|
A_buffer_n = A_buffer;
|
||||||
B_buffer_n = B_buffer;
|
B_buffer_n = B_buffer;
|
||||||
@@ -426,20 +463,15 @@ module VX_tensor_octet #(
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
always @(posedge clk) begin
|
VX_tensor_reg #(
|
||||||
if (reset) begin
|
.N($bits(substeps))
|
||||||
A_buffer <= '0;
|
) staging_substeps (
|
||||||
B_buffer <= '0;
|
.clk(clk),
|
||||||
C_buffer <= '0;
|
.reset(reset),
|
||||||
substeps <= '0;
|
.d(substeps_n),
|
||||||
end
|
.en(1'b1),
|
||||||
else begin
|
.q(substeps)
|
||||||
A_buffer <= A_buffer_n;
|
);
|
||||||
B_buffer <= B_buffer_n;
|
|
||||||
C_buffer <= C_buffer_n;
|
|
||||||
substeps <= substeps_n;
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
wire outbuf_ready_in;
|
wire outbuf_ready_in;
|
||||||
wire hmma_ready;
|
wire hmma_ready;
|
||||||
@@ -458,8 +490,8 @@ module VX_tensor_octet #(
|
|||||||
};
|
};
|
||||||
// C is 4x4 fp32 matrix
|
// C is 4x4 fp32 matrix
|
||||||
logic [3:0][3:0][31:0] C_tile;
|
logic [3:0][3:0][31:0] C_tile;
|
||||||
logic [3:0][3:0][31:0] D_tile;
|
wire [3:0][3:0][31:0] D_tile;
|
||||||
logic [`NW_WIDTH-1:0] D_wid_dpu;
|
wire [`NW_WIDTH-1:0] D_wid_dpu;
|
||||||
|
|
||||||
always @(*) begin
|
always @(*) begin
|
||||||
C_tile[3] = { halves_buf.C_half[7], C_buffer[operands_wid_buf][7], halves_buf.C_half[5], C_buffer[operands_wid_buf][5] };
|
C_tile[3] = { halves_buf.C_half[7], C_buffer[operands_wid_buf][7], halves_buf.C_half[5], C_buffer[operands_wid_buf][5] };
|
||||||
|
|||||||
@@ -62,11 +62,11 @@ module VX_tensor_dpu #(
|
|||||||
// stalling
|
// stalling
|
||||||
// assign ready_in = ready_out;
|
// assign ready_in = ready_out;
|
||||||
|
|
||||||
logic synced_fire;
|
wire synced_fire;
|
||||||
assign synced_fire = valid_in && ready_in;
|
assign synced_fire = valid_in && ready_in;
|
||||||
|
|
||||||
logic [1:0] threadgroup_valids;
|
wire [1:0] threadgroup_valids;
|
||||||
logic [1:0] threadgroup_readys;
|
wire [1:0] threadgroup_readys;
|
||||||
// B_tile is shared across the two threadgroups; see Figure 13
|
// B_tile is shared across the two threadgroups; see Figure 13
|
||||||
VX_tensor_threadgroup #(
|
VX_tensor_threadgroup #(
|
||||||
.ISSUE_QUEUE_DEPTH(ISSUE_QUEUE_DEPTH)
|
.ISSUE_QUEUE_DEPTH(ISSUE_QUEUE_DEPTH)
|
||||||
@@ -187,7 +187,7 @@ module VX_tensor_threadgroup #(
|
|||||||
`UNUSED_PIN(size)
|
`UNUSED_PIN(size)
|
||||||
);
|
);
|
||||||
|
|
||||||
logic [3:0] fedp_valids;
|
wire [3:0] fedp_valids;
|
||||||
wire fedp_valid_out = &(fedp_valids);
|
wire fedp_valid_out = &(fedp_valids);
|
||||||
wire fedp_ready_out = !stall;
|
wire fedp_ready_out = !stall;
|
||||||
wire fedp_fire_out = fedp_valid_out && fedp_ready_out;
|
wire fedp_fire_out = fedp_valid_out && fedp_ready_out;
|
||||||
@@ -198,14 +198,27 @@ module VX_tensor_threadgroup #(
|
|||||||
|
|
||||||
// 0: FEDP uses first half from input_buffer
|
// 0: FEDP uses first half from input_buffer
|
||||||
// 1: FEDP uses last half and pops input_buffer
|
// 1: FEDP uses last half and pops input_buffer
|
||||||
logic step_in;
|
wire step_in;
|
||||||
// 0: FEDP produces first half of D_frag
|
// 0: FEDP produces first half of D_frag
|
||||||
// 1: FEDP produces last half of D_frag and asserts valid_out
|
// 1: FEDP produces last half of D_frag and asserts valid_out
|
||||||
logic step_out;
|
wire step_out;
|
||||||
assign ready_buf = fedp_fire_in && (step_in == 1'b1);
|
assign ready_buf = fedp_fire_in && (step_in == 1'b1);
|
||||||
|
|
||||||
|
wire [3:0][31:0] D_reg;
|
||||||
|
logic [3:0][31:0] D_reg_n;
|
||||||
|
|
||||||
|
// Staging buffer that latches the D half-tile.
|
||||||
|
VX_tensor_reg #(
|
||||||
|
.N($bits(D_reg))
|
||||||
|
) staging_d (
|
||||||
|
.clk(clk),
|
||||||
|
.reset(reset),
|
||||||
|
.d(D_reg_n),
|
||||||
|
.en(1'b1),
|
||||||
|
.q(D_reg)
|
||||||
|
);
|
||||||
|
|
||||||
// latch the first-half result of D_frag
|
// latch the first-half result of D_frag
|
||||||
logic [3:0][31:0] D_reg, D_reg_n;
|
|
||||||
wire [3:0][31:0] D_half;
|
wire [3:0][31:0] D_half;
|
||||||
always @(*) begin
|
always @(*) begin
|
||||||
D_reg_n = D_reg;
|
D_reg_n = D_reg;
|
||||||
@@ -216,23 +229,25 @@ module VX_tensor_threadgroup #(
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
always @(posedge clk) begin
|
// flip step_in/step_out on FEDP in/out fire, respectively
|
||||||
if (reset) begin
|
VX_tensor_reg #(
|
||||||
step_in <= '0;
|
.N(1)
|
||||||
step_out <= '0;
|
) staging_step_in (
|
||||||
|
.clk(clk),
|
||||||
D_reg <= '0;
|
.reset(reset),
|
||||||
end else begin
|
.d(~step_in),
|
||||||
if (fedp_fire_in) begin
|
.en(fedp_fire_in),
|
||||||
step_in <= ~step_in;
|
.q(step_in)
|
||||||
end
|
);
|
||||||
if (fedp_fire_out) begin
|
VX_tensor_reg #(
|
||||||
step_out <= ~step_out;
|
.N(1)
|
||||||
end
|
) staging_step_out (
|
||||||
|
.clk(clk),
|
||||||
D_reg <= D_reg_n;
|
.reset(reset),
|
||||||
end
|
.d(~step_out),
|
||||||
end
|
.en(fedp_fire_out),
|
||||||
|
.q(step_out)
|
||||||
|
);
|
||||||
|
|
||||||
// TODO: Instead of latching half-result and constructing a full D tile,
|
// TODO: Instead of latching half-result and constructing a full D tile,
|
||||||
// we should be able to send these half fragments down to commit stage
|
// we should be able to send these half fragments down to commit stage
|
||||||
|
|||||||
Reference in New Issue
Block a user