flash: Fix hardcoded barrier for tcore; move tcore-specific flags
This commit is contained in:
@@ -8,6 +8,9 @@
|
||||
#include "gemmini_mmio.h"
|
||||
#include "flash_impl.hpp"
|
||||
|
||||
constexpr bool DEBUG = false;
|
||||
constexpr bool Q_IS_K_MAJOR = true;
|
||||
|
||||
void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
// @perf: All threads are running these compute whose result is mostly same
|
||||
// across the threadblock
|
||||
@@ -88,6 +91,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
uint8_t *smem_per_threadblock = reinterpret_cast<uint8_t *>(
|
||||
DEV_SMEM_START_ADDR);
|
||||
float *smem_cursor = reinterpret_cast<float *>(smem_per_threadblock);
|
||||
// constexpr uint32_t DEV_FAKE_SMEM_START_ADDR = 0xf0000000;
|
||||
// float *smem_cursor = reinterpret_cast<float *>(DEV_FAKE_SMEM_START_ADDR);
|
||||
float *smem_Q0 = smem_cursor;
|
||||
smem_cursor += smem_Q_size;
|
||||
@@ -310,7 +314,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
|
||||
// "inner loop" along the columns of K^T
|
||||
const uint32_t k_tiles = (dim_seqlen / B_COL);
|
||||
for (uint32_t tile_k = 0; tile_k < k_tiles; tile_k++) {
|
||||
for (uint32_t tile_k = 0; tile_k < (4 /* for perf measurement */ * k_tiles);
|
||||
tile_k++) {
|
||||
// float *smem_P_produce = (tile_k % 2) ? smem_P0 : smem_P1;
|
||||
// float *smem_P_consume = (tile_k % 2) ? smem_P1 : smem_P0;
|
||||
// float *smem_V_produce = (tile_k % 2) ? smem_V0 : smem_V1;
|
||||
|
||||
Reference in New Issue
Block a user