sgemm_tcore: Swap out mul with bitwise ops for addr ping-pong
This commit is contained in:
@@ -546,20 +546,22 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
||||
|
||||
if (warpgroup_id == 0) {
|
||||
// TODO: bring initiation pipeline here
|
||||
uint32_t k_index = 0;
|
||||
int32_t k_index = 0;
|
||||
#pragma GCC unroll 1
|
||||
for (uint32_t k = 0; k < dim_k - BK; k += BK) {
|
||||
volatile float *local_a_produce;
|
||||
volatile float *local_b_produce;
|
||||
if constexpr (DOUBLE_BUFFER) {
|
||||
local_a_produce = (k_index % 2) ? local_a : local_a_buf;
|
||||
local_b_produce = (k_index % 2) ? local_b : local_b_buf;
|
||||
const uint32_t mask_odd = (k_index & 1) << 31 >> 31;
|
||||
const uint32_t mask_even = ((k_index & 1) ^ 1) << 31 >> 31;
|
||||
// local_a_produce = (k_index % 2) ? local_a : local_a_buf;
|
||||
// local_b_produce = (k_index % 2) ? local_b : local_b_buf;
|
||||
local_a_produce = reinterpret_cast<volatile float *>(
|
||||
((k_index & 1) & 1) * reinterpret_cast<uint32_t>(local_a) +
|
||||
((k_index & 1) ^ 1) * reinterpret_cast<uint32_t>(local_a_buf));
|
||||
(mask_odd & reinterpret_cast<uint32_t>(local_a)) |
|
||||
(mask_even & reinterpret_cast<uint32_t>(local_a_buf)));
|
||||
local_b_produce = reinterpret_cast<volatile float *>(
|
||||
((k_index & 1) & 1) * reinterpret_cast<uint32_t>(local_b) +
|
||||
((k_index & 1) ^ 1) * reinterpret_cast<uint32_t>(local_b_buf));
|
||||
(mask_odd & reinterpret_cast<uint32_t>(local_b)) |
|
||||
(mask_even & reinterpret_cast<uint32_t>(local_b_buf)));
|
||||
} else {
|
||||
local_a_produce = local_a;
|
||||
local_b_produce = local_b;
|
||||
@@ -575,7 +577,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
||||
|
||||
threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y);
|
||||
} else {
|
||||
uint32_t k_index = 0;
|
||||
int32_t k_index = 0;
|
||||
#pragma GCC unroll 1
|
||||
for (uint32_t k = 0; k < dim_k; k += BK) {
|
||||
volatile float *local_a_consume;
|
||||
@@ -584,12 +586,14 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
||||
// local_a_consume = (k_index % 2) ? local_a_buf : local_a;
|
||||
// local_b_consume = (k_index % 2) ? local_b_buf : local_b;
|
||||
// FIXME: swap multiply with bitshifts
|
||||
const uint32_t mask_odd = (k_index & 1) << 31 >> 31;
|
||||
const uint32_t mask_even = ((k_index & 1) ^ 1) << 31 >> 31;
|
||||
local_a_consume = reinterpret_cast<volatile float *>(
|
||||
((k_index & 1) & 1) * reinterpret_cast<uint32_t>(local_a_buf) +
|
||||
((k_index & 1) ^ 1) * reinterpret_cast<uint32_t>(local_a));
|
||||
(mask_odd & reinterpret_cast<uint32_t>(local_a_buf)) |
|
||||
(mask_even & reinterpret_cast<uint32_t>(local_a)));
|
||||
local_b_consume = reinterpret_cast<volatile float *>(
|
||||
((k_index & 1) & 1) * reinterpret_cast<uint32_t>(local_b_buf) +
|
||||
((k_index & 1) ^ 1) * reinterpret_cast<uint32_t>(local_b));
|
||||
(mask_odd & reinterpret_cast<uint32_t>(local_b_buf)) |
|
||||
(mask_even & reinterpret_cast<uint32_t>(local_b)));
|
||||
} else {
|
||||
local_a_consume = local_a;
|
||||
local_b_consume = local_b;
|
||||
@@ -601,7 +605,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
||||
// vx_wmma_load
|
||||
#pragma GCC unroll 1
|
||||
for (int i = 0; i < BK_LOOP; i++) {
|
||||
#pragma GCC unroll 10
|
||||
#pragma GCC unroll 4
|
||||
for (uint32_t local_k = 0; local_k < BK; local_k += TCK) {
|
||||
// perform wmma
|
||||
// vx_wmma_load(local_a_consume, local_b_consume, warp_x, warp_y,
|
||||
|
||||
Reference in New Issue
Block a user