flash: Change smem alloc for less bank conflicts; noskip stc
This commit is contained in:
@@ -8,7 +8,7 @@
|
|||||||
#include "gemmini_mmio.h"
|
#include "gemmini_mmio.h"
|
||||||
#include "flash_impl.hpp"
|
#include "flash_impl.hpp"
|
||||||
|
|
||||||
constexpr bool DEBUG = true;
|
constexpr bool DEBUG = false;
|
||||||
|
|
||||||
static_assert(GEMMINI_DMA && !WARP_SPECIALIZED,
|
static_assert(GEMMINI_DMA && !WARP_SPECIALIZED,
|
||||||
"GEMMINI_DMA should be set and WARP_SPECIALIZED unset");
|
"GEMMINI_DMA should be set and WARP_SPECIALIZED unset");
|
||||||
@@ -90,69 +90,48 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
static_assert(
|
static_assert(
|
||||||
threads_per_threadblock == NUM_WARPS * NUM_THREADS * CORES_PER_CLUSTER,
|
threads_per_threadblock == NUM_WARPS * NUM_THREADS * CORES_PER_CLUSTER,
|
||||||
"flashattention kernel assumes 1 threadblock occupancy per cluster");
|
"flashattention kernel assumes 1 threadblock occupancy per cluster");
|
||||||
uint8_t *smem_per_threadblock = reinterpret_cast<uint8_t *>(
|
uint8_t *smem_per_threadblock = reinterpret_cast<uint8_t *>(DEV_SMEM_START_ADDR);
|
||||||
DEV_SMEM_START_ADDR);
|
constexpr uint32_t smem_start = DEV_SMEM_START_ADDR;
|
||||||
float *smem_cursor = reinterpret_cast<float *>(smem_per_threadblock);
|
constexpr uint32_t smem_quart0 = 0 * (SMEM_SIZE / 4);
|
||||||
// float *smem_cursor = reinterpret_cast<float *>(DEV_FAKE_SMEM_START_ADDR);
|
constexpr uint32_t smem_quart1 = 1 * (SMEM_SIZE / 4);
|
||||||
float *smem_Q0 = smem_cursor;
|
constexpr uint32_t smem_quart2 = 2 * (SMEM_SIZE / 4);
|
||||||
smem_cursor += smem_Q_size;
|
constexpr uint32_t smem_quart3 = 3 * (SMEM_SIZE / 4);
|
||||||
float *smem_Q1 = smem_cursor;
|
|
||||||
smem_cursor += smem_Q_size;
|
|
||||||
float *smem_K0 = smem_cursor;
|
|
||||||
smem_cursor += smem_K_size;
|
|
||||||
float *smem_K1 = smem_cursor;
|
|
||||||
smem_cursor += smem_K_size;
|
|
||||||
float *smem_V0 = smem_cursor;
|
|
||||||
smem_cursor += smem_V_size;
|
|
||||||
float *smem_V1 = smem_cursor;
|
|
||||||
smem_cursor += smem_V_size;
|
|
||||||
float *smem_S0 = smem_cursor;
|
|
||||||
smem_cursor += smem_QK_size;
|
|
||||||
float *smem_S1 = smem_cursor;
|
|
||||||
smem_cursor += smem_QK_size;
|
|
||||||
float *smem_P0 = smem_cursor;
|
|
||||||
smem_cursor += smem_QK_size;
|
|
||||||
float *smem_P1 = smem_cursor;
|
|
||||||
smem_cursor += smem_QK_size;
|
|
||||||
float *smem_O0 = smem_cursor;
|
|
||||||
smem_cursor += smem_O_size;
|
|
||||||
float *smem_O1 = smem_cursor;
|
|
||||||
smem_cursor += smem_O_size;
|
|
||||||
|
|
||||||
// NOTE: this has to match with smem_*
|
// Q/V/S in quart0/1, K/P/O in quart2/3
|
||||||
static_assert(sizeof(elem_t) == sizeof(float));
|
constexpr uint32_t smem_Q0_offset = smem_quart0;
|
||||||
constexpr uint32_t spad_addr_factor = DIM * sizeof(elem_t);
|
constexpr uint32_t smem_Q1_offset = smem_quart1;
|
||||||
constexpr uint32_t spad_addr_Q0 = 0;
|
constexpr uint32_t smem_K0_offset = smem_quart2;
|
||||||
constexpr uint32_t spad_addr_Q1 =
|
constexpr uint32_t smem_K1_offset = smem_quart3;
|
||||||
spad_addr_Q0 + (smem_Q_size * sizeof(float) / spad_addr_factor);
|
constexpr uint32_t smem_V0_offset = smem_Q0_offset + smem_Q_size * sizeof(float);
|
||||||
constexpr uint32_t spad_addr_K0 =
|
constexpr uint32_t smem_V1_offset = smem_Q1_offset + smem_Q_size * sizeof(float);
|
||||||
spad_addr_Q1 + (smem_Q_size * sizeof(float) / spad_addr_factor);
|
constexpr uint32_t smem_S0_offset = smem_V0_offset + smem_V_size * sizeof(float);
|
||||||
constexpr uint32_t spad_addr_K1 =
|
constexpr uint32_t smem_S1_offset = smem_V1_offset + smem_V_size * sizeof(float);
|
||||||
spad_addr_K0 + (smem_K_size * sizeof(float) / spad_addr_factor);
|
constexpr uint32_t smem_P0_offset = smem_K0_offset + smem_K_size * sizeof(float);
|
||||||
constexpr uint32_t spad_addr_V0 =
|
constexpr uint32_t smem_P1_offset = smem_K1_offset + smem_K_size * sizeof(float);
|
||||||
spad_addr_K1 + (smem_K_size * sizeof(float) / spad_addr_factor);
|
constexpr uint32_t smem_O0_offset = smem_P1_offset + smem_QK_size * sizeof(float);
|
||||||
constexpr uint32_t spad_addr_V1 =
|
constexpr uint32_t smem_O1_offset = smem_P0_offset + smem_QK_size * sizeof(float); // unused
|
||||||
spad_addr_V0 + (smem_V_size * sizeof(float) / spad_addr_factor);
|
|
||||||
constexpr uint32_t spad_addr_S0 =
|
float *smem_Q0 = reinterpret_cast<float *>(smem_start + smem_Q0_offset);
|
||||||
spad_addr_V1 + (smem_V_size * sizeof(float) / spad_addr_factor);
|
float *smem_Q1 = reinterpret_cast<float *>(smem_start + smem_Q1_offset);
|
||||||
constexpr uint32_t spad_addr_S1 =
|
float *smem_K0 = reinterpret_cast<float *>(smem_start + smem_K0_offset);
|
||||||
spad_addr_S0 + (smem_QK_size * sizeof(float) / spad_addr_factor);
|
float *smem_K1 = reinterpret_cast<float *>(smem_start + smem_K1_offset);
|
||||||
constexpr uint32_t spad_addr_P0 =
|
float *smem_V0 = reinterpret_cast<float *>(smem_start + smem_V0_offset);
|
||||||
spad_addr_S1 + (smem_QK_size * sizeof(float) / spad_addr_factor);
|
float *smem_V1 = reinterpret_cast<float *>(smem_start + smem_V1_offset);
|
||||||
constexpr uint32_t spad_addr_P1 =
|
float *smem_S0 = reinterpret_cast<float *>(smem_start + smem_S0_offset);
|
||||||
spad_addr_P0 + (smem_QK_size * sizeof(float) / spad_addr_factor);
|
float *smem_S1 = reinterpret_cast<float *>(smem_start + smem_S1_offset);
|
||||||
constexpr uint32_t spad_addr_O0 =
|
float *smem_P0 = reinterpret_cast<float *>(smem_start + smem_P0_offset);
|
||||||
spad_addr_P1 + (smem_QK_size * sizeof(float) / spad_addr_factor);
|
float *smem_P1 = reinterpret_cast<float *>(smem_start + smem_P1_offset);
|
||||||
constexpr uint32_t spad_addr_O1 =
|
float *smem_O0 = reinterpret_cast<float *>(smem_start + smem_O0_offset);
|
||||||
spad_addr_O0 + (smem_O_size * sizeof(float) / spad_addr_factor);
|
float *smem_O1 = reinterpret_cast<float *>(smem_start + smem_O1_offset);
|
||||||
|
|
||||||
// allocate rowmax/rowsum storage at the end of the sharedmem address space
|
// allocate rowmax/rowsum storage at the end of the sharedmem address space
|
||||||
constexpr uint32_t smem_rowmax_size = B_ROW * ROWMAX_SETS;
|
constexpr uint32_t smem_rowmax_size = B_ROW * ROWMAX_SETS;
|
||||||
constexpr uint32_t smem_rowsum_size = B_ROW;
|
constexpr uint32_t smem_rowsum_size = B_ROW;
|
||||||
constexpr uint32_t smem_O_row_scale_size = B_ROW;
|
constexpr uint32_t smem_O_row_scale_size = B_ROW;
|
||||||
// FIXME: dangerous
|
|
||||||
smem_cursor = reinterpret_cast<float *>(0xff038000);
|
|
||||||
|
|
||||||
|
float *smem_cursor = smem_O1 + smem_O_size;
|
||||||
|
// // FIXME: dangerous
|
||||||
|
// smem_cursor = reinterpret_cast<float *>(0xff038000);
|
||||||
float *smem_rowmax_0 = smem_cursor;
|
float *smem_rowmax_0 = smem_cursor;
|
||||||
smem_cursor += smem_rowmax_size;
|
smem_cursor += smem_rowmax_size;
|
||||||
float *smem_rowmax_1 = smem_cursor;
|
float *smem_rowmax_1 = smem_cursor;
|
||||||
@@ -176,6 +155,21 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
float *smem_scratchpad_1 = smem_cursor;
|
float *smem_scratchpad_1 = smem_cursor;
|
||||||
smem_cursor += smem_scratchpad_size;
|
smem_cursor += smem_scratchpad_size;
|
||||||
|
|
||||||
|
static_assert(sizeof(elem_t) == sizeof(float));
|
||||||
|
constexpr uint32_t spad_addr_factor = DIM * sizeof(elem_t);
|
||||||
|
constexpr uint32_t spad_addr_Q0 = smem_Q0_offset / spad_addr_factor;
|
||||||
|
constexpr uint32_t spad_addr_Q1 = smem_Q1_offset / spad_addr_factor;
|
||||||
|
constexpr uint32_t spad_addr_K0 = smem_K0_offset / spad_addr_factor;
|
||||||
|
constexpr uint32_t spad_addr_K1 = smem_K1_offset / spad_addr_factor;
|
||||||
|
constexpr uint32_t spad_addr_V0 = smem_V0_offset / spad_addr_factor;
|
||||||
|
constexpr uint32_t spad_addr_V1 = smem_V1_offset / spad_addr_factor;
|
||||||
|
constexpr uint32_t spad_addr_S0 = smem_S0_offset / spad_addr_factor;
|
||||||
|
constexpr uint32_t spad_addr_S1 = smem_S1_offset / spad_addr_factor;
|
||||||
|
constexpr uint32_t spad_addr_P0 = smem_P0_offset / spad_addr_factor;
|
||||||
|
constexpr uint32_t spad_addr_P1 = smem_P1_offset / spad_addr_factor;
|
||||||
|
constexpr uint32_t spad_addr_O0 = smem_O0_offset / spad_addr_factor;
|
||||||
|
constexpr uint32_t spad_addr_O1 = smem_O1_offset / spad_addr_factor;
|
||||||
|
|
||||||
// initialize rowmax/rowsum values in sharedmem
|
// initialize rowmax/rowsum values in sharedmem
|
||||||
thread_block_init_sharedmem(tid_in_warpgroup, threads_per_warpgroup, smem_O0,
|
thread_block_init_sharedmem(tid_in_warpgroup, threads_per_warpgroup, smem_O0,
|
||||||
smem_rowmax_0, smem_rowsum_0, smem_O_row_scale_0);
|
smem_rowmax_0, smem_rowsum_0, smem_O_row_scale_0);
|
||||||
@@ -184,11 +178,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
|
|
||||||
constexpr uint32_t global_barrier_id = NUM_WARPS - 1; // arbitrary
|
constexpr uint32_t global_barrier_id = NUM_WARPS - 1; // arbitrary
|
||||||
|
|
||||||
// // delay warpgroup 0 by 1 iteration to do ping-pong scheduling
|
|
||||||
// if (WARP_SPECIALIZED && warpgroup_id == 1) {
|
|
||||||
// threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);
|
|
||||||
// }
|
|
||||||
|
|
||||||
static_assert(!GEMMINI_DMA || Q_IS_K_MAJOR,
|
static_assert(!GEMMINI_DMA || Q_IS_K_MAJOR,
|
||||||
"DMA code assumes Q matrix is stored K-major");
|
"DMA code assumes Q matrix is stored K-major");
|
||||||
|
|
||||||
@@ -207,7 +196,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
/*skip_ex=*/1, /*skip_stc=*/0);
|
/*skip_ex=*/1, /*skip_stc=*/0);
|
||||||
constexpr uint32_t skips_matmul =
|
constexpr uint32_t skips_matmul =
|
||||||
loop_matmul_skips(/*skip_lda=*/1, /*skip_ldb=*/1, /*skip_ldd=*/1,
|
loop_matmul_skips(/*skip_lda=*/1, /*skip_ldb=*/1, /*skip_ldd=*/1,
|
||||||
/*skip_ex=*/0, /*skip_stc=*/1);
|
/*skip_ex=*/0, /*skip_stc=*/0);
|
||||||
constexpr uint32_t skips_matmul_preload =
|
constexpr uint32_t skips_matmul_preload =
|
||||||
loop_matmul_skips(/*skip_lda=*/1, /*skip_ldb=*/1, /*skip_ldd=*/0,
|
loop_matmul_skips(/*skip_lda=*/1, /*skip_ldb=*/1, /*skip_ldd=*/0,
|
||||||
/*skip_ex=*/0, /*skip_stc=*/1);
|
/*skip_ex=*/0, /*skip_stc=*/1);
|
||||||
@@ -327,9 +316,9 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
const uint32_t k_tiles = (dim_seqlen / B_COL);
|
const uint32_t k_tiles = (dim_seqlen / B_COL);
|
||||||
for (uint32_t tile_k = 0; tile_k < k_tiles + 2 /*pipeline latency*/;
|
for (uint32_t tile_k = 0; tile_k < k_tiles + 2 /*pipeline latency*/;
|
||||||
tile_k++) {
|
tile_k++) {
|
||||||
if constexpr (DEBUG) {
|
if constexpr (DEBUG || true) {
|
||||||
// barrier for debugging
|
// barrier for debugging
|
||||||
threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);
|
// threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);
|
||||||
}
|
}
|
||||||
|
|
||||||
// select the correct double buffer by tile iteration
|
// select the correct double buffer by tile iteration
|
||||||
@@ -394,6 +383,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
// DMA knows the full matrix dimensions
|
// DMA knows the full matrix dimensions
|
||||||
// FIXME: perf: prevent GMEM->SMEM load for O tile
|
// FIXME: perf: prevent GMEM->SMEM load for O tile
|
||||||
gemmini_fence();
|
gemmini_fence();
|
||||||
|
gemmini_fence();
|
||||||
sp_tiled_matmul_full_spad_ws(
|
sp_tiled_matmul_full_spad_ws(
|
||||||
spad_addr_P_consume, spad_addr_V_consume,
|
spad_addr_P_consume, spad_addr_V_consume,
|
||||||
/*spad_D=*/spad_addr_O, /*spad_C=*/spad_addr_O,
|
/*spad_D=*/spad_addr_O, /*spad_C=*/spad_addr_O,
|
||||||
@@ -449,11 +439,14 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// fence GEMM II to make sure dependency on O tile is settled
|
||||||
if (tid_in_warpgroup == 0) {
|
if (tid_in_warpgroup == 0) {
|
||||||
// fence GEMM-II to make sure dependency on O tile is settled
|
gemmini_fence();
|
||||||
|
gemmini_fence();
|
||||||
gemmini_fence();
|
gemmini_fence();
|
||||||
gemmini_fence();
|
gemmini_fence();
|
||||||
|
|
||||||
|
#if 1
|
||||||
// mvout to SMEM
|
// mvout to SMEM
|
||||||
// GEMMINI_CISC_CMD_I(9);
|
// GEMMINI_CISC_CMD_I(9);
|
||||||
sp_tiled_matmul_full_spad_ws(
|
sp_tiled_matmul_full_spad_ws(
|
||||||
@@ -463,6 +456,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
|
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
|
||||||
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
|
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
|
||||||
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_mvout_spad);
|
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_mvout_spad);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
// reconverge from mmio divergence
|
// reconverge from mmio divergence
|
||||||
@@ -497,6 +491,9 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
|
|
||||||
if (tid_in_warpgroup == 0) {
|
if (tid_in_warpgroup == 0) {
|
||||||
gemmini_fence();
|
gemmini_fence();
|
||||||
|
gemmini_fence();
|
||||||
|
gemmini_fence();
|
||||||
|
|
||||||
// 0,2,.: opcode 0 (quartile 0/2, no accum)
|
// 0,2,.: opcode 0 (quartile 0/2, no accum)
|
||||||
// 1,3,.: opcode 3 (quartile 1/3, no accum)
|
// 1,3,.: opcode 3 (quartile 1/3, no accum)
|
||||||
const uint32_t opcode = 3 * (tile_k & 1);
|
const uint32_t opcode = 3 * (tile_k & 1);
|
||||||
@@ -574,6 +571,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
gemmini_fence();
|
gemmini_fence();
|
||||||
gemmini_fence();
|
gemmini_fence();
|
||||||
|
|
||||||
|
#if 1
|
||||||
// mvout to SMEM
|
// mvout to SMEM
|
||||||
// GEMMINI_CISC_CMD_I(9);
|
// GEMMINI_CISC_CMD_I(9);
|
||||||
sp_tiled_matmul_full_spad_ws(
|
sp_tiled_matmul_full_spad_ws(
|
||||||
@@ -584,7 +582,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
|
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
|
||||||
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
|
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
|
||||||
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_mvout_spad);
|
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_mvout_spad);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
// reconverge from mmio divergence
|
// reconverge from mmio divergence
|
||||||
@@ -668,12 +666,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
asm volatile ("tile_loop_finish_%=:" :: );
|
asm volatile ("tile_loop_finish_%=:" :: );
|
||||||
|
|
||||||
// // wait for warpgroup 1 to finish, which called the global barrier before
|
|
||||||
// // entering the loop
|
|
||||||
// if (warpgroup_id == 0) {
|
|
||||||
// threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);
|
|
||||||
// }
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int main() {
|
int main() {
|
||||||
|
|||||||
Reference in New Issue
Block a user