sgemm_tcore: Use arithmetic instead of branch for double-buffered addr

This commit is contained in:
Hansung Kim
2024-06-05 18:03:08 -07:00
parent c7a6ed03de
commit 65c653afde

View File

@@ -385,7 +385,7 @@ global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, const uint32_t k,
row_stride_as * 8 <= BK, row_stride_as * 8 <= BK,
"manual loop unrolling condition not met; consider increasing BK"); "manual loop unrolling condition not met; consider increasing BK");
#pragma GCC ivdep #pragma GCC unroll 2
for (uint32_t local_row_offset = 0; local_row_offset < BK; for (uint32_t local_row_offset = 0; local_row_offset < BK;
local_row_offset += row_stride_as * 8) { local_row_offset += row_stride_as * 8) {
// @perf: bank conflicts here // @perf: bank conflicts here
@@ -436,7 +436,7 @@ global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, const uint32_t k,
row_stride_b * 8 <= BK, row_stride_b * 8 <= BK,
"manual loop unrolling condition not met; consider increasing BK"); "manual loop unrolling condition not met; consider increasing BK");
#pragma GCC ivdep #pragma GCC unroll 2
for (uint32_t load_offset = 0; load_offset < BK; for (uint32_t load_offset = 0; load_offset < BK;
load_offset += row_stride_b * 8) { load_offset += row_stride_b * 8) {
// const uint32_t global_b_offset = // const uint32_t global_b_offset =
@@ -551,18 +551,18 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
for (uint32_t k = 0; k < dim_k - BK; k += BK) { for (uint32_t k = 0; k < dim_k - BK; k += BK) {
volatile float *local_a_produce; volatile float *local_a_produce;
volatile float *local_b_produce; volatile float *local_b_produce;
volatile float *local_a_consume;
volatile float *local_b_consume;
if constexpr (DOUBLE_BUFFER) { if constexpr (DOUBLE_BUFFER) {
local_a_produce = (k_index % 2) ? local_a : local_a_buf; local_a_produce = (k_index % 2) ? local_a : local_a_buf;
local_b_produce = (k_index % 2) ? local_b : local_b_buf; local_b_produce = (k_index % 2) ? local_b : local_b_buf;
local_a_consume = (k_index % 2) ? local_a_buf : local_a; local_a_produce = reinterpret_cast<volatile float *>(
local_b_consume = (k_index % 2) ? local_b_buf : local_b; ((k_index & 1) & 1) * reinterpret_cast<uint32_t>(local_a) +
((k_index & 1) ^ 1) * reinterpret_cast<uint32_t>(local_a_buf));
local_b_produce = reinterpret_cast<volatile float *>(
((k_index & 1) & 1) * reinterpret_cast<uint32_t>(local_b) +
((k_index & 1) ^ 1) * reinterpret_cast<uint32_t>(local_b_buf));
} else { } else {
local_a_produce = local_a; local_a_produce = local_a;
local_b_produce = local_b; local_b_produce = local_b;
local_a_consume = local_a;
local_b_consume = local_b;
} }
k_index++; k_index++;
@@ -578,18 +578,19 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
uint32_t k_index = 0; uint32_t k_index = 0;
#pragma GCC unroll 1 #pragma GCC unroll 1
for (uint32_t k = 0; k < dim_k; k += BK) { for (uint32_t k = 0; k < dim_k; k += BK) {
volatile float *local_a_produce;
volatile float *local_b_produce;
volatile float *local_a_consume; volatile float *local_a_consume;
volatile float *local_b_consume; volatile float *local_b_consume;
if constexpr (DOUBLE_BUFFER) { if constexpr (DOUBLE_BUFFER) {
local_a_produce = (k_index % 2) ? local_a : local_a_buf; // local_a_consume = (k_index % 2) ? local_a_buf : local_a;
local_b_produce = (k_index % 2) ? local_b : local_b_buf; // local_b_consume = (k_index % 2) ? local_b_buf : local_b;
local_a_consume = (k_index % 2) ? local_a_buf : local_a; // FIXME: swap multiply with bitshifts
local_b_consume = (k_index % 2) ? local_b_buf : local_b; local_a_consume = reinterpret_cast<volatile float *>(
((k_index & 1) & 1) * reinterpret_cast<uint32_t>(local_a_buf) +
((k_index & 1) ^ 1) * reinterpret_cast<uint32_t>(local_a));
local_b_consume = reinterpret_cast<volatile float *>(
((k_index & 1) & 1) * reinterpret_cast<uint32_t>(local_b_buf) +
((k_index & 1) ^ 1) * reinterpret_cast<uint32_t>(local_b));
} else { } else {
local_a_produce = local_a;
local_b_produce = local_b;
local_a_consume = local_a; local_a_consume = local_a;
local_b_consume = local_b; local_b_consume = local_b;
} }
@@ -600,7 +601,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
// vx_wmma_load // vx_wmma_load
#pragma GCC unroll 1 #pragma GCC unroll 1
for (int i = 0; i < BK_LOOP; i++) { for (int i = 0; i < BK_LOOP; i++) {
#pragma GCC unroll 1 #pragma GCC unroll 10
for (uint32_t local_k = 0; local_k < BK; local_k += TCK) { for (uint32_t local_k = 0; local_k < BK; local_k += TCK) {
// perform wmma // perform wmma
// vx_wmma_load(local_a_consume, local_b_consume, warp_x, warp_y, // vx_wmma_load(local_a_consume, local_b_consume, warp_x, warp_y,
@@ -612,7 +613,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
vx_wmma_load_b(local_b_consume, local_k, warp_col, wn_iter, vx_wmma_load_b(local_b_consume, local_k, warp_col, wn_iter,
tid_in_warp); tid_in_warp);
// vx_wmma_load_b(local_b_consume, 0, 0, 0, tid_in_warp); // vx_wmma_load_b(local_b_consume, 0, 0, 0, tid_in_warp);
#pragma GCC unroll 1 #pragma GCC unroll 2
for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) { for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) {
#if TC_SINGLE_WARP #if TC_SINGLE_WARP
if (warp_in_warpgroup == 0) { if (warp_in_warpgroup == 0) {