flash: Add non-warp-specialized gemmini flash kernel

This commit is contained in:
Hansung Kim
2024-11-09 19:08:39 -08:00
parent ac42f2dbba
commit 673e07ed43
3 changed files with 731 additions and 12 deletions

View File

@@ -17,6 +17,7 @@ constexpr uint32_t ROWMAX_SETS = 3;
// constexpr bool WARP_SPECIALIZED = true;
// constexpr bool TENSOR_CORE = true;
constexpr bool WARP_SPECIALIZED = false;
constexpr bool GEMMINI_WARP_SPECIALIZED = false;
constexpr bool TENSOR_CORE = false;
// temporary safety stop for wrong configs
@@ -101,7 +102,7 @@ inline void thread_block_copy_rowmax(const float *src, float *dest,
dest[offset] = src[offset];
}
if constexpr (!TENSOR_CORE) {
if constexpr (!TENSOR_CORE && GEMMINI_WARP_SPECIALIZED) {
threadblock_barrier(1, 7);
} else {
threadblock_barrier(threadblock_id_in_cluster,
@@ -133,7 +134,7 @@ inline void thread_block_copy_tile(const float *src, float *dest,
if (row >= B_ROW) {
// WARNING: the number of barrier calls have to exactly match that in the
// outside of the branch to prevent stalls!! FIXME better proof this.
if constexpr (!TENSOR_CORE) {
if constexpr (!TENSOR_CORE && GEMMINI_WARP_SPECIALIZED) {
threadblock_barrier(1, 7);
} else {
threadblock_barrier(threadblock_id_in_cluster,
@@ -156,7 +157,7 @@ inline void thread_block_copy_tile(const float *src, float *dest,
dest[gmem_offset] = src[smem_offset];
}
if constexpr (!TENSOR_CORE) {
if constexpr (!TENSOR_CORE && GEMMINI_WARP_SPECIALIZED) {
threadblock_barrier(1, 7);
} else {
threadblock_barrier(threadblock_id_in_cluster,
@@ -213,7 +214,7 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
if (row >= B_ROW) {
// WARNING: the number of barrier calls have to exactly match that in the
// outside of the branch to prevent stalls!! FIXME better proof this.
if constexpr (!TENSOR_CORE) {
if constexpr (!TENSOR_CORE && GEMMINI_WARP_SPECIALIZED) {
threadblock_barrier(1, 7);
threadblock_barrier(1, 7);
threadblock_barrier(1, 7);
@@ -300,7 +301,7 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
warp_smem[tid_in_warp] = per_thread_max;
// sync writes to warp_smem
if constexpr (!TENSOR_CORE) {
if constexpr (!TENSOR_CORE && GEMMINI_WARP_SPECIALIZED) {
threadblock_barrier(1, 7);
} else {
threadblock_barrier(threadblock_id_in_cluster,
@@ -355,7 +356,7 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
#endif // PARALLEL_ROWMAX
#endif // DUMB_ROWMAX
if constexpr (!TENSOR_CORE) {
if constexpr (!TENSOR_CORE && GEMMINI_WARP_SPECIALIZED) {
threadblock_barrier(1, 7);
} else {
threadblock_barrier(threadblock_id_in_cluster,
@@ -403,7 +404,7 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
asm volatile("flashattn_exp_p_end_%=:" ::);
if constexpr (!TENSOR_CORE) {
if constexpr (!TENSOR_CORE && GEMMINI_WARP_SPECIALIZED) {
threadblock_barrier(1, 7);
} else {
threadblock_barrier(threadblock_id_in_cluster,
@@ -434,7 +435,7 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
warp_smem[tid_in_warp] = per_thread_sum;
// sync writes to warp_smem
if constexpr (!TENSOR_CORE) {
if constexpr (!TENSOR_CORE && GEMMINI_WARP_SPECIALIZED) {
threadblock_barrier(1, 7);
} else {
threadblock_barrier(threadblock_id_in_cluster,
@@ -467,7 +468,7 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
asm volatile("flashattn_rowsum_end_%=:" ::);
if constexpr (!TENSOR_CORE) {
if constexpr (!TENSOR_CORE && GEMMINI_WARP_SPECIALIZED) {
threadblock_barrier(1, 7);
} else {
threadblock_barrier(threadblock_id_in_cluster,
@@ -496,7 +497,7 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
asm volatile("flashattn_rescale_factor_end_%=:" ::);
if constexpr (!TENSOR_CORE) {
if constexpr (!TENSOR_CORE && GEMMINI_WARP_SPECIALIZED) {
threadblock_barrier(1, 7);
} else {
threadblock_barrier(threadblock_id_in_cluster,
@@ -551,7 +552,7 @@ __attribute__((always_inline)) inline void thread_block_O_rescale(
}
// reconverge after warp divergence
if constexpr (!TENSOR_CORE) {
if constexpr (!TENSOR_CORE && GEMMINI_WARP_SPECIALIZED) {
threadblock_barrier(1, 7);
} else {
threadblock_barrier(threadblock_id_in_cluster,