sgemm_tg: Use reg mapping functions

This commit is contained in:
Hansung Kim
2024-05-12 22:22:54 -07:00
parent 8a521a1de8
commit 5c298c81df
2 changed files with 136 additions and 87 deletions

View File

@@ -0,0 +1 @@
sgemm_tcore

View File

@@ -10,18 +10,85 @@
#define BN 16 #define BN 16
#define BK 8 #define BK 8
inline void vx_wmma() { inline constexpr void map_operand_32lanes(const int tid, int &row, int &col) {
asm volatile (".insn r %0, 0, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3)); const int tg = tid / 4;
// A (row major)
// Figure 7(a) in paper
// row 0~ 3: threadgroups 0 and 2
// row 4~ 7: threadgroups 4 and 6
// row 8~11: threadgroups 1 and 3
// row 12~15: threadgroups 5 and 7
row = tid % 4;
row += (tg * 8) % 16;
row += (tg / 4) * 4;
// B (column major)
// NOTE: Matrix B mapping in Figure 7(a) is incorrect; below is the
// corrected mapping:
// col 0~ 3: threadgroups 0 and 1
// col 4~ 7: threadgroups 4 and 5
// col 8~11: threadgroups 2 and 3
// col 12~15: threadgroups 6 and 7
col = tid % 4;
col += ((tg % 4) / 2) * 8;
col += (tg / 4) * 4;
} }
void vx_wmma_load(volatile float *smem_A, volatile float *smem_B, int warp_x, int warp_y, int thread_in_warp) { inline constexpr void map_operand_8lanes(const int tid, int &row, int &col) {
int tid = thread_in_warp; const int tg = tid / 4;
int tg = tid / 4;
// load A // A (row major)
int row = tid % 4; // row 0~ 3: threadgroup 0
row += (tg * 8) % 16; // row 4~ 7: threadgroup 1
row += (tg / 4) * 4; row = tid % 4;
row += tg * 4;
// B (column major)
// col 0~ 3: threadgroup 0
// col 4~ 7: threadgroup 1
col = tid % 4;
col += tg * 4;
}
inline constexpr void map_c_32lanes(const int tid, int &row, int &col) {
const int tg = tid / 4;
// C
// Figure 7(b), left
col = ((tg % 4) / 2) * 8;
row = (tg * 8) % 16;
row += (tg / 4) * 4;
// Figure 7(b), right
row += (tid % 4) % 2;
col += ((tid % 4) / 2) * 2;
}
inline constexpr void map_c_8lanes(const int tid, int &row, int &col) {
const int tg = tid / 4;
// C
col = 0;
row = tg * 4;
// Figure 7(b), right
row += (tid % 4) % 2;
col += ((tid % 4) / 2) * 2;
}
inline void vx_wmma() {
asm volatile (".insn r %0, 0, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3));
}
void vx_wmma_load(volatile float *smem_A, volatile float *smem_B, int warp_x,
int warp_y, int thread_in_warp) {
int tid = thread_in_warp;
int tg = tid / 4;
int row = 0;
int col = 0;
map_operand_32lanes(tid, row, col);
int smem_A_m = 32; int smem_A_m = 32;
int smem_A_n = 8; int smem_A_n = 8;
@@ -30,101 +97,83 @@ void vx_wmma_load(volatile float *smem_A, volatile float *smem_B, int warp_x, in
int A_offset = (row + BM * warp_y) * smem_A_n; int A_offset = (row + BM * warp_y) * smem_A_n;
asm volatile ("flw f0, %0" :: "m"(smem_A[A_offset + 0])); asm volatile("flw f0, %0" ::"m"(smem_A[A_offset + 0]));
asm volatile ("flw f1, %0" :: "m"(smem_A[A_offset + 1])); asm volatile("flw f1, %0" ::"m"(smem_A[A_offset + 1]));
asm volatile ("flw f2, %0" :: "m"(smem_A[A_offset + 2])); asm volatile("flw f2, %0" ::"m"(smem_A[A_offset + 2]));
asm volatile ("flw f3, %0" :: "m"(smem_A[A_offset + 3])); asm volatile("flw f3, %0" ::"m"(smem_A[A_offset + 3]));
asm volatile ("flw f4, %0" :: "m"(smem_A[A_offset + 4])); asm volatile("flw f4, %0" ::"m"(smem_A[A_offset + 4]));
asm volatile ("flw f5, %0" :: "m"(smem_A[A_offset + 5])); asm volatile("flw f5, %0" ::"m"(smem_A[A_offset + 5]));
asm volatile ("flw f6, %0" :: "m"(smem_A[A_offset + 6])); asm volatile("flw f6, %0" ::"m"(smem_A[A_offset + 6]));
asm volatile ("flw f7, %0" :: "m"(smem_A[A_offset + 7])); asm volatile("flw f7, %0" ::"m"(smem_A[A_offset + 7]));
// load B asm volatile("flw f8 , %0" ::"m"(smem_B[(0 * smem_B_n) + warp_x * BN + col]));
int col = tid % 4; asm volatile("flw f9 , %0" ::"m"(smem_B[(1 * smem_B_n) + warp_x * BN + col]));
col += ((tg % 4) / 2) * 8; asm volatile("flw f10, %0" ::"m"(smem_B[(2 * smem_B_n) + warp_x * BN + col]));
col += (tg / 4) * 4; asm volatile("flw f11, %0" ::"m"(smem_B[(3 * smem_B_n) + warp_x * BN + col]));
asm volatile("flw f12, %0" ::"m"(smem_B[(4 * smem_B_n) + warp_x * BN + col]));
asm volatile ("flw f8 , %0" :: "m"(smem_B[(0 * smem_B_n) + warp_x * BN + col])); asm volatile("flw f13, %0" ::"m"(smem_B[(5 * smem_B_n) + warp_x * BN + col]));
asm volatile ("flw f9 , %0" :: "m"(smem_B[(1 * smem_B_n) + warp_x * BN + col])); asm volatile("flw f14, %0" ::"m"(smem_B[(6 * smem_B_n) + warp_x * BN + col]));
asm volatile ("flw f10, %0" :: "m"(smem_B[(2 * smem_B_n) + warp_x * BN + col])); asm volatile("flw f15, %0" ::"m"(smem_B[(7 * smem_B_n) + warp_x * BN + col]));
asm volatile ("flw f11, %0" :: "m"(smem_B[(3 * smem_B_n) + warp_x * BN + col]));
asm volatile ("flw f12, %0" :: "m"(smem_B[(4 * smem_B_n) + warp_x * BN + col]));
asm volatile ("flw f13, %0" :: "m"(smem_B[(5 * smem_B_n) + warp_x * BN + col]));
asm volatile ("flw f14, %0" :: "m"(smem_B[(6 * smem_B_n) + warp_x * BN + col]));
asm volatile ("flw f15, %0" :: "m"(smem_B[(7 * smem_B_n) + warp_x * BN + col]));
} }
inline void initialize_C() { inline void initialize_C() {
// initialize C to zeros // initialize C to zeros
asm volatile ("fmv.w.x f16, x0"); asm volatile("fmv.w.x f16, x0");
asm volatile ("fmv.w.x f17, x0"); asm volatile("fmv.w.x f17, x0");
asm volatile ("fmv.w.x f18, x0"); asm volatile("fmv.w.x f18, x0");
asm volatile ("fmv.w.x f19, x0"); asm volatile("fmv.w.x f19, x0");
asm volatile ("fmv.w.x f20, x0"); asm volatile("fmv.w.x f20, x0");
asm volatile ("fmv.w.x f21, x0"); asm volatile("fmv.w.x f21, x0");
asm volatile ("fmv.w.x f22, x0"); asm volatile("fmv.w.x f22, x0");
asm volatile ("fmv.w.x f23, x0"); asm volatile("fmv.w.x f23, x0");
} }
inline void write_results( inline void write_results(volatile float *local_warp_results,
volatile float *local_warp_results, int thread_in_warp, int warp_x, int warp_y, int dim_m,
int thread_in_warp, int dim_n, float *C, int threadblock_id_x,
int warp_x, int threadblock_id_y) {
int warp_y,
int dim_m,
int dim_n,
float *C,
int threadblock_id_x,
int threadblock_id_y
) {
int tid = thread_in_warp; int tid = thread_in_warp;
int tg = tid / 4; int tg = tid / 4;
asm volatile ("fsw f16, %0" :: "m"(local_warp_results[tid*8+0])); asm volatile("fsw f16, %0" ::"m"(local_warp_results[tid * 8 + 0]));
asm volatile ("fsw f17, %0" :: "m"(local_warp_results[tid*8+1])); asm volatile("fsw f17, %0" ::"m"(local_warp_results[tid * 8 + 1]));
asm volatile ("fsw f18, %0" :: "m"(local_warp_results[tid*8+2])); asm volatile("fsw f18, %0" ::"m"(local_warp_results[tid * 8 + 2]));
asm volatile ("fsw f19, %0" :: "m"(local_warp_results[tid*8+3])); asm volatile("fsw f19, %0" ::"m"(local_warp_results[tid * 8 + 3]));
asm volatile ("fsw f20, %0" :: "m"(local_warp_results[tid*8+4])); asm volatile("fsw f20, %0" ::"m"(local_warp_results[tid * 8 + 4]));
asm volatile ("fsw f21, %0" :: "m"(local_warp_results[tid*8+5])); asm volatile("fsw f21, %0" ::"m"(local_warp_results[tid * 8 + 5]));
asm volatile ("fsw f22, %0" :: "m"(local_warp_results[tid*8+6])); asm volatile("fsw f22, %0" ::"m"(local_warp_results[tid * 8 + 6]));
asm volatile ("fsw f23, %0" :: "m"(local_warp_results[tid*8+7])); asm volatile("fsw f23, %0" ::"m"(local_warp_results[tid * 8 + 7]));
/* /*
col = ((threadgroup % 4) // 2) * 8 col = ((threadgroup % 4) // 2) * 8
row = (threadgroup * 8) % 16 row = (threadgroup * 8) % 16
row += (threadgroup // 4) * 4 row += (threadgroup // 4) * 4
offsets = [(0, 0), (0, 1), (2, 0), (2, 1), (0, 4), (0, 5), (2, 4), (2, 5)] offsets = [(0, 0), (0, 1), (2, 0), (2, 1), (0, 4), (0, 5), (2, 4), (2, 5)]
offset = offsets[register-16] offset = offsets[register-16]
row += offset[0] row += offset[0]
col += offset[1] col += offset[1]
thread_offsets = [(0, 0), (1, 0), (0, 2), (1, 2)] thread_offsets = [(0, 0), (1, 0), (0, 2), (1, 2)]
thread_offset = thread_offsets[thread % 4] thread_offset = thread_offsets[thread % 4]
row += thread_offset[0] row += thread_offset[0]
col += thread_offset[1] col += thread_offset[1]
return (row, col) return (row, col)
*/ */
int local_col = ((tg % 4) / 2) * 8; int local_row = 0;
int local_row = (tg * 8) % 16; int local_col = 0;
local_row += (tg / 4) * 4; map_c_32lanes(tid, local_row, local_col);
// int row_offsets[8] = {0, 0, 2, 2, 0, 0, 2, 2}; float *global_offset_C = C +
// int col_offsets[8] = {0, 1, 0, 1, 4, 5, 4, 5}; (threadblock_id_y * BM * 2 + warp_y * BM) * dim_n +
threadblock_id_x * BN * 2 + warp_x * BM;
// int thread_row_offsets[4] = {0, 1, 0, 1};
// int thread_col_offsets[4] = {0, 0, 2, 2};
int thread_row_offset = (tid % 4) % 2;
int thread_col_offset = ((tid % 4) / 2) * 2;
float *global_offset_C = C + (threadblock_id_y * BM * 2 + warp_y * BM) * dim_n + threadblock_id_x * BN * 2 + warp_x * BM;
for (int i = 0; i < 8; i += 1) { for (int i = 0; i < 8; i += 1) {
int row_offset = ((i / 2) % 2) * 2; int row_offset = ((i / 2) % 2) * 2;
int col_offset = (i / 4) * 4 + i % 2; int col_offset = (i / 4) * 4 + i % 2;
int adjusted_local_row = local_row + thread_row_offset + row_offset; int adjusted_local_row = local_row + row_offset;
int adjusted_local_col = local_col + thread_col_offset + col_offset; int adjusted_local_col = local_col + col_offset;
float v = local_warp_results[tid*8+i]; float v = local_warp_results[tid * 8 + i];
global_offset_C[adjusted_local_row * dim_n + adjusted_local_col] = v; global_offset_C[adjusted_local_row * dim_n + adjusted_local_col] = v;
} }
} }
@@ -174,7 +223,6 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
const uint32_t global_b_col = threadblock_dim_x * threadblock_id_x; const uint32_t global_b_col = threadblock_dim_x * threadblock_id_x;
const uint32_t local_b_row = warp_in_threadblock; const uint32_t local_b_row = warp_in_threadblock;
const uint32_t local_b_col = tid_in_warp; const uint32_t local_b_col = tid_in_warp;
volatile float *local_a = sharedmem_per_threadblock; volatile float *local_a = sharedmem_per_threadblock;
const size_t local_a_elems = (threadblock_dim_y * BK); const size_t local_a_elems = (threadblock_dim_y * BK);