Use SWISH in activate_block for tcore and gemmini

This commit is contained in:
Hansung Kim
2024-06-19 15:41:50 -07:00
parent ae9e707280
commit bebdd3353e
2 changed files with 155 additions and 92 deletions

View File

@@ -253,7 +253,7 @@ inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k,
}
}
inline void activate_block(const uint32_t dim_n, const float *const C,
inline void activate_block(const uint32_t dim_n, float *const C,
const uint32_t tile_i, const uint32_t tile_j,
const uint32_t warp_row, const uint32_t warp_col,
const uint32_t tid_in_threadblock) {
@@ -267,7 +267,7 @@ inline void activate_block(const uint32_t dim_n, const float *const C,
const uint32_t row_in_warptile = 0;
const uint32_t C_row = (tile_i * BM) + (warp_row * WM) + row_in_warptile;
const uint32_t C_col = (tile_j * BN) + (warp_col * WN) + col_in_warptile;
const float *const global_C = C + dim_n * C_row + C_col;
float *const global_C = C + dim_n * C_row + C_col;
const float *global_C_curr = global_C;
// ELEM_PER_THREAD macro does not take into account warp-specialization
@@ -278,6 +278,16 @@ inline void activate_block(const uint32_t dim_n, const float *const C,
static_assert((elem_per_thread % asm_unrolled) == 0,
"unmet manual unroll condition for elem_per_thread");
#if 1
float elems[elem_per_thread];
#pragma GCC unroll asm_unrolled
for (int elem_i = 0; elem_i < elem_per_thread; elem_i++) {
elems[elem_i] = global_C[dim_n * elem_i];
elems[elem_i] = SWISH(1.0f, elems[elem_i]);
global_C[dim_n * elem_i] = elems[elem_i];
}
#else
for (int i = 0; i < elem_per_thread; i += asm_unrolled) {
// read in elements from GMEM to RF
asm volatile("mv t6, %0" ::"r"(global_C_curr));
@@ -299,6 +309,8 @@ inline void activate_block(const uint32_t dim_n, const float *const C,
asm volatile("add t6, t6, %0" ::"r"(dim_n * sizeof(float)));
if constexpr (true) {
// FIXME: this is likely incorrect; f0~f7 regs get overwritten by
// the compiler
register float x0 asm("f0");
register float x1 asm("f1");
register float x2 asm("f2");
@@ -437,6 +449,7 @@ inline void activate_block(const uint32_t dim_n, const float *const C,
asm volatile("add t6, t6, %0" ::"r"(dim_n * sizeof(float)));
asm volatile("mv %0, t6" :"=r"(global_C_curr));
}
#endif
}
inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,