diff --git a/tests/regression/flash_attention/kernel.gemmini.cpp b/tests/regression/flash_attention/kernel.gemmini.cpp index 51993b21..a583feb7 100644 --- a/tests/regression/flash_attention/kernel.gemmini.cpp +++ b/tests/regression/flash_attention/kernel.gemmini.cpp @@ -108,8 +108,10 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { constexpr uint32_t smem_K1_offset = smem_quart3; constexpr uint32_t smem_V0_offset = smem_Q0_offset + smem_Q_size * sizeof(float); constexpr uint32_t smem_V1_offset = smem_Q1_offset + smem_Q_size * sizeof(float); - constexpr uint32_t smem_S0_offset = smem_V0_offset + smem_V_size * sizeof(float); - constexpr uint32_t smem_S1_offset = smem_V1_offset + smem_V_size * sizeof(float); + // put S1/S0 with V0/V1 so that softmax and GEMM-II doesn't cause bank + // conflicts + constexpr uint32_t smem_S0_offset = smem_V1_offset + smem_V_size * sizeof(float); + constexpr uint32_t smem_S1_offset = smem_V0_offset + smem_V_size * sizeof(float); constexpr uint32_t smem_P0_offset = smem_K0_offset + smem_K_size * sizeof(float); constexpr uint32_t smem_P1_offset = smem_K1_offset + smem_K_size * sizeof(float); // reversed! @@ -177,14 +179,16 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { constexpr uint32_t spad_addr_O0 = smem_O0_offset / spad_addr_factor; constexpr uint32_t spad_addr_O1 = smem_O1_offset / spad_addr_factor; + constexpr uint32_t global_barrier_id = NUM_WARPS - 1; // arbitrary + static_assert(warps_per_threadblock_per_core == NUM_WARPS); + // initialize rowmax/rowsum values in sharedmem thread_block_init_sharedmem(tid_in_warpgroup, threads_per_warpgroup, smem_O0, smem_rowmax_0, smem_rowsum_0, smem_O_row_scale_0); thread_block_init_sharedmem(tid_in_warpgroup, threads_per_warpgroup, smem_O1, smem_rowmax_1, smem_rowsum_1, smem_O_row_scale_1); - constexpr uint32_t global_barrier_id = NUM_WARPS - 1; // arbitrary - static_assert(warps_per_threadblock_per_core == NUM_WARPS); + threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); static_assert(!GEMMINI_DMA || Q_IS_K_MAJOR, "DMA code assumes Q matrix is stored K-major"); @@ -209,22 +213,19 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { loop_matmul_skips(/*skip_lda=*/1, /*skip_ldb=*/1, /*skip_ldd=*/0, /*skip_ex=*/0, /*skip_stc=*/1); - if constexpr (GEMMINI_DMA) { - if (tid_in_warpgroup == 0) { - gemmini_extended_config_ex(WEIGHT_STATIONARY, 0, 0, 1, 0, 0); + if (tid_in_warpgroup == 0) { + gemmini_extended_config_ex(WEIGHT_STATIONARY, 0, 0, 1, 0, 0); - // configure DMA with GMEM address strides - // Q matrix - gemmini_extended3_config_ld(HEADDIM * sizeof(elem_t), MVIN_SCALE_IDENTITY, - false, 0); - // K matrix - gemmini_extended3_config_ld(dim_seqlen * sizeof(elem_t), MVIN_SCALE_IDENTITY, - false, 1); - // configure DMA for Q*K store - gemmini_extended_config_st(B_COL * sizeof(elem_t), 0, - MVIN_SCALE_IDENTITY); - gemmini_fence(); - } + // configure DMA with GMEM address strides + // Q matrix + gemmini_extended3_config_ld(HEADDIM * sizeof(elem_t), MVIN_SCALE_IDENTITY, + false, 0); + // K matrix + gemmini_extended3_config_ld(dim_seqlen * sizeof(elem_t), + MVIN_SCALE_IDENTITY, false, 1); + // configure DMA for Q*K store + gemmini_extended_config_st(B_COL * sizeof(elem_t), 0, MVIN_SCALE_IDENTITY); + gemmini_fence(); } // NOTE about barriers: Placing barriers around thread-divergent branches may @@ -319,8 +320,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); // } - threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); - constexpr uint32_t threads_per_warpgroup_simt = threads_per_warpgroup - CORES_PER_CLUSTER * NUM_THREADS /*warp 0, 4, 8, 12*/; @@ -337,7 +336,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { const uint32_t k_tiles = (dim_seqlen / B_COL); for (uint32_t tile_k = 0; tile_k < - (1 /*FIXME: for perf measurement*/ * k_tiles) + 2 /*pipeline latency*/; + (4 /*FIXME: for perf measurement*/ * k_tiles) + 2 /*pipeline latency*/; tile_k++) { if constexpr (DEBUG || true) { threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); @@ -456,28 +455,17 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_matmul); - gemmini_fence(); - gemmini_fence(); - gemmini_fence(); - gemmini_fence(); - } - // // reconverge after mmio - // threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); + // gemmini_fence(); + // gemmini_fence(); + // gemmini_fence(); + // gemmini_fence(); + asm volatile("gemm_qk_finish_%=:" ::); - asm volatile("gemm_qk_finish_%=:" ::); + // data move for K and V + // + // Q stays in SMEM for the entire loop + asm volatile("move_k_v_start_%=:" ::); - // TODO: put synchronization here with online softmax - - // data move for K and V - // - // Q stays in SMEM for the entire loop - asm volatile("move_k_v_start_%=:" ::); - - // NOTE: Beware of race conditions; with warp specialization, we need to - // make sure below command code to DMA is not executed simultaneously - // from the two warpgroups (which will result in hardware fault). - // Currently the ping-pong scheduling scheme prevents that. - if (tid_in_warpgroup == 0) { // configure GMEM addresses for K and V tiles // load K for the next iteration const float *gmem_K_tile = gmem_K + (B_COL * (tile_k + 1 /*runahead*/)); @@ -497,7 +485,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { GEMMINI_CISC_CMD_R((HEADDIM /*V*/ << 20) | (dim_seqlen /*KT*/ << 8) | 8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/); #endif - gemmini_fence(); + // gemmini_fence(); // do DMA if (tile_k == 0) { @@ -518,6 +506,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips); } + + // fence everything before going to the next tile gemmini_fence(); gemmini_fence(); gemmini_fence();