flash: Fix loop iteration for gemmini

Kernel is software-pipelined around 2 GEMMs and softmax; it requires two
iterations to fully complete a tile.
This commit is contained in:
Hansung Kim
2024-11-08 16:43:08 -08:00
parent 4055255018
commit 4e087a8aab

View File

@@ -336,7 +336,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
// "inner loop" along the columns of K^T
const uint32_t k_tiles = (dim_seqlen / B_COL);
for (uint32_t tile_k = 0;
tile_k < (4 /*for perf measurement*/ * k_tiles) + 2 /*pipeline latency*/;
tile_k < (4 /*for perf measurement*/ *
// virgo kernel is fully pipelined around (2 GEMMs | softmax);
// requires two loop iterations to finish one tile compute
(2 * k_tiles)) +
2 /*pipeline latency*/;
tile_k++) {
if constexpr (DEBUG || true) {
threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);