From 5c298c81df683ad891a09433ab2a3b4cc89be448 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Sun, 12 May 2024 22:22:54 -0700 Subject: [PATCH] sgemm_tg: Use reg mapping functions --- tests/regression/sgemm_tcore/.gitignore | 1 + tests/regression/sgemm_tcore/kernel.cpp | 222 ++++++++++++++---------- 2 files changed, 136 insertions(+), 87 deletions(-) create mode 100644 tests/regression/sgemm_tcore/.gitignore diff --git a/tests/regression/sgemm_tcore/.gitignore b/tests/regression/sgemm_tcore/.gitignore new file mode 100644 index 00000000..6ef379cc --- /dev/null +++ b/tests/regression/sgemm_tcore/.gitignore @@ -0,0 +1 @@ +sgemm_tcore diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp index 11a795df..f498f57b 100644 --- a/tests/regression/sgemm_tcore/kernel.cpp +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -10,18 +10,85 @@ #define BN 16 #define BK 8 -inline void vx_wmma() { - asm volatile (".insn r %0, 0, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3)); +inline constexpr void map_operand_32lanes(const int tid, int &row, int &col) { + const int tg = tid / 4; + + // A (row major) + // Figure 7(a) in paper + // row 0~ 3: threadgroups 0 and 2 + // row 4~ 7: threadgroups 4 and 6 + // row 8~11: threadgroups 1 and 3 + // row 12~15: threadgroups 5 and 7 + row = tid % 4; + row += (tg * 8) % 16; + row += (tg / 4) * 4; + + // B (column major) + // NOTE: Matrix B mapping in Figure 7(a) is incorrect; below is the + // corrected mapping: + // col 0~ 3: threadgroups 0 and 1 + // col 4~ 7: threadgroups 4 and 5 + // col 8~11: threadgroups 2 and 3 + // col 12~15: threadgroups 6 and 7 + col = tid % 4; + col += ((tg % 4) / 2) * 8; + col += (tg / 4) * 4; } -void vx_wmma_load(volatile float *smem_A, volatile float *smem_B, int warp_x, int warp_y, int thread_in_warp) { - int tid = thread_in_warp; - int tg = tid / 4; +inline constexpr void map_operand_8lanes(const int tid, int &row, int &col) { + const int tg = tid / 4; - // load A - int row = tid % 4; - row += (tg * 8) % 16; - row += (tg / 4) * 4; + // A (row major) + // row 0~ 3: threadgroup 0 + // row 4~ 7: threadgroup 1 + row = tid % 4; + row += tg * 4; + + // B (column major) + // col 0~ 3: threadgroup 0 + // col 4~ 7: threadgroup 1 + col = tid % 4; + col += tg * 4; +} + +inline constexpr void map_c_32lanes(const int tid, int &row, int &col) { + const int tg = tid / 4; + + // C + // Figure 7(b), left + col = ((tg % 4) / 2) * 8; + row = (tg * 8) % 16; + row += (tg / 4) * 4; + + // Figure 7(b), right + row += (tid % 4) % 2; + col += ((tid % 4) / 2) * 2; +} + +inline constexpr void map_c_8lanes(const int tid, int &row, int &col) { + const int tg = tid / 4; + + // C + col = 0; + row = tg * 4; + + // Figure 7(b), right + row += (tid % 4) % 2; + col += ((tid % 4) / 2) * 2; +} + +inline void vx_wmma() { + asm volatile (".insn r %0, 0, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3)); +} + +void vx_wmma_load(volatile float *smem_A, volatile float *smem_B, int warp_x, + int warp_y, int thread_in_warp) { + int tid = thread_in_warp; + int tg = tid / 4; + + int row = 0; + int col = 0; + map_operand_32lanes(tid, row, col); int smem_A_m = 32; int smem_A_n = 8; @@ -30,101 +97,83 @@ void vx_wmma_load(volatile float *smem_A, volatile float *smem_B, int warp_x, in int A_offset = (row + BM * warp_y) * smem_A_n; - asm volatile ("flw f0, %0" :: "m"(smem_A[A_offset + 0])); - asm volatile ("flw f1, %0" :: "m"(smem_A[A_offset + 1])); - asm volatile ("flw f2, %0" :: "m"(smem_A[A_offset + 2])); - asm volatile ("flw f3, %0" :: "m"(smem_A[A_offset + 3])); - asm volatile ("flw f4, %0" :: "m"(smem_A[A_offset + 4])); - asm volatile ("flw f5, %0" :: "m"(smem_A[A_offset + 5])); - asm volatile ("flw f6, %0" :: "m"(smem_A[A_offset + 6])); - asm volatile ("flw f7, %0" :: "m"(smem_A[A_offset + 7])); + asm volatile("flw f0, %0" ::"m"(smem_A[A_offset + 0])); + asm volatile("flw f1, %0" ::"m"(smem_A[A_offset + 1])); + asm volatile("flw f2, %0" ::"m"(smem_A[A_offset + 2])); + asm volatile("flw f3, %0" ::"m"(smem_A[A_offset + 3])); + asm volatile("flw f4, %0" ::"m"(smem_A[A_offset + 4])); + asm volatile("flw f5, %0" ::"m"(smem_A[A_offset + 5])); + asm volatile("flw f6, %0" ::"m"(smem_A[A_offset + 6])); + asm volatile("flw f7, %0" ::"m"(smem_A[A_offset + 7])); - // load B - int col = tid % 4; - col += ((tg % 4) / 2) * 8; - col += (tg / 4) * 4; - - asm volatile ("flw f8 , %0" :: "m"(smem_B[(0 * smem_B_n) + warp_x * BN + col])); - asm volatile ("flw f9 , %0" :: "m"(smem_B[(1 * smem_B_n) + warp_x * BN + col])); - asm volatile ("flw f10, %0" :: "m"(smem_B[(2 * smem_B_n) + warp_x * BN + col])); - asm volatile ("flw f11, %0" :: "m"(smem_B[(3 * smem_B_n) + warp_x * BN + col])); - asm volatile ("flw f12, %0" :: "m"(smem_B[(4 * smem_B_n) + warp_x * BN + col])); - asm volatile ("flw f13, %0" :: "m"(smem_B[(5 * smem_B_n) + warp_x * BN + col])); - asm volatile ("flw f14, %0" :: "m"(smem_B[(6 * smem_B_n) + warp_x * BN + col])); - asm volatile ("flw f15, %0" :: "m"(smem_B[(7 * smem_B_n) + warp_x * BN + col])); + asm volatile("flw f8 , %0" ::"m"(smem_B[(0 * smem_B_n) + warp_x * BN + col])); + asm volatile("flw f9 , %0" ::"m"(smem_B[(1 * smem_B_n) + warp_x * BN + col])); + asm volatile("flw f10, %0" ::"m"(smem_B[(2 * smem_B_n) + warp_x * BN + col])); + asm volatile("flw f11, %0" ::"m"(smem_B[(3 * smem_B_n) + warp_x * BN + col])); + asm volatile("flw f12, %0" ::"m"(smem_B[(4 * smem_B_n) + warp_x * BN + col])); + asm volatile("flw f13, %0" ::"m"(smem_B[(5 * smem_B_n) + warp_x * BN + col])); + asm volatile("flw f14, %0" ::"m"(smem_B[(6 * smem_B_n) + warp_x * BN + col])); + asm volatile("flw f15, %0" ::"m"(smem_B[(7 * smem_B_n) + warp_x * BN + col])); } inline void initialize_C() { // initialize C to zeros - asm volatile ("fmv.w.x f16, x0"); - asm volatile ("fmv.w.x f17, x0"); - asm volatile ("fmv.w.x f18, x0"); - asm volatile ("fmv.w.x f19, x0"); - asm volatile ("fmv.w.x f20, x0"); - asm volatile ("fmv.w.x f21, x0"); - asm volatile ("fmv.w.x f22, x0"); - asm volatile ("fmv.w.x f23, x0"); + asm volatile("fmv.w.x f16, x0"); + asm volatile("fmv.w.x f17, x0"); + asm volatile("fmv.w.x f18, x0"); + asm volatile("fmv.w.x f19, x0"); + asm volatile("fmv.w.x f20, x0"); + asm volatile("fmv.w.x f21, x0"); + asm volatile("fmv.w.x f22, x0"); + asm volatile("fmv.w.x f23, x0"); } -inline void write_results( - volatile float *local_warp_results, - int thread_in_warp, - int warp_x, - int warp_y, - int dim_m, - int dim_n, - float *C, - int threadblock_id_x, - int threadblock_id_y -) { +inline void write_results(volatile float *local_warp_results, + int thread_in_warp, int warp_x, int warp_y, int dim_m, + int dim_n, float *C, int threadblock_id_x, + int threadblock_id_y) { int tid = thread_in_warp; - int tg = tid / 4; + int tg = tid / 4; - asm volatile ("fsw f16, %0" :: "m"(local_warp_results[tid*8+0])); - asm volatile ("fsw f17, %0" :: "m"(local_warp_results[tid*8+1])); - asm volatile ("fsw f18, %0" :: "m"(local_warp_results[tid*8+2])); - asm volatile ("fsw f19, %0" :: "m"(local_warp_results[tid*8+3])); - asm volatile ("fsw f20, %0" :: "m"(local_warp_results[tid*8+4])); - asm volatile ("fsw f21, %0" :: "m"(local_warp_results[tid*8+5])); - asm volatile ("fsw f22, %0" :: "m"(local_warp_results[tid*8+6])); - asm volatile ("fsw f23, %0" :: "m"(local_warp_results[tid*8+7])); + asm volatile("fsw f16, %0" ::"m"(local_warp_results[tid * 8 + 0])); + asm volatile("fsw f17, %0" ::"m"(local_warp_results[tid * 8 + 1])); + asm volatile("fsw f18, %0" ::"m"(local_warp_results[tid * 8 + 2])); + asm volatile("fsw f19, %0" ::"m"(local_warp_results[tid * 8 + 3])); + asm volatile("fsw f20, %0" ::"m"(local_warp_results[tid * 8 + 4])); + asm volatile("fsw f21, %0" ::"m"(local_warp_results[tid * 8 + 5])); + asm volatile("fsw f22, %0" ::"m"(local_warp_results[tid * 8 + 6])); + asm volatile("fsw f23, %0" ::"m"(local_warp_results[tid * 8 + 7])); /* - col = ((threadgroup % 4) // 2) * 8 - row = (threadgroup * 8) % 16 - row += (threadgroup // 4) * 4 - offsets = [(0, 0), (0, 1), (2, 0), (2, 1), (0, 4), (0, 5), (2, 4), (2, 5)] - offset = offsets[register-16] - row += offset[0] - col += offset[1] - thread_offsets = [(0, 0), (1, 0), (0, 2), (1, 2)] - thread_offset = thread_offsets[thread % 4] - row += thread_offset[0] - col += thread_offset[1] - return (row, col) - */ + col = ((threadgroup % 4) // 2) * 8 + row = (threadgroup * 8) % 16 + row += (threadgroup // 4) * 4 + offsets = [(0, 0), (0, 1), (2, 0), (2, 1), (0, 4), (0, 5), (2, 4), (2, 5)] + offset = offsets[register-16] + row += offset[0] + col += offset[1] + thread_offsets = [(0, 0), (1, 0), (0, 2), (1, 2)] + thread_offset = thread_offsets[thread % 4] + row += thread_offset[0] + col += thread_offset[1] + return (row, col) + */ - int local_col = ((tg % 4) / 2) * 8; - int local_row = (tg * 8) % 16; - local_row += (tg / 4) * 4; + int local_row = 0; + int local_col = 0; + map_c_32lanes(tid, local_row, local_col); - // int row_offsets[8] = {0, 0, 2, 2, 0, 0, 2, 2}; - // int col_offsets[8] = {0, 1, 0, 1, 4, 5, 4, 5}; - - // int thread_row_offsets[4] = {0, 1, 0, 1}; - // int thread_col_offsets[4] = {0, 0, 2, 2}; - int thread_row_offset = (tid % 4) % 2; - int thread_col_offset = ((tid % 4) / 2) * 2; - - float *global_offset_C = C + (threadblock_id_y * BM * 2 + warp_y * BM) * dim_n + threadblock_id_x * BN * 2 + warp_x * BM; + float *global_offset_C = C + + (threadblock_id_y * BM * 2 + warp_y * BM) * dim_n + + threadblock_id_x * BN * 2 + warp_x * BM; for (int i = 0; i < 8; i += 1) { int row_offset = ((i / 2) % 2) * 2; int col_offset = (i / 4) * 4 + i % 2; - int adjusted_local_row = local_row + thread_row_offset + row_offset; - int adjusted_local_col = local_col + thread_col_offset + col_offset; + int adjusted_local_row = local_row + row_offset; + int adjusted_local_col = local_col + col_offset; - float v = local_warp_results[tid*8+i]; + float v = local_warp_results[tid * 8 + i]; global_offset_C[adjusted_local_row * dim_n + adjusted_local_col] = v; } } @@ -174,7 +223,6 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, const uint32_t global_b_col = threadblock_dim_x * threadblock_id_x; const uint32_t local_b_row = warp_in_threadblock; const uint32_t local_b_col = tid_in_warp; - volatile float *local_a = sharedmem_per_threadblock; const size_t local_a_elems = (threadblock_dim_y * BK);