Use SWISH in activate_block for tcore and gemmini
This commit is contained in:
@@ -253,7 +253,7 @@ inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k,
|
||||
}
|
||||
}
|
||||
|
||||
inline void activate_block(const uint32_t dim_n, const float *const C,
|
||||
inline void activate_block(const uint32_t dim_n, float *const C,
|
||||
const uint32_t tile_i, const uint32_t tile_j,
|
||||
const uint32_t warp_row, const uint32_t warp_col,
|
||||
const uint32_t tid_in_threadblock) {
|
||||
@@ -267,7 +267,7 @@ inline void activate_block(const uint32_t dim_n, const float *const C,
|
||||
const uint32_t row_in_warptile = 0;
|
||||
const uint32_t C_row = (tile_i * BM) + (warp_row * WM) + row_in_warptile;
|
||||
const uint32_t C_col = (tile_j * BN) + (warp_col * WN) + col_in_warptile;
|
||||
const float *const global_C = C + dim_n * C_row + C_col;
|
||||
float *const global_C = C + dim_n * C_row + C_col;
|
||||
const float *global_C_curr = global_C;
|
||||
|
||||
// ELEM_PER_THREAD macro does not take into account warp-specialization
|
||||
@@ -278,6 +278,16 @@ inline void activate_block(const uint32_t dim_n, const float *const C,
|
||||
static_assert((elem_per_thread % asm_unrolled) == 0,
|
||||
"unmet manual unroll condition for elem_per_thread");
|
||||
|
||||
#if 1
|
||||
float elems[elem_per_thread];
|
||||
|
||||
#pragma GCC unroll asm_unrolled
|
||||
for (int elem_i = 0; elem_i < elem_per_thread; elem_i++) {
|
||||
elems[elem_i] = global_C[dim_n * elem_i];
|
||||
elems[elem_i] = SWISH(1.0f, elems[elem_i]);
|
||||
global_C[dim_n * elem_i] = elems[elem_i];
|
||||
}
|
||||
#else
|
||||
for (int i = 0; i < elem_per_thread; i += asm_unrolled) {
|
||||
// read in elements from GMEM to RF
|
||||
asm volatile("mv t6, %0" ::"r"(global_C_curr));
|
||||
@@ -299,6 +309,8 @@ inline void activate_block(const uint32_t dim_n, const float *const C,
|
||||
asm volatile("add t6, t6, %0" ::"r"(dim_n * sizeof(float)));
|
||||
|
||||
if constexpr (true) {
|
||||
// FIXME: this is likely incorrect; f0~f7 regs get overwritten by
|
||||
// the compiler
|
||||
register float x0 asm("f0");
|
||||
register float x1 asm("f1");
|
||||
register float x2 asm("f2");
|
||||
@@ -437,6 +449,7 @@ inline void activate_block(const uint32_t dim_n, const float *const C,
|
||||
asm volatile("add t6, t6, %0" ::"r"(dim_n * sizeof(float)));
|
||||
asm volatile("mv %0, t6" :"=r"(global_C_curr));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
||||
|
||||
Reference in New Issue
Block a user