sgemm_impl: Rename initialize_C

This commit is contained in:
Hansung Kim
2024-08-19 16:12:35 -07:00
parent 4aba018733
commit 7ac038fadf

View File

@@ -320,9 +320,10 @@ inline void wmma_load_b(const volatile T *smem_B, const int local_k,
asm volatile ("wmma_load_b_finish_%=:" :: ); asm volatile ("wmma_load_b_finish_%=:" :: );
} }
inline void initialize_C(const int dest_reg) { // Initialize the accumulator registers to zero before starting FMA operations
// initialize C to zeros // with the tensor cores.
if (dest_reg == 0) { template <int accum_reg_set> inline void initialize_accum_regs() {
if constexpr (accum_reg_set == 0) {
asm volatile("fmv.w.x f16, x0"); asm volatile("fmv.w.x f16, x0");
asm volatile("fmv.w.x f17, x0"); asm volatile("fmv.w.x f17, x0");
asm volatile("fmv.w.x f18, x0"); asm volatile("fmv.w.x f18, x0");
@@ -650,13 +651,6 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
const uint32_t threadblocks_per_cluster, const uint32_t threadblocks_per_cluster,
const uint32_t threadblock_id_in_cluster, const uint32_t threadblock_id_in_cluster,
uint8_t *sharedmem_per_threadblock) { uint8_t *sharedmem_per_threadblock) {
const uint32_t local_a_row = tid_in_threadblock / BK;
const uint32_t local_a_col = tid_in_threadblock % BK;
const uint32_t local_as_row = tid_in_threadblock / BM;
const uint32_t local_as_col = tid_in_threadblock % BM;
const uint32_t local_b_row = tid_in_threadblock / BN;
const uint32_t local_b_col = tid_in_threadblock % BN;
// no double-buffering // no double-buffering
const uint32_t threads_per_warpgroup = threads_per_threadblock; const uint32_t threads_per_warpgroup = threads_per_threadblock;
const uint32_t warp_id_in_warpgroup = tid_in_threadblock / NUM_THREADS; const uint32_t warp_id_in_warpgroup = tid_in_threadblock / NUM_THREADS;
@@ -703,9 +697,9 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
for (uint32_t block_m = block_m_start; block_m < block_m_end; block_m++) { for (uint32_t block_m = block_m_start; block_m < block_m_end; block_m++) {
#pragma GCC unroll 1 #pragma GCC unroll 1
for (uint32_t block_n = 0; (block_n * BN) < dim_n; block_n++) { for (uint32_t block_n = 0; (block_n * BN) < dim_n; block_n++) {
// clear out C // clear out accumulators
initialize_C(0); initialize_accum_regs<0>();
initialize_C(1); initialize_accum_regs<1>();
if constexpr (GEMMINI_DMA) { if constexpr (GEMMINI_DMA) {
// pipeline initiation // pipeline initiation