flash: Fix hardcoded barrier for tcore; move tcore-specific flags
This commit is contained in:
@@ -11,11 +11,8 @@
|
|||||||
#define ROW_REMAINDER_LOGIC
|
#define ROW_REMAINDER_LOGIC
|
||||||
|
|
||||||
constexpr uint32_t ROWMAX_SETS = 3;
|
constexpr uint32_t ROWMAX_SETS = 3;
|
||||||
constexpr bool WARP_SPECIALIZED = false;
|
constexpr bool WARP_SPECIALIZED = true;
|
||||||
|
constexpr bool TENSOR_CORE = true;
|
||||||
constexpr uint32_t DEV_FAKE_SMEM_START_ADDR = 0xf0000000;
|
|
||||||
|
|
||||||
constexpr bool Q_IS_K_MAJOR = true;
|
|
||||||
|
|
||||||
// temporary safety stop for wrong configs
|
// temporary safety stop for wrong configs
|
||||||
static_assert(NUM_CORES == 4);
|
static_assert(NUM_CORES == 4);
|
||||||
@@ -99,9 +96,12 @@ inline void thread_block_copy_rowmax(const float *src, float *dest,
|
|||||||
dest[offset] = src[offset];
|
dest[offset] = src[offset];
|
||||||
}
|
}
|
||||||
|
|
||||||
// threadblock_barrier(threadblock_id_in_cluster,
|
if constexpr (!TENSOR_CORE) {
|
||||||
// warps_per_threadblock_per_core);
|
threadblock_barrier(1, 7);
|
||||||
threadblock_barrier(1, 7);
|
} else {
|
||||||
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
|
warps_per_threadblock_per_core);
|
||||||
|
}
|
||||||
|
|
||||||
asm volatile("threadblock_copy_rowmax_finish_%=:" ::);
|
asm volatile("threadblock_copy_rowmax_finish_%=:" ::);
|
||||||
}
|
}
|
||||||
@@ -128,7 +128,12 @@ inline void thread_block_copy_tile(const float *src, float *dest,
|
|||||||
if (row >= B_ROW) {
|
if (row >= B_ROW) {
|
||||||
// WARNING: the number of barrier calls have to exactly match that in the
|
// WARNING: the number of barrier calls have to exactly match that in the
|
||||||
// outside of the branch to prevent stalls!! FIXME better proof this.
|
// outside of the branch to prevent stalls!! FIXME better proof this.
|
||||||
threadblock_barrier(1, 7);
|
if constexpr (!TENSOR_CORE) {
|
||||||
|
threadblock_barrier(1, 7);
|
||||||
|
} else {
|
||||||
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
|
warps_per_threadblock_per_core);
|
||||||
|
}
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
@@ -146,9 +151,12 @@ inline void thread_block_copy_tile(const float *src, float *dest,
|
|||||||
dest[gmem_offset] = src[smem_offset];
|
dest[gmem_offset] = src[smem_offset];
|
||||||
}
|
}
|
||||||
|
|
||||||
// threadblock_barrier(threadblock_id_in_cluster,
|
if constexpr (!TENSOR_CORE) {
|
||||||
// warps_per_threadblock_per_core);
|
threadblock_barrier(1, 7);
|
||||||
threadblock_barrier(1, 7);
|
} else {
|
||||||
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
|
warps_per_threadblock_per_core);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
asm volatile("threadblock_copy_tile_finish_%=:" ::);
|
asm volatile("threadblock_copy_tile_finish_%=:" ::);
|
||||||
@@ -200,12 +208,28 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
|
|||||||
if (row >= B_ROW) {
|
if (row >= B_ROW) {
|
||||||
// WARNING: the number of barrier calls have to exactly match that in the
|
// WARNING: the number of barrier calls have to exactly match that in the
|
||||||
// outside of the branch to prevent stalls!! FIXME better proof this.
|
// outside of the branch to prevent stalls!! FIXME better proof this.
|
||||||
threadblock_barrier(1, 7);
|
if constexpr (!TENSOR_CORE) {
|
||||||
threadblock_barrier(1, 7);
|
threadblock_barrier(1, 7);
|
||||||
threadblock_barrier(1, 7);
|
threadblock_barrier(1, 7);
|
||||||
threadblock_barrier(1, 7);
|
threadblock_barrier(1, 7);
|
||||||
threadblock_barrier(1, 7);
|
threadblock_barrier(1, 7);
|
||||||
threadblock_barrier(1, 7);
|
threadblock_barrier(1, 7);
|
||||||
|
threadblock_barrier(1, 7);
|
||||||
|
} else {
|
||||||
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
|
warps_per_threadblock_per_core);
|
||||||
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
|
warps_per_threadblock_per_core);
|
||||||
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
|
warps_per_threadblock_per_core);
|
||||||
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
|
warps_per_threadblock_per_core);
|
||||||
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
|
warps_per_threadblock_per_core);
|
||||||
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
|
warps_per_threadblock_per_core);
|
||||||
|
}
|
||||||
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
@@ -271,9 +295,12 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
|
|||||||
warp_smem[tid_in_warp] = per_thread_max;
|
warp_smem[tid_in_warp] = per_thread_max;
|
||||||
|
|
||||||
// sync writes to warp_smem
|
// sync writes to warp_smem
|
||||||
// threadblock_barrier(threadblock_id_in_cluster,
|
if constexpr (!TENSOR_CORE) {
|
||||||
// warps_per_threadblock_per_core);
|
threadblock_barrier(1, 7);
|
||||||
threadblock_barrier(1, 7);
|
} else {
|
||||||
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
|
warps_per_threadblock_per_core);
|
||||||
|
}
|
||||||
|
|
||||||
// #define PARALLEL_ROWMAX
|
// #define PARALLEL_ROWMAX
|
||||||
#ifndef PARALLEL_ROWMAX
|
#ifndef PARALLEL_ROWMAX
|
||||||
@@ -323,9 +350,13 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
|
|||||||
#endif // PARALLEL_ROWMAX
|
#endif // PARALLEL_ROWMAX
|
||||||
#endif // DUMB_ROWMAX
|
#endif // DUMB_ROWMAX
|
||||||
|
|
||||||
// threadblock_barrier(threadblock_id_in_cluster,
|
if constexpr (!TENSOR_CORE) {
|
||||||
// warps_per_threadblock_per_core);
|
threadblock_barrier(1, 7);
|
||||||
threadblock_barrier(1, 7);
|
} else {
|
||||||
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
|
warps_per_threadblock_per_core);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
// broadcast prev rowmax to all threads in the warp
|
// broadcast prev rowmax to all threads in the warp
|
||||||
// NOTE: memory consistency is a little sketchy here
|
// NOTE: memory consistency is a little sketchy here
|
||||||
@@ -367,9 +398,12 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
|
|||||||
|
|
||||||
asm volatile("flashattn_exp_p_end_%=:" ::);
|
asm volatile("flashattn_exp_p_end_%=:" ::);
|
||||||
|
|
||||||
// threadblock_barrier(threadblock_id_in_cluster,
|
if constexpr (!TENSOR_CORE) {
|
||||||
// warps_per_threadblock_per_core);
|
threadblock_barrier(1, 7);
|
||||||
threadblock_barrier(1, 7);
|
} else {
|
||||||
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
|
warps_per_threadblock_per_core);
|
||||||
|
}
|
||||||
|
|
||||||
// rowsum
|
// rowsum
|
||||||
//
|
//
|
||||||
@@ -395,9 +429,12 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
|
|||||||
warp_smem[tid_in_warp] = per_thread_sum;
|
warp_smem[tid_in_warp] = per_thread_sum;
|
||||||
|
|
||||||
// sync writes to warp_smem
|
// sync writes to warp_smem
|
||||||
// threadblock_barrier(threadblock_id_in_cluster,
|
if constexpr (!TENSOR_CORE) {
|
||||||
// warps_per_threadblock_per_core);
|
threadblock_barrier(1, 7);
|
||||||
threadblock_barrier(1, 7);
|
} else {
|
||||||
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
|
warps_per_threadblock_per_core);
|
||||||
|
}
|
||||||
|
|
||||||
// 0-th thread collects all other thread's values in the warp
|
// 0-th thread collects all other thread's values in the warp
|
||||||
if (tid_in_warp == 0) {
|
if (tid_in_warp == 0) {
|
||||||
@@ -425,9 +462,12 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
|
|||||||
|
|
||||||
asm volatile("flashattn_rowsum_end_%=:" ::);
|
asm volatile("flashattn_rowsum_end_%=:" ::);
|
||||||
|
|
||||||
// threadblock_barrier(threadblock_id_in_cluster,
|
if constexpr (!TENSOR_CORE) {
|
||||||
// warps_per_threadblock_per_core);
|
threadblock_barrier(1, 7);
|
||||||
threadblock_barrier(1, 7);
|
} else {
|
||||||
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
|
warps_per_threadblock_per_core);
|
||||||
|
}
|
||||||
|
|
||||||
// compute Oi rescale factor
|
// compute Oi rescale factor
|
||||||
// FIXME: parallelize this across threads
|
// FIXME: parallelize this across threads
|
||||||
@@ -451,9 +491,12 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
|
|||||||
|
|
||||||
asm volatile("flashattn_rescale_factor_end_%=:" ::);
|
asm volatile("flashattn_rescale_factor_end_%=:" ::);
|
||||||
|
|
||||||
// threadblock_barrier(threadblock_id_in_cluster,
|
if constexpr (!TENSOR_CORE) {
|
||||||
// warps_per_threadblock_per_core);
|
threadblock_barrier(1, 7);
|
||||||
threadblock_barrier(1, 7);
|
} else {
|
||||||
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
|
warps_per_threadblock_per_core);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
asm volatile("thread_block_online_softmax_finish_%=:" ::);
|
asm volatile("thread_block_online_softmax_finish_%=:" ::);
|
||||||
@@ -503,7 +546,12 @@ __attribute__((always_inline)) inline void thread_block_O_rescale(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// reconverge after warp divergence
|
// reconverge after warp divergence
|
||||||
threadblock_barrier(1, 7);
|
if constexpr (!TENSOR_CORE) {
|
||||||
|
threadblock_barrier(1, 7);
|
||||||
|
} else {
|
||||||
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
|
warps_per_threadblock_per_core);
|
||||||
|
}
|
||||||
|
|
||||||
asm volatile("thread_block_O_rescale_finish_%=:" ::);
|
asm volatile("thread_block_O_rescale_finish_%=:" ::);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,6 +8,9 @@
|
|||||||
#include "gemmini_mmio.h"
|
#include "gemmini_mmio.h"
|
||||||
#include "flash_impl.hpp"
|
#include "flash_impl.hpp"
|
||||||
|
|
||||||
|
constexpr bool DEBUG = false;
|
||||||
|
constexpr bool Q_IS_K_MAJOR = true;
|
||||||
|
|
||||||
void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||||
// @perf: All threads are running these compute whose result is mostly same
|
// @perf: All threads are running these compute whose result is mostly same
|
||||||
// across the threadblock
|
// across the threadblock
|
||||||
@@ -88,6 +91,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
uint8_t *smem_per_threadblock = reinterpret_cast<uint8_t *>(
|
uint8_t *smem_per_threadblock = reinterpret_cast<uint8_t *>(
|
||||||
DEV_SMEM_START_ADDR);
|
DEV_SMEM_START_ADDR);
|
||||||
float *smem_cursor = reinterpret_cast<float *>(smem_per_threadblock);
|
float *smem_cursor = reinterpret_cast<float *>(smem_per_threadblock);
|
||||||
|
// constexpr uint32_t DEV_FAKE_SMEM_START_ADDR = 0xf0000000;
|
||||||
// float *smem_cursor = reinterpret_cast<float *>(DEV_FAKE_SMEM_START_ADDR);
|
// float *smem_cursor = reinterpret_cast<float *>(DEV_FAKE_SMEM_START_ADDR);
|
||||||
float *smem_Q0 = smem_cursor;
|
float *smem_Q0 = smem_cursor;
|
||||||
smem_cursor += smem_Q_size;
|
smem_cursor += smem_Q_size;
|
||||||
@@ -310,7 +314,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
|
|
||||||
// "inner loop" along the columns of K^T
|
// "inner loop" along the columns of K^T
|
||||||
const uint32_t k_tiles = (dim_seqlen / B_COL);
|
const uint32_t k_tiles = (dim_seqlen / B_COL);
|
||||||
for (uint32_t tile_k = 0; tile_k < k_tiles; tile_k++) {
|
for (uint32_t tile_k = 0; tile_k < (4 /* for perf measurement */ * k_tiles);
|
||||||
|
tile_k++) {
|
||||||
// float *smem_P_produce = (tile_k % 2) ? smem_P0 : smem_P1;
|
// float *smem_P_produce = (tile_k % 2) ? smem_P0 : smem_P1;
|
||||||
// float *smem_P_consume = (tile_k % 2) ? smem_P1 : smem_P0;
|
// float *smem_P_consume = (tile_k % 2) ? smem_P1 : smem_P0;
|
||||||
// float *smem_V_produce = (tile_k % 2) ? smem_V0 : smem_V1;
|
// float *smem_V_produce = (tile_k % 2) ? smem_V0 : smem_V1;
|
||||||
|
|||||||
Reference in New Issue
Block a user