sgemm_tcore: Fix DMA smem addresses, add markers
Take into account that DMA writes B tiles starting from the end of the quartile.
This commit is contained in:
@@ -7,8 +7,12 @@
|
||||
#include "include/gemmini.h"
|
||||
#include "gemmini_mmio.h"
|
||||
|
||||
#define MARK_BEG() asm volatile ("slti x0, x1, -1047")
|
||||
#define MARK_END() asm volatile ("slti x0, x1, -499")
|
||||
|
||||
constexpr bool DEBUG = false;
|
||||
|
||||
// FIXME: doesn't take FLOAT_SIZE into account
|
||||
template <uint32_t tile_dim_row, uint32_t tile_dim_col>
|
||||
inline void thread_block_copy_tile(const float *src, float *dest,
|
||||
const uint32_t tid_in_threadblock,
|
||||
@@ -87,15 +91,25 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
DEV_SMEM_START_ADDR +
|
||||
sizeof(float_type) * 2 * (2 * BM * BK) * threadblock_id_in_cluster);
|
||||
|
||||
MARK_BEG();
|
||||
|
||||
// NOTE: hardcoded
|
||||
constexpr uint32_t quartile = (128 << 10) >> 2; // 128KB / 4
|
||||
static_assert((quartile * 4) == SMEM_SIZE, "wrong quartile constant");
|
||||
|
||||
constexpr uint32_t smem_a_offset = 0;
|
||||
constexpr uint32_t smem_a_dbuf_offset = 1 * quartile;
|
||||
constexpr uint32_t smem_b_offset =
|
||||
3 * quartile - BN * BK * sizeof(float_type);
|
||||
constexpr uint32_t smem_b_dbuf_offset =
|
||||
4 * quartile - BN * BK * sizeof(float_type);
|
||||
thread_block_gemm<float_type, threads_per_threadblock,
|
||||
/*write_to_gmem=*/true,
|
||||
/*smem_a_offset=*/0,
|
||||
/*smem_a_offset=*/smem_a_offset,
|
||||
#ifdef GEMMINI_DMA
|
||||
/*smem_a_dbuf_offset=*/1 * 128 * 128 * 2/*fp16*/,
|
||||
/*smem_b_offset=*/2 * 128 * 128 * 2/*fp16*/,
|
||||
/*smem_b_dbuf_offset=*/3 * 128 * 128 * 2/*fp16*/
|
||||
// FIXME: above offsets are hardcoded to agree with CISC
|
||||
// spadQuartile
|
||||
/*smem_a_dbuf_offset=*/smem_a_dbuf_offset,
|
||||
/*smem_b_offset=*/smem_b_offset,
|
||||
/*smem_b_dbuf_offset=*/smem_b_dbuf_offset
|
||||
#else
|
||||
/*smem_a_dbuf_offset=*/1 * BM * BK * sizeof(float_type),
|
||||
/*smem_b_offset=*/2 * BM * BK * sizeof(float_type),
|
||||
@@ -107,21 +121,38 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
threadblocks_per_cluster, threadblock_id_in_cluster,
|
||||
sharedmem_per_threadblock);
|
||||
|
||||
MARK_END();
|
||||
|
||||
float *gmem_tmp_d0 = reinterpret_cast<float *>(0xd0000000UL);
|
||||
float *gmem_tmp_d1 = reinterpret_cast<float *>(0xd1000000UL);
|
||||
float *gmem_tmp_d2 = reinterpret_cast<float *>(0xd2000000UL);
|
||||
float *gmem_tmp_d3 = reinterpret_cast<float *>(0xd3000000UL);
|
||||
|
||||
const float *smem_A = reinterpret_cast<float *>(sharedmem_per_threadblock);
|
||||
const float *smem_B = reinterpret_cast<float *>(
|
||||
sharedmem_per_threadblock + 2 * BM * BK * sizeof(float_type));
|
||||
const float *smem_A0 =
|
||||
reinterpret_cast<float *>(sharedmem_per_threadblock + smem_a_offset);
|
||||
const float *smem_A1 =
|
||||
reinterpret_cast<float *>(sharedmem_per_threadblock + smem_a_dbuf_offset);
|
||||
const float *smem_B0 =
|
||||
reinterpret_cast<float *>(sharedmem_per_threadblock + smem_b_offset);
|
||||
const float *smem_B1 =
|
||||
reinterpret_cast<float *>(sharedmem_per_threadblock + smem_b_dbuf_offset);
|
||||
// const float *smem_B = reinterpret_cast<float *>(
|
||||
// sharedmem_per_threadblock + 2 * BM * BK * sizeof(float_type));
|
||||
|
||||
if constexpr (DEBUG) {
|
||||
threadblock_barrier(threadblock_id_in_cluster,
|
||||
warps_per_threadblock_per_core);
|
||||
|
||||
thread_block_copy_tile<BM, BK>(smem_A, gmem_tmp_d0, tid_in_threadblock,
|
||||
thread_block_copy_tile<BM, BK>(smem_A0, gmem_tmp_d0, tid_in_threadblock,
|
||||
threads_per_threadblock,
|
||||
threadblock_id_in_cluster);
|
||||
thread_block_copy_tile<BK, BN>(smem_B, gmem_tmp_d1, tid_in_threadblock,
|
||||
thread_block_copy_tile<BM, BK>(smem_A1, gmem_tmp_d1, tid_in_threadblock,
|
||||
threads_per_threadblock,
|
||||
threadblock_id_in_cluster);
|
||||
thread_block_copy_tile<BK, BN>(smem_B0, gmem_tmp_d2, tid_in_threadblock,
|
||||
threads_per_threadblock,
|
||||
threadblock_id_in_cluster);
|
||||
thread_block_copy_tile<BK, BN>(smem_B1, gmem_tmp_d3, tid_in_threadblock,
|
||||
threads_per_threadblock,
|
||||
threadblock_id_in_cluster);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user