From bebdd3353e4de29dda99625e188bf14b268f8770 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Wed, 19 Jun 2024 15:41:50 -0700 Subject: [PATCH] Use SWISH in activate_block for tcore and gemmini --- .../sgemm_gemmini_dma/kernel.activation.cpp | 230 +++++++++++------- .../sgemm_tcore/kernel.activation.cpp | 17 +- 2 files changed, 155 insertions(+), 92 deletions(-) diff --git a/tests/regression/sgemm_gemmini_dma/kernel.activation.cpp b/tests/regression/sgemm_gemmini_dma/kernel.activation.cpp index 704bb273..2855f009 100644 --- a/tests/regression/sgemm_gemmini_dma/kernel.activation.cpp +++ b/tests/regression/sgemm_gemmini_dma/kernel.activation.cpp @@ -57,7 +57,7 @@ inline void threadblock_barrier(unsigned int barrier_id, unsigned int count) { vx_barrier(barrier_id, count); } -inline void activate_block(const uint32_t dim_n, const float *const C, +inline void activate_block(const uint32_t dim_n, float *const C, const uint32_t tile_i, const uint32_t tile_j, const uint32_t warp_row, const uint32_t warp_col, const uint32_t tid_in_threadblock) { @@ -71,16 +71,28 @@ inline void activate_block(const uint32_t dim_n, const float *const C, const uint32_t row_in_warptile = 0; const uint32_t C_row = (tile_i * TILE_M) + (warp_row * WM) + row_in_warptile; const uint32_t C_col = (tile_j * TILE_N) + (warp_col * WN) + col_in_warptile; - const float *const global_C = C + dim_n * C_row + C_col; + float *const global_C = C + dim_n * C_row + C_col; const float *global_C_curr = global_C; - // read in elements from GMEM to RF - // each thread works on ELEM_PER_THREAD elements, which can be larger than 1 - static_assert(ELEM_PER_THREAD == 16, "currently assumes ELEM_PER_THREAD == 16"); - + // ELEM_PER_THREAD macro does not take into account warp-specialization + constexpr uint32_t elem_per_thread = ELEM_PER_THREAD; constexpr uint32_t asm_unrolled = 8; // working with f0~f7 at a time + // each thread works on ELEM_PER_THREAD elements, which can be larger than 1 + static_assert((elem_per_thread % asm_unrolled) == 0, + "unmet manual unroll condition for elem_per_thread"); - for (int i = 0; i < ELEM_PER_THREAD; i += asm_unrolled) { +#if 1 + float elems[elem_per_thread]; + +#pragma GCC unroll asm_unrolled + for (int elem_i = 0; elem_i < elem_per_thread; elem_i++) { + elems[elem_i] = global_C[dim_n * elem_i]; + elems[elem_i] = SWISH(1.0f, elems[elem_i]); + global_C[dim_n * elem_i] = elems[elem_i]; + } +#else + for (int i = 0; i < elem_per_thread; i += asm_unrolled) { + // read in elements from GMEM to RF asm volatile("mv t6, %0" ::"r"(global_C_curr)); asm volatile("flw f0, (t6)"); asm volatile("add t6, t6, %0" ::"r"(dim_n * sizeof(float))); @@ -99,89 +111,126 @@ inline void activate_block(const uint32_t dim_n, const float *const C, asm volatile("flw f7, (t6)"); asm volatile("add t6, t6, %0" ::"r"(dim_n * sizeof(float))); - // do elem-wise e^x - // each register has 3 temporary registers: - // f0 has f8, f9, f10 - // f1 has f11, f12, f13 - asm volatile("fcvt.s.w f9, %0" ::"r"(1)); - asm volatile("fadd.s f8, f9, f0"); // acc = 1 + x - asm volatile("fcvt.s.w f9, %0" ::"r"(2)); - asm volatile("fdiv.s f10, f0, f9"); // x / 2 - asm volatile("fmadd.s f8, f10, f0, f8"); // acc += (x / 2) * x - asm volatile("fcvt.s.w f9, %0" ::"r"(3)); - asm volatile("fmul.s f10, f10, f0"); // (x * x) / 2 - asm volatile("fdiv.s f10, f10, f9"); // (x * x) / (2 * 3) - asm volatile("fmadd.s f0, f10, f0, f8"); // acc += (x * x) / (2 * 3) * x + if constexpr (true) { + // FIXME: this is likely incorrect; f0~f7 regs get overwritten by + // the compiler + register float x0 asm("f0"); + register float x1 asm("f1"); + register float x2 asm("f2"); + register float x3 asm("f3"); + register float x4 asm("f4"); + register float x5 asm("f5"); + register float x6 asm("f6"); + register float x7 asm("f7"); + asm volatile("fmv.s %0, f0" :"=f"(x0)); + x0 = SWISH(1.0f, x0); + asm volatile("fmv.s f0, %0" ::"f"(x0)); + asm volatile("fmv.s %0, f1" :"=f"(x1)); + x1 = SWISH(1.0f, x1); + asm volatile("fmv.s f1, %0" ::"f"(x1)); + asm volatile("fmv.s %0, f1" :"=f"(x2)); + x2 = SWISH(1.0f, x2); + asm volatile("fmv.s f1, %0" ::"f"(x2)); + asm volatile("fmv.s %0, f1" :"=f"(x3)); + x3 = SWISH(1.0f, x3); + asm volatile("fmv.s f1, %0" ::"f"(x3)); + asm volatile("fmv.s %0, f1" :"=f"(x4)); + x4 = SWISH(1.0f, x4); + asm volatile("fmv.s f1, %0" ::"f"(x4)); + asm volatile("fmv.s %0, f1" :"=f"(x5)); + x5 = SWISH(1.0f, x5); + asm volatile("fmv.s f1, %0" ::"f"(x5)); + asm volatile("fmv.s %0, f1" :"=f"(x6)); + x6 = SWISH(1.0f, x6); + asm volatile("fmv.s f1, %0" ::"f"(x6)); + asm volatile("fmv.s %0, f1" :"=f"(x7)); + x7 = SWISH(1.0f, x7); + asm volatile("fmv.s f1, %0" ::"f"(x7)); + } else { + // do elem-wise e^x + // each register has 3 temporary registers: + // f0 has f8, f9, f10 + // f1 has f11, f12, f13 + asm volatile("fcvt.s.w f9, %0" ::"r"(1)); + asm volatile("fadd.s f8, f9, f0"); // acc = 1 + x + asm volatile("fcvt.s.w f9, %0" ::"r"(2)); + asm volatile("fdiv.s f10, f0, f9"); // x / 2 + asm volatile("fmadd.s f8, f10, f0, f8"); // acc += (x / 2) * x + asm volatile("fcvt.s.w f9, %0" ::"r"(3)); + asm volatile("fmul.s f10, f10, f0"); // (x * x) / 2 + asm volatile("fdiv.s f10, f10, f9"); // (x * x) / (2 * 3) + asm volatile("fmadd.s f0, f10, f0, f8"); // acc += (x * x) / (2 * 3) * x - asm volatile("fcvt.s.w f12, %0" ::"r"(1)); - asm volatile("fadd.s f11, f12, f1"); - asm volatile("fcvt.s.w f12, %0" ::"r"(2)); - asm volatile("fdiv.s f13, f1, f12"); - asm volatile("fmadd.s f11, f13, f1, f11"); - asm volatile("fcvt.s.w f12, %0" ::"r"(3)); - asm volatile("fmul.s f13, f13, f1"); - asm volatile("fdiv.s f13, f13, f12"); - asm volatile("fmadd.s f1, f13, f1, f11"); + asm volatile("fcvt.s.w f12, %0" ::"r"(1)); + asm volatile("fadd.s f11, f12, f1"); + asm volatile("fcvt.s.w f12, %0" ::"r"(2)); + asm volatile("fdiv.s f13, f1, f12"); + asm volatile("fmadd.s f11, f13, f1, f11"); + asm volatile("fcvt.s.w f12, %0" ::"r"(3)); + asm volatile("fmul.s f13, f13, f1"); + asm volatile("fdiv.s f13, f13, f12"); + asm volatile("fmadd.s f1, f13, f1, f11"); - asm volatile("fcvt.s.w f15, %0" ::"r"(1)); - asm volatile("fadd.s f14, f15, f2"); - asm volatile("fcvt.s.w f15, %0" ::"r"(2)); - asm volatile("fdiv.s f16, f2, f15"); - asm volatile("fmadd.s f14, f16, f2, f14"); - asm volatile("fcvt.s.w f15, %0" ::"r"(3)); - asm volatile("fmul.s f16, f16, f2"); - asm volatile("fdiv.s f16, f16, f15"); - asm volatile("fmadd.s f2, f16, f2, f14"); + asm volatile("fcvt.s.w f15, %0" ::"r"(1)); + asm volatile("fadd.s f14, f15, f2"); + asm volatile("fcvt.s.w f15, %0" ::"r"(2)); + asm volatile("fdiv.s f16, f2, f15"); + asm volatile("fmadd.s f14, f16, f2, f14"); + asm volatile("fcvt.s.w f15, %0" ::"r"(3)); + asm volatile("fmul.s f16, f16, f2"); + asm volatile("fdiv.s f16, f16, f15"); + asm volatile("fmadd.s f2, f16, f2, f14"); - asm volatile("fcvt.s.w f18, %0" ::"r"(1)); - asm volatile("fadd.s f17, f18, f3"); - asm volatile("fcvt.s.w f18, %0" ::"r"(2)); - asm volatile("fdiv.s f19, f3, f18"); - asm volatile("fmadd.s f17, f19, f3, f17"); - asm volatile("fcvt.s.w f18, %0" ::"r"(3)); - asm volatile("fmul.s f19, f19, f3"); - asm volatile("fdiv.s f19, f19, f18"); - asm volatile("fmadd.s f3, f19, f3, f17"); + asm volatile("fcvt.s.w f18, %0" ::"r"(1)); + asm volatile("fadd.s f17, f18, f3"); + asm volatile("fcvt.s.w f18, %0" ::"r"(2)); + asm volatile("fdiv.s f19, f3, f18"); + asm volatile("fmadd.s f17, f19, f3, f17"); + asm volatile("fcvt.s.w f18, %0" ::"r"(3)); + asm volatile("fmul.s f19, f19, f3"); + asm volatile("fdiv.s f19, f19, f18"); + asm volatile("fmadd.s f3, f19, f3, f17"); - asm volatile("fcvt.s.w f21, %0" ::"r"(1)); - asm volatile("fadd.s f20, f21, f4"); - asm volatile("fcvt.s.w f21, %0" ::"r"(2)); - asm volatile("fdiv.s f22, f4, f21"); - asm volatile("fmadd.s f20, f22, f4, f20"); - asm volatile("fcvt.s.w f21, %0" ::"r"(3)); - asm volatile("fmul.s f22, f22, f4"); - asm volatile("fdiv.s f22, f22, f21"); - asm volatile("fmadd.s f4, f22, f4, f20"); + asm volatile("fcvt.s.w f21, %0" ::"r"(1)); + asm volatile("fadd.s f20, f21, f4"); + asm volatile("fcvt.s.w f21, %0" ::"r"(2)); + asm volatile("fdiv.s f22, f4, f21"); + asm volatile("fmadd.s f20, f22, f4, f20"); + asm volatile("fcvt.s.w f21, %0" ::"r"(3)); + asm volatile("fmul.s f22, f22, f4"); + asm volatile("fdiv.s f22, f22, f21"); + asm volatile("fmadd.s f4, f22, f4, f20"); - asm volatile("fcvt.s.w f24, %0" ::"r"(1)); - asm volatile("fadd.s f23, f24, f5"); - asm volatile("fcvt.s.w f24, %0" ::"r"(2)); - asm volatile("fdiv.s f25, f5, f24"); - asm volatile("fmadd.s f23, f25, f5, f23"); - asm volatile("fcvt.s.w f24, %0" ::"r"(3)); - asm volatile("fmul.s f25, f25, f5"); - asm volatile("fdiv.s f25, f25, f24"); - asm volatile("fmadd.s f5, f25, f5, f23"); + asm volatile("fcvt.s.w f24, %0" ::"r"(1)); + asm volatile("fadd.s f23, f24, f5"); + asm volatile("fcvt.s.w f24, %0" ::"r"(2)); + asm volatile("fdiv.s f25, f5, f24"); + asm volatile("fmadd.s f23, f25, f5, f23"); + asm volatile("fcvt.s.w f24, %0" ::"r"(3)); + asm volatile("fmul.s f25, f25, f5"); + asm volatile("fdiv.s f25, f25, f24"); + asm volatile("fmadd.s f5, f25, f5, f23"); - asm volatile("fcvt.s.w f27, %0" ::"r"(1)); - asm volatile("fadd.s f26, f27, f6"); - asm volatile("fcvt.s.w f27, %0" ::"r"(2)); - asm volatile("fdiv.s f28, f6, f27"); - asm volatile("fmadd.s f26, f28, f6, f26"); - asm volatile("fcvt.s.w f27, %0" ::"r"(3)); - asm volatile("fmul.s f28, f28, f6"); - asm volatile("fdiv.s f28, f28, f27"); - asm volatile("fmadd.s f6, f28, f6, f26"); + asm volatile("fcvt.s.w f27, %0" ::"r"(1)); + asm volatile("fadd.s f26, f27, f6"); + asm volatile("fcvt.s.w f27, %0" ::"r"(2)); + asm volatile("fdiv.s f28, f6, f27"); + asm volatile("fmadd.s f26, f28, f6, f26"); + asm volatile("fcvt.s.w f27, %0" ::"r"(3)); + asm volatile("fmul.s f28, f28, f6"); + asm volatile("fdiv.s f28, f28, f27"); + asm volatile("fmadd.s f6, f28, f6, f26"); - asm volatile("fcvt.s.w f30, %0" ::"r"(1)); - asm volatile("fadd.s f29, f30, f7"); - asm volatile("fcvt.s.w f30, %0" ::"r"(2)); - asm volatile("fdiv.s f31, f7, f30"); - asm volatile("fmadd.s f29, f31, f7, f29"); - asm volatile("fcvt.s.w f30, %0" ::"r"(3)); - asm volatile("fmul.s f31, f31, f7"); - asm volatile("fdiv.s f31, f31, f30"); - asm volatile("fmadd.s f7, f31, f7, f29"); + asm volatile("fcvt.s.w f30, %0" ::"r"(1)); + asm volatile("fadd.s f29, f30, f7"); + asm volatile("fcvt.s.w f30, %0" ::"r"(2)); + asm volatile("fdiv.s f31, f7, f30"); + asm volatile("fmadd.s f29, f31, f7, f29"); + asm volatile("fcvt.s.w f30, %0" ::"r"(3)); + asm volatile("fmul.s f31, f31, f7"); + asm volatile("fdiv.s f31, f31, f30"); + asm volatile("fmadd.s f7, f31, f7, f29"); + } // move back from RF to gmem asm volatile("mv t6, %0" ::"r"(global_C_curr)); @@ -203,6 +252,7 @@ inline void activate_block(const uint32_t dim_n, const float *const C, asm volatile("add t6, t6, %0" ::"r"(dim_n * sizeof(float))); asm volatile("mv %0, t6" :"=r"(global_C_curr)); } +#endif } void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg, @@ -287,13 +337,13 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg, activate_block(dim_n, C, tile_i, tile_j, warp_row, warp_col, tid_in_threadblock); - // for warp 1, do warp 0's worth of work as well - if (vx_warp_id() == 1) { - const uint32_t warp_row = (warp_id_in_threadblock - 1) / (TILE_N / WN); - const uint32_t warp_col = (warp_id_in_threadblock - 1) % (TILE_N / WN); - activate_block(dim_n, C, tile_i, tile_j, warp_row, warp_col, - tid_in_threadblock); - } + // // for warp 1, do warp 0's worth of work as well + // if (vx_warp_id() == 1) { + // const uint32_t warp_row = (warp_id_in_threadblock - 1) / (TILE_N / WN); + // const uint32_t warp_col = (warp_id_in_threadblock - 1) % (TILE_N / WN); + // activate_block(dim_n, C, tile_i, tile_j, warp_row, warp_col, + // tid_in_threadblock); + // } } if (HW_TID() == 0) { diff --git a/tests/regression/sgemm_tcore/kernel.activation.cpp b/tests/regression/sgemm_tcore/kernel.activation.cpp index db089ee0..05ef904e 100644 --- a/tests/regression/sgemm_tcore/kernel.activation.cpp +++ b/tests/regression/sgemm_tcore/kernel.activation.cpp @@ -253,7 +253,7 @@ inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, } } -inline void activate_block(const uint32_t dim_n, const float *const C, +inline void activate_block(const uint32_t dim_n, float *const C, const uint32_t tile_i, const uint32_t tile_j, const uint32_t warp_row, const uint32_t warp_col, const uint32_t tid_in_threadblock) { @@ -267,7 +267,7 @@ inline void activate_block(const uint32_t dim_n, const float *const C, const uint32_t row_in_warptile = 0; const uint32_t C_row = (tile_i * BM) + (warp_row * WM) + row_in_warptile; const uint32_t C_col = (tile_j * BN) + (warp_col * WN) + col_in_warptile; - const float *const global_C = C + dim_n * C_row + C_col; + float *const global_C = C + dim_n * C_row + C_col; const float *global_C_curr = global_C; // ELEM_PER_THREAD macro does not take into account warp-specialization @@ -278,6 +278,16 @@ inline void activate_block(const uint32_t dim_n, const float *const C, static_assert((elem_per_thread % asm_unrolled) == 0, "unmet manual unroll condition for elem_per_thread"); +#if 1 + float elems[elem_per_thread]; + +#pragma GCC unroll asm_unrolled + for (int elem_i = 0; elem_i < elem_per_thread; elem_i++) { + elems[elem_i] = global_C[dim_n * elem_i]; + elems[elem_i] = SWISH(1.0f, elems[elem_i]); + global_C[dim_n * elem_i] = elems[elem_i]; + } +#else for (int i = 0; i < elem_per_thread; i += asm_unrolled) { // read in elements from GMEM to RF asm volatile("mv t6, %0" ::"r"(global_C_curr)); @@ -299,6 +309,8 @@ inline void activate_block(const uint32_t dim_n, const float *const C, asm volatile("add t6, t6, %0" ::"r"(dim_n * sizeof(float))); if constexpr (true) { + // FIXME: this is likely incorrect; f0~f7 regs get overwritten by + // the compiler register float x0 asm("f0"); register float x1 asm("f1"); register float x2 asm("f2"); @@ -437,6 +449,7 @@ inline void activate_block(const uint32_t dim_n, const float *const C, asm volatile("add t6, t6, %0" ::"r"(dim_n * sizeof(float))); asm volatile("mv %0, t6" :"=r"(global_C_curr)); } +#endif } inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,