sgemm_tcore: Remove uneffective register asm
This commit is contained in:
@@ -183,7 +183,7 @@ inline void vx_wmma_load_a(volatile float *smem_A, const int local_k,
|
||||
} else {
|
||||
// transposed A
|
||||
// f8-f15 stores a single row of A
|
||||
register volatile float *smem_addr asm("t5");
|
||||
volatile float *smem_addr;
|
||||
smem_addr = &smem_A[((local_k + 0) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row];
|
||||
asm volatile("flw f0, %0(%1)" :: "i"(smem_AS_cols * 0 * sizeof(float)), "r"(smem_addr));
|
||||
asm volatile("flw f1, %0(%1)" :: "i"(smem_AS_cols * 1 * sizeof(float)), "r"(smem_addr));
|
||||
@@ -220,7 +220,7 @@ inline void vx_wmma_load_b(volatile float *smem_B, const int local_k,
|
||||
constexpr int smem_B_cols = BN;
|
||||
|
||||
// f8-f15 stores a single column of B
|
||||
register volatile float *smem_addr asm("t5");
|
||||
volatile float *smem_addr;
|
||||
smem_addr = &smem_B[((local_k + 0) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col];
|
||||
asm volatile("flw f8, %0(%1)" :: "i"(smem_B_cols * 0 * sizeof(float)), "r"(smem_addr));
|
||||
asm volatile("flw f9, %0(%1)" :: "i"(smem_B_cols * 1 * sizeof(float)), "r"(smem_addr));
|
||||
@@ -286,8 +286,8 @@ inline void write_results(const int thread_in_warp, const int warp_col,
|
||||
|
||||
// @perf: this likely causes a lot of gmem bank conflicts
|
||||
if (wm_iter == 0) {
|
||||
register volatile float *gmem_addr asm("t5");
|
||||
register volatile float *gmem_addr_tmp asm("t6");
|
||||
volatile float *gmem_addr;
|
||||
volatile float *gmem_addr_tmp;
|
||||
gmem_addr = &global_offset_C[dim_n * (local_row + 0) + (local_col + 0)];
|
||||
asm volatile ("fsw f16, %0" :: "m"(*(gmem_addr + 0)));
|
||||
asm volatile ("fsw f17, %0" :: "m"(*(gmem_addr + 1)));
|
||||
@@ -309,8 +309,8 @@ inline void write_results(const int thread_in_warp, const int warp_col,
|
||||
// asm volatile ("fsw f22, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 4)]));
|
||||
// asm volatile ("fsw f23, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 5)]));
|
||||
} else {
|
||||
register volatile float *gmem_addr asm("t5");
|
||||
register volatile float *gmem_addr_tmp asm("t6");
|
||||
volatile float *gmem_addr;
|
||||
volatile float *gmem_addr_tmp;
|
||||
gmem_addr = &global_offset_C[dim_n * (local_row + 0) + (local_col + 0)];
|
||||
gmem_addr_tmp = gmem_addr + (2 * dim_n);
|
||||
asm volatile ("fsw f24, %0" :: "m"(*(gmem_addr + 0)));
|
||||
@@ -494,9 +494,9 @@ inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k,
|
||||
static_assert(
|
||||
row_stride_b * 8 <= BK,
|
||||
"manual loop unrolling condition not met; consider increasing BK");
|
||||
static_assert(
|
||||
(BK % (row_stride_b * 8)) == 0,
|
||||
"manual loop unrolling condition not met; BK should be power-of-two");
|
||||
static_assert(
|
||||
(BK % (row_stride_b * 8)) == 0,
|
||||
"manual loop unrolling condition not met; BK should be power-of-two");
|
||||
|
||||
#pragma GCC unroll 1
|
||||
for (uint32_t load_offset = 0; load_offset < BK;
|
||||
@@ -618,11 +618,11 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
||||
// local_a_produce = (k_index % 2) ? local_a : local_a_buf;
|
||||
// local_b_produce = (k_index % 2) ? local_b : local_b_buf;
|
||||
local_a_produce = reinterpret_cast<volatile float *>(
|
||||
(mask_odd & reinterpret_cast<uint32_t>(local_a)) |
|
||||
(mask_even & reinterpret_cast<uint32_t>(local_a_buf)));
|
||||
(mask_odd & reinterpret_cast<uintmax_t>(local_a)) |
|
||||
(mask_even & reinterpret_cast<uintmax_t>(local_a_buf)));
|
||||
local_b_produce = reinterpret_cast<volatile float *>(
|
||||
(mask_odd & reinterpret_cast<uint32_t>(local_b)) |
|
||||
(mask_even & reinterpret_cast<uint32_t>(local_b_buf)));
|
||||
(mask_odd & reinterpret_cast<uintmax_t>(local_b)) |
|
||||
(mask_even & reinterpret_cast<uintmax_t>(local_b_buf)));
|
||||
} else {
|
||||
local_a_produce = local_a;
|
||||
local_b_produce = local_b;
|
||||
@@ -666,11 +666,11 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
||||
const uint32_t mask_odd = (k_index & 1) << 31 >> 31;
|
||||
const uint32_t mask_even = ((k_index & 1) ^ 1) << 31 >> 31;
|
||||
local_a_consume = reinterpret_cast<volatile float *>(
|
||||
(mask_odd & reinterpret_cast<uint32_t>(local_a_buf)) |
|
||||
(mask_even & reinterpret_cast<uint32_t>(local_a)));
|
||||
(mask_odd & reinterpret_cast<uintmax_t>(local_a_buf)) |
|
||||
(mask_even & reinterpret_cast<uintmax_t>(local_a)));
|
||||
local_b_consume = reinterpret_cast<volatile float *>(
|
||||
(mask_odd & reinterpret_cast<uint32_t>(local_b_buf)) |
|
||||
(mask_even & reinterpret_cast<uint32_t>(local_b)));
|
||||
(mask_odd & reinterpret_cast<uintmax_t>(local_b_buf)) |
|
||||
(mask_even & reinterpret_cast<uintmax_t>(local_b)));
|
||||
} else {
|
||||
local_a_consume = local_a;
|
||||
local_b_consume = local_b;
|
||||
|
||||
Reference in New Issue
Block a user