flash: Do exponential approx to rowsum and Oi as well
This commit is contained in:
@@ -204,8 +204,9 @@ inline void thread_block_online_softmax(
|
|||||||
// check Q*K result
|
// check Q*K result
|
||||||
gmem_tmp0[thread_offset] = f0;
|
gmem_tmp0[thread_offset] = f0;
|
||||||
|
|
||||||
// FIXME: placeholder for proper exp
|
|
||||||
f0 -= rowmax_new;
|
f0 -= rowmax_new;
|
||||||
|
|
||||||
|
// 2nd-order Taylor approximation
|
||||||
float exp = 1.0f;
|
float exp = 1.0f;
|
||||||
exp += exponential_taylor_term<1>(f0);
|
exp += exponential_taylor_term<1>(f0);
|
||||||
exp += exponential_taylor_term<2>(f0);
|
exp += exponential_taylor_term<2>(f0);
|
||||||
@@ -228,6 +229,8 @@ inline void thread_block_online_softmax(
|
|||||||
//
|
//
|
||||||
// two-level tree reduction, similar to rowmax
|
// two-level tree reduction, similar to rowmax
|
||||||
|
|
||||||
|
asm volatile("flashattn_rowsum_start_%=:" ::);
|
||||||
|
|
||||||
thread_offset = first_thread_offset + tid_in_warp;
|
thread_offset = first_thread_offset + tid_in_warp;
|
||||||
float per_thread_sum = 0.0f;
|
float per_thread_sum = 0.0f;
|
||||||
#pragma GCC unroll
|
#pragma GCC unroll
|
||||||
@@ -254,38 +257,54 @@ inline void thread_block_online_softmax(
|
|||||||
|
|
||||||
const float mi_prev = smem_rowmax_prev[row];
|
const float mi_prev = smem_rowmax_prev[row];
|
||||||
const float mi_this = smem_rowmax_this[row];
|
const float mi_this = smem_rowmax_this[row];
|
||||||
const float exp = mi_prev - mi_this;
|
|
||||||
|
const float x = mi_prev - mi_this;
|
||||||
|
// 2nd-order Taylor approximation
|
||||||
|
float exp = 1.0f;
|
||||||
|
exp += exponential_taylor_term<1>(x);
|
||||||
|
exp += exponential_taylor_term<2>(x);
|
||||||
|
|
||||||
// update rowsum
|
// update rowsum
|
||||||
const float rowsum_prev = smem_rowsum[row];
|
const float rowsum_prev = smem_rowsum[row];
|
||||||
// FIXME: placeholder for exponential
|
|
||||||
float rowsum_new = exp * rowsum_prev + rowsum;
|
float rowsum_new = exp * rowsum_prev + rowsum;
|
||||||
|
|
||||||
smem_rowsum[row] = rowsum_new;
|
smem_rowsum[row] = rowsum_new;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
asm volatile("flashattn_rowsum_end_%=:" ::);
|
||||||
|
|
||||||
threadblock_barrier(threadblock_id_in_cluster,
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
warps_per_threadblock_per_core);
|
warps_per_threadblock_per_core);
|
||||||
|
|
||||||
// Oi rescale
|
// Oi rescale
|
||||||
//
|
//
|
||||||
|
asm volatile("flashattn_o_rescale_start_%=:" ::);
|
||||||
|
|
||||||
thread_offset = first_thread_offset + tid_in_warp;
|
thread_offset = first_thread_offset + tid_in_warp;
|
||||||
#pragma GCC unroll
|
#pragma GCC unroll
|
||||||
for (int i = 0; i < per_row_iter; i++) {
|
for (int i = 0; i < per_row_iter; i++) {
|
||||||
float fval = smem_O[thread_offset];
|
float o = smem_O[thread_offset];
|
||||||
|
|
||||||
const float mi_prev = smem_rowmax_prev[row];
|
const float mi_prev = smem_rowmax_prev[row];
|
||||||
const float mi_new = smem_rowmax_new[row];
|
const float mi_new = smem_rowmax_new[row];
|
||||||
const float exp = mi_prev - mi_new;
|
|
||||||
|
|
||||||
// FIXME: placeholder for proper exp
|
const float x = mi_prev - mi_new;
|
||||||
fval *= exp;
|
// 2nd-order Taylor approximation
|
||||||
|
float exp = 1.0f;
|
||||||
|
exp += exponential_taylor_term<1>(x);
|
||||||
|
exp += exponential_taylor_term<2>(x);
|
||||||
|
|
||||||
|
// @perf: div vs. expansion on e(-x)?
|
||||||
|
o /= exp;
|
||||||
|
|
||||||
// update Oi in-place
|
// update Oi in-place
|
||||||
smem_O[thread_offset] = fval;
|
smem_O[thread_offset] = o;
|
||||||
|
|
||||||
thread_offset += NUM_THREADS;
|
thread_offset += NUM_THREADS;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
asm volatile("flashattn_o_rescale_end_%=:" ::);
|
||||||
|
|
||||||
threadblock_barrier(threadblock_id_in_cluster,
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
warps_per_threadblock_per_core);
|
warps_per_threadblock_per_core);
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user