flash: Fix online softmax for DMA layout

This commit is contained in:
Hansung Kim
2024-09-07 23:21:28 -07:00
parent 2e1485877d
commit c51dc4902d
3 changed files with 54 additions and 48 deletions

View File

@@ -3,7 +3,7 @@ PROJECT = flash_attention
SRCS = main.cpp common.h
VX_SRCS = kernel.gemmini.cpp
VX_INCLUDES = ../sgemm_tcore/sgemm_impl.hpp
VX_INCLUDES = flash_impl.hpp ../sgemm_tcore/sgemm_impl.hpp
OPTS ?= -n16

View File

@@ -152,6 +152,7 @@ inline float exponential_taylor_term(const float x) {
return res;
}
template <bool block_row_major = false>
__attribute__((always_inline)) inline void thread_block_online_softmax(
const float *smem_S, float *smem_P, const uint32_t tid_in_threadblock,
const uint32_t threads_per_threadblock,
@@ -180,7 +181,6 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
// one warp handles one row in tile
constexpr uint32_t per_row_iter = B_COL / NUM_THREADS;
uint32_t thread_offset = first_thread_offset + tid_in_warp;
// FIXME: threadblock_id needs to be in here too
float *warp_smem = smem_scratchpad + (warp_id * NUM_THREADS);
@@ -219,11 +219,16 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
float per_thread_max = FLT_MIN;
#pragma GCC unroll
for (int i = 0; i < per_row_iter; i++) {
const float next = smem_S[thread_offset];
const uint32_t col_offset = NUM_THREADS * i;
const uint32_t col = col_offset + tid_in_warp;
const auto [smem_row, smem_col] =
remap_to_gemmini_dma_layout<block_row_major, B_COL>(row, col);
const uint32_t offset = B_COL * smem_row + smem_col;
const float next = smem_S[offset];
asm volatile("fmax.s %0, %1, %2"
: "=f"(per_thread_max)
: "f"(per_thread_max), "f"(next));
thread_offset += NUM_THREADS;
}
// stage per-thread max value in smem
warp_smem[tid_in_warp] = per_thread_max;
@@ -299,10 +304,15 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
asm volatile("flashattn_exp_p_start_%=:" ::);
thread_offset = first_thread_offset + tid_in_warp;
#pragma GCC unroll
for (int i = 0; i < per_row_iter; i++) {
float f0 = smem_S[thread_offset];
const uint32_t col_offset = NUM_THREADS * i;
const uint32_t col = col_offset + tid_in_warp;
const auto [smem_row, smem_col] =
remap_to_gemmini_dma_layout<block_row_major, B_COL>(row, col);
const uint32_t offset = B_COL * smem_row + smem_col;
float f0 = smem_S[offset];
f0 -= rowmax_new;
@@ -313,9 +323,7 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
// Store S transposed to the shared memory
smem_P[thread_offset] = exp;
thread_offset += NUM_THREADS;
smem_P[offset] = exp;
}
asm volatile("flashattn_exp_p_end_%=:" ::);
@@ -332,11 +340,15 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
float per_thread_sum = 0.0f;
thread_offset = first_thread_offset + tid_in_warp;
#pragma GCC unroll
for (int i = 0; i < per_row_iter; i++) {
per_thread_sum += smem_P[thread_offset];
thread_offset += NUM_THREADS;
const uint32_t col_offset = NUM_THREADS * i;
const uint32_t col = col_offset + tid_in_warp;
const auto [smem_row, smem_col] =
remap_to_gemmini_dma_layout<block_row_major, B_COL>(row, col);
const uint32_t offset = B_COL * smem_row + smem_col;
per_thread_sum += smem_P[offset];
}
// stage per-thread sum value in smem
// FIXME: threadblock_id needs to be in here too
@@ -381,7 +393,6 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
//
asm volatile("flashattn_rescale_factor_start_%=:" ::);
thread_offset = first_thread_offset + tid_in_warp;
#pragma GCC unroll
for (int i = 0; i < per_row_iter; i++) {
const float mi_prev = rowmax_prev;
@@ -395,8 +406,6 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
// @perf: div vs. expansion on e(-x)?
smem_O_row_scale[row] = 1.0f / exp;
thread_offset += NUM_THREADS;
}
asm volatile("flashattn_rescale_factor_end_%=:" ::);

View File

@@ -319,8 +319,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
gemmini_fence();
gemmini_fence();
#if 0
// weight-stationary matmul loop
#if 0 // TODO
// loop_ws variant that skips configuring strides
#define gemmini_loop_ws(I, J, K, pad_I, pad_J, pad_K, A, B, D, C, A_stride, B_stride, D_stride, C_stride, A_transpose, B_transpose, full_C, low_D, ex_accumulate, act, a_spad_id, b_spad_id, is_resadd) \
{ \
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, ((uint64_t)(pad_K) << 32) | ((uint64_t)(pad_J) << 16) | (uint64_t)(pad_I), ((uint64_t)(K) << 32) | ((uint64_t)(J) << 16) | (uint64_t)(I), k_LOOP_WS_CONFIG_BOUNDS) \
@@ -373,17 +373,40 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
// inter-warpgroup barrier before online softmax
threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);
#if 0
// Online softmax
//
thread_block_online_softmax(smem_S, smem_P, tid_in_warpgroup,
threads_per_warpgroup, warpgroup_id_in_cluster,
smem_scratchpad, smem_rowmax, smem_rowsum,
smem_O_row_scale);
thread_block_online_softmax</*block_row_major=*/GEMMINI_DMA>(
smem_S, smem_P, tid_in_warpgroup, threads_per_warpgroup,
warpgroup_id_in_cluster, smem_scratchpad, smem_rowmax, smem_rowsum,
smem_O_row_scale);
// FIXME: unnecessary?
threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
if constexpr (DEBUG) {
if (warpgroup_id == 0) {
if (tile_k == 0) {
thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e0, tid_in_warpgroup,
threads_per_warpgroup,
warpgroup_id_in_cluster);
thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e2, tid_in_warpgroup,
threads_per_warpgroup,
warpgroup_id_in_cluster);
} else if (tile_k == 1) {
thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e1, tid_in_warpgroup,
threads_per_warpgroup,
warpgroup_id_in_cluster);
thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e3, tid_in_warpgroup,
threads_per_warpgroup,
warpgroup_id_in_cluster);
}
threadblock_barrier(warpgroup_id_in_cluster,
warps_per_warpgroup_per_core);
}
}
#if 0
// data movement for K and V
//
// Q stays in SMEM for the entire loop
@@ -434,32 +457,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
}
asm volatile("move_k_v_finish_%=:" ::);
// protect write to SMEM
threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
if constexpr (DEBUG) {
if (warpgroup_id == 0) {
if (tile_k == 0) {
thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e0, tid_in_warpgroup,
threads_per_warpgroup,
warpgroup_id_in_cluster);
thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e2, tid_in_warpgroup,
threads_per_warpgroup,
warpgroup_id_in_cluster);
} else if (tile_k == 1) {
thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e1, tid_in_warpgroup,
threads_per_warpgroup,
warpgroup_id_in_cluster);
thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e3, tid_in_warpgroup,
threads_per_warpgroup,
warpgroup_id_in_cluster);
}
threadblock_barrier(warpgroup_id_in_cluster,
warps_per_warpgroup_per_core);
}
}
// inter-warpgroup barrier before GEMM II
threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);