flash: Conditionally enable GEMM II fence code, fix tile_k for DEBUG

This commit is contained in:
Hansung Kim
2024-09-10 22:53:35 -07:00
parent 28b2eaec8f
commit dc746272fb

View File

@@ -8,7 +8,7 @@
#include "gemmini_mmio.h"
#include "flash_impl.hpp"
constexpr bool DEBUG = true;
constexpr bool DEBUG = false;
static_assert(GEMMINI_DMA && !WARP_SPECIALIZED,
"GEMMINI_DMA should be set and WARP_SPECIALIZED unset");
@@ -438,9 +438,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
gemmini_fence();
gemmini_fence();
#ifdef FENCE_GEMM_II
// signal that GEMM II is finished to O rescale step
*smem_O_flag = 1;
vx_fence();
#endif
// 0,2,.: opcode 0 (quartile 0/2, no accum)
// 1,3,.: opcode 3 (quartile 1/3, no accum)
@@ -540,8 +542,10 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
const uint32_t tile_k_ = tile_k - 1;
if constexpr (DEBUG) {
// verify S = Q*K
gemmini_fence();
gemmini_fence();
// verify S = Q*K
if (warpgroup_id == 0) {
if (tile_k_ == 0) {
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
@@ -588,6 +592,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
}
}
#ifdef FENCE_GEMM_II
// check flag to make sure GEMM II finished and read-after-write
// dependency on O tile is settled for rescale
if (tid_in_warpgroup_simt == 0) {
@@ -597,6 +602,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
*smem_O_flag = 0;
vx_fence();
}
#endif
#if 0
if (tid_in_warpgroup == 0) {
@@ -612,15 +618,16 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
#endif
if constexpr (DEBUG) {
// gemmini_fence();
if (warpgroup_id == 0) {
gemmini_fence();
gemmini_fence();
// O after PV
if (tile_k_ == 0) {
if (tile_k_ == 1 /*wait until GEMM II finshes */) {
thread_block_copy_tile<B_ROW, HEADDIM, GEMMINI_DMA>(
smem_O, gmem_tmp_d6, tid_in_warpgroup_simt, threads_per_warpgroup_simt,
warpgroup_id_simt);
} else if (tile_k_ == 1) {
} else if (tile_k_ == 2) {
thread_block_copy_tile<B_ROW, HEADDIM, GEMMINI_DMA>(
smem_O, gmem_tmp_d7, tid_in_warpgroup_simt, threads_per_warpgroup_simt,
warpgroup_id_simt);