flash: Fix online softmax for DMA layout
This commit is contained in:
@@ -152,6 +152,7 @@ inline float exponential_taylor_term(const float x) {
|
||||
return res;
|
||||
}
|
||||
|
||||
template <bool block_row_major = false>
|
||||
__attribute__((always_inline)) inline void thread_block_online_softmax(
|
||||
const float *smem_S, float *smem_P, const uint32_t tid_in_threadblock,
|
||||
const uint32_t threads_per_threadblock,
|
||||
@@ -180,7 +181,6 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
|
||||
// one warp handles one row in tile
|
||||
|
||||
constexpr uint32_t per_row_iter = B_COL / NUM_THREADS;
|
||||
uint32_t thread_offset = first_thread_offset + tid_in_warp;
|
||||
// FIXME: threadblock_id needs to be in here too
|
||||
float *warp_smem = smem_scratchpad + (warp_id * NUM_THREADS);
|
||||
|
||||
@@ -219,11 +219,16 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
|
||||
float per_thread_max = FLT_MIN;
|
||||
#pragma GCC unroll
|
||||
for (int i = 0; i < per_row_iter; i++) {
|
||||
const float next = smem_S[thread_offset];
|
||||
const uint32_t col_offset = NUM_THREADS * i;
|
||||
const uint32_t col = col_offset + tid_in_warp;
|
||||
const auto [smem_row, smem_col] =
|
||||
remap_to_gemmini_dma_layout<block_row_major, B_COL>(row, col);
|
||||
const uint32_t offset = B_COL * smem_row + smem_col;
|
||||
|
||||
const float next = smem_S[offset];
|
||||
asm volatile("fmax.s %0, %1, %2"
|
||||
: "=f"(per_thread_max)
|
||||
: "f"(per_thread_max), "f"(next));
|
||||
thread_offset += NUM_THREADS;
|
||||
}
|
||||
// stage per-thread max value in smem
|
||||
warp_smem[tid_in_warp] = per_thread_max;
|
||||
@@ -299,10 +304,15 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
|
||||
|
||||
asm volatile("flashattn_exp_p_start_%=:" ::);
|
||||
|
||||
thread_offset = first_thread_offset + tid_in_warp;
|
||||
#pragma GCC unroll
|
||||
for (int i = 0; i < per_row_iter; i++) {
|
||||
float f0 = smem_S[thread_offset];
|
||||
const uint32_t col_offset = NUM_THREADS * i;
|
||||
const uint32_t col = col_offset + tid_in_warp;
|
||||
const auto [smem_row, smem_col] =
|
||||
remap_to_gemmini_dma_layout<block_row_major, B_COL>(row, col);
|
||||
const uint32_t offset = B_COL * smem_row + smem_col;
|
||||
|
||||
float f0 = smem_S[offset];
|
||||
|
||||
f0 -= rowmax_new;
|
||||
|
||||
@@ -313,9 +323,7 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
|
||||
|
||||
// Store S transposed to the shared memory
|
||||
|
||||
smem_P[thread_offset] = exp;
|
||||
|
||||
thread_offset += NUM_THREADS;
|
||||
smem_P[offset] = exp;
|
||||
}
|
||||
|
||||
asm volatile("flashattn_exp_p_end_%=:" ::);
|
||||
@@ -332,11 +340,15 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
|
||||
|
||||
float per_thread_sum = 0.0f;
|
||||
|
||||
thread_offset = first_thread_offset + tid_in_warp;
|
||||
#pragma GCC unroll
|
||||
for (int i = 0; i < per_row_iter; i++) {
|
||||
per_thread_sum += smem_P[thread_offset];
|
||||
thread_offset += NUM_THREADS;
|
||||
const uint32_t col_offset = NUM_THREADS * i;
|
||||
const uint32_t col = col_offset + tid_in_warp;
|
||||
const auto [smem_row, smem_col] =
|
||||
remap_to_gemmini_dma_layout<block_row_major, B_COL>(row, col);
|
||||
const uint32_t offset = B_COL * smem_row + smem_col;
|
||||
|
||||
per_thread_sum += smem_P[offset];
|
||||
}
|
||||
// stage per-thread sum value in smem
|
||||
// FIXME: threadblock_id needs to be in here too
|
||||
@@ -381,7 +393,6 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
|
||||
//
|
||||
asm volatile("flashattn_rescale_factor_start_%=:" ::);
|
||||
|
||||
thread_offset = first_thread_offset + tid_in_warp;
|
||||
#pragma GCC unroll
|
||||
for (int i = 0; i < per_row_iter; i++) {
|
||||
const float mi_prev = rowmax_prev;
|
||||
@@ -395,8 +406,6 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
|
||||
|
||||
// @perf: div vs. expansion on e(-x)?
|
||||
smem_O_row_scale[row] = 1.0f / exp;
|
||||
|
||||
thread_offset += NUM_THREADS;
|
||||
}
|
||||
|
||||
asm volatile("flashattn_rescale_factor_end_%=:" ::);
|
||||
|
||||
Reference in New Issue
Block a user