sgemm_impl: Rename initialize_C

This commit is contained in:
Hansung Kim
2024-08-19 16:12:35 -07:00
parent 4aba018733
commit 7ac038fadf

View File

@@ -320,9 +320,10 @@ inline void wmma_load_b(const volatile T *smem_B, const int local_k,
asm volatile ("wmma_load_b_finish_%=:" :: );
}
inline void initialize_C(const int dest_reg) {
// initialize C to zeros
if (dest_reg == 0) {
// Initialize the accumulator registers to zero before starting FMA operations
// with the tensor cores.
template <int accum_reg_set> inline void initialize_accum_regs() {
if constexpr (accum_reg_set == 0) {
asm volatile("fmv.w.x f16, x0");
asm volatile("fmv.w.x f17, x0");
asm volatile("fmv.w.x f18, x0");
@@ -650,13 +651,6 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
const uint32_t threadblocks_per_cluster,
const uint32_t threadblock_id_in_cluster,
uint8_t *sharedmem_per_threadblock) {
const uint32_t local_a_row = tid_in_threadblock / BK;
const uint32_t local_a_col = tid_in_threadblock % BK;
const uint32_t local_as_row = tid_in_threadblock / BM;
const uint32_t local_as_col = tid_in_threadblock % BM;
const uint32_t local_b_row = tid_in_threadblock / BN;
const uint32_t local_b_col = tid_in_threadblock % BN;
// no double-buffering
const uint32_t threads_per_warpgroup = threads_per_threadblock;
const uint32_t warp_id_in_warpgroup = tid_in_threadblock / NUM_THREADS;
@@ -703,9 +697,9 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
for (uint32_t block_m = block_m_start; block_m < block_m_end; block_m++) {
#pragma GCC unroll 1
for (uint32_t block_n = 0; (block_n * BN) < dim_n; block_n++) {
// clear out C
initialize_C(0);
initialize_C(1);
// clear out accumulators
initialize_accum_regs<0>();
initialize_accum_regs<1>();
if constexpr (GEMMINI_DMA) {
// pipeline initiation