flash: Supply correct tile dims to single_tile
This commit is contained in:
@@ -141,7 +141,7 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
|
|||||||
const uint32_t threadblocks_per_cluster,
|
const uint32_t threadblocks_per_cluster,
|
||||||
const uint32_t threadblock_id_in_cluster, float *smem_scratchpad,
|
const uint32_t threadblock_id_in_cluster, float *smem_scratchpad,
|
||||||
float *smem_rowmax, float *smem_rowsum) {
|
float *smem_rowmax, float *smem_rowsum) {
|
||||||
asm volatile("thread_block_flashattn_start_%=:" ::);
|
asm volatile("thread_block_online_softmax_start_%=:" ::);
|
||||||
|
|
||||||
const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS;
|
const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS;
|
||||||
const uint32_t warp_id = tid_in_threadblock / NUM_THREADS;
|
const uint32_t warp_id = tid_in_threadblock / NUM_THREADS;
|
||||||
@@ -250,20 +250,11 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
|
|||||||
// broadcast updated rowmax to all threads in the warp
|
// broadcast updated rowmax to all threads in the warp
|
||||||
const float rowmax_new = smem_rowmax[row];
|
const float rowmax_new = smem_rowmax[row];
|
||||||
|
|
||||||
// each thread computes two fp32 elements, downconverts it to fp16, then
|
|
||||||
// packs them into one fp32
|
|
||||||
constexpr uint32_t elem_per_thread = 1;
|
|
||||||
static_assert((B_COL % (elem_per_thread * NUM_THREADS)) == 0,
|
|
||||||
"B_COL condition not met for P compute");
|
|
||||||
|
|
||||||
thread_offset = first_thread_offset + (elem_per_thread * tid_in_warp);
|
|
||||||
constexpr uint32_t exp_per_row_iter =
|
|
||||||
B_COL / (elem_per_thread * NUM_THREADS);
|
|
||||||
|
|
||||||
asm volatile("flashattn_exp_p_start_%=:" ::);
|
asm volatile("flashattn_exp_p_start_%=:" ::);
|
||||||
|
|
||||||
|
thread_offset = first_thread_offset + tid_in_warp;
|
||||||
#pragma GCC unroll
|
#pragma GCC unroll
|
||||||
for (int i = 0; i < exp_per_row_iter; i++) {
|
for (int i = 0; i < per_row_iter; i++) {
|
||||||
float f0 = smem_S[thread_offset];
|
float f0 = smem_S[thread_offset];
|
||||||
|
|
||||||
f0 -= rowmax_new;
|
f0 -= rowmax_new;
|
||||||
@@ -292,8 +283,9 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
|
|||||||
|
|
||||||
asm volatile("flashattn_rowsum_start_%=:" ::);
|
asm volatile("flashattn_rowsum_start_%=:" ::);
|
||||||
|
|
||||||
thread_offset = first_thread_offset + tid_in_warp;
|
|
||||||
float per_thread_sum = 0.0f;
|
float per_thread_sum = 0.0f;
|
||||||
|
|
||||||
|
thread_offset = first_thread_offset + tid_in_warp;
|
||||||
#pragma GCC unroll
|
#pragma GCC unroll
|
||||||
for (int i = 0; i < per_row_iter; i++) {
|
for (int i = 0; i < per_row_iter; i++) {
|
||||||
per_thread_sum += smem_P[thread_offset];
|
per_thread_sum += smem_P[thread_offset];
|
||||||
@@ -317,7 +309,6 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
|
|||||||
}
|
}
|
||||||
|
|
||||||
const float mi_prev = rowmax_prev;
|
const float mi_prev = rowmax_prev;
|
||||||
// TODO: replace this with a register?
|
|
||||||
const float mi_this = rowmax_this;
|
const float mi_this = rowmax_this;
|
||||||
|
|
||||||
const float x = mi_prev - mi_this;
|
const float x = mi_prev - mi_this;
|
||||||
@@ -371,7 +362,7 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
|
|||||||
warps_per_threadblock_per_core);
|
warps_per_threadblock_per_core);
|
||||||
}
|
}
|
||||||
|
|
||||||
asm volatile("thread_block_flashattn_finish_%=:" ::);
|
asm volatile("thread_block_online_softmax_finish_%=:" ::);
|
||||||
}
|
}
|
||||||
|
|
||||||
void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||||
@@ -497,7 +488,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
|
|
||||||
// GEMM I: S = Q*K
|
// GEMM I: S = Q*K
|
||||||
thread_block_gemm_single_tile<float, MemLayout::MN_major,
|
thread_block_gemm_single_tile<float, MemLayout::MN_major,
|
||||||
MemLayout::MN_major,
|
MemLayout::MN_major, B_ROW, B_COL, HEADDIM,
|
||||||
/*load_accum=*/false,
|
/*load_accum=*/false,
|
||||||
/*write_to_smem=*/true>(
|
/*write_to_smem=*/true>(
|
||||||
smem_Q, smem_K, nullptr /*ignore accum*/, smem_S, tid_in_threadblock,
|
smem_Q, smem_K, nullptr /*ignore accum*/, smem_S, tid_in_threadblock,
|
||||||
@@ -583,14 +574,23 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
threadblock_barrier(threadblock_id_in_cluster,
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
warps_per_threadblock_per_core);
|
warps_per_threadblock_per_core);
|
||||||
|
|
||||||
// FIXME: support MN_major for A for ideal performance
|
|
||||||
thread_block_gemm_single_tile<float, MemLayout::K_major,
|
thread_block_gemm_single_tile<float, MemLayout::K_major,
|
||||||
MemLayout::MN_major,
|
MemLayout::MN_major,
|
||||||
|
B_ROW, HEADDIM, B_COL,
|
||||||
/*load_accum=*/true,
|
/*load_accum=*/true,
|
||||||
/*write_to_smem=*/true>(
|
/*write_to_smem=*/true>(
|
||||||
smem_P, smem_V, smem_O /*load accum*/, smem_O,
|
smem_P, smem_V, smem_O /*load accum*/, smem_O,
|
||||||
tid_in_threadblock, threads_per_threadblock, threadblocks_per_cluster,
|
tid_in_threadblock, threads_per_threadblock, threadblocks_per_cluster,
|
||||||
threadblock_id_in_cluster);
|
threadblock_id_in_cluster);
|
||||||
|
// FIXME: wrong but fast
|
||||||
|
// thread_block_gemm_single_tile<float, MemLayout::MN_major,
|
||||||
|
// MemLayout::MN_major,
|
||||||
|
// B_ROW, HEADDIM, B_COL,
|
||||||
|
// /*load_accum=*/true,
|
||||||
|
// /*write_to_smem=*/true>(
|
||||||
|
// smem_P, smem_V, smem_O /*load accum*/, smem_O, tid_in_threadblock,
|
||||||
|
// threads_per_threadblock, threadblocks_per_cluster,
|
||||||
|
// threadblock_id_in_cluster);
|
||||||
|
|
||||||
threadblock_barrier(threadblock_id_in_cluster,
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
warps_per_threadblock_per_core);
|
warps_per_threadblock_per_core);
|
||||||
|
|||||||
Reference in New Issue
Block a user