flash: Add flag in SMEM for dependency check on O

TODO: results unverified.
Stalls O rescale until GEMM II finishes.
This commit is contained in:
Hansung Kim
2024-09-10 13:37:32 -07:00
parent 88760596cb
commit 90e03894fc
2 changed files with 41 additions and 12 deletions

View File

@@ -176,6 +176,19 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
for (int row_offset = 0; row_offset < B_ROW;
row_offset += warps_in_threadblock) {
const uint32_t row = row_offset + warp_id;
// if the number of warps doesn't exactly divide the number of rows,
// early-exit to prevent out-of-bounds access
// if (row >= B_ROW) {
// // WARNING: the number of barrier calls have to exactly match that in the
// // outside of the branch to prevent stalls!! FIXME better proof this.
// threadblock_barrier(1, 7);
// threadblock_barrier(1, 7);
// threadblock_barrier(1, 7);
// threadblock_barrier(1, 7);
// threadblock_barrier(1, 7);
// threadblock_barrier(1, 7);
// continue;
// }
const uint32_t first_thread_offset = B_COL * row;
// rowmax
@@ -334,7 +347,6 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
asm volatile("flashattn_exp_p_end_%=:" ::);
// threadblock_barrier(threadblock_id_in_cluster,
// warps_per_threadblock_per_core);
threadblock_barrier(1, 7);

View File

@@ -112,6 +112,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
constexpr uint32_t smem_S1_offset = smem_V1_offset + smem_V_size * sizeof(float);
constexpr uint32_t smem_P0_offset = smem_K0_offset + smem_K_size * sizeof(float);
constexpr uint32_t smem_P1_offset = smem_K1_offset + smem_K_size * sizeof(float);
// reversed!
constexpr uint32_t smem_O0_offset = smem_P1_offset + smem_QK_size * sizeof(float);
constexpr uint32_t smem_O1_offset = smem_P0_offset + smem_QK_size * sizeof(float); // unused
@@ -158,6 +159,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
smem_cursor += smem_scratchpad_size;
float *smem_scratchpad_1 = smem_cursor;
smem_cursor += smem_scratchpad_size;
uint32_t *smem_O_flag = reinterpret_cast<uint32_t *>(smem_cursor);
smem_cursor += 1 /* 4Byte */;
static_assert(sizeof(elem_t) == sizeof(float));
constexpr uint32_t spad_addr_factor = DIM * sizeof(elem_t);
@@ -332,7 +335,9 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
// "inner loop" along the columns of K^T
const uint32_t k_tiles = (dim_seqlen / B_COL);
for (uint32_t tile_k = 0; tile_k < k_tiles + 2 /*pipeline latency*/;
for (uint32_t tile_k = 0;
tile_k <
(1 /*FIXME: for perf measurement*/ * k_tiles) + 2 /*pipeline latency*/;
tile_k++) {
if constexpr (DEBUG || true) {
threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);
@@ -371,16 +376,16 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
const auto spad_addr_O = spad_addr_O0; // NOTE: there's only single O tile
asm volatile ("dbuf_sel_end_%=:" :: );
// GEMM II: O = O + P*V
// --------------------
// This is done *before* GEMM I in the software pipeline, working on the
// online softmax result tile from the previous iteration
if (vx_warp_id() == 0 /* warp 0 in every core */) {
if (tile_k >= 2) // delay by 2 iters for pipelining
{
const uint32_t tile_k_ = tile_k - 2;
// GEMM II: O = O + P*V
// --------------------
// This is done *before* GEMM I in the software pipeline, working on the
// online softmax result tile from the previous iteration
asm volatile("gemm_pv_start_%=:" ::);
if (tid_in_warpgroup == 0) {
@@ -427,11 +432,16 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
asm volatile("gemm_qk_start_%=:" ::);
if (tid_in_warpgroup == 0) {
// fence to GEMM II completion
gemmini_fence();
gemmini_fence();
gemmini_fence();
gemmini_fence();
// signal that GEMM II is finished to O rescale step
*smem_O_flag = 1;
vx_fence();
// 0,2,.: opcode 0 (quartile 0/2, no accum)
// 1,3,.: opcode 3 (quartile 1/3, no accum)
// const uint32_t opcode = 3 * (tile_k & 1);
@@ -448,7 +458,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
gemmini_fence();
gemmini_fence();
gemmini_fence();
}
// // reconverge after mmio
// threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
@@ -534,11 +543,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
// verify S = Q*K
if (warpgroup_id == 0) {
if (tile_k == 0) {
if (tile_k_ == 0) {
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
smem_S_produce, gmem_tmp_d0, tid_in_warpgroup_simt,
threads_per_warpgroup_simt, warpgroup_id_simt);
} else if (tile_k == 1) {
} else if (tile_k_ == 1) {
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
smem_S_produce, gmem_tmp_d1, tid_in_warpgroup_simt,
threads_per_warpgroup_simt, warpgroup_id_simt);
@@ -579,9 +588,17 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
}
}
// FIXME: put synchronization with GEMM II here
// check flag to make sure GEMM II finished and read-after-write
// dependency on O tile is settled for rescale
if (tid_in_warpgroup_simt == 0) {
while ((*smem_O_flag) != 1)
;
// set it back to 0 for the next tile iteration
*smem_O_flag = 0;
vx_fence();
}
#if 0
// fence GEMM II to make sure dependency on O tile is settled
if (tid_in_warpgroup == 0) {
gemmini_fence();
gemmini_fence();