flash: Fix online softmax for DMA layout

This commit is contained in:
Hansung Kim
2024-09-07 23:21:28 -07:00
parent 2e1485877d
commit c51dc4902d
3 changed files with 54 additions and 48 deletions

View File

@@ -152,6 +152,7 @@ inline float exponential_taylor_term(const float x) {
return res;
}
template <bool block_row_major = false>
__attribute__((always_inline)) inline void thread_block_online_softmax(
const float *smem_S, float *smem_P, const uint32_t tid_in_threadblock,
const uint32_t threads_per_threadblock,
@@ -180,7 +181,6 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
// one warp handles one row in tile
constexpr uint32_t per_row_iter = B_COL / NUM_THREADS;
uint32_t thread_offset = first_thread_offset + tid_in_warp;
// FIXME: threadblock_id needs to be in here too
float *warp_smem = smem_scratchpad + (warp_id * NUM_THREADS);
@@ -219,11 +219,16 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
float per_thread_max = FLT_MIN;
#pragma GCC unroll
for (int i = 0; i < per_row_iter; i++) {
const float next = smem_S[thread_offset];
const uint32_t col_offset = NUM_THREADS * i;
const uint32_t col = col_offset + tid_in_warp;
const auto [smem_row, smem_col] =
remap_to_gemmini_dma_layout<block_row_major, B_COL>(row, col);
const uint32_t offset = B_COL * smem_row + smem_col;
const float next = smem_S[offset];
asm volatile("fmax.s %0, %1, %2"
: "=f"(per_thread_max)
: "f"(per_thread_max), "f"(next));
thread_offset += NUM_THREADS;
}
// stage per-thread max value in smem
warp_smem[tid_in_warp] = per_thread_max;
@@ -299,10 +304,15 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
asm volatile("flashattn_exp_p_start_%=:" ::);
thread_offset = first_thread_offset + tid_in_warp;
#pragma GCC unroll
for (int i = 0; i < per_row_iter; i++) {
float f0 = smem_S[thread_offset];
const uint32_t col_offset = NUM_THREADS * i;
const uint32_t col = col_offset + tid_in_warp;
const auto [smem_row, smem_col] =
remap_to_gemmini_dma_layout<block_row_major, B_COL>(row, col);
const uint32_t offset = B_COL * smem_row + smem_col;
float f0 = smem_S[offset];
f0 -= rowmax_new;
@@ -313,9 +323,7 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
// Store S transposed to the shared memory
smem_P[thread_offset] = exp;
thread_offset += NUM_THREADS;
smem_P[offset] = exp;
}
asm volatile("flashattn_exp_p_end_%=:" ::);
@@ -332,11 +340,15 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
float per_thread_sum = 0.0f;
thread_offset = first_thread_offset + tid_in_warp;
#pragma GCC unroll
for (int i = 0; i < per_row_iter; i++) {
per_thread_sum += smem_P[thread_offset];
thread_offset += NUM_THREADS;
const uint32_t col_offset = NUM_THREADS * i;
const uint32_t col = col_offset + tid_in_warp;
const auto [smem_row, smem_col] =
remap_to_gemmini_dma_layout<block_row_major, B_COL>(row, col);
const uint32_t offset = B_COL * smem_row + smem_col;
per_thread_sum += smem_P[offset];
}
// stage per-thread sum value in smem
// FIXME: threadblock_id needs to be in here too
@@ -381,7 +393,6 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
//
asm volatile("flashattn_rescale_factor_start_%=:" ::);
thread_offset = first_thread_offset + tid_in_warp;
#pragma GCC unroll
for (int i = 0; i < per_row_iter; i++) {
const float mi_prev = rowmax_prev;
@@ -395,8 +406,6 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
// @perf: div vs. expansion on e(-x)?
smem_O_row_scale[row] = 1.0f / exp;
thread_offset += NUM_THREADS;
}
asm volatile("flashattn_rescale_factor_end_%=:" ::);