From a42fa6a1131b0288796b5c4ed52f97d2cfcbe372 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Wed, 5 Jun 2024 19:01:59 -0700 Subject: [PATCH] sgemm_tcore: Swap out mul with bitwise ops for addr ping-pong --- tests/regression/sgemm_tcore/kernel.cpp | 30 ++++++++++++++----------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp index 8f73e0f4..fb06966a 100644 --- a/tests/regression/sgemm_tcore/kernel.cpp +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -546,20 +546,22 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, if (warpgroup_id == 0) { // TODO: bring initiation pipeline here - uint32_t k_index = 0; + int32_t k_index = 0; #pragma GCC unroll 1 for (uint32_t k = 0; k < dim_k - BK; k += BK) { volatile float *local_a_produce; volatile float *local_b_produce; if constexpr (DOUBLE_BUFFER) { - local_a_produce = (k_index % 2) ? local_a : local_a_buf; - local_b_produce = (k_index % 2) ? local_b : local_b_buf; + const uint32_t mask_odd = (k_index & 1) << 31 >> 31; + const uint32_t mask_even = ((k_index & 1) ^ 1) << 31 >> 31; + // local_a_produce = (k_index % 2) ? local_a : local_a_buf; + // local_b_produce = (k_index % 2) ? local_b : local_b_buf; local_a_produce = reinterpret_cast( - ((k_index & 1) & 1) * reinterpret_cast(local_a) + - ((k_index & 1) ^ 1) * reinterpret_cast(local_a_buf)); + (mask_odd & reinterpret_cast(local_a)) | + (mask_even & reinterpret_cast(local_a_buf))); local_b_produce = reinterpret_cast( - ((k_index & 1) & 1) * reinterpret_cast(local_b) + - ((k_index & 1) ^ 1) * reinterpret_cast(local_b_buf)); + (mask_odd & reinterpret_cast(local_b)) | + (mask_even & reinterpret_cast(local_b_buf))); } else { local_a_produce = local_a; local_b_produce = local_b; @@ -575,7 +577,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y); } else { - uint32_t k_index = 0; + int32_t k_index = 0; #pragma GCC unroll 1 for (uint32_t k = 0; k < dim_k; k += BK) { volatile float *local_a_consume; @@ -584,12 +586,14 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, // local_a_consume = (k_index % 2) ? local_a_buf : local_a; // local_b_consume = (k_index % 2) ? local_b_buf : local_b; // FIXME: swap multiply with bitshifts + const uint32_t mask_odd = (k_index & 1) << 31 >> 31; + const uint32_t mask_even = ((k_index & 1) ^ 1) << 31 >> 31; local_a_consume = reinterpret_cast( - ((k_index & 1) & 1) * reinterpret_cast(local_a_buf) + - ((k_index & 1) ^ 1) * reinterpret_cast(local_a)); + (mask_odd & reinterpret_cast(local_a_buf)) | + (mask_even & reinterpret_cast(local_a))); local_b_consume = reinterpret_cast( - ((k_index & 1) & 1) * reinterpret_cast(local_b_buf) + - ((k_index & 1) ^ 1) * reinterpret_cast(local_b)); + (mask_odd & reinterpret_cast(local_b_buf)) | + (mask_even & reinterpret_cast(local_b))); } else { local_a_consume = local_a; local_b_consume = local_b; @@ -601,7 +605,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, // vx_wmma_load #pragma GCC unroll 1 for (int i = 0; i < BK_LOOP; i++) { -#pragma GCC unroll 10 +#pragma GCC unroll 4 for (uint32_t local_k = 0; local_k < BK; local_k += TCK) { // perform wmma // vx_wmma_load(local_a_consume, local_b_consume, warp_x, warp_y,