diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp index f76e7a24..d3bda941 100644 --- a/tests/regression/sgemm_tcore/kernel.cpp +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -7,8 +7,12 @@ #include "include/gemmini.h" #include "gemmini_mmio.h" +#define MARK_BEG() asm volatile ("slti x0, x1, -1047") +#define MARK_END() asm volatile ("slti x0, x1, -499") + constexpr bool DEBUG = false; +// FIXME: doesn't take FLOAT_SIZE into account template inline void thread_block_copy_tile(const float *src, float *dest, const uint32_t tid_in_threadblock, @@ -87,15 +91,25 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { DEV_SMEM_START_ADDR + sizeof(float_type) * 2 * (2 * BM * BK) * threadblock_id_in_cluster); + MARK_BEG(); + + // NOTE: hardcoded + constexpr uint32_t quartile = (128 << 10) >> 2; // 128KB / 4 + static_assert((quartile * 4) == SMEM_SIZE, "wrong quartile constant"); + + constexpr uint32_t smem_a_offset = 0; + constexpr uint32_t smem_a_dbuf_offset = 1 * quartile; + constexpr uint32_t smem_b_offset = + 3 * quartile - BN * BK * sizeof(float_type); + constexpr uint32_t smem_b_dbuf_offset = + 4 * quartile - BN * BK * sizeof(float_type); thread_block_gemm(0xd0000000UL); float *gmem_tmp_d1 = reinterpret_cast(0xd1000000UL); + float *gmem_tmp_d2 = reinterpret_cast(0xd2000000UL); + float *gmem_tmp_d3 = reinterpret_cast(0xd3000000UL); - const float *smem_A = reinterpret_cast(sharedmem_per_threadblock); - const float *smem_B = reinterpret_cast( - sharedmem_per_threadblock + 2 * BM * BK * sizeof(float_type)); + const float *smem_A0 = + reinterpret_cast(sharedmem_per_threadblock + smem_a_offset); + const float *smem_A1 = + reinterpret_cast(sharedmem_per_threadblock + smem_a_dbuf_offset); + const float *smem_B0 = + reinterpret_cast(sharedmem_per_threadblock + smem_b_offset); + const float *smem_B1 = + reinterpret_cast(sharedmem_per_threadblock + smem_b_dbuf_offset); + // const float *smem_B = reinterpret_cast( + // sharedmem_per_threadblock + 2 * BM * BK * sizeof(float_type)); if constexpr (DEBUG) { threadblock_barrier(threadblock_id_in_cluster, warps_per_threadblock_per_core); - thread_block_copy_tile(smem_A, gmem_tmp_d0, tid_in_threadblock, + thread_block_copy_tile(smem_A0, gmem_tmp_d0, tid_in_threadblock, threads_per_threadblock, threadblock_id_in_cluster); - thread_block_copy_tile(smem_B, gmem_tmp_d1, tid_in_threadblock, + thread_block_copy_tile(smem_A1, gmem_tmp_d1, tid_in_threadblock, + threads_per_threadblock, + threadblock_id_in_cluster); + thread_block_copy_tile(smem_B0, gmem_tmp_d2, tid_in_threadblock, + threads_per_threadblock, + threadblock_id_in_cluster); + thread_block_copy_tile(smem_B1, gmem_tmp_d3, tid_in_threadblock, threads_per_threadblock, threadblock_id_in_cluster); }