flash: Add flag in SMEM for dependency check on O
TODO: results unverified. Stalls O rescale until GEMM II finishes.
This commit is contained in:
@@ -176,6 +176,19 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
|
||||
for (int row_offset = 0; row_offset < B_ROW;
|
||||
row_offset += warps_in_threadblock) {
|
||||
const uint32_t row = row_offset + warp_id;
|
||||
// if the number of warps doesn't exactly divide the number of rows,
|
||||
// early-exit to prevent out-of-bounds access
|
||||
// if (row >= B_ROW) {
|
||||
// // WARNING: the number of barrier calls have to exactly match that in the
|
||||
// // outside of the branch to prevent stalls!! FIXME better proof this.
|
||||
// threadblock_barrier(1, 7);
|
||||
// threadblock_barrier(1, 7);
|
||||
// threadblock_barrier(1, 7);
|
||||
// threadblock_barrier(1, 7);
|
||||
// threadblock_barrier(1, 7);
|
||||
// threadblock_barrier(1, 7);
|
||||
// continue;
|
||||
// }
|
||||
const uint32_t first_thread_offset = B_COL * row;
|
||||
|
||||
// rowmax
|
||||
@@ -334,7 +347,6 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
|
||||
|
||||
asm volatile("flashattn_exp_p_end_%=:" ::);
|
||||
|
||||
|
||||
// threadblock_barrier(threadblock_id_in_cluster,
|
||||
// warps_per_threadblock_per_core);
|
||||
threadblock_barrier(1, 7);
|
||||
|
||||
@@ -112,6 +112,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
constexpr uint32_t smem_S1_offset = smem_V1_offset + smem_V_size * sizeof(float);
|
||||
constexpr uint32_t smem_P0_offset = smem_K0_offset + smem_K_size * sizeof(float);
|
||||
constexpr uint32_t smem_P1_offset = smem_K1_offset + smem_K_size * sizeof(float);
|
||||
// reversed!
|
||||
constexpr uint32_t smem_O0_offset = smem_P1_offset + smem_QK_size * sizeof(float);
|
||||
constexpr uint32_t smem_O1_offset = smem_P0_offset + smem_QK_size * sizeof(float); // unused
|
||||
|
||||
@@ -158,6 +159,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
smem_cursor += smem_scratchpad_size;
|
||||
float *smem_scratchpad_1 = smem_cursor;
|
||||
smem_cursor += smem_scratchpad_size;
|
||||
uint32_t *smem_O_flag = reinterpret_cast<uint32_t *>(smem_cursor);
|
||||
smem_cursor += 1 /* 4Byte */;
|
||||
|
||||
static_assert(sizeof(elem_t) == sizeof(float));
|
||||
constexpr uint32_t spad_addr_factor = DIM * sizeof(elem_t);
|
||||
@@ -332,7 +335,9 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
|
||||
// "inner loop" along the columns of K^T
|
||||
const uint32_t k_tiles = (dim_seqlen / B_COL);
|
||||
for (uint32_t tile_k = 0; tile_k < k_tiles + 2 /*pipeline latency*/;
|
||||
for (uint32_t tile_k = 0;
|
||||
tile_k <
|
||||
(1 /*FIXME: for perf measurement*/ * k_tiles) + 2 /*pipeline latency*/;
|
||||
tile_k++) {
|
||||
if constexpr (DEBUG || true) {
|
||||
threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);
|
||||
@@ -371,16 +376,16 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
const auto spad_addr_O = spad_addr_O0; // NOTE: there's only single O tile
|
||||
asm volatile ("dbuf_sel_end_%=:" :: );
|
||||
|
||||
// GEMM II: O = O + P*V
|
||||
// --------------------
|
||||
// This is done *before* GEMM I in the software pipeline, working on the
|
||||
// online softmax result tile from the previous iteration
|
||||
|
||||
if (vx_warp_id() == 0 /* warp 0 in every core */) {
|
||||
if (tile_k >= 2) // delay by 2 iters for pipelining
|
||||
{
|
||||
const uint32_t tile_k_ = tile_k - 2;
|
||||
|
||||
// GEMM II: O = O + P*V
|
||||
// --------------------
|
||||
// This is done *before* GEMM I in the software pipeline, working on the
|
||||
// online softmax result tile from the previous iteration
|
||||
|
||||
asm volatile("gemm_pv_start_%=:" ::);
|
||||
|
||||
if (tid_in_warpgroup == 0) {
|
||||
@@ -427,11 +432,16 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
asm volatile("gemm_qk_start_%=:" ::);
|
||||
|
||||
if (tid_in_warpgroup == 0) {
|
||||
// fence to GEMM II completion
|
||||
gemmini_fence();
|
||||
gemmini_fence();
|
||||
gemmini_fence();
|
||||
gemmini_fence();
|
||||
|
||||
// signal that GEMM II is finished to O rescale step
|
||||
*smem_O_flag = 1;
|
||||
vx_fence();
|
||||
|
||||
// 0,2,.: opcode 0 (quartile 0/2, no accum)
|
||||
// 1,3,.: opcode 3 (quartile 1/3, no accum)
|
||||
// const uint32_t opcode = 3 * (tile_k & 1);
|
||||
@@ -448,7 +458,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
gemmini_fence();
|
||||
gemmini_fence();
|
||||
gemmini_fence();
|
||||
|
||||
}
|
||||
// // reconverge after mmio
|
||||
// threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
|
||||
@@ -534,11 +543,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
// verify S = Q*K
|
||||
|
||||
if (warpgroup_id == 0) {
|
||||
if (tile_k == 0) {
|
||||
if (tile_k_ == 0) {
|
||||
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
|
||||
smem_S_produce, gmem_tmp_d0, tid_in_warpgroup_simt,
|
||||
threads_per_warpgroup_simt, warpgroup_id_simt);
|
||||
} else if (tile_k == 1) {
|
||||
} else if (tile_k_ == 1) {
|
||||
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
|
||||
smem_S_produce, gmem_tmp_d1, tid_in_warpgroup_simt,
|
||||
threads_per_warpgroup_simt, warpgroup_id_simt);
|
||||
@@ -579,9 +588,17 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
}
|
||||
}
|
||||
|
||||
// FIXME: put synchronization with GEMM II here
|
||||
// check flag to make sure GEMM II finished and read-after-write
|
||||
// dependency on O tile is settled for rescale
|
||||
if (tid_in_warpgroup_simt == 0) {
|
||||
while ((*smem_O_flag) != 1)
|
||||
;
|
||||
// set it back to 0 for the next tile iteration
|
||||
*smem_O_flag = 0;
|
||||
vx_fence();
|
||||
}
|
||||
|
||||
#if 0
|
||||
// fence GEMM II to make sure dependency on O tile is settled
|
||||
if (tid_in_warpgroup == 0) {
|
||||
gemmini_fence();
|
||||
gemmini_fence();
|
||||
|
||||
Reference in New Issue
Block a user