diff --git a/tests/regression/flash_attention/kernel.cpp b/tests/regression/flash_attention/kernel.cpp index 4295e69d..772d4db1 100644 --- a/tests/regression/flash_attention/kernel.cpp +++ b/tests/regression/flash_attention/kernel.cpp @@ -14,6 +14,7 @@ #define HEADDIM B_COL constexpr bool DEBUG = true; +constexpr bool DOUBLE_BUF = false; inline void thread_block_init_sharedmem(const uint32_t tid_in_threadblock, const uint32_t threads_per_threadblock, @@ -26,12 +27,14 @@ inline void thread_block_init_sharedmem(const uint32_t tid_in_threadblock, static_assert((B_ROW % NUM_THREADS) == 0, "B_ROW must be a multiple of NUM_THREADS"); - // FIXME: this shouldn't be necessary - static_assert(B_ROW < (NUM_THREADS * CORES_PER_CLUSTER * NUM_WARPS), + static_assert(B_ROW < (NUM_THREADS * CORES_PER_CLUSTER * + (NUM_WARPS / (DOUBLE_BUF ? 2 : 1))), "not enough warps to initialize rowmax/rowsum"); - constexpr uint32_t num_warps = B_ROW / NUM_THREADS; - if (warp_id < num_warps) { + // each thread initializes one element in rowmax/rowsum + // multiple warps participate for the whole vector + constexpr uint32_t needed_warps = B_ROW / NUM_THREADS; + if (warp_id < needed_warps /* more warps in HW than needed? */) { uint32_t offset = NUM_THREADS * warp_id + tid_in_warp; // mi, mi~, minew smem_rowmax[offset] = FLT_MIN; @@ -40,10 +43,10 @@ inline void thread_block_init_sharedmem(const uint32_t tid_in_threadblock, smem_rowsum[offset] = 0.0f; } + // each warp clears out a row of smem_O // FIXME: dedup this pattern for (int warp_offset = 0; warp_offset < B_COL; warp_offset += warps_in_threadblock) { - // each warp clears out a row of smem_O const uint32_t row = warp_offset + warp_id; uint32_t thread_offset = HEADDIM * row + tid_in_warp; constexpr uint32_t per_row_iter = HEADDIM / NUM_THREADS; @@ -58,7 +61,6 @@ inline void thread_block_init_sharedmem(const uint32_t tid_in_threadblock, inline void thread_block_copy_rowmax(const float *src, float *dest, const uint32_t tid_in_threadblock, const uint32_t threads_per_threadblock, - const uint32_t threadblocks_per_cluster, const uint32_t threadblock_id_in_cluster) { asm volatile("threadblock_copy_rowmax_start_%=:" ::); @@ -66,8 +68,10 @@ inline void thread_block_copy_rowmax(const float *src, float *dest, const uint32_t warp_id = tid_in_threadblock / NUM_THREADS; const uint32_t warps_in_threadblock = threads_per_threadblock / NUM_THREADS; const uint32_t warps_per_threadblock_per_core = - NUM_WARPS / threadblocks_per_cluster; + warps_in_threadblock / CORES_PER_CLUSTER; + // each thread copies one element in rowmax + // multiple warps participate for the whole vector constexpr uint32_t num_warps = B_ROW / NUM_THREADS; if (warp_id < num_warps) { uint32_t offset = NUM_THREADS * warp_id + tid_in_warp; @@ -83,7 +87,6 @@ inline void thread_block_copy_rowmax(const float *src, float *dest, inline void thread_block_copy_tile(const float *src, float *dest, const uint32_t tid_in_threadblock, const uint32_t threads_per_threadblock, - const uint32_t threadblocks_per_cluster, const uint32_t threadblock_id_in_cluster) { asm volatile("threadblock_copy_tile_start_%=:" ::); @@ -91,7 +94,7 @@ inline void thread_block_copy_tile(const float *src, float *dest, const uint32_t warp_id = tid_in_threadblock / NUM_THREADS; const uint32_t warps_in_threadblock = threads_per_threadblock / NUM_THREADS; const uint32_t warps_per_threadblock_per_core = - NUM_WARPS / threadblocks_per_cluster; + warps_in_threadblock / CORES_PER_CLUSTER; // FIXME: dedup this pattern for (int warp_offset = 0; warp_offset < B_ROW; @@ -138,7 +141,6 @@ inline float exponential_taylor_term(const float x) { __attribute__((always_inline)) inline void thread_block_online_softmax( const float *smem_S, float *smem_O, float *smem_P, const uint32_t tid_in_threadblock, const uint32_t threads_per_threadblock, - const uint32_t threadblocks_per_cluster, const uint32_t threadblock_id_in_cluster, float *smem_scratchpad, float *smem_rowmax, float *smem_rowsum) { asm volatile("thread_block_online_softmax_start_%=:" ::); @@ -147,7 +149,7 @@ __attribute__((always_inline)) inline void thread_block_online_softmax( const uint32_t warp_id = tid_in_threadblock / NUM_THREADS; const uint32_t warps_in_threadblock = threads_per_threadblock / NUM_THREADS; const uint32_t warps_per_threadblock_per_core = - NUM_WARPS / threadblocks_per_cluster; + warps_in_threadblock / CORES_PER_CLUSTER; // float ft[8]; // asm volatile("fmv.s %0, f16" : "=f"(ft[0])); @@ -402,7 +404,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { #endif // FIXME: headdim not considered - uint32_t threads_per_threadblock = (B_ROW * B_COL) / (ELEM_PER_THREAD); + uint32_t threads_per_threadblock = + (B_ROW * B_COL) / (ELEM_PER_THREAD) / (DOUBLE_BUF ? 2 : 1); const uint32_t hw_threads_per_cluster = cores_per_cluster * vx_num_threads() * vx_num_warps(); // cap maximum threadblock size to # of HW threads in cluster, to prevent @@ -418,6 +421,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { threadblock_id % threadblocks_per_cluster; const int tid_in_threadblock = task_id % threads_per_threadblock; + // FIXME do proper software pipelining + if (DOUBLE_BUF && threadblock_id != 0) { + return; + } + const uint32_t dim_seqlen = arg->dim_seqlen; const uint32_t dim_headdim = arg->dim_headdim; @@ -528,7 +536,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { HEADDIM>(dim_seqlen, 0, tile_k, gmem_Q /*=gmem_S*/, smem_S, tid_in_threadblock); // the above should be equivalent to: - // load_tile_to_smem(dim_seqlen, tile_k, 0, gmem_Q /*=gmem_S*/, // smem_S, tid_in_threadblock); @@ -541,7 +550,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { thread_block_online_softmax( smem_S, smem_O, smem_P, tid_in_threadblock, threads_per_threadblock, - threadblocks_per_cluster, threadblock_id_in_cluster, smem_scratchpad, + threadblock_id_in_cluster, smem_scratchpad, smem_rowmax, smem_rowsum); // FIXME unnecessary? @@ -550,34 +559,30 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { if constexpr (DEBUG) { if (tile_k == 0) { - thread_block_copy_tile( - smem_P, gmem_tmp_d0, tid_in_threadblock, threads_per_threadblock, - threadblocks_per_cluster, threadblock_id_in_cluster); - thread_block_copy_tile( - smem_O, gmem_tmp_d2, tid_in_threadblock, threads_per_threadblock, - threadblocks_per_cluster, threadblock_id_in_cluster); + thread_block_copy_tile(smem_P, gmem_tmp_d0, tid_in_threadblock, + threads_per_threadblock, + threadblock_id_in_cluster); + thread_block_copy_tile(smem_O, gmem_tmp_d2, tid_in_threadblock, + threads_per_threadblock, + threadblock_id_in_cluster); thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e0, tid_in_threadblock, threads_per_threadblock, - threadblocks_per_cluster, threadblock_id_in_cluster); thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e2, tid_in_threadblock, threads_per_threadblock, - threadblocks_per_cluster, threadblock_id_in_cluster); } else if (tile_k == k_tiles - 1) { - thread_block_copy_tile( - smem_P, gmem_tmp_d1, tid_in_threadblock, threads_per_threadblock, - threadblocks_per_cluster, threadblock_id_in_cluster); - thread_block_copy_tile( - smem_O, gmem_tmp_d3, tid_in_threadblock, threads_per_threadblock, - threadblocks_per_cluster, threadblock_id_in_cluster); + thread_block_copy_tile(smem_P, gmem_tmp_d1, tid_in_threadblock, + threads_per_threadblock, + threadblock_id_in_cluster); + thread_block_copy_tile(smem_O, gmem_tmp_d3, tid_in_threadblock, + threads_per_threadblock, + threadblock_id_in_cluster); thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e1, tid_in_threadblock, threads_per_threadblock, - threadblocks_per_cluster, threadblock_id_in_cluster); thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e3, tid_in_threadblock, threads_per_threadblock, - threadblocks_per_cluster, threadblock_id_in_cluster); } @@ -601,12 +606,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { warps_per_threadblock_per_core); thread_block_gemm_single_tile( - smem_P, smem_V, smem_O /*load accum*/, smem_O, - tid_in_threadblock, threads_per_threadblock, threadblocks_per_cluster, + smem_P, smem_V, smem_O /*load accum*/, smem_O, tid_in_threadblock, + threads_per_threadblock, threadblocks_per_cluster, threadblock_id_in_cluster); // FIXME: wrong but fast // thread_block_gemm_single_tile