sgemm_impl: Rename initialize_C
This commit is contained in:
@@ -320,9 +320,10 @@ inline void wmma_load_b(const volatile T *smem_B, const int local_k,
|
|||||||
asm volatile ("wmma_load_b_finish_%=:" :: );
|
asm volatile ("wmma_load_b_finish_%=:" :: );
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void initialize_C(const int dest_reg) {
|
// Initialize the accumulator registers to zero before starting FMA operations
|
||||||
// initialize C to zeros
|
// with the tensor cores.
|
||||||
if (dest_reg == 0) {
|
template <int accum_reg_set> inline void initialize_accum_regs() {
|
||||||
|
if constexpr (accum_reg_set == 0) {
|
||||||
asm volatile("fmv.w.x f16, x0");
|
asm volatile("fmv.w.x f16, x0");
|
||||||
asm volatile("fmv.w.x f17, x0");
|
asm volatile("fmv.w.x f17, x0");
|
||||||
asm volatile("fmv.w.x f18, x0");
|
asm volatile("fmv.w.x f18, x0");
|
||||||
@@ -650,13 +651,6 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
|
|||||||
const uint32_t threadblocks_per_cluster,
|
const uint32_t threadblocks_per_cluster,
|
||||||
const uint32_t threadblock_id_in_cluster,
|
const uint32_t threadblock_id_in_cluster,
|
||||||
uint8_t *sharedmem_per_threadblock) {
|
uint8_t *sharedmem_per_threadblock) {
|
||||||
const uint32_t local_a_row = tid_in_threadblock / BK;
|
|
||||||
const uint32_t local_a_col = tid_in_threadblock % BK;
|
|
||||||
const uint32_t local_as_row = tid_in_threadblock / BM;
|
|
||||||
const uint32_t local_as_col = tid_in_threadblock % BM;
|
|
||||||
const uint32_t local_b_row = tid_in_threadblock / BN;
|
|
||||||
const uint32_t local_b_col = tid_in_threadblock % BN;
|
|
||||||
|
|
||||||
// no double-buffering
|
// no double-buffering
|
||||||
const uint32_t threads_per_warpgroup = threads_per_threadblock;
|
const uint32_t threads_per_warpgroup = threads_per_threadblock;
|
||||||
const uint32_t warp_id_in_warpgroup = tid_in_threadblock / NUM_THREADS;
|
const uint32_t warp_id_in_warpgroup = tid_in_threadblock / NUM_THREADS;
|
||||||
@@ -703,9 +697,9 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
|
|||||||
for (uint32_t block_m = block_m_start; block_m < block_m_end; block_m++) {
|
for (uint32_t block_m = block_m_start; block_m < block_m_end; block_m++) {
|
||||||
#pragma GCC unroll 1
|
#pragma GCC unroll 1
|
||||||
for (uint32_t block_n = 0; (block_n * BN) < dim_n; block_n++) {
|
for (uint32_t block_n = 0; (block_n * BN) < dim_n; block_n++) {
|
||||||
// clear out C
|
// clear out accumulators
|
||||||
initialize_C(0);
|
initialize_accum_regs<0>();
|
||||||
initialize_C(1);
|
initialize_accum_regs<1>();
|
||||||
|
|
||||||
if constexpr (GEMMINI_DMA) {
|
if constexpr (GEMMINI_DMA) {
|
||||||
// pipeline initiation
|
// pipeline initiation
|
||||||
|
|||||||
Reference in New Issue
Block a user