flash: Rename nowarpspec to default
This commit is contained in:
@@ -3,8 +3,8 @@ PROJECT = flash_attention
|
||||
SRCS = main.cpp common.h
|
||||
|
||||
# VX_SRCS = kernel.cpp
|
||||
# VX_SRCS = kernel.gemmini.cpp
|
||||
VX_SRCS = kernel.gemmini.nowarpspec.cpp
|
||||
# VX_SRCS = kernel.gemmini.warpspec.cpp
|
||||
VX_SRCS = kernel.gemmini.cpp
|
||||
VX_INCLUDES = flash_impl.hpp ../sgemm_tcore/sgemm_impl.hpp
|
||||
|
||||
OPTS ?= -n16
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
|
||||
constexpr uint32_t ROWMAX_SETS = 3;
|
||||
// constexpr bool WARP_SPECIALIZED = true;
|
||||
// constexpr bool GEMMINI_WARP_SPECIALIZED = false;
|
||||
// constexpr bool TENSOR_CORE = true;
|
||||
constexpr bool WARP_SPECIALIZED = false;
|
||||
constexpr bool GEMMINI_WARP_SPECIALIZED = false;
|
||||
|
||||
@@ -342,16 +342,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
// threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
|
||||
// }
|
||||
|
||||
constexpr uint32_t threads_per_warpgroup_simt =
|
||||
threads_per_warpgroup -
|
||||
CORES_PER_CLUSTER * NUM_THREADS /*warp 0, 4, 8, 12*/;
|
||||
constexpr uint32_t warpgroup_id_simt = 1;
|
||||
constexpr uint32_t barrier_id_simt = 1;
|
||||
constexpr uint32_t barrier_count_simt = NUM_WARPS - 1;
|
||||
const uint32_t tid_in_warpgroup_simt =
|
||||
tid_in_warpgroup - (CORES_PER_CLUSTER * NUM_THREADS);
|
||||
static_assert(barrier_id_simt == 1 && barrier_count_simt == 7);
|
||||
|
||||
asm volatile ("tile_loop_start_%=:" :: );
|
||||
|
||||
// "inner loop" along the columns of K^T
|
||||
@@ -411,8 +401,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
const auto spad_hex_O = smem_O0_hexadecile; // NOTE: there's only single O tile
|
||||
asm volatile ("dbuf_sel_end_%=:" :: );
|
||||
|
||||
if (vx_warp_id() == 0 /* warp 0 in every core */) {
|
||||
if (tile_k >= 2) // delay by 2 iters for pipelining
|
||||
{
|
||||
if (tile_k >= 2) // delay GEMM II by 2 iters for pipelining
|
||||
{
|
||||
const uint32_t tile_k_ = tile_k - 2;
|
||||
|
||||
@@ -457,16 +447,17 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
asm volatile("gemm_qk_start_%=:" ::);
|
||||
|
||||
if (tid_in_warpgroup == 0) {
|
||||
// fence to GEMM II completion
|
||||
gemmini_fence();
|
||||
// FIXME: remove
|
||||
// // fence to GEMM II completion
|
||||
// gemmini_fence();
|
||||
|
||||
#ifdef FENCE_GEMM_II
|
||||
asm volatile("rescale_fence_write_start_%=:" ::);
|
||||
// signal that GEMM II is finished to O rescale step
|
||||
*smem_O_flag = 1;
|
||||
vx_fence();
|
||||
asm volatile("rescale_fence_write_end_%=:" ::);
|
||||
#endif
|
||||
// #ifdef FENCE_GEMM_II
|
||||
// asm volatile("rescale_fence_write_start_%=:" ::);
|
||||
// // signal that GEMM II is finished to O rescale step
|
||||
// *smem_O_flag = 1;
|
||||
// vx_fence();
|
||||
// asm volatile("rescale_fence_write_end_%=:" ::);
|
||||
// #endif
|
||||
|
||||
// Kick off GEMM I
|
||||
//
|
||||
@@ -499,14 +490,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
const float *gmem_V_tile =
|
||||
gmem_V + (HEADDIM * B_COL * (tile_k - 1 /*dragbehind*/));
|
||||
|
||||
#if 0
|
||||
// fence mvout S to SMEM
|
||||
gemmini_fence();
|
||||
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_K_tile),
|
||||
(uint64_t)(gmem_V_tile),
|
||||
k_LOOP_WS_CONFIG_ADDRS_AB)
|
||||
#endif
|
||||
|
||||
// do DMA
|
||||
if (tile_k == 0) {
|
||||
// // configure address strides for the DMA
|
||||
@@ -545,24 +528,24 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips);
|
||||
#endif
|
||||
}
|
||||
|
||||
// fence everything before going to the next tile
|
||||
gemmini_fence();
|
||||
}
|
||||
|
||||
// threadblock_barrier(warpgroup_id_in_cluster,
|
||||
// warps_per_warpgroup_per_core);
|
||||
// reconverge from mmio divergence
|
||||
threadblock_barrier(warpgroup_id_in_cluster,
|
||||
warps_per_warpgroup_per_core);
|
||||
|
||||
asm volatile("move_k_v_finish_%=:" ::);
|
||||
|
||||
// FIXME: remove for nowarpspec
|
||||
//
|
||||
// NOTE: cannot put barrier here; thread 1-7 in warp 0 will skip the
|
||||
// branch and call this barrier earlier than when thread 0 finishes.
|
||||
// Since tmask is not considered, that will be a barrier resolve done too
|
||||
// early
|
||||
// threadblock_barrier(0, 1);
|
||||
}
|
||||
|
||||
} else /* warp_id != 0 */ {
|
||||
|
||||
{
|
||||
if (tile_k >= 1) // delay online softmax by 1 iters
|
||||
{
|
||||
const uint32_t tile_k_ = tile_k - 1;
|
||||
@@ -572,46 +555,49 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
if (warpgroup_id == 0) {
|
||||
if (tile_k_ == 0) {
|
||||
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
|
||||
smem_S_consume, gmem_tmp_d0, tid_in_warpgroup_simt,
|
||||
threads_per_warpgroup_simt, warpgroup_id_simt);
|
||||
smem_S_consume, gmem_tmp_d0, tid_in_warpgroup,
|
||||
threads_per_warpgroup, warpgroup_id);
|
||||
} else if (tile_k_ == 1) {
|
||||
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
|
||||
smem_S_consume, gmem_tmp_d1, tid_in_warpgroup_simt,
|
||||
threads_per_warpgroup_simt, warpgroup_id_simt);
|
||||
smem_S_consume, gmem_tmp_d1, tid_in_warpgroup,
|
||||
threads_per_warpgroup, warpgroup_id);
|
||||
}
|
||||
|
||||
threadblock_barrier(barrier_id_simt, barrier_count_simt);
|
||||
threadblock_barrier(warpgroup_id_in_cluster,
|
||||
warps_per_warpgroup_per_core);
|
||||
}
|
||||
}
|
||||
|
||||
// Online softmax
|
||||
//
|
||||
thread_block_online_softmax</*block_row_major=*/GEMMINI_DMA>(
|
||||
smem_S_consume, smem_P_produce, tid_in_warpgroup_simt,
|
||||
threads_per_warpgroup_simt, warpgroup_id_simt, smem_scratchpad,
|
||||
smem_S_consume, smem_P_produce, tid_in_warpgroup,
|
||||
threads_per_warpgroup, warpgroup_id, smem_scratchpad,
|
||||
smem_rowmax, smem_rowsum, smem_O_row_scale);
|
||||
|
||||
threadblock_barrier(barrier_id_simt, barrier_count_simt);
|
||||
threadblock_barrier(warpgroup_id_in_cluster,
|
||||
warps_per_warpgroup_per_core);
|
||||
|
||||
if constexpr (DEBUG) {
|
||||
if (warpgroup_id == 0) {
|
||||
if (tile_k_ == 0) {
|
||||
thread_block_copy_rowmax(
|
||||
smem_rowmax, gmem_tmp_e0, tid_in_warpgroup_simt,
|
||||
threads_per_warpgroup_simt, warpgroup_id_simt);
|
||||
smem_rowmax, gmem_tmp_e0, tid_in_warpgroup,
|
||||
threads_per_warpgroup, warpgroup_id_in_cluster);
|
||||
thread_block_copy_rowmax(
|
||||
smem_rowsum, gmem_tmp_e2, tid_in_warpgroup_simt,
|
||||
threads_per_warpgroup_simt, warpgroup_id_simt);
|
||||
smem_rowsum, gmem_tmp_e2, tid_in_warpgroup,
|
||||
threads_per_warpgroup, warpgroup_id_in_cluster);
|
||||
} else if (tile_k_ == 1) {
|
||||
thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e1,
|
||||
tid_in_warpgroup_simt, threads_per_warpgroup_simt,
|
||||
warpgroup_id_simt);
|
||||
tid_in_warpgroup, threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e3,
|
||||
tid_in_warpgroup_simt, threads_per_warpgroup_simt,
|
||||
warpgroup_id_simt);
|
||||
tid_in_warpgroup, threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
}
|
||||
|
||||
threadblock_barrier(barrier_id_simt, barrier_count_simt);
|
||||
threadblock_barrier(warpgroup_id_in_cluster,
|
||||
warps_per_warpgroup_per_core);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -619,7 +605,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
asm volatile("rescale_fence_read_start_%=:" ::);
|
||||
// check flag to make sure GEMM II finished and read-after-write
|
||||
// dependency on O tile is settled for rescale
|
||||
if (tid_in_warpgroup_simt == 0) {
|
||||
if (tid_in_warpgroup == 0) {
|
||||
while ((*smem_O_flag) != 1)
|
||||
;
|
||||
// set it back to 0 for the next tile iteration
|
||||
@@ -643,74 +629,66 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
#endif
|
||||
|
||||
if constexpr (DEBUG) {
|
||||
if (warpgroup_id == 0) {
|
||||
if (warpgroup_id_in_cluster == 0) {
|
||||
gemmini_fence();
|
||||
gemmini_fence();
|
||||
|
||||
// O after PV
|
||||
if (tile_k_ == 1 /*wait until GEMM II finshes */) {
|
||||
thread_block_copy_tile<B_ROW, HEADDIM, GEMMINI_DMA>(
|
||||
smem_O, gmem_tmp_d6, tid_in_warpgroup_simt, threads_per_warpgroup_simt,
|
||||
warpgroup_id_simt);
|
||||
smem_O, gmem_tmp_d6, tid_in_warpgroup, threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
} else if (tile_k_ == 2) {
|
||||
thread_block_copy_tile<B_ROW, HEADDIM, GEMMINI_DMA>(
|
||||
smem_O, gmem_tmp_d7, tid_in_warpgroup_simt, threads_per_warpgroup_simt,
|
||||
warpgroup_id_simt);
|
||||
smem_O, gmem_tmp_d7, tid_in_warpgroup, threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
}
|
||||
|
||||
threadblock_barrier(barrier_id_simt, barrier_count_simt);
|
||||
threadblock_barrier(warpgroup_id_in_cluster,
|
||||
warps_per_warpgroup_per_core);
|
||||
}
|
||||
}
|
||||
|
||||
// Oi rescale
|
||||
thread_block_O_rescale</*block_row_major=*/GEMMINI_DMA>(
|
||||
smem_O, smem_O /*in-place*/, smem_O_row_scale,
|
||||
tid_in_warpgroup_simt, threads_per_warpgroup_simt,
|
||||
warpgroup_id_simt);
|
||||
smem_O, smem_O /*in-place*/, smem_O_row_scale, tid_in_warpgroup,
|
||||
threads_per_warpgroup, warpgroup_id_in_cluster);
|
||||
|
||||
// rescale-to-PV-GEMM barrier
|
||||
threadblock_barrier(barrier_id_simt, barrier_count_simt);
|
||||
threadblock_barrier(warpgroup_id_in_cluster,
|
||||
warps_per_warpgroup_per_core);
|
||||
|
||||
if constexpr (DEBUG) {
|
||||
if (warpgroup_id == 0) {
|
||||
if (warpgroup_id_in_cluster == 0) {
|
||||
// O before PV
|
||||
if (tile_k_ == 0) {
|
||||
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
|
||||
smem_P_produce, gmem_tmp_d2, tid_in_warpgroup_simt,
|
||||
threads_per_warpgroup_simt, warpgroup_id_simt);
|
||||
smem_P_produce, gmem_tmp_d2, tid_in_warpgroup,
|
||||
threads_per_warpgroup, warpgroup_id_in_cluster);
|
||||
thread_block_copy_tile<B_ROW, HEADDIM, GEMMINI_DMA>(
|
||||
smem_O, gmem_tmp_d4, tid_in_warpgroup_simt,
|
||||
threads_per_warpgroup_simt, warpgroup_id_simt);
|
||||
smem_O, gmem_tmp_d4, tid_in_warpgroup, threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
} else if (tile_k_ == 1) {
|
||||
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
|
||||
smem_P_produce, gmem_tmp_d3, tid_in_warpgroup_simt,
|
||||
threads_per_warpgroup_simt, warpgroup_id_simt);
|
||||
smem_P_produce, gmem_tmp_d3, tid_in_warpgroup,
|
||||
threads_per_warpgroup, warpgroup_id_in_cluster);
|
||||
thread_block_copy_tile<B_ROW, HEADDIM, GEMMINI_DMA>(
|
||||
smem_O, gmem_tmp_d5, tid_in_warpgroup_simt,
|
||||
threads_per_warpgroup_simt, warpgroup_id_simt);
|
||||
smem_O, gmem_tmp_d5, tid_in_warpgroup, threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
}
|
||||
|
||||
threadblock_barrier(barrier_id_simt, barrier_count_simt);
|
||||
threadblock_barrier(warpgroup_id_in_cluster,
|
||||
warps_per_warpgroup_per_core);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#if 0
|
||||
// fence GEMM I after Oi rescale
|
||||
if (tid_in_warpgroup == 0) {
|
||||
gemmini_fence();
|
||||
gemmini_fence();
|
||||
gemmini_fence();
|
||||
gemmini_fence();
|
||||
}
|
||||
|
||||
// reconverge from mmio divergence
|
||||
// intra-warpgroup barrier
|
||||
threadblock_barrier(warpgroup_id_in_cluster,
|
||||
warps_per_warpgroup_per_core);
|
||||
#endif
|
||||
|
||||
// intra-warpgroup barrier
|
||||
threadblock_barrier(barrier_id_simt, barrier_count_simt);
|
||||
// fence everything before going to the next tile
|
||||
gemmini_fence();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -342,6 +342,16 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
// threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
|
||||
// }
|
||||
|
||||
constexpr uint32_t threads_per_warpgroup_simt =
|
||||
threads_per_warpgroup -
|
||||
CORES_PER_CLUSTER * NUM_THREADS /*warp 0, 4, 8, 12*/;
|
||||
constexpr uint32_t warpgroup_id_simt = 1;
|
||||
constexpr uint32_t barrier_id_simt = 1;
|
||||
constexpr uint32_t barrier_count_simt = NUM_WARPS - 1;
|
||||
const uint32_t tid_in_warpgroup_simt =
|
||||
tid_in_warpgroup - (CORES_PER_CLUSTER * NUM_THREADS);
|
||||
static_assert(barrier_id_simt == 1 && barrier_count_simt == 7);
|
||||
|
||||
asm volatile ("tile_loop_start_%=:" :: );
|
||||
|
||||
// "inner loop" along the columns of K^T
|
||||
@@ -401,8 +411,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
const auto spad_hex_O = smem_O0_hexadecile; // NOTE: there's only single O tile
|
||||
asm volatile ("dbuf_sel_end_%=:" :: );
|
||||
|
||||
{
|
||||
if (tile_k >= 2) // delay GEMM II by 2 iters for pipelining
|
||||
if (vx_warp_id() == 0 /* warp 0 in every core */) {
|
||||
if (tile_k >= 2) // delay by 2 iters for pipelining
|
||||
{
|
||||
const uint32_t tile_k_ = tile_k - 2;
|
||||
|
||||
@@ -447,17 +457,16 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
asm volatile("gemm_qk_start_%=:" ::);
|
||||
|
||||
if (tid_in_warpgroup == 0) {
|
||||
// FIXME: remove
|
||||
// // fence to GEMM II completion
|
||||
// gemmini_fence();
|
||||
// fence to GEMM II completion
|
||||
gemmini_fence();
|
||||
|
||||
// #ifdef FENCE_GEMM_II
|
||||
// asm volatile("rescale_fence_write_start_%=:" ::);
|
||||
// // signal that GEMM II is finished to O rescale step
|
||||
// *smem_O_flag = 1;
|
||||
// vx_fence();
|
||||
// asm volatile("rescale_fence_write_end_%=:" ::);
|
||||
// #endif
|
||||
#ifdef FENCE_GEMM_II
|
||||
asm volatile("rescale_fence_write_start_%=:" ::);
|
||||
// signal that GEMM II is finished to O rescale step
|
||||
*smem_O_flag = 1;
|
||||
vx_fence();
|
||||
asm volatile("rescale_fence_write_end_%=:" ::);
|
||||
#endif
|
||||
|
||||
// Kick off GEMM I
|
||||
//
|
||||
@@ -490,6 +499,14 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
const float *gmem_V_tile =
|
||||
gmem_V + (HEADDIM * B_COL * (tile_k - 1 /*dragbehind*/));
|
||||
|
||||
#if 0
|
||||
// fence mvout S to SMEM
|
||||
gemmini_fence();
|
||||
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_K_tile),
|
||||
(uint64_t)(gmem_V_tile),
|
||||
k_LOOP_WS_CONFIG_ADDRS_AB)
|
||||
#endif
|
||||
|
||||
// do DMA
|
||||
if (tile_k == 0) {
|
||||
// // configure address strides for the DMA
|
||||
@@ -528,24 +545,24 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips);
|
||||
#endif
|
||||
}
|
||||
|
||||
// fence everything before going to the next tile
|
||||
gemmini_fence();
|
||||
}
|
||||
|
||||
// reconverge from mmio divergence
|
||||
threadblock_barrier(warpgroup_id_in_cluster,
|
||||
warps_per_warpgroup_per_core);
|
||||
// threadblock_barrier(warpgroup_id_in_cluster,
|
||||
// warps_per_warpgroup_per_core);
|
||||
|
||||
asm volatile("move_k_v_finish_%=:" ::);
|
||||
|
||||
// FIXME: remove for nowarpspec
|
||||
//
|
||||
// NOTE: cannot put barrier here; thread 1-7 in warp 0 will skip the
|
||||
// branch and call this barrier earlier than when thread 0 finishes.
|
||||
// Since tmask is not considered, that will be a barrier resolve done too
|
||||
// early
|
||||
// threadblock_barrier(0, 1);
|
||||
}
|
||||
|
||||
{
|
||||
} else /* warp_id != 0 */ {
|
||||
|
||||
if (tile_k >= 1) // delay online softmax by 1 iters
|
||||
{
|
||||
const uint32_t tile_k_ = tile_k - 1;
|
||||
@@ -555,49 +572,46 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
if (warpgroup_id == 0) {
|
||||
if (tile_k_ == 0) {
|
||||
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
|
||||
smem_S_consume, gmem_tmp_d0, tid_in_warpgroup,
|
||||
threads_per_warpgroup, warpgroup_id);
|
||||
smem_S_consume, gmem_tmp_d0, tid_in_warpgroup_simt,
|
||||
threads_per_warpgroup_simt, warpgroup_id_simt);
|
||||
} else if (tile_k_ == 1) {
|
||||
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
|
||||
smem_S_consume, gmem_tmp_d1, tid_in_warpgroup,
|
||||
threads_per_warpgroup, warpgroup_id);
|
||||
smem_S_consume, gmem_tmp_d1, tid_in_warpgroup_simt,
|
||||
threads_per_warpgroup_simt, warpgroup_id_simt);
|
||||
}
|
||||
|
||||
threadblock_barrier(warpgroup_id_in_cluster,
|
||||
warps_per_warpgroup_per_core);
|
||||
threadblock_barrier(barrier_id_simt, barrier_count_simt);
|
||||
}
|
||||
}
|
||||
|
||||
// Online softmax
|
||||
//
|
||||
thread_block_online_softmax</*block_row_major=*/GEMMINI_DMA>(
|
||||
smem_S_consume, smem_P_produce, tid_in_warpgroup,
|
||||
threads_per_warpgroup, warpgroup_id, smem_scratchpad,
|
||||
smem_S_consume, smem_P_produce, tid_in_warpgroup_simt,
|
||||
threads_per_warpgroup_simt, warpgroup_id_simt, smem_scratchpad,
|
||||
smem_rowmax, smem_rowsum, smem_O_row_scale);
|
||||
|
||||
threadblock_barrier(warpgroup_id_in_cluster,
|
||||
warps_per_warpgroup_per_core);
|
||||
threadblock_barrier(barrier_id_simt, barrier_count_simt);
|
||||
|
||||
if constexpr (DEBUG) {
|
||||
if (warpgroup_id == 0) {
|
||||
if (tile_k_ == 0) {
|
||||
thread_block_copy_rowmax(
|
||||
smem_rowmax, gmem_tmp_e0, tid_in_warpgroup,
|
||||
threads_per_warpgroup, warpgroup_id_in_cluster);
|
||||
smem_rowmax, gmem_tmp_e0, tid_in_warpgroup_simt,
|
||||
threads_per_warpgroup_simt, warpgroup_id_simt);
|
||||
thread_block_copy_rowmax(
|
||||
smem_rowsum, gmem_tmp_e2, tid_in_warpgroup,
|
||||
threads_per_warpgroup, warpgroup_id_in_cluster);
|
||||
smem_rowsum, gmem_tmp_e2, tid_in_warpgroup_simt,
|
||||
threads_per_warpgroup_simt, warpgroup_id_simt);
|
||||
} else if (tile_k_ == 1) {
|
||||
thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e1,
|
||||
tid_in_warpgroup, threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
tid_in_warpgroup_simt, threads_per_warpgroup_simt,
|
||||
warpgroup_id_simt);
|
||||
thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e3,
|
||||
tid_in_warpgroup, threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
tid_in_warpgroup_simt, threads_per_warpgroup_simt,
|
||||
warpgroup_id_simt);
|
||||
}
|
||||
|
||||
threadblock_barrier(warpgroup_id_in_cluster,
|
||||
warps_per_warpgroup_per_core);
|
||||
threadblock_barrier(barrier_id_simt, barrier_count_simt);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -605,7 +619,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
asm volatile("rescale_fence_read_start_%=:" ::);
|
||||
// check flag to make sure GEMM II finished and read-after-write
|
||||
// dependency on O tile is settled for rescale
|
||||
if (tid_in_warpgroup == 0) {
|
||||
if (tid_in_warpgroup_simt == 0) {
|
||||
while ((*smem_O_flag) != 1)
|
||||
;
|
||||
// set it back to 0 for the next tile iteration
|
||||
@@ -629,66 +643,74 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
#endif
|
||||
|
||||
if constexpr (DEBUG) {
|
||||
if (warpgroup_id_in_cluster == 0) {
|
||||
if (warpgroup_id == 0) {
|
||||
gemmini_fence();
|
||||
gemmini_fence();
|
||||
|
||||
// O after PV
|
||||
if (tile_k_ == 1 /*wait until GEMM II finshes */) {
|
||||
thread_block_copy_tile<B_ROW, HEADDIM, GEMMINI_DMA>(
|
||||
smem_O, gmem_tmp_d6, tid_in_warpgroup, threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
smem_O, gmem_tmp_d6, tid_in_warpgroup_simt, threads_per_warpgroup_simt,
|
||||
warpgroup_id_simt);
|
||||
} else if (tile_k_ == 2) {
|
||||
thread_block_copy_tile<B_ROW, HEADDIM, GEMMINI_DMA>(
|
||||
smem_O, gmem_tmp_d7, tid_in_warpgroup, threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
smem_O, gmem_tmp_d7, tid_in_warpgroup_simt, threads_per_warpgroup_simt,
|
||||
warpgroup_id_simt);
|
||||
}
|
||||
|
||||
threadblock_barrier(warpgroup_id_in_cluster,
|
||||
warps_per_warpgroup_per_core);
|
||||
threadblock_barrier(barrier_id_simt, barrier_count_simt);
|
||||
}
|
||||
}
|
||||
|
||||
// Oi rescale
|
||||
thread_block_O_rescale</*block_row_major=*/GEMMINI_DMA>(
|
||||
smem_O, smem_O /*in-place*/, smem_O_row_scale, tid_in_warpgroup,
|
||||
threads_per_warpgroup, warpgroup_id_in_cluster);
|
||||
smem_O, smem_O /*in-place*/, smem_O_row_scale,
|
||||
tid_in_warpgroup_simt, threads_per_warpgroup_simt,
|
||||
warpgroup_id_simt);
|
||||
|
||||
// rescale-to-PV-GEMM barrier
|
||||
threadblock_barrier(warpgroup_id_in_cluster,
|
||||
warps_per_warpgroup_per_core);
|
||||
threadblock_barrier(barrier_id_simt, barrier_count_simt);
|
||||
|
||||
if constexpr (DEBUG) {
|
||||
if (warpgroup_id_in_cluster == 0) {
|
||||
if (warpgroup_id == 0) {
|
||||
// O before PV
|
||||
if (tile_k_ == 0) {
|
||||
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
|
||||
smem_P_produce, gmem_tmp_d2, tid_in_warpgroup,
|
||||
threads_per_warpgroup, warpgroup_id_in_cluster);
|
||||
smem_P_produce, gmem_tmp_d2, tid_in_warpgroup_simt,
|
||||
threads_per_warpgroup_simt, warpgroup_id_simt);
|
||||
thread_block_copy_tile<B_ROW, HEADDIM, GEMMINI_DMA>(
|
||||
smem_O, gmem_tmp_d4, tid_in_warpgroup, threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
smem_O, gmem_tmp_d4, tid_in_warpgroup_simt,
|
||||
threads_per_warpgroup_simt, warpgroup_id_simt);
|
||||
} else if (tile_k_ == 1) {
|
||||
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
|
||||
smem_P_produce, gmem_tmp_d3, tid_in_warpgroup,
|
||||
threads_per_warpgroup, warpgroup_id_in_cluster);
|
||||
smem_P_produce, gmem_tmp_d3, tid_in_warpgroup_simt,
|
||||
threads_per_warpgroup_simt, warpgroup_id_simt);
|
||||
thread_block_copy_tile<B_ROW, HEADDIM, GEMMINI_DMA>(
|
||||
smem_O, gmem_tmp_d5, tid_in_warpgroup, threads_per_warpgroup,
|
||||
warpgroup_id_in_cluster);
|
||||
smem_O, gmem_tmp_d5, tid_in_warpgroup_simt,
|
||||
threads_per_warpgroup_simt, warpgroup_id_simt);
|
||||
}
|
||||
|
||||
threadblock_barrier(warpgroup_id_in_cluster,
|
||||
warps_per_warpgroup_per_core);
|
||||
threadblock_barrier(barrier_id_simt, barrier_count_simt);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// intra-warpgroup barrier
|
||||
#if 0
|
||||
// fence GEMM I after Oi rescale
|
||||
if (tid_in_warpgroup == 0) {
|
||||
gemmini_fence();
|
||||
gemmini_fence();
|
||||
gemmini_fence();
|
||||
gemmini_fence();
|
||||
}
|
||||
|
||||
// reconverge from mmio divergence
|
||||
threadblock_barrier(warpgroup_id_in_cluster,
|
||||
warps_per_warpgroup_per_core);
|
||||
#endif
|
||||
|
||||
// fence everything before going to the next tile
|
||||
gemmini_fence();
|
||||
// intra-warpgroup barrier
|
||||
threadblock_barrier(barrier_id_simt, barrier_count_simt);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user