From 985c5fc0dcf868e1f4677ee02fd57d25ac121e40 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Thu, 6 Jun 2024 18:50:31 -0700 Subject: [PATCH] sgemm_tcore: Remove uneffective register asm --- tests/regression/sgemm_tcore/kernel.cpp | 34 ++++++++++++------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp index 6db7ae3d..42443de7 100644 --- a/tests/regression/sgemm_tcore/kernel.cpp +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -183,7 +183,7 @@ inline void vx_wmma_load_a(volatile float *smem_A, const int local_k, } else { // transposed A // f8-f15 stores a single row of A - register volatile float *smem_addr asm("t5"); + volatile float *smem_addr; smem_addr = &smem_A[((local_k + 0) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row]; asm volatile("flw f0, %0(%1)" :: "i"(smem_AS_cols * 0 * sizeof(float)), "r"(smem_addr)); asm volatile("flw f1, %0(%1)" :: "i"(smem_AS_cols * 1 * sizeof(float)), "r"(smem_addr)); @@ -220,7 +220,7 @@ inline void vx_wmma_load_b(volatile float *smem_B, const int local_k, constexpr int smem_B_cols = BN; // f8-f15 stores a single column of B - register volatile float *smem_addr asm("t5"); + volatile float *smem_addr; smem_addr = &smem_B[((local_k + 0) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col]; asm volatile("flw f8, %0(%1)" :: "i"(smem_B_cols * 0 * sizeof(float)), "r"(smem_addr)); asm volatile("flw f9, %0(%1)" :: "i"(smem_B_cols * 1 * sizeof(float)), "r"(smem_addr)); @@ -286,8 +286,8 @@ inline void write_results(const int thread_in_warp, const int warp_col, // @perf: this likely causes a lot of gmem bank conflicts if (wm_iter == 0) { - register volatile float *gmem_addr asm("t5"); - register volatile float *gmem_addr_tmp asm("t6"); + volatile float *gmem_addr; + volatile float *gmem_addr_tmp; gmem_addr = &global_offset_C[dim_n * (local_row + 0) + (local_col + 0)]; asm volatile ("fsw f16, %0" :: "m"(*(gmem_addr + 0))); asm volatile ("fsw f17, %0" :: "m"(*(gmem_addr + 1))); @@ -309,8 +309,8 @@ inline void write_results(const int thread_in_warp, const int warp_col, // asm volatile ("fsw f22, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 4)])); // asm volatile ("fsw f23, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 5)])); } else { - register volatile float *gmem_addr asm("t5"); - register volatile float *gmem_addr_tmp asm("t6"); + volatile float *gmem_addr; + volatile float *gmem_addr_tmp; gmem_addr = &global_offset_C[dim_n * (local_row + 0) + (local_col + 0)]; gmem_addr_tmp = gmem_addr + (2 * dim_n); asm volatile ("fsw f24, %0" :: "m"(*(gmem_addr + 0))); @@ -494,9 +494,9 @@ inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, static_assert( row_stride_b * 8 <= BK, "manual loop unrolling condition not met; consider increasing BK"); - static_assert( - (BK % (row_stride_b * 8)) == 0, - "manual loop unrolling condition not met; BK should be power-of-two"); + static_assert( + (BK % (row_stride_b * 8)) == 0, + "manual loop unrolling condition not met; BK should be power-of-two"); #pragma GCC unroll 1 for (uint32_t load_offset = 0; load_offset < BK; @@ -618,11 +618,11 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, // local_a_produce = (k_index % 2) ? local_a : local_a_buf; // local_b_produce = (k_index % 2) ? local_b : local_b_buf; local_a_produce = reinterpret_cast( - (mask_odd & reinterpret_cast(local_a)) | - (mask_even & reinterpret_cast(local_a_buf))); + (mask_odd & reinterpret_cast(local_a)) | + (mask_even & reinterpret_cast(local_a_buf))); local_b_produce = reinterpret_cast( - (mask_odd & reinterpret_cast(local_b)) | - (mask_even & reinterpret_cast(local_b_buf))); + (mask_odd & reinterpret_cast(local_b)) | + (mask_even & reinterpret_cast(local_b_buf))); } else { local_a_produce = local_a; local_b_produce = local_b; @@ -666,11 +666,11 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, const uint32_t mask_odd = (k_index & 1) << 31 >> 31; const uint32_t mask_even = ((k_index & 1) ^ 1) << 31 >> 31; local_a_consume = reinterpret_cast( - (mask_odd & reinterpret_cast(local_a_buf)) | - (mask_even & reinterpret_cast(local_a))); + (mask_odd & reinterpret_cast(local_a_buf)) | + (mask_even & reinterpret_cast(local_a))); local_b_consume = reinterpret_cast( - (mask_odd & reinterpret_cast(local_b_buf)) | - (mask_even & reinterpret_cast(local_b))); + (mask_odd & reinterpret_cast(local_b_buf)) | + (mask_even & reinterpret_cast(local_b))); } else { local_a_consume = local_a; local_b_consume = local_b;