flash: Add flag in SMEM for dependency check on O
TODO: results unverified. Stalls O rescale until GEMM II finishes.
This commit is contained in:
@@ -176,6 +176,19 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
|
|||||||
for (int row_offset = 0; row_offset < B_ROW;
|
for (int row_offset = 0; row_offset < B_ROW;
|
||||||
row_offset += warps_in_threadblock) {
|
row_offset += warps_in_threadblock) {
|
||||||
const uint32_t row = row_offset + warp_id;
|
const uint32_t row = row_offset + warp_id;
|
||||||
|
// if the number of warps doesn't exactly divide the number of rows,
|
||||||
|
// early-exit to prevent out-of-bounds access
|
||||||
|
// if (row >= B_ROW) {
|
||||||
|
// // WARNING: the number of barrier calls have to exactly match that in the
|
||||||
|
// // outside of the branch to prevent stalls!! FIXME better proof this.
|
||||||
|
// threadblock_barrier(1, 7);
|
||||||
|
// threadblock_barrier(1, 7);
|
||||||
|
// threadblock_barrier(1, 7);
|
||||||
|
// threadblock_barrier(1, 7);
|
||||||
|
// threadblock_barrier(1, 7);
|
||||||
|
// threadblock_barrier(1, 7);
|
||||||
|
// continue;
|
||||||
|
// }
|
||||||
const uint32_t first_thread_offset = B_COL * row;
|
const uint32_t first_thread_offset = B_COL * row;
|
||||||
|
|
||||||
// rowmax
|
// rowmax
|
||||||
@@ -334,7 +347,6 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
|
|||||||
|
|
||||||
asm volatile("flashattn_exp_p_end_%=:" ::);
|
asm volatile("flashattn_exp_p_end_%=:" ::);
|
||||||
|
|
||||||
|
|
||||||
// threadblock_barrier(threadblock_id_in_cluster,
|
// threadblock_barrier(threadblock_id_in_cluster,
|
||||||
// warps_per_threadblock_per_core);
|
// warps_per_threadblock_per_core);
|
||||||
threadblock_barrier(1, 7);
|
threadblock_barrier(1, 7);
|
||||||
|
|||||||
@@ -112,6 +112,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
constexpr uint32_t smem_S1_offset = smem_V1_offset + smem_V_size * sizeof(float);
|
constexpr uint32_t smem_S1_offset = smem_V1_offset + smem_V_size * sizeof(float);
|
||||||
constexpr uint32_t smem_P0_offset = smem_K0_offset + smem_K_size * sizeof(float);
|
constexpr uint32_t smem_P0_offset = smem_K0_offset + smem_K_size * sizeof(float);
|
||||||
constexpr uint32_t smem_P1_offset = smem_K1_offset + smem_K_size * sizeof(float);
|
constexpr uint32_t smem_P1_offset = smem_K1_offset + smem_K_size * sizeof(float);
|
||||||
|
// reversed!
|
||||||
constexpr uint32_t smem_O0_offset = smem_P1_offset + smem_QK_size * sizeof(float);
|
constexpr uint32_t smem_O0_offset = smem_P1_offset + smem_QK_size * sizeof(float);
|
||||||
constexpr uint32_t smem_O1_offset = smem_P0_offset + smem_QK_size * sizeof(float); // unused
|
constexpr uint32_t smem_O1_offset = smem_P0_offset + smem_QK_size * sizeof(float); // unused
|
||||||
|
|
||||||
@@ -158,6 +159,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
smem_cursor += smem_scratchpad_size;
|
smem_cursor += smem_scratchpad_size;
|
||||||
float *smem_scratchpad_1 = smem_cursor;
|
float *smem_scratchpad_1 = smem_cursor;
|
||||||
smem_cursor += smem_scratchpad_size;
|
smem_cursor += smem_scratchpad_size;
|
||||||
|
uint32_t *smem_O_flag = reinterpret_cast<uint32_t *>(smem_cursor);
|
||||||
|
smem_cursor += 1 /* 4Byte */;
|
||||||
|
|
||||||
static_assert(sizeof(elem_t) == sizeof(float));
|
static_assert(sizeof(elem_t) == sizeof(float));
|
||||||
constexpr uint32_t spad_addr_factor = DIM * sizeof(elem_t);
|
constexpr uint32_t spad_addr_factor = DIM * sizeof(elem_t);
|
||||||
@@ -332,7 +335,9 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
|
|
||||||
// "inner loop" along the columns of K^T
|
// "inner loop" along the columns of K^T
|
||||||
const uint32_t k_tiles = (dim_seqlen / B_COL);
|
const uint32_t k_tiles = (dim_seqlen / B_COL);
|
||||||
for (uint32_t tile_k = 0; tile_k < k_tiles + 2 /*pipeline latency*/;
|
for (uint32_t tile_k = 0;
|
||||||
|
tile_k <
|
||||||
|
(1 /*FIXME: for perf measurement*/ * k_tiles) + 2 /*pipeline latency*/;
|
||||||
tile_k++) {
|
tile_k++) {
|
||||||
if constexpr (DEBUG || true) {
|
if constexpr (DEBUG || true) {
|
||||||
threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);
|
threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);
|
||||||
@@ -371,16 +376,16 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
const auto spad_addr_O = spad_addr_O0; // NOTE: there's only single O tile
|
const auto spad_addr_O = spad_addr_O0; // NOTE: there's only single O tile
|
||||||
asm volatile ("dbuf_sel_end_%=:" :: );
|
asm volatile ("dbuf_sel_end_%=:" :: );
|
||||||
|
|
||||||
// GEMM II: O = O + P*V
|
|
||||||
// --------------------
|
|
||||||
// This is done *before* GEMM I in the software pipeline, working on the
|
|
||||||
// online softmax result tile from the previous iteration
|
|
||||||
|
|
||||||
if (vx_warp_id() == 0 /* warp 0 in every core */) {
|
if (vx_warp_id() == 0 /* warp 0 in every core */) {
|
||||||
if (tile_k >= 2) // delay by 2 iters for pipelining
|
if (tile_k >= 2) // delay by 2 iters for pipelining
|
||||||
{
|
{
|
||||||
const uint32_t tile_k_ = tile_k - 2;
|
const uint32_t tile_k_ = tile_k - 2;
|
||||||
|
|
||||||
|
// GEMM II: O = O + P*V
|
||||||
|
// --------------------
|
||||||
|
// This is done *before* GEMM I in the software pipeline, working on the
|
||||||
|
// online softmax result tile from the previous iteration
|
||||||
|
|
||||||
asm volatile("gemm_pv_start_%=:" ::);
|
asm volatile("gemm_pv_start_%=:" ::);
|
||||||
|
|
||||||
if (tid_in_warpgroup == 0) {
|
if (tid_in_warpgroup == 0) {
|
||||||
@@ -427,11 +432,16 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
asm volatile("gemm_qk_start_%=:" ::);
|
asm volatile("gemm_qk_start_%=:" ::);
|
||||||
|
|
||||||
if (tid_in_warpgroup == 0) {
|
if (tid_in_warpgroup == 0) {
|
||||||
|
// fence to GEMM II completion
|
||||||
gemmini_fence();
|
gemmini_fence();
|
||||||
gemmini_fence();
|
gemmini_fence();
|
||||||
gemmini_fence();
|
gemmini_fence();
|
||||||
gemmini_fence();
|
gemmini_fence();
|
||||||
|
|
||||||
|
// signal that GEMM II is finished to O rescale step
|
||||||
|
*smem_O_flag = 1;
|
||||||
|
vx_fence();
|
||||||
|
|
||||||
// 0,2,.: opcode 0 (quartile 0/2, no accum)
|
// 0,2,.: opcode 0 (quartile 0/2, no accum)
|
||||||
// 1,3,.: opcode 3 (quartile 1/3, no accum)
|
// 1,3,.: opcode 3 (quartile 1/3, no accum)
|
||||||
// const uint32_t opcode = 3 * (tile_k & 1);
|
// const uint32_t opcode = 3 * (tile_k & 1);
|
||||||
@@ -448,7 +458,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
gemmini_fence();
|
gemmini_fence();
|
||||||
gemmini_fence();
|
gemmini_fence();
|
||||||
gemmini_fence();
|
gemmini_fence();
|
||||||
|
|
||||||
}
|
}
|
||||||
// // reconverge after mmio
|
// // reconverge after mmio
|
||||||
// threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
|
// threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
|
||||||
@@ -534,11 +543,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
// verify S = Q*K
|
// verify S = Q*K
|
||||||
|
|
||||||
if (warpgroup_id == 0) {
|
if (warpgroup_id == 0) {
|
||||||
if (tile_k == 0) {
|
if (tile_k_ == 0) {
|
||||||
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
|
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
|
||||||
smem_S_produce, gmem_tmp_d0, tid_in_warpgroup_simt,
|
smem_S_produce, gmem_tmp_d0, tid_in_warpgroup_simt,
|
||||||
threads_per_warpgroup_simt, warpgroup_id_simt);
|
threads_per_warpgroup_simt, warpgroup_id_simt);
|
||||||
} else if (tile_k == 1) {
|
} else if (tile_k_ == 1) {
|
||||||
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
|
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
|
||||||
smem_S_produce, gmem_tmp_d1, tid_in_warpgroup_simt,
|
smem_S_produce, gmem_tmp_d1, tid_in_warpgroup_simt,
|
||||||
threads_per_warpgroup_simt, warpgroup_id_simt);
|
threads_per_warpgroup_simt, warpgroup_id_simt);
|
||||||
@@ -579,9 +588,17 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// FIXME: put synchronization with GEMM II here
|
// check flag to make sure GEMM II finished and read-after-write
|
||||||
|
// dependency on O tile is settled for rescale
|
||||||
|
if (tid_in_warpgroup_simt == 0) {
|
||||||
|
while ((*smem_O_flag) != 1)
|
||||||
|
;
|
||||||
|
// set it back to 0 for the next tile iteration
|
||||||
|
*smem_O_flag = 0;
|
||||||
|
vx_fence();
|
||||||
|
}
|
||||||
|
|
||||||
#if 0
|
#if 0
|
||||||
// fence GEMM II to make sure dependency on O tile is settled
|
|
||||||
if (tid_in_warpgroup == 0) {
|
if (tid_in_warpgroup == 0) {
|
||||||
gemmini_fence();
|
gemmini_fence();
|
||||||
gemmini_fence();
|
gemmini_fence();
|
||||||
|
|||||||
Reference in New Issue
Block a user