flash: Enable GEMM II fence; Pull 1st KV move out of the loop
This commit is contained in:
@@ -8,7 +8,9 @@
|
|||||||
#include "gemmini_mmio.h"
|
#include "gemmini_mmio.h"
|
||||||
#include "flash_impl.hpp"
|
#include "flash_impl.hpp"
|
||||||
|
|
||||||
constexpr bool DEBUG = false;
|
#define FENCE_GEMM_II
|
||||||
|
|
||||||
|
constexpr bool DEBUG = true;
|
||||||
|
|
||||||
static_assert(GEMMINI_DMA && !WARP_SPECIALIZED,
|
static_assert(GEMMINI_DMA && !WARP_SPECIALIZED,
|
||||||
"GEMMINI_DMA should be set and WARP_SPECIALIZED unset");
|
"GEMMINI_DMA should be set and WARP_SPECIALIZED unset");
|
||||||
@@ -290,11 +292,13 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
/*I=*/(B_ROW / DIM), /*J=*/(B_COL / DIM), /*K=*/(HEADDIM / DIM),
|
/*I=*/(B_ROW / DIM), /*J=*/(B_COL / DIM), /*K=*/(HEADDIM / DIM),
|
||||||
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
|
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
|
||||||
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
|
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
|
||||||
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_only_a);
|
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
gemmini_fence();
|
gemmini_fence();
|
||||||
gemmini_fence();
|
gemmini_fence();
|
||||||
|
gemmini_fence();
|
||||||
|
gemmini_fence();
|
||||||
|
|
||||||
// re-configure DMA for K and V load that will later happen in the loop
|
// re-configure DMA for K and V load that will later happen in the loop
|
||||||
// GMEM addr stride for K
|
// GMEM addr stride for K
|
||||||
@@ -480,27 +484,27 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_K_tile),
|
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_K_tile),
|
||||||
(uint64_t)(gmem_V_tile),
|
(uint64_t)(gmem_V_tile),
|
||||||
k_LOOP_WS_CONFIG_ADDRS_AB)
|
k_LOOP_WS_CONFIG_ADDRS_AB)
|
||||||
|
#endif
|
||||||
// configure address strides for the DMA
|
// configure address strides for the DMA
|
||||||
// FIXME: unnecessary?
|
// FIXME: unnecessary?
|
||||||
GEMMINI_CISC_CMD_R((HEADDIM /*V*/ << 20) | (dim_seqlen /*KT*/ << 8) |
|
GEMMINI_CISC_CMD_R((HEADDIM /*V*/ << 20) | (dim_seqlen /*KT*/ << 8) |
|
||||||
8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/);
|
8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/);
|
||||||
#endif
|
|
||||||
// gemmini_fence();
|
// gemmini_fence();
|
||||||
|
|
||||||
// do DMA
|
// do DMA
|
||||||
if (tile_k == 0) {
|
if (tile_k == 0) {
|
||||||
// we load (k-1)th tile for V; skip V for the 1st iteration,
|
// we load (k-1)th tile for V; skip V for the 1st iteration,
|
||||||
sp_tiled_matmul_full_spad_ws(
|
// sp_tiled_matmul_full_spad_ws(
|
||||||
spad_addr_K_produce, spad_addr_V_produce,
|
// spad_addr_K_produce, spad_addr_V_produce,
|
||||||
/*spad_D=*/0, /*spad_C=*/spad_addr_S_produce /*FIXME:bogus*/,
|
// /*spad_D=*/0, /*spad_C=*/0,
|
||||||
/*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM),
|
// /*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM),
|
||||||
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
|
// /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
|
||||||
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
|
// /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
|
||||||
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_only_a);
|
// /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_only_a);
|
||||||
} else {
|
} else {
|
||||||
sp_tiled_matmul_full_spad_ws(
|
sp_tiled_matmul_full_spad_ws(
|
||||||
spad_addr_K_produce, spad_addr_V_produce,
|
spad_addr_K_produce, spad_addr_V_produce,
|
||||||
/*spad_D=*/0, /*spad_C=*/spad_addr_S_produce /*FIXME:bogus*/,
|
/*spad_D=*/0, /*spad_C=*/0,
|
||||||
/*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM),
|
/*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM),
|
||||||
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
|
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
|
||||||
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
|
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
|
||||||
@@ -532,18 +536,15 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
const uint32_t tile_k_ = tile_k - 1;
|
const uint32_t tile_k_ = tile_k - 1;
|
||||||
|
|
||||||
if constexpr (DEBUG) {
|
if constexpr (DEBUG) {
|
||||||
gemmini_fence();
|
// verify S = Q*K before softmax
|
||||||
gemmini_fence();
|
|
||||||
|
|
||||||
// verify S = Q*K
|
|
||||||
if (warpgroup_id == 0) {
|
if (warpgroup_id == 0) {
|
||||||
if (tile_k_ == 0) {
|
if (tile_k_ == 0) {
|
||||||
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
|
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
|
||||||
smem_S_produce, gmem_tmp_d0, tid_in_warpgroup_simt,
|
smem_S_consume, gmem_tmp_d0, tid_in_warpgroup_simt,
|
||||||
threads_per_warpgroup_simt, warpgroup_id_simt);
|
threads_per_warpgroup_simt, warpgroup_id_simt);
|
||||||
} else if (tile_k_ == 1) {
|
} else if (tile_k_ == 1) {
|
||||||
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
|
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
|
||||||
smem_S_produce, gmem_tmp_d1, tid_in_warpgroup_simt,
|
smem_S_consume, gmem_tmp_d1, tid_in_warpgroup_simt,
|
||||||
threads_per_warpgroup_simt, warpgroup_id_simt);
|
threads_per_warpgroup_simt, warpgroup_id_simt);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user