sgemm_wg: Cleanup & proper unroll

This commit is contained in:
Hansung Kim
2024-02-28 21:17:42 -08:00
parent 46f242e520
commit a06b2dd20e

View File

@@ -40,30 +40,30 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
const uint32_t global_a_row = BM * threadblock_id_y + local_a_row;
const uint32_t global_b_col = BN * threadblock_id_x + local_b_col;
// each thread generates one output element
// each thread generates TM output element
float reg_c[TM] = { 0.0f };
for (uint32_t k = 0; k < dim_k; k += BK) {
float *local_a = sharedmem_per_threadblock;
size_t local_a_elems = threadblock_dim_x * threadblock_dim_y;
float *local_b = sharedmem_per_threadblock + local_a_elems;
volatile float *local_a = sharedmem_per_threadblock;
const size_t local_a_elems = threadblock_dim_x * threadblock_dim_y;
volatile float *local_b = sharedmem_per_threadblock + local_a_elems;
for (uint32_t k = 0; k < dim_k; k += BK) {
uint32_t global_a_offset = dim_k * global_a_row + (k + local_a_col);
uint32_t global_b_offset = dim_n * (k + local_b_row) + global_b_col;
// NOTE: local_b is transposed to column-major to facilitate better memory
// access.
local_a[BK * local_a_row + local_a_col] = A[global_a_offset];
local_b[BN * local_b_row + local_b_col] = B[global_b_offset];
vx_barrier(threadblock_id_in_core, threadblock_dim_y);
vx_fence();
#pragma GCC unroll TM
for (uint32_t local_k = 0; local_k < BK; local_k++) {
// Compute multiple result elements (TM) per thread
const float local_b_tmp = local_b[BN * local_k + local_b_col];
#pragma GCC unroll 4
#pragma GCC unroll TM
for (uint32_t result_idx = 0; result_idx < TM; result_idx++) {
// NOTE use of local_b_row
reg_c[result_idx] +=
local_a[BK * (TM * local_b_row + result_idx) + local_k] *
local_b_tmp;
@@ -74,8 +74,9 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
vx_fence();
}
#pragma GCC unroll 4
#pragma GCC unroll TM
for (uint32_t result_idx = 0; result_idx < TM; result_idx++) {
// NOTE use of local_b_row and global_b_col here
C[dim_n * (BM * threadblock_id_y + TM * local_b_row + result_idx) +
global_b_col] = reg_c[result_idx];
}