sgemm_tg: Use reg mapping functions

This commit is contained in:
Hansung Kim
2024-05-12 22:22:54 -07:00
parent 8a521a1de8
commit 5c298c81df
2 changed files with 136 additions and 87 deletions

View File

@@ -0,0 +1 @@
sgemm_tcore

View File

@@ -10,18 +10,85 @@
#define BN 16
#define BK 8
inline void vx_wmma() {
asm volatile (".insn r %0, 0, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3));
inline constexpr void map_operand_32lanes(const int tid, int &row, int &col) {
const int tg = tid / 4;
// A (row major)
// Figure 7(a) in paper
// row 0~ 3: threadgroups 0 and 2
// row 4~ 7: threadgroups 4 and 6
// row 8~11: threadgroups 1 and 3
// row 12~15: threadgroups 5 and 7
row = tid % 4;
row += (tg * 8) % 16;
row += (tg / 4) * 4;
// B (column major)
// NOTE: Matrix B mapping in Figure 7(a) is incorrect; below is the
// corrected mapping:
// col 0~ 3: threadgroups 0 and 1
// col 4~ 7: threadgroups 4 and 5
// col 8~11: threadgroups 2 and 3
// col 12~15: threadgroups 6 and 7
col = tid % 4;
col += ((tg % 4) / 2) * 8;
col += (tg / 4) * 4;
}
void vx_wmma_load(volatile float *smem_A, volatile float *smem_B, int warp_x, int warp_y, int thread_in_warp) {
int tid = thread_in_warp;
int tg = tid / 4;
inline constexpr void map_operand_8lanes(const int tid, int &row, int &col) {
const int tg = tid / 4;
// load A
int row = tid % 4;
row += (tg * 8) % 16;
row += (tg / 4) * 4;
// A (row major)
// row 0~ 3: threadgroup 0
// row 4~ 7: threadgroup 1
row = tid % 4;
row += tg * 4;
// B (column major)
// col 0~ 3: threadgroup 0
// col 4~ 7: threadgroup 1
col = tid % 4;
col += tg * 4;
}
inline constexpr void map_c_32lanes(const int tid, int &row, int &col) {
const int tg = tid / 4;
// C
// Figure 7(b), left
col = ((tg % 4) / 2) * 8;
row = (tg * 8) % 16;
row += (tg / 4) * 4;
// Figure 7(b), right
row += (tid % 4) % 2;
col += ((tid % 4) / 2) * 2;
}
inline constexpr void map_c_8lanes(const int tid, int &row, int &col) {
const int tg = tid / 4;
// C
col = 0;
row = tg * 4;
// Figure 7(b), right
row += (tid % 4) % 2;
col += ((tid % 4) / 2) * 2;
}
inline void vx_wmma() {
asm volatile (".insn r %0, 0, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3));
}
void vx_wmma_load(volatile float *smem_A, volatile float *smem_B, int warp_x,
int warp_y, int thread_in_warp) {
int tid = thread_in_warp;
int tg = tid / 4;
int row = 0;
int col = 0;
map_operand_32lanes(tid, row, col);
int smem_A_m = 32;
int smem_A_n = 8;
@@ -30,101 +97,83 @@ void vx_wmma_load(volatile float *smem_A, volatile float *smem_B, int warp_x, in
int A_offset = (row + BM * warp_y) * smem_A_n;
asm volatile ("flw f0, %0" :: "m"(smem_A[A_offset + 0]));
asm volatile ("flw f1, %0" :: "m"(smem_A[A_offset + 1]));
asm volatile ("flw f2, %0" :: "m"(smem_A[A_offset + 2]));
asm volatile ("flw f3, %0" :: "m"(smem_A[A_offset + 3]));
asm volatile ("flw f4, %0" :: "m"(smem_A[A_offset + 4]));
asm volatile ("flw f5, %0" :: "m"(smem_A[A_offset + 5]));
asm volatile ("flw f6, %0" :: "m"(smem_A[A_offset + 6]));
asm volatile ("flw f7, %0" :: "m"(smem_A[A_offset + 7]));
asm volatile("flw f0, %0" ::"m"(smem_A[A_offset + 0]));
asm volatile("flw f1, %0" ::"m"(smem_A[A_offset + 1]));
asm volatile("flw f2, %0" ::"m"(smem_A[A_offset + 2]));
asm volatile("flw f3, %0" ::"m"(smem_A[A_offset + 3]));
asm volatile("flw f4, %0" ::"m"(smem_A[A_offset + 4]));
asm volatile("flw f5, %0" ::"m"(smem_A[A_offset + 5]));
asm volatile("flw f6, %0" ::"m"(smem_A[A_offset + 6]));
asm volatile("flw f7, %0" ::"m"(smem_A[A_offset + 7]));
// load B
int col = tid % 4;
col += ((tg % 4) / 2) * 8;
col += (tg / 4) * 4;
asm volatile ("flw f8 , %0" :: "m"(smem_B[(0 * smem_B_n) + warp_x * BN + col]));
asm volatile ("flw f9 , %0" :: "m"(smem_B[(1 * smem_B_n) + warp_x * BN + col]));
asm volatile ("flw f10, %0" :: "m"(smem_B[(2 * smem_B_n) + warp_x * BN + col]));
asm volatile ("flw f11, %0" :: "m"(smem_B[(3 * smem_B_n) + warp_x * BN + col]));
asm volatile ("flw f12, %0" :: "m"(smem_B[(4 * smem_B_n) + warp_x * BN + col]));
asm volatile ("flw f13, %0" :: "m"(smem_B[(5 * smem_B_n) + warp_x * BN + col]));
asm volatile ("flw f14, %0" :: "m"(smem_B[(6 * smem_B_n) + warp_x * BN + col]));
asm volatile ("flw f15, %0" :: "m"(smem_B[(7 * smem_B_n) + warp_x * BN + col]));
asm volatile("flw f8 , %0" ::"m"(smem_B[(0 * smem_B_n) + warp_x * BN + col]));
asm volatile("flw f9 , %0" ::"m"(smem_B[(1 * smem_B_n) + warp_x * BN + col]));
asm volatile("flw f10, %0" ::"m"(smem_B[(2 * smem_B_n) + warp_x * BN + col]));
asm volatile("flw f11, %0" ::"m"(smem_B[(3 * smem_B_n) + warp_x * BN + col]));
asm volatile("flw f12, %0" ::"m"(smem_B[(4 * smem_B_n) + warp_x * BN + col]));
asm volatile("flw f13, %0" ::"m"(smem_B[(5 * smem_B_n) + warp_x * BN + col]));
asm volatile("flw f14, %0" ::"m"(smem_B[(6 * smem_B_n) + warp_x * BN + col]));
asm volatile("flw f15, %0" ::"m"(smem_B[(7 * smem_B_n) + warp_x * BN + col]));
}
inline void initialize_C() {
// initialize C to zeros
asm volatile ("fmv.w.x f16, x0");
asm volatile ("fmv.w.x f17, x0");
asm volatile ("fmv.w.x f18, x0");
asm volatile ("fmv.w.x f19, x0");
asm volatile ("fmv.w.x f20, x0");
asm volatile ("fmv.w.x f21, x0");
asm volatile ("fmv.w.x f22, x0");
asm volatile ("fmv.w.x f23, x0");
asm volatile("fmv.w.x f16, x0");
asm volatile("fmv.w.x f17, x0");
asm volatile("fmv.w.x f18, x0");
asm volatile("fmv.w.x f19, x0");
asm volatile("fmv.w.x f20, x0");
asm volatile("fmv.w.x f21, x0");
asm volatile("fmv.w.x f22, x0");
asm volatile("fmv.w.x f23, x0");
}
inline void write_results(
volatile float *local_warp_results,
int thread_in_warp,
int warp_x,
int warp_y,
int dim_m,
int dim_n,
float *C,
int threadblock_id_x,
int threadblock_id_y
) {
inline void write_results(volatile float *local_warp_results,
int thread_in_warp, int warp_x, int warp_y, int dim_m,
int dim_n, float *C, int threadblock_id_x,
int threadblock_id_y) {
int tid = thread_in_warp;
int tg = tid / 4;
int tg = tid / 4;
asm volatile ("fsw f16, %0" :: "m"(local_warp_results[tid*8+0]));
asm volatile ("fsw f17, %0" :: "m"(local_warp_results[tid*8+1]));
asm volatile ("fsw f18, %0" :: "m"(local_warp_results[tid*8+2]));
asm volatile ("fsw f19, %0" :: "m"(local_warp_results[tid*8+3]));
asm volatile ("fsw f20, %0" :: "m"(local_warp_results[tid*8+4]));
asm volatile ("fsw f21, %0" :: "m"(local_warp_results[tid*8+5]));
asm volatile ("fsw f22, %0" :: "m"(local_warp_results[tid*8+6]));
asm volatile ("fsw f23, %0" :: "m"(local_warp_results[tid*8+7]));
asm volatile("fsw f16, %0" ::"m"(local_warp_results[tid * 8 + 0]));
asm volatile("fsw f17, %0" ::"m"(local_warp_results[tid * 8 + 1]));
asm volatile("fsw f18, %0" ::"m"(local_warp_results[tid * 8 + 2]));
asm volatile("fsw f19, %0" ::"m"(local_warp_results[tid * 8 + 3]));
asm volatile("fsw f20, %0" ::"m"(local_warp_results[tid * 8 + 4]));
asm volatile("fsw f21, %0" ::"m"(local_warp_results[tid * 8 + 5]));
asm volatile("fsw f22, %0" ::"m"(local_warp_results[tid * 8 + 6]));
asm volatile("fsw f23, %0" ::"m"(local_warp_results[tid * 8 + 7]));
/*
col = ((threadgroup % 4) // 2) * 8
row = (threadgroup * 8) % 16
row += (threadgroup // 4) * 4
offsets = [(0, 0), (0, 1), (2, 0), (2, 1), (0, 4), (0, 5), (2, 4), (2, 5)]
offset = offsets[register-16]
row += offset[0]
col += offset[1]
thread_offsets = [(0, 0), (1, 0), (0, 2), (1, 2)]
thread_offset = thread_offsets[thread % 4]
row += thread_offset[0]
col += thread_offset[1]
return (row, col)
*/
col = ((threadgroup % 4) // 2) * 8
row = (threadgroup * 8) % 16
row += (threadgroup // 4) * 4
offsets = [(0, 0), (0, 1), (2, 0), (2, 1), (0, 4), (0, 5), (2, 4), (2, 5)]
offset = offsets[register-16]
row += offset[0]
col += offset[1]
thread_offsets = [(0, 0), (1, 0), (0, 2), (1, 2)]
thread_offset = thread_offsets[thread % 4]
row += thread_offset[0]
col += thread_offset[1]
return (row, col)
*/
int local_col = ((tg % 4) / 2) * 8;
int local_row = (tg * 8) % 16;
local_row += (tg / 4) * 4;
int local_row = 0;
int local_col = 0;
map_c_32lanes(tid, local_row, local_col);
// int row_offsets[8] = {0, 0, 2, 2, 0, 0, 2, 2};
// int col_offsets[8] = {0, 1, 0, 1, 4, 5, 4, 5};
// int thread_row_offsets[4] = {0, 1, 0, 1};
// int thread_col_offsets[4] = {0, 0, 2, 2};
int thread_row_offset = (tid % 4) % 2;
int thread_col_offset = ((tid % 4) / 2) * 2;
float *global_offset_C = C + (threadblock_id_y * BM * 2 + warp_y * BM) * dim_n + threadblock_id_x * BN * 2 + warp_x * BM;
float *global_offset_C = C +
(threadblock_id_y * BM * 2 + warp_y * BM) * dim_n +
threadblock_id_x * BN * 2 + warp_x * BM;
for (int i = 0; i < 8; i += 1) {
int row_offset = ((i / 2) % 2) * 2;
int col_offset = (i / 4) * 4 + i % 2;
int adjusted_local_row = local_row + thread_row_offset + row_offset;
int adjusted_local_col = local_col + thread_col_offset + col_offset;
int adjusted_local_row = local_row + row_offset;
int adjusted_local_col = local_col + col_offset;
float v = local_warp_results[tid*8+i];
float v = local_warp_results[tid * 8 + i];
global_offset_C[adjusted_local_row * dim_n + adjusted_local_col] = v;
}
}
@@ -174,7 +223,6 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
const uint32_t global_b_col = threadblock_dim_x * threadblock_id_x;
const uint32_t local_b_row = warp_in_threadblock;
const uint32_t local_b_col = tid_in_warp;
volatile float *local_a = sharedmem_per_threadblock;
const size_t local_a_elems = (threadblock_dim_y * BK);