sgemm_tcore: Use arithmetic instead of branch for double-buffered addr
This commit is contained in:
@@ -385,7 +385,7 @@ global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, const uint32_t k,
|
|||||||
row_stride_as * 8 <= BK,
|
row_stride_as * 8 <= BK,
|
||||||
"manual loop unrolling condition not met; consider increasing BK");
|
"manual loop unrolling condition not met; consider increasing BK");
|
||||||
|
|
||||||
#pragma GCC ivdep
|
#pragma GCC unroll 2
|
||||||
for (uint32_t local_row_offset = 0; local_row_offset < BK;
|
for (uint32_t local_row_offset = 0; local_row_offset < BK;
|
||||||
local_row_offset += row_stride_as * 8) {
|
local_row_offset += row_stride_as * 8) {
|
||||||
// @perf: bank conflicts here
|
// @perf: bank conflicts here
|
||||||
@@ -436,7 +436,7 @@ global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, const uint32_t k,
|
|||||||
row_stride_b * 8 <= BK,
|
row_stride_b * 8 <= BK,
|
||||||
"manual loop unrolling condition not met; consider increasing BK");
|
"manual loop unrolling condition not met; consider increasing BK");
|
||||||
|
|
||||||
#pragma GCC ivdep
|
#pragma GCC unroll 2
|
||||||
for (uint32_t load_offset = 0; load_offset < BK;
|
for (uint32_t load_offset = 0; load_offset < BK;
|
||||||
load_offset += row_stride_b * 8) {
|
load_offset += row_stride_b * 8) {
|
||||||
// const uint32_t global_b_offset =
|
// const uint32_t global_b_offset =
|
||||||
@@ -551,18 +551,18 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
|||||||
for (uint32_t k = 0; k < dim_k - BK; k += BK) {
|
for (uint32_t k = 0; k < dim_k - BK; k += BK) {
|
||||||
volatile float *local_a_produce;
|
volatile float *local_a_produce;
|
||||||
volatile float *local_b_produce;
|
volatile float *local_b_produce;
|
||||||
volatile float *local_a_consume;
|
|
||||||
volatile float *local_b_consume;
|
|
||||||
if constexpr (DOUBLE_BUFFER) {
|
if constexpr (DOUBLE_BUFFER) {
|
||||||
local_a_produce = (k_index % 2) ? local_a : local_a_buf;
|
local_a_produce = (k_index % 2) ? local_a : local_a_buf;
|
||||||
local_b_produce = (k_index % 2) ? local_b : local_b_buf;
|
local_b_produce = (k_index % 2) ? local_b : local_b_buf;
|
||||||
local_a_consume = (k_index % 2) ? local_a_buf : local_a;
|
local_a_produce = reinterpret_cast<volatile float *>(
|
||||||
local_b_consume = (k_index % 2) ? local_b_buf : local_b;
|
((k_index & 1) & 1) * reinterpret_cast<uint32_t>(local_a) +
|
||||||
|
((k_index & 1) ^ 1) * reinterpret_cast<uint32_t>(local_a_buf));
|
||||||
|
local_b_produce = reinterpret_cast<volatile float *>(
|
||||||
|
((k_index & 1) & 1) * reinterpret_cast<uint32_t>(local_b) +
|
||||||
|
((k_index & 1) ^ 1) * reinterpret_cast<uint32_t>(local_b_buf));
|
||||||
} else {
|
} else {
|
||||||
local_a_produce = local_a;
|
local_a_produce = local_a;
|
||||||
local_b_produce = local_b;
|
local_b_produce = local_b;
|
||||||
local_a_consume = local_a;
|
|
||||||
local_b_consume = local_b;
|
|
||||||
}
|
}
|
||||||
k_index++;
|
k_index++;
|
||||||
|
|
||||||
@@ -578,18 +578,19 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
|||||||
uint32_t k_index = 0;
|
uint32_t k_index = 0;
|
||||||
#pragma GCC unroll 1
|
#pragma GCC unroll 1
|
||||||
for (uint32_t k = 0; k < dim_k; k += BK) {
|
for (uint32_t k = 0; k < dim_k; k += BK) {
|
||||||
volatile float *local_a_produce;
|
|
||||||
volatile float *local_b_produce;
|
|
||||||
volatile float *local_a_consume;
|
volatile float *local_a_consume;
|
||||||
volatile float *local_b_consume;
|
volatile float *local_b_consume;
|
||||||
if constexpr (DOUBLE_BUFFER) {
|
if constexpr (DOUBLE_BUFFER) {
|
||||||
local_a_produce = (k_index % 2) ? local_a : local_a_buf;
|
// local_a_consume = (k_index % 2) ? local_a_buf : local_a;
|
||||||
local_b_produce = (k_index % 2) ? local_b : local_b_buf;
|
// local_b_consume = (k_index % 2) ? local_b_buf : local_b;
|
||||||
local_a_consume = (k_index % 2) ? local_a_buf : local_a;
|
// FIXME: swap multiply with bitshifts
|
||||||
local_b_consume = (k_index % 2) ? local_b_buf : local_b;
|
local_a_consume = reinterpret_cast<volatile float *>(
|
||||||
|
((k_index & 1) & 1) * reinterpret_cast<uint32_t>(local_a_buf) +
|
||||||
|
((k_index & 1) ^ 1) * reinterpret_cast<uint32_t>(local_a));
|
||||||
|
local_b_consume = reinterpret_cast<volatile float *>(
|
||||||
|
((k_index & 1) & 1) * reinterpret_cast<uint32_t>(local_b_buf) +
|
||||||
|
((k_index & 1) ^ 1) * reinterpret_cast<uint32_t>(local_b));
|
||||||
} else {
|
} else {
|
||||||
local_a_produce = local_a;
|
|
||||||
local_b_produce = local_b;
|
|
||||||
local_a_consume = local_a;
|
local_a_consume = local_a;
|
||||||
local_b_consume = local_b;
|
local_b_consume = local_b;
|
||||||
}
|
}
|
||||||
@@ -600,7 +601,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
|||||||
// vx_wmma_load
|
// vx_wmma_load
|
||||||
#pragma GCC unroll 1
|
#pragma GCC unroll 1
|
||||||
for (int i = 0; i < BK_LOOP; i++) {
|
for (int i = 0; i < BK_LOOP; i++) {
|
||||||
#pragma GCC unroll 1
|
#pragma GCC unroll 10
|
||||||
for (uint32_t local_k = 0; local_k < BK; local_k += TCK) {
|
for (uint32_t local_k = 0; local_k < BK; local_k += TCK) {
|
||||||
// perform wmma
|
// perform wmma
|
||||||
// vx_wmma_load(local_a_consume, local_b_consume, warp_x, warp_y,
|
// vx_wmma_load(local_a_consume, local_b_consume, warp_x, warp_y,
|
||||||
@@ -612,7 +613,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
|||||||
vx_wmma_load_b(local_b_consume, local_k, warp_col, wn_iter,
|
vx_wmma_load_b(local_b_consume, local_k, warp_col, wn_iter,
|
||||||
tid_in_warp);
|
tid_in_warp);
|
||||||
// vx_wmma_load_b(local_b_consume, 0, 0, 0, tid_in_warp);
|
// vx_wmma_load_b(local_b_consume, 0, 0, 0, tid_in_warp);
|
||||||
#pragma GCC unroll 1
|
#pragma GCC unroll 2
|
||||||
for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) {
|
for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) {
|
||||||
#if TC_SINGLE_WARP
|
#if TC_SINGLE_WARP
|
||||||
if (warp_in_warpgroup == 0) {
|
if (warp_in_warpgroup == 0) {
|
||||||
|
|||||||
Reference in New Issue
Block a user