tensor: Fix sync for dpu warp queue as well

This commit is contained in:
Hansung Kim
2024-05-30 18:22:36 -07:00
parent 0a032ab400
commit 83f9f6d84f
2 changed files with 16 additions and 13 deletions

View File

@@ -471,7 +471,7 @@ module VX_tensor_octet #(
VX_tensor_dpu #( VX_tensor_dpu #(
.ISW(ISW), .ISW(ISW),
.OCTET(OCTET), .OCTET(OCTET),
.ISSUE_QUEUE_DEPTH(2) .ISSUE_QUEUE_DEPTH(4)
) dpu ( ) dpu (
.clk(clk), .clk(clk),
.reset(reset), .reset(reset),

View File

@@ -39,13 +39,6 @@ module VX_tensor_dpu #(
end end
end end
// ready as soon as valid_out
// assign ready_in = ready_reg;
// fully pipelined; ready_in is coupled to ready_out by immediately
// stalling
// assign ready_in = ready_out;
// // fixed-latency queue // // fixed-latency queue
// VX_shift_register #( // VX_shift_register #(
// .DATAW (1 + $bits(wid)/* + $bits(D_tile)*/), // .DATAW (1 + $bits(wid)/* + $bits(D_tile)*/),
@@ -59,6 +52,16 @@ module VX_tensor_dpu #(
// .data_out ({valid_out, D_wid/*, D_tile */}) // .data_out ({valid_out, D_wid/*, D_tile */})
// ); // );
// ready as soon as valid_out
// assign ready_in = ready_reg || valid_out;
// fully pipelined; ready_in is coupled to ready_out by immediately
// stalling
// assign ready_in = ready_out;
logic synced_fire;
assign synced_fire = valid_in && ready_in;
logic [1:0] threadgroup_valids; logic [1:0] threadgroup_valids;
logic [1:0] threadgroup_readys; logic [1:0] threadgroup_readys;
// B_tile is shared across the two threadgroups; see Figure 13 // B_tile is shared across the two threadgroups; see Figure 13
@@ -67,7 +70,7 @@ module VX_tensor_dpu #(
) threadgroup_0 ( ) threadgroup_0 (
.clk (clk), .clk (clk),
.reset (reset), .reset (reset),
.valid_in (valid_in), .valid_in (synced_fire),
.ready_in (threadgroup_readys[0]), .ready_in (threadgroup_readys[0]),
.stall (!ready_out), .stall (!ready_out),
.A_frag (A_tile[1:0]), .A_frag (A_tile[1:0]),
@@ -81,7 +84,7 @@ module VX_tensor_dpu #(
) threadgroup_1 ( ) threadgroup_1 (
.clk (clk), .clk (clk),
.reset (reset), .reset (reset),
.valid_in (valid_in), .valid_in (synced_fire),
.ready_in (threadgroup_readys[1]), .ready_in (threadgroup_readys[1]),
.stall (!ready_out), .stall (!ready_out),
.A_frag (A_tile[3:2]), .A_frag (A_tile[3:2]),
@@ -102,7 +105,7 @@ module VX_tensor_dpu #(
// need to pass along warp id's to do multithreading // need to pass along warp id's to do multithreading
VX_fifo_queue #( VX_fifo_queue #(
.DATAW ($bits(wid)), .DATAW ($bits(wid)),
.DEPTH (ISSUE_QUEUE_DEPTH) .DEPTH (ISSUE_QUEUE_DEPTH + ISSUE_QUEUE_DEPTH)
) wid_queue ( ) wid_queue (
.clk (clk), .clk (clk),
.reset (reset), .reset (reset),
@@ -117,8 +120,8 @@ module VX_tensor_dpu #(
`UNUSED_PIN(size) `UNUSED_PIN(size)
); );
// `RUNTIME_ASSERT(reset || (&(threadgroup_valids) == valid_out), `RUNTIME_ASSERT(reset || !(deq && empty),
// ("FEDP and metadata queue went out of sync!")) ("dequeueing from empty warp id queue!"))
endmodule endmodule
// does (m,n,k) = (2,4,2) matmul compute over 2 cycles. // does (m,n,k) = (2,4,2) matmul compute over 2 cycles.