flash: Conditionally enable GEMM II fence code, fix tile_k for DEBUG
This commit is contained in:
@@ -8,7 +8,7 @@
|
||||
#include "gemmini_mmio.h"
|
||||
#include "flash_impl.hpp"
|
||||
|
||||
constexpr bool DEBUG = true;
|
||||
constexpr bool DEBUG = false;
|
||||
|
||||
static_assert(GEMMINI_DMA && !WARP_SPECIALIZED,
|
||||
"GEMMINI_DMA should be set and WARP_SPECIALIZED unset");
|
||||
@@ -438,9 +438,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
gemmini_fence();
|
||||
gemmini_fence();
|
||||
|
||||
#ifdef FENCE_GEMM_II
|
||||
// signal that GEMM II is finished to O rescale step
|
||||
*smem_O_flag = 1;
|
||||
vx_fence();
|
||||
#endif
|
||||
|
||||
// 0,2,.: opcode 0 (quartile 0/2, no accum)
|
||||
// 1,3,.: opcode 3 (quartile 1/3, no accum)
|
||||
@@ -540,8 +542,10 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
const uint32_t tile_k_ = tile_k - 1;
|
||||
|
||||
if constexpr (DEBUG) {
|
||||
// verify S = Q*K
|
||||
gemmini_fence();
|
||||
gemmini_fence();
|
||||
|
||||
// verify S = Q*K
|
||||
if (warpgroup_id == 0) {
|
||||
if (tile_k_ == 0) {
|
||||
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
|
||||
@@ -588,6 +592,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef FENCE_GEMM_II
|
||||
// check flag to make sure GEMM II finished and read-after-write
|
||||
// dependency on O tile is settled for rescale
|
||||
if (tid_in_warpgroup_simt == 0) {
|
||||
@@ -597,6 +602,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
*smem_O_flag = 0;
|
||||
vx_fence();
|
||||
}
|
||||
#endif
|
||||
|
||||
#if 0
|
||||
if (tid_in_warpgroup == 0) {
|
||||
@@ -612,15 +618,16 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
#endif
|
||||
|
||||
if constexpr (DEBUG) {
|
||||
// gemmini_fence();
|
||||
|
||||
if (warpgroup_id == 0) {
|
||||
gemmini_fence();
|
||||
gemmini_fence();
|
||||
|
||||
// O after PV
|
||||
if (tile_k_ == 0) {
|
||||
if (tile_k_ == 1 /*wait until GEMM II finshes */) {
|
||||
thread_block_copy_tile<B_ROW, HEADDIM, GEMMINI_DMA>(
|
||||
smem_O, gmem_tmp_d6, tid_in_warpgroup_simt, threads_per_warpgroup_simt,
|
||||
warpgroup_id_simt);
|
||||
} else if (tile_k_ == 1) {
|
||||
} else if (tile_k_ == 2) {
|
||||
thread_block_copy_tile<B_ROW, HEADDIM, GEMMINI_DMA>(
|
||||
smem_O, gmem_tmp_d7, tid_in_warpgroup_simt, threads_per_warpgroup_simt,
|
||||
warpgroup_id_simt);
|
||||
|
||||
Reference in New Issue
Block a user