flash: Add non-warp-specialized gemmini flash kernel
This commit is contained in:
@@ -17,6 +17,7 @@ constexpr uint32_t ROWMAX_SETS = 3;
|
||||
// constexpr bool WARP_SPECIALIZED = true;
|
||||
// constexpr bool TENSOR_CORE = true;
|
||||
constexpr bool WARP_SPECIALIZED = false;
|
||||
constexpr bool GEMMINI_WARP_SPECIALIZED = false;
|
||||
constexpr bool TENSOR_CORE = false;
|
||||
|
||||
// temporary safety stop for wrong configs
|
||||
@@ -101,7 +102,7 @@ inline void thread_block_copy_rowmax(const float *src, float *dest,
|
||||
dest[offset] = src[offset];
|
||||
}
|
||||
|
||||
if constexpr (!TENSOR_CORE) {
|
||||
if constexpr (!TENSOR_CORE && GEMMINI_WARP_SPECIALIZED) {
|
||||
threadblock_barrier(1, 7);
|
||||
} else {
|
||||
threadblock_barrier(threadblock_id_in_cluster,
|
||||
@@ -133,7 +134,7 @@ inline void thread_block_copy_tile(const float *src, float *dest,
|
||||
if (row >= B_ROW) {
|
||||
// WARNING: the number of barrier calls have to exactly match that in the
|
||||
// outside of the branch to prevent stalls!! FIXME better proof this.
|
||||
if constexpr (!TENSOR_CORE) {
|
||||
if constexpr (!TENSOR_CORE && GEMMINI_WARP_SPECIALIZED) {
|
||||
threadblock_barrier(1, 7);
|
||||
} else {
|
||||
threadblock_barrier(threadblock_id_in_cluster,
|
||||
@@ -156,7 +157,7 @@ inline void thread_block_copy_tile(const float *src, float *dest,
|
||||
dest[gmem_offset] = src[smem_offset];
|
||||
}
|
||||
|
||||
if constexpr (!TENSOR_CORE) {
|
||||
if constexpr (!TENSOR_CORE && GEMMINI_WARP_SPECIALIZED) {
|
||||
threadblock_barrier(1, 7);
|
||||
} else {
|
||||
threadblock_barrier(threadblock_id_in_cluster,
|
||||
@@ -213,7 +214,7 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
|
||||
if (row >= B_ROW) {
|
||||
// WARNING: the number of barrier calls have to exactly match that in the
|
||||
// outside of the branch to prevent stalls!! FIXME better proof this.
|
||||
if constexpr (!TENSOR_CORE) {
|
||||
if constexpr (!TENSOR_CORE && GEMMINI_WARP_SPECIALIZED) {
|
||||
threadblock_barrier(1, 7);
|
||||
threadblock_barrier(1, 7);
|
||||
threadblock_barrier(1, 7);
|
||||
@@ -300,7 +301,7 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
|
||||
warp_smem[tid_in_warp] = per_thread_max;
|
||||
|
||||
// sync writes to warp_smem
|
||||
if constexpr (!TENSOR_CORE) {
|
||||
if constexpr (!TENSOR_CORE && GEMMINI_WARP_SPECIALIZED) {
|
||||
threadblock_barrier(1, 7);
|
||||
} else {
|
||||
threadblock_barrier(threadblock_id_in_cluster,
|
||||
@@ -355,7 +356,7 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
|
||||
#endif // PARALLEL_ROWMAX
|
||||
#endif // DUMB_ROWMAX
|
||||
|
||||
if constexpr (!TENSOR_CORE) {
|
||||
if constexpr (!TENSOR_CORE && GEMMINI_WARP_SPECIALIZED) {
|
||||
threadblock_barrier(1, 7);
|
||||
} else {
|
||||
threadblock_barrier(threadblock_id_in_cluster,
|
||||
@@ -403,7 +404,7 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
|
||||
|
||||
asm volatile("flashattn_exp_p_end_%=:" ::);
|
||||
|
||||
if constexpr (!TENSOR_CORE) {
|
||||
if constexpr (!TENSOR_CORE && GEMMINI_WARP_SPECIALIZED) {
|
||||
threadblock_barrier(1, 7);
|
||||
} else {
|
||||
threadblock_barrier(threadblock_id_in_cluster,
|
||||
@@ -434,7 +435,7 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
|
||||
warp_smem[tid_in_warp] = per_thread_sum;
|
||||
|
||||
// sync writes to warp_smem
|
||||
if constexpr (!TENSOR_CORE) {
|
||||
if constexpr (!TENSOR_CORE && GEMMINI_WARP_SPECIALIZED) {
|
||||
threadblock_barrier(1, 7);
|
||||
} else {
|
||||
threadblock_barrier(threadblock_id_in_cluster,
|
||||
@@ -467,7 +468,7 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
|
||||
|
||||
asm volatile("flashattn_rowsum_end_%=:" ::);
|
||||
|
||||
if constexpr (!TENSOR_CORE) {
|
||||
if constexpr (!TENSOR_CORE && GEMMINI_WARP_SPECIALIZED) {
|
||||
threadblock_barrier(1, 7);
|
||||
} else {
|
||||
threadblock_barrier(threadblock_id_in_cluster,
|
||||
@@ -496,7 +497,7 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
|
||||
|
||||
asm volatile("flashattn_rescale_factor_end_%=:" ::);
|
||||
|
||||
if constexpr (!TENSOR_CORE) {
|
||||
if constexpr (!TENSOR_CORE && GEMMINI_WARP_SPECIALIZED) {
|
||||
threadblock_barrier(1, 7);
|
||||
} else {
|
||||
threadblock_barrier(threadblock_id_in_cluster,
|
||||
@@ -551,7 +552,7 @@ __attribute__((always_inline)) inline void thread_block_O_rescale(
|
||||
}
|
||||
|
||||
// reconverge after warp divergence
|
||||
if constexpr (!TENSOR_CORE) {
|
||||
if constexpr (!TENSOR_CORE && GEMMINI_WARP_SPECIALIZED) {
|
||||
threadblock_barrier(1, 7);
|
||||
} else {
|
||||
threadblock_barrier(threadblock_id_in_cluster,
|
||||
|
||||
Reference in New Issue
Block a user