diff --git a/tests/regression/flash_attention/kernel.cpp b/tests/regression/flash_attention/kernel.cpp index 5737b00a..b1141875 100644 --- a/tests/regression/flash_attention/kernel.cpp +++ b/tests/regression/flash_attention/kernel.cpp @@ -53,6 +53,26 @@ inline void thread_block_init_sharedmem(const uint32_t tid_in_threadblock, } } +template +inline float exponential_taylor_term(const float x) { + asm volatile("exponential_taylor_term_start_%=:" ::); + + float res = 1.0f; + + if constexpr (order == 1) { + res = x; + } else if constexpr (order == 2) { + res = x * x; + res /= 2.0f; + } else if constexpr (order == 3) { + res = x * x * x; + res /= 6.0f; + } + + asm volatile("exponential_taylor_term_end_%=:" ::); + return res; +} + inline void thread_block_online_softmax( const float *smem_S, float *smem_O, float *smem_P, const uint32_t tid_in_threadblock, const uint32_t threads_per_threadblock, @@ -174,24 +194,33 @@ inline void thread_block_online_softmax( thread_offset = first_thread_offset + (elem_per_thread * tid_in_warp); constexpr uint32_t exp_per_row_iter = B_COL / (elem_per_thread * NUM_THREADS); + + asm volatile("flashattn_exp_p_start_%=:" ::); + #pragma GCC unroll for (int i = 0; i < exp_per_row_iter; i++) { float f0 = smem_S[thread_offset]; // check Q*K result - gmem_tmp0[thread_offset] = f0;; + gmem_tmp0[thread_offset] = f0; // FIXME: placeholder for proper exp f0 -= rowmax_new; + float exp = 1.0f; + exp += exponential_taylor_term<1>(f0); + exp += exponential_taylor_term<2>(f0); // Store S transposed to the shared memory - smem_P[thread_offset] = f0; - gmem_tmp1[thread_offset] = f0; + smem_P[thread_offset] = exp; + gmem_tmp1[thread_offset] = exp; thread_offset += NUM_THREADS; } + asm volatile("flashattn_exp_p_end_%=:" ::); + + threadblock_barrier(threadblock_id_in_cluster, warps_per_threadblock_per_core);