flash: Fix online softmax for DMA layout
This commit is contained in:
@@ -3,7 +3,7 @@ PROJECT = flash_attention
|
||||
SRCS = main.cpp common.h
|
||||
|
||||
VX_SRCS = kernel.gemmini.cpp
|
||||
VX_INCLUDES = ../sgemm_tcore/sgemm_impl.hpp
|
||||
VX_INCLUDES = flash_impl.hpp ../sgemm_tcore/sgemm_impl.hpp
|
||||
|
||||
OPTS ?= -n16
|
||||
|
||||
|
||||
@@ -152,6 +152,7 @@ inline float exponential_taylor_term(const float x) {
|
||||
return res;
|
||||
}
|
||||
|
||||
template <bool block_row_major = false>
|
||||
__attribute__((always_inline)) inline void thread_block_online_softmax(
|
||||
const float *smem_S, float *smem_P, const uint32_t tid_in_threadblock,
|
||||
const uint32_t threads_per_threadblock,
|
||||
@@ -180,7 +181,6 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
|
||||
// one warp handles one row in tile
|
||||
|
||||
constexpr uint32_t per_row_iter = B_COL / NUM_THREADS;
|
||||
uint32_t thread_offset = first_thread_offset + tid_in_warp;
|
||||
// FIXME: threadblock_id needs to be in here too
|
||||
float *warp_smem = smem_scratchpad + (warp_id * NUM_THREADS);
|
||||
|
||||
@@ -219,11 +219,16 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
|
||||
float per_thread_max = FLT_MIN;
|
||||
#pragma GCC unroll
|
||||
for (int i = 0; i < per_row_iter; i++) {
|
||||
const float next = smem_S[thread_offset];
|
||||
const uint32_t col_offset = NUM_THREADS * i;
|
||||
const uint32_t col = col_offset + tid_in_warp;
|
||||
const auto [smem_row, smem_col] =
|
||||
remap_to_gemmini_dma_layout<block_row_major, B_COL>(row, col);
|
||||
const uint32_t offset = B_COL * smem_row + smem_col;
|
||||
|
||||
const float next = smem_S[offset];
|
||||
asm volatile("fmax.s %0, %1, %2"
|
||||
: "=f"(per_thread_max)
|
||||
: "f"(per_thread_max), "f"(next));
|
||||
thread_offset += NUM_THREADS;
|
||||
}
|
||||
// stage per-thread max value in smem
|
||||
warp_smem[tid_in_warp] = per_thread_max;
|
||||
@@ -299,10 +304,15 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
|
||||
|
||||
asm volatile("flashattn_exp_p_start_%=:" ::);
|
||||
|
||||
thread_offset = first_thread_offset + tid_in_warp;
|
||||
#pragma GCC unroll
|
||||
for (int i = 0; i < per_row_iter; i++) {
|
||||
float f0 = smem_S[thread_offset];
|
||||
const uint32_t col_offset = NUM_THREADS * i;
|
||||
const uint32_t col = col_offset + tid_in_warp;
|
||||
const auto [smem_row, smem_col] =
|
||||
remap_to_gemmini_dma_layout<block_row_major, B_COL>(row, col);
|
||||
const uint32_t offset = B_COL * smem_row + smem_col;
|
||||
|
||||
float f0 = smem_S[offset];
|
||||
|
||||
f0 -= rowmax_new;
|
||||
|
||||
@@ -313,9 +323,7 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
|
||||
|
||||
// Store S transposed to the shared memory
|
||||
|
||||
smem_P[thread_offset] = exp;
|
||||
|
||||
thread_offset += NUM_THREADS;
|
||||
smem_P[offset] = exp;
|
||||
}
|
||||
|
||||
asm volatile("flashattn_exp_p_end_%=:" ::);
|
||||
@@ -332,11 +340,15 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
|
||||
|
||||
float per_thread_sum = 0.0f;
|
||||
|
||||
thread_offset = first_thread_offset + tid_in_warp;
|
||||
#pragma GCC unroll
|
||||
for (int i = 0; i < per_row_iter; i++) {
|
||||
per_thread_sum += smem_P[thread_offset];
|
||||
thread_offset += NUM_THREADS;
|
||||
const uint32_t col_offset = NUM_THREADS * i;
|
||||
const uint32_t col = col_offset + tid_in_warp;
|
||||
const auto [smem_row, smem_col] =
|
||||
remap_to_gemmini_dma_layout<block_row_major, B_COL>(row, col);
|
||||
const uint32_t offset = B_COL * smem_row + smem_col;
|
||||
|
||||
per_thread_sum += smem_P[offset];
|
||||
}
|
||||
// stage per-thread sum value in smem
|
||||
// FIXME: threadblock_id needs to be in here too
|
||||
@@ -381,7 +393,6 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
|
||||
//
|
||||
asm volatile("flashattn_rescale_factor_start_%=:" ::);
|
||||
|
||||
thread_offset = first_thread_offset + tid_in_warp;
|
||||
#pragma GCC unroll
|
||||
for (int i = 0; i < per_row_iter; i++) {
|
||||
const float mi_prev = rowmax_prev;
|
||||
@@ -395,8 +406,6 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
|
||||
|
||||
// @perf: div vs. expansion on e(-x)?
|
||||
smem_O_row_scale[row] = 1.0f / exp;
|
||||
|
||||
thread_offset += NUM_THREADS;
|
||||
}
|
||||
|
||||
asm volatile("flashattn_rescale_factor_end_%=:" ::);
|
||||
|
||||
@@ -319,8 +319,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
gemmini_fence();
|
||||
gemmini_fence();
|
||||
|
||||
#if 0
|
||||
// weight-stationary matmul loop
|
||||
#if 0 // TODO
|
||||
// loop_ws variant that skips configuring strides
|
||||
#define gemmini_loop_ws(I, J, K, pad_I, pad_J, pad_K, A, B, D, C, A_stride, B_stride, D_stride, C_stride, A_transpose, B_transpose, full_C, low_D, ex_accumulate, act, a_spad_id, b_spad_id, is_resadd) \
|
||||
{ \
|
||||
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, ((uint64_t)(pad_K) << 32) | ((uint64_t)(pad_J) << 16) | (uint64_t)(pad_I), ((uint64_t)(K) << 32) | ((uint64_t)(J) << 16) | (uint64_t)(I), k_LOOP_WS_CONFIG_BOUNDS) \
|
||||
@@ -373,17 +373,40 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
// inter-warpgroup barrier before online softmax
|
||||
threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);
|
||||
|
||||
#if 0
|
||||
// Online softmax
|
||||
//
|
||||
thread_block_online_softmax(smem_S, smem_P, tid_in_warpgroup,
|
||||
threads_per_warpgroup, warpgroup_id_in_cluster,
|
||||
smem_scratchpad, smem_rowmax, smem_rowsum,
|
||||
smem_O_row_scale);
|
||||
thread_block_online_softmax</*block_row_major=*/GEMMINI_DMA>(
|
||||
smem_S, smem_P, tid_in_warpgroup, threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster, smem_scratchpad, smem_rowmax, smem_rowsum,
|
||||
smem_O_row_scale);
|
||||
|
||||
// FIXME: unnecessary?
|
||||
threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
|
||||
|
||||
if constexpr (DEBUG) {
|
||||
if (warpgroup_id == 0) {
|
||||
if (tile_k == 0) {
|
||||
thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e0, tid_in_warpgroup,
|
||||
threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e2, tid_in_warpgroup,
|
||||
threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
} else if (tile_k == 1) {
|
||||
thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e1, tid_in_warpgroup,
|
||||
threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e3, tid_in_warpgroup,
|
||||
threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
}
|
||||
|
||||
threadblock_barrier(warpgroup_id_in_cluster,
|
||||
warps_per_warpgroup_per_core);
|
||||
}
|
||||
}
|
||||
|
||||
#if 0
|
||||
// data movement for K and V
|
||||
//
|
||||
// Q stays in SMEM for the entire loop
|
||||
@@ -434,32 +457,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
}
|
||||
asm volatile("move_k_v_finish_%=:" ::);
|
||||
|
||||
// protect write to SMEM
|
||||
threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
|
||||
|
||||
if constexpr (DEBUG) {
|
||||
if (warpgroup_id == 0) {
|
||||
if (tile_k == 0) {
|
||||
thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e0, tid_in_warpgroup,
|
||||
threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e2, tid_in_warpgroup,
|
||||
threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
} else if (tile_k == 1) {
|
||||
thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e1, tid_in_warpgroup,
|
||||
threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e3, tid_in_warpgroup,
|
||||
threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
}
|
||||
|
||||
threadblock_barrier(warpgroup_id_in_cluster,
|
||||
warps_per_warpgroup_per_core);
|
||||
}
|
||||
}
|
||||
|
||||
// inter-warpgroup barrier before GEMM II
|
||||
threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user