tensor: Improve commit efficiency by decoupling dpu with fifo
This commit is contained in:
@@ -391,7 +391,7 @@
|
|||||||
|
|
||||||
// Tensor Core Latency
|
// Tensor Core Latency
|
||||||
`ifndef LATENCY_HMMA
|
`ifndef LATENCY_HMMA
|
||||||
`define LATENCY_HMMA 8
|
`define LATENCY_HMMA 2
|
||||||
`endif
|
`endif
|
||||||
|
|
||||||
// Icache Configurable Knobs //////////////////////////////////////////////////
|
// Icache Configurable Knobs //////////////////////////////////////////////////
|
||||||
|
|||||||
@@ -189,8 +189,8 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
|
|||||||
wire [`NUM_WARPS-1:0][DATAW-1:0] execute_if_data_deq;
|
wire [`NUM_WARPS-1:0][DATAW-1:0] execute_if_data_deq;
|
||||||
|
|
||||||
for (genvar i = 0; i < `NUM_WARPS; i++) begin
|
for (genvar i = 0; i < `NUM_WARPS; i++) begin
|
||||||
wire enq = execute_if_fire && (execute_if.data.wid == i);
|
wire enq = execute_if_fire && (execute_if.data.wid == `NW_WIDTH'(i));
|
||||||
wire deq = commit_if_fire && (wb_wid == i);
|
wire deq = commit_if_fire && ( wb_wid == `NW_WIDTH'(i));
|
||||||
logic full;
|
logic full;
|
||||||
|
|
||||||
// execute_if request queue.
|
// execute_if request queue.
|
||||||
@@ -395,7 +395,7 @@ module VX_tensor_octet #(
|
|||||||
// C is 4x4 fp32 matrix
|
// C is 4x4 fp32 matrix
|
||||||
logic [3:0][3:0][31:0] C_tile;
|
logic [3:0][3:0][31:0] C_tile;
|
||||||
logic [3:0][3:0][31:0] D_tile;
|
logic [3:0][3:0][31:0] D_tile;
|
||||||
logic [`NW_WIDTH-1:0] D_warp_id;
|
logic [`NW_WIDTH-1:0] D_wid_dpu;
|
||||||
|
|
||||||
always @(*) begin
|
always @(*) begin
|
||||||
C_tile[3] = { C_half[7], C_buffer[operands_wid][7], C_half[5], C_buffer[operands_wid][5] };
|
C_tile[3] = { C_half[7], C_buffer[operands_wid][7], C_half[5], C_buffer[operands_wid][5] };
|
||||||
@@ -423,27 +423,41 @@ module VX_tensor_octet #(
|
|||||||
.A_tile(A_tile),
|
.A_tile(A_tile),
|
||||||
.B_tile(B_tile),
|
.B_tile(B_tile),
|
||||||
.C_tile(C_tile),
|
.C_tile(C_tile),
|
||||||
.warp_id(operands_wid),
|
.wid(operands_wid),
|
||||||
|
|
||||||
.valid_out(dpu_valid),
|
.valid_out(dpu_valid),
|
||||||
.D_tile(D_tile),
|
.D_tile(D_tile),
|
||||||
.D_warp_id(D_warp_id)
|
.D_wid(D_wid_dpu)
|
||||||
);
|
);
|
||||||
|
|
||||||
|
wire outbuf_empty;
|
||||||
|
wire outbuf_full;
|
||||||
|
assign outbuf_ready_in = ~outbuf_full;
|
||||||
|
assign result_valid = ~outbuf_empty;
|
||||||
|
|
||||||
|
wire outbuf_enq = outbuf_ready_in && dpu_valid;
|
||||||
|
wire outbuf_deq = result_valid && result_ready;
|
||||||
|
|
||||||
// buffer to stage the result tile for 2 cycles until commit/writeback is
|
// buffer to stage the result tile for 2 cycles until commit/writeback is
|
||||||
// complete
|
// complete. This decouples the irregular dpu output traffic from the
|
||||||
VX_stream_buffer #(
|
// regular, every-2-cycle commit traffic and thereby ensures the commit
|
||||||
|
// pipeline is used more efficiently.
|
||||||
|
// TODO: This is probably oversized.
|
||||||
|
VX_fifo_queue #(
|
||||||
.DATAW ($bits(D_wid) + $bits(D_out)),
|
.DATAW ($bits(D_wid) + $bits(D_out)),
|
||||||
.OUT_REG (1) // not sure this is necessary
|
.DEPTH (8 /* FIXME: arbitrary */)
|
||||||
) output_buffer (
|
) output_buffer (
|
||||||
.clk (clk),
|
.clk (clk),
|
||||||
.reset (reset),
|
.reset (reset),
|
||||||
.valid_in (dpu_valid),
|
.push (outbuf_enq),
|
||||||
.ready_in (outbuf_ready_in),
|
.pop (outbuf_deq),
|
||||||
.data_in ({D_warp_id, D_tile}),
|
.data_in ({D_wid_dpu, D_tile}),
|
||||||
.data_out ({D_wid, D_out}),
|
.data_out ({D_wid, D_out}),
|
||||||
.ready_out (result_ready),
|
.empty (outbuf_empty),
|
||||||
.valid_out (result_valid)
|
`UNUSED_PIN(alm_empty),
|
||||||
|
.full (outbuf_full), // should be impossible to overflow
|
||||||
|
`UNUSED_PIN(alm_full),
|
||||||
|
`UNUSED_PIN(size)
|
||||||
);
|
);
|
||||||
|
|
||||||
`ifdef PERF_ENABLE
|
`ifdef PERF_ENABLE
|
||||||
|
|||||||
@@ -15,11 +15,11 @@ module VX_tensor_dpu #(
|
|||||||
input [3:0][1:0][31:0] A_tile,
|
input [3:0][1:0][31:0] A_tile,
|
||||||
input [1:0][3:0][31:0] B_tile,
|
input [1:0][3:0][31:0] B_tile,
|
||||||
input [3:0][3:0][31:0] C_tile,
|
input [3:0][3:0][31:0] C_tile,
|
||||||
input [`NW_WIDTH-1:0] warp_id,
|
input [`NW_WIDTH-1:0] wid,
|
||||||
|
|
||||||
output valid_out,
|
output valid_out,
|
||||||
output [3:0][3:0][31:0] D_tile,
|
output [3:0][3:0][31:0] D_tile,
|
||||||
output [`NW_WIDTH-1:0] D_warp_id
|
output [`NW_WIDTH-1:0] D_wid
|
||||||
);
|
);
|
||||||
logic [3:0][3:0][31:0] result_hmma;
|
logic [3:0][3:0][31:0] result_hmma;
|
||||||
|
|
||||||
@@ -44,15 +44,15 @@ module VX_tensor_dpu #(
|
|||||||
|
|
||||||
// fixed-latency model
|
// fixed-latency model
|
||||||
VX_shift_register #(
|
VX_shift_register #(
|
||||||
.DATAW (1 + $bits(warp_id) + $bits(D_tile)),
|
.DATAW (1 + $bits(wid) + $bits(D_tile)),
|
||||||
.DEPTH (`LATENCY_HMMA),
|
.DEPTH (`LATENCY_HMMA),
|
||||||
.RESETW (1)
|
.RESETW (1)
|
||||||
) shift_reg (
|
) shift_reg (
|
||||||
.clk (clk),
|
.clk (clk),
|
||||||
.reset (reset),
|
.reset (reset),
|
||||||
.enable (~stall),
|
.enable (~stall),
|
||||||
.data_in ({valid_in, warp_id, result_hmma}),
|
.data_in ({valid_in, wid, result_hmma}),
|
||||||
.data_out ({valid_out, D_warp_id, D_tile})
|
.data_out ({valid_out, D_wid, D_tile})
|
||||||
);
|
);
|
||||||
endmodule
|
endmodule
|
||||||
`endif
|
`endif
|
||||||
|
|||||||
Reference in New Issue
Block a user