sgemm_tcore: Fix DMA smem addresses, add markers
Take into account that DMA writes B tiles starting from the end of the quartile.
This commit is contained in:
@@ -7,8 +7,12 @@
|
|||||||
#include "include/gemmini.h"
|
#include "include/gemmini.h"
|
||||||
#include "gemmini_mmio.h"
|
#include "gemmini_mmio.h"
|
||||||
|
|
||||||
|
#define MARK_BEG() asm volatile ("slti x0, x1, -1047")
|
||||||
|
#define MARK_END() asm volatile ("slti x0, x1, -499")
|
||||||
|
|
||||||
constexpr bool DEBUG = false;
|
constexpr bool DEBUG = false;
|
||||||
|
|
||||||
|
// FIXME: doesn't take FLOAT_SIZE into account
|
||||||
template <uint32_t tile_dim_row, uint32_t tile_dim_col>
|
template <uint32_t tile_dim_row, uint32_t tile_dim_col>
|
||||||
inline void thread_block_copy_tile(const float *src, float *dest,
|
inline void thread_block_copy_tile(const float *src, float *dest,
|
||||||
const uint32_t tid_in_threadblock,
|
const uint32_t tid_in_threadblock,
|
||||||
@@ -87,15 +91,25 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
DEV_SMEM_START_ADDR +
|
DEV_SMEM_START_ADDR +
|
||||||
sizeof(float_type) * 2 * (2 * BM * BK) * threadblock_id_in_cluster);
|
sizeof(float_type) * 2 * (2 * BM * BK) * threadblock_id_in_cluster);
|
||||||
|
|
||||||
|
MARK_BEG();
|
||||||
|
|
||||||
|
// NOTE: hardcoded
|
||||||
|
constexpr uint32_t quartile = (128 << 10) >> 2; // 128KB / 4
|
||||||
|
static_assert((quartile * 4) == SMEM_SIZE, "wrong quartile constant");
|
||||||
|
|
||||||
|
constexpr uint32_t smem_a_offset = 0;
|
||||||
|
constexpr uint32_t smem_a_dbuf_offset = 1 * quartile;
|
||||||
|
constexpr uint32_t smem_b_offset =
|
||||||
|
3 * quartile - BN * BK * sizeof(float_type);
|
||||||
|
constexpr uint32_t smem_b_dbuf_offset =
|
||||||
|
4 * quartile - BN * BK * sizeof(float_type);
|
||||||
thread_block_gemm<float_type, threads_per_threadblock,
|
thread_block_gemm<float_type, threads_per_threadblock,
|
||||||
/*write_to_gmem=*/true,
|
/*write_to_gmem=*/true,
|
||||||
/*smem_a_offset=*/0,
|
/*smem_a_offset=*/smem_a_offset,
|
||||||
#ifdef GEMMINI_DMA
|
#ifdef GEMMINI_DMA
|
||||||
/*smem_a_dbuf_offset=*/1 * 128 * 128 * 2/*fp16*/,
|
/*smem_a_dbuf_offset=*/smem_a_dbuf_offset,
|
||||||
/*smem_b_offset=*/2 * 128 * 128 * 2/*fp16*/,
|
/*smem_b_offset=*/smem_b_offset,
|
||||||
/*smem_b_dbuf_offset=*/3 * 128 * 128 * 2/*fp16*/
|
/*smem_b_dbuf_offset=*/smem_b_dbuf_offset
|
||||||
// FIXME: above offsets are hardcoded to agree with CISC
|
|
||||||
// spadQuartile
|
|
||||||
#else
|
#else
|
||||||
/*smem_a_dbuf_offset=*/1 * BM * BK * sizeof(float_type),
|
/*smem_a_dbuf_offset=*/1 * BM * BK * sizeof(float_type),
|
||||||
/*smem_b_offset=*/2 * BM * BK * sizeof(float_type),
|
/*smem_b_offset=*/2 * BM * BK * sizeof(float_type),
|
||||||
@@ -107,21 +121,38 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
threadblocks_per_cluster, threadblock_id_in_cluster,
|
threadblocks_per_cluster, threadblock_id_in_cluster,
|
||||||
sharedmem_per_threadblock);
|
sharedmem_per_threadblock);
|
||||||
|
|
||||||
|
MARK_END();
|
||||||
|
|
||||||
float *gmem_tmp_d0 = reinterpret_cast<float *>(0xd0000000UL);
|
float *gmem_tmp_d0 = reinterpret_cast<float *>(0xd0000000UL);
|
||||||
float *gmem_tmp_d1 = reinterpret_cast<float *>(0xd1000000UL);
|
float *gmem_tmp_d1 = reinterpret_cast<float *>(0xd1000000UL);
|
||||||
|
float *gmem_tmp_d2 = reinterpret_cast<float *>(0xd2000000UL);
|
||||||
|
float *gmem_tmp_d3 = reinterpret_cast<float *>(0xd3000000UL);
|
||||||
|
|
||||||
const float *smem_A = reinterpret_cast<float *>(sharedmem_per_threadblock);
|
const float *smem_A0 =
|
||||||
const float *smem_B = reinterpret_cast<float *>(
|
reinterpret_cast<float *>(sharedmem_per_threadblock + smem_a_offset);
|
||||||
sharedmem_per_threadblock + 2 * BM * BK * sizeof(float_type));
|
const float *smem_A1 =
|
||||||
|
reinterpret_cast<float *>(sharedmem_per_threadblock + smem_a_dbuf_offset);
|
||||||
|
const float *smem_B0 =
|
||||||
|
reinterpret_cast<float *>(sharedmem_per_threadblock + smem_b_offset);
|
||||||
|
const float *smem_B1 =
|
||||||
|
reinterpret_cast<float *>(sharedmem_per_threadblock + smem_b_dbuf_offset);
|
||||||
|
// const float *smem_B = reinterpret_cast<float *>(
|
||||||
|
// sharedmem_per_threadblock + 2 * BM * BK * sizeof(float_type));
|
||||||
|
|
||||||
if constexpr (DEBUG) {
|
if constexpr (DEBUG) {
|
||||||
threadblock_barrier(threadblock_id_in_cluster,
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
warps_per_threadblock_per_core);
|
warps_per_threadblock_per_core);
|
||||||
|
|
||||||
thread_block_copy_tile<BM, BK>(smem_A, gmem_tmp_d0, tid_in_threadblock,
|
thread_block_copy_tile<BM, BK>(smem_A0, gmem_tmp_d0, tid_in_threadblock,
|
||||||
threads_per_threadblock,
|
threads_per_threadblock,
|
||||||
threadblock_id_in_cluster);
|
threadblock_id_in_cluster);
|
||||||
thread_block_copy_tile<BK, BN>(smem_B, gmem_tmp_d1, tid_in_threadblock,
|
thread_block_copy_tile<BM, BK>(smem_A1, gmem_tmp_d1, tid_in_threadblock,
|
||||||
|
threads_per_threadblock,
|
||||||
|
threadblock_id_in_cluster);
|
||||||
|
thread_block_copy_tile<BK, BN>(smem_B0, gmem_tmp_d2, tid_in_threadblock,
|
||||||
|
threads_per_threadblock,
|
||||||
|
threadblock_id_in_cluster);
|
||||||
|
thread_block_copy_tile<BK, BN>(smem_B1, gmem_tmp_d3, tid_in_threadblock,
|
||||||
threads_per_threadblock,
|
threads_per_threadblock,
|
||||||
threadblock_id_in_cluster);
|
threadblock_id_in_cluster);
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user