flash: Conditionally enable GEMM II fence code, fix tile_k for DEBUG
This commit is contained in:
@@ -8,7 +8,7 @@
|
|||||||
#include "gemmini_mmio.h"
|
#include "gemmini_mmio.h"
|
||||||
#include "flash_impl.hpp"
|
#include "flash_impl.hpp"
|
||||||
|
|
||||||
constexpr bool DEBUG = true;
|
constexpr bool DEBUG = false;
|
||||||
|
|
||||||
static_assert(GEMMINI_DMA && !WARP_SPECIALIZED,
|
static_assert(GEMMINI_DMA && !WARP_SPECIALIZED,
|
||||||
"GEMMINI_DMA should be set and WARP_SPECIALIZED unset");
|
"GEMMINI_DMA should be set and WARP_SPECIALIZED unset");
|
||||||
@@ -438,9 +438,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
gemmini_fence();
|
gemmini_fence();
|
||||||
gemmini_fence();
|
gemmini_fence();
|
||||||
|
|
||||||
|
#ifdef FENCE_GEMM_II
|
||||||
// signal that GEMM II is finished to O rescale step
|
// signal that GEMM II is finished to O rescale step
|
||||||
*smem_O_flag = 1;
|
*smem_O_flag = 1;
|
||||||
vx_fence();
|
vx_fence();
|
||||||
|
#endif
|
||||||
|
|
||||||
// 0,2,.: opcode 0 (quartile 0/2, no accum)
|
// 0,2,.: opcode 0 (quartile 0/2, no accum)
|
||||||
// 1,3,.: opcode 3 (quartile 1/3, no accum)
|
// 1,3,.: opcode 3 (quartile 1/3, no accum)
|
||||||
@@ -540,8 +542,10 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
const uint32_t tile_k_ = tile_k - 1;
|
const uint32_t tile_k_ = tile_k - 1;
|
||||||
|
|
||||||
if constexpr (DEBUG) {
|
if constexpr (DEBUG) {
|
||||||
// verify S = Q*K
|
gemmini_fence();
|
||||||
|
gemmini_fence();
|
||||||
|
|
||||||
|
// verify S = Q*K
|
||||||
if (warpgroup_id == 0) {
|
if (warpgroup_id == 0) {
|
||||||
if (tile_k_ == 0) {
|
if (tile_k_ == 0) {
|
||||||
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
|
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
|
||||||
@@ -588,6 +592,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef FENCE_GEMM_II
|
||||||
// check flag to make sure GEMM II finished and read-after-write
|
// check flag to make sure GEMM II finished and read-after-write
|
||||||
// dependency on O tile is settled for rescale
|
// dependency on O tile is settled for rescale
|
||||||
if (tid_in_warpgroup_simt == 0) {
|
if (tid_in_warpgroup_simt == 0) {
|
||||||
@@ -597,6 +602,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
*smem_O_flag = 0;
|
*smem_O_flag = 0;
|
||||||
vx_fence();
|
vx_fence();
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
#if 0
|
#if 0
|
||||||
if (tid_in_warpgroup == 0) {
|
if (tid_in_warpgroup == 0) {
|
||||||
@@ -612,15 +618,16 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
if constexpr (DEBUG) {
|
if constexpr (DEBUG) {
|
||||||
// gemmini_fence();
|
|
||||||
|
|
||||||
if (warpgroup_id == 0) {
|
if (warpgroup_id == 0) {
|
||||||
|
gemmini_fence();
|
||||||
|
gemmini_fence();
|
||||||
|
|
||||||
// O after PV
|
// O after PV
|
||||||
if (tile_k_ == 0) {
|
if (tile_k_ == 1 /*wait until GEMM II finshes */) {
|
||||||
thread_block_copy_tile<B_ROW, HEADDIM, GEMMINI_DMA>(
|
thread_block_copy_tile<B_ROW, HEADDIM, GEMMINI_DMA>(
|
||||||
smem_O, gmem_tmp_d6, tid_in_warpgroup_simt, threads_per_warpgroup_simt,
|
smem_O, gmem_tmp_d6, tid_in_warpgroup_simt, threads_per_warpgroup_simt,
|
||||||
warpgroup_id_simt);
|
warpgroup_id_simt);
|
||||||
} else if (tile_k_ == 1) {
|
} else if (tile_k_ == 2) {
|
||||||
thread_block_copy_tile<B_ROW, HEADDIM, GEMMINI_DMA>(
|
thread_block_copy_tile<B_ROW, HEADDIM, GEMMINI_DMA>(
|
||||||
smem_O, gmem_tmp_d7, tid_in_warpgroup_simt, threads_per_warpgroup_simt,
|
smem_O, gmem_tmp_d7, tid_in_warpgroup_simt, threads_per_warpgroup_simt,
|
||||||
warpgroup_id_simt);
|
warpgroup_id_simt);
|
||||||
|
|||||||
Reference in New Issue
Block a user