sgemm_tcore: Swap out mul with bitwise ops for addr ping-pong

This commit is contained in:
Hansung Kim
2024-06-05 19:01:59 -07:00
parent 65c653afde
commit a42fa6a113

View File

@@ -546,20 +546,22 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
if (warpgroup_id == 0) {
// TODO: bring initiation pipeline here
uint32_t k_index = 0;
int32_t k_index = 0;
#pragma GCC unroll 1
for (uint32_t k = 0; k < dim_k - BK; k += BK) {
volatile float *local_a_produce;
volatile float *local_b_produce;
if constexpr (DOUBLE_BUFFER) {
local_a_produce = (k_index % 2) ? local_a : local_a_buf;
local_b_produce = (k_index % 2) ? local_b : local_b_buf;
const uint32_t mask_odd = (k_index & 1) << 31 >> 31;
const uint32_t mask_even = ((k_index & 1) ^ 1) << 31 >> 31;
// local_a_produce = (k_index % 2) ? local_a : local_a_buf;
// local_b_produce = (k_index % 2) ? local_b : local_b_buf;
local_a_produce = reinterpret_cast<volatile float *>(
((k_index & 1) & 1) * reinterpret_cast<uint32_t>(local_a) +
((k_index & 1) ^ 1) * reinterpret_cast<uint32_t>(local_a_buf));
(mask_odd & reinterpret_cast<uint32_t>(local_a)) |
(mask_even & reinterpret_cast<uint32_t>(local_a_buf)));
local_b_produce = reinterpret_cast<volatile float *>(
((k_index & 1) & 1) * reinterpret_cast<uint32_t>(local_b) +
((k_index & 1) ^ 1) * reinterpret_cast<uint32_t>(local_b_buf));
(mask_odd & reinterpret_cast<uint32_t>(local_b)) |
(mask_even & reinterpret_cast<uint32_t>(local_b_buf)));
} else {
local_a_produce = local_a;
local_b_produce = local_b;
@@ -575,7 +577,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y);
} else {
uint32_t k_index = 0;
int32_t k_index = 0;
#pragma GCC unroll 1
for (uint32_t k = 0; k < dim_k; k += BK) {
volatile float *local_a_consume;
@@ -584,12 +586,14 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
// local_a_consume = (k_index % 2) ? local_a_buf : local_a;
// local_b_consume = (k_index % 2) ? local_b_buf : local_b;
// FIXME: swap multiply with bitshifts
const uint32_t mask_odd = (k_index & 1) << 31 >> 31;
const uint32_t mask_even = ((k_index & 1) ^ 1) << 31 >> 31;
local_a_consume = reinterpret_cast<volatile float *>(
((k_index & 1) & 1) * reinterpret_cast<uint32_t>(local_a_buf) +
((k_index & 1) ^ 1) * reinterpret_cast<uint32_t>(local_a));
(mask_odd & reinterpret_cast<uint32_t>(local_a_buf)) |
(mask_even & reinterpret_cast<uint32_t>(local_a)));
local_b_consume = reinterpret_cast<volatile float *>(
((k_index & 1) & 1) * reinterpret_cast<uint32_t>(local_b_buf) +
((k_index & 1) ^ 1) * reinterpret_cast<uint32_t>(local_b));
(mask_odd & reinterpret_cast<uint32_t>(local_b_buf)) |
(mask_even & reinterpret_cast<uint32_t>(local_b)));
} else {
local_a_consume = local_a;
local_b_consume = local_b;
@@ -601,7 +605,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
// vx_wmma_load
#pragma GCC unroll 1
for (int i = 0; i < BK_LOOP; i++) {
#pragma GCC unroll 10
#pragma GCC unroll 4
for (uint32_t local_k = 0; local_k < BK; local_k += TCK) {
// perform wmma
// vx_wmma_load(local_a_consume, local_b_consume, warp_x, warp_y,