sgemm_tcore: Fix DMA smem addresses, add markers

Take into account that DMA writes B tiles starting from the end of the
quartile.
This commit is contained in:
Hansung Kim
2024-10-28 17:26:07 -07:00
parent b4dadfaf61
commit ae98ae6e93

View File

@@ -7,8 +7,12 @@
#include "include/gemmini.h"
#include "gemmini_mmio.h"
#define MARK_BEG() asm volatile ("slti x0, x1, -1047")
#define MARK_END() asm volatile ("slti x0, x1, -499")
constexpr bool DEBUG = false;
// FIXME: doesn't take FLOAT_SIZE into account
template <uint32_t tile_dim_row, uint32_t tile_dim_col>
inline void thread_block_copy_tile(const float *src, float *dest,
const uint32_t tid_in_threadblock,
@@ -87,15 +91,25 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
DEV_SMEM_START_ADDR +
sizeof(float_type) * 2 * (2 * BM * BK) * threadblock_id_in_cluster);
MARK_BEG();
// NOTE: hardcoded
constexpr uint32_t quartile = (128 << 10) >> 2; // 128KB / 4
static_assert((quartile * 4) == SMEM_SIZE, "wrong quartile constant");
constexpr uint32_t smem_a_offset = 0;
constexpr uint32_t smem_a_dbuf_offset = 1 * quartile;
constexpr uint32_t smem_b_offset =
3 * quartile - BN * BK * sizeof(float_type);
constexpr uint32_t smem_b_dbuf_offset =
4 * quartile - BN * BK * sizeof(float_type);
thread_block_gemm<float_type, threads_per_threadblock,
/*write_to_gmem=*/true,
/*smem_a_offset=*/0,
/*smem_a_offset=*/smem_a_offset,
#ifdef GEMMINI_DMA
/*smem_a_dbuf_offset=*/1 * 128 * 128 * 2/*fp16*/,
/*smem_b_offset=*/2 * 128 * 128 * 2/*fp16*/,
/*smem_b_dbuf_offset=*/3 * 128 * 128 * 2/*fp16*/
// FIXME: above offsets are hardcoded to agree with CISC
// spadQuartile
/*smem_a_dbuf_offset=*/smem_a_dbuf_offset,
/*smem_b_offset=*/smem_b_offset,
/*smem_b_dbuf_offset=*/smem_b_dbuf_offset
#else
/*smem_a_dbuf_offset=*/1 * BM * BK * sizeof(float_type),
/*smem_b_offset=*/2 * BM * BK * sizeof(float_type),
@@ -107,21 +121,38 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
threadblocks_per_cluster, threadblock_id_in_cluster,
sharedmem_per_threadblock);
MARK_END();
float *gmem_tmp_d0 = reinterpret_cast<float *>(0xd0000000UL);
float *gmem_tmp_d1 = reinterpret_cast<float *>(0xd1000000UL);
float *gmem_tmp_d2 = reinterpret_cast<float *>(0xd2000000UL);
float *gmem_tmp_d3 = reinterpret_cast<float *>(0xd3000000UL);
const float *smem_A = reinterpret_cast<float *>(sharedmem_per_threadblock);
const float *smem_B = reinterpret_cast<float *>(
sharedmem_per_threadblock + 2 * BM * BK * sizeof(float_type));
const float *smem_A0 =
reinterpret_cast<float *>(sharedmem_per_threadblock + smem_a_offset);
const float *smem_A1 =
reinterpret_cast<float *>(sharedmem_per_threadblock + smem_a_dbuf_offset);
const float *smem_B0 =
reinterpret_cast<float *>(sharedmem_per_threadblock + smem_b_offset);
const float *smem_B1 =
reinterpret_cast<float *>(sharedmem_per_threadblock + smem_b_dbuf_offset);
// const float *smem_B = reinterpret_cast<float *>(
// sharedmem_per_threadblock + 2 * BM * BK * sizeof(float_type));
if constexpr (DEBUG) {
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
thread_block_copy_tile<BM, BK>(smem_A, gmem_tmp_d0, tid_in_threadblock,
thread_block_copy_tile<BM, BK>(smem_A0, gmem_tmp_d0, tid_in_threadblock,
threads_per_threadblock,
threadblock_id_in_cluster);
thread_block_copy_tile<BK, BN>(smem_B, gmem_tmp_d1, tid_in_threadblock,
thread_block_copy_tile<BM, BK>(smem_A1, gmem_tmp_d1, tid_in_threadblock,
threads_per_threadblock,
threadblock_id_in_cluster);
thread_block_copy_tile<BK, BN>(smem_B0, gmem_tmp_d2, tid_in_threadblock,
threads_per_threadblock,
threadblock_id_in_cluster);
thread_block_copy_tile<BK, BN>(smem_B1, gmem_tmp_d3, tid_in_threadblock,
threads_per_threadblock,
threadblock_id_in_cluster);
}