sgemm_tg: Use reg mapping functions
This commit is contained in:
1
tests/regression/sgemm_tcore/.gitignore
vendored
Normal file
1
tests/regression/sgemm_tcore/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
sgemm_tcore
|
||||
@@ -10,18 +10,85 @@
|
||||
#define BN 16
|
||||
#define BK 8
|
||||
|
||||
inline void vx_wmma() {
|
||||
asm volatile (".insn r %0, 0, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3));
|
||||
inline constexpr void map_operand_32lanes(const int tid, int &row, int &col) {
|
||||
const int tg = tid / 4;
|
||||
|
||||
// A (row major)
|
||||
// Figure 7(a) in paper
|
||||
// row 0~ 3: threadgroups 0 and 2
|
||||
// row 4~ 7: threadgroups 4 and 6
|
||||
// row 8~11: threadgroups 1 and 3
|
||||
// row 12~15: threadgroups 5 and 7
|
||||
row = tid % 4;
|
||||
row += (tg * 8) % 16;
|
||||
row += (tg / 4) * 4;
|
||||
|
||||
// B (column major)
|
||||
// NOTE: Matrix B mapping in Figure 7(a) is incorrect; below is the
|
||||
// corrected mapping:
|
||||
// col 0~ 3: threadgroups 0 and 1
|
||||
// col 4~ 7: threadgroups 4 and 5
|
||||
// col 8~11: threadgroups 2 and 3
|
||||
// col 12~15: threadgroups 6 and 7
|
||||
col = tid % 4;
|
||||
col += ((tg % 4) / 2) * 8;
|
||||
col += (tg / 4) * 4;
|
||||
}
|
||||
|
||||
void vx_wmma_load(volatile float *smem_A, volatile float *smem_B, int warp_x, int warp_y, int thread_in_warp) {
|
||||
int tid = thread_in_warp;
|
||||
int tg = tid / 4;
|
||||
inline constexpr void map_operand_8lanes(const int tid, int &row, int &col) {
|
||||
const int tg = tid / 4;
|
||||
|
||||
// load A
|
||||
int row = tid % 4;
|
||||
row += (tg * 8) % 16;
|
||||
row += (tg / 4) * 4;
|
||||
// A (row major)
|
||||
// row 0~ 3: threadgroup 0
|
||||
// row 4~ 7: threadgroup 1
|
||||
row = tid % 4;
|
||||
row += tg * 4;
|
||||
|
||||
// B (column major)
|
||||
// col 0~ 3: threadgroup 0
|
||||
// col 4~ 7: threadgroup 1
|
||||
col = tid % 4;
|
||||
col += tg * 4;
|
||||
}
|
||||
|
||||
inline constexpr void map_c_32lanes(const int tid, int &row, int &col) {
|
||||
const int tg = tid / 4;
|
||||
|
||||
// C
|
||||
// Figure 7(b), left
|
||||
col = ((tg % 4) / 2) * 8;
|
||||
row = (tg * 8) % 16;
|
||||
row += (tg / 4) * 4;
|
||||
|
||||
// Figure 7(b), right
|
||||
row += (tid % 4) % 2;
|
||||
col += ((tid % 4) / 2) * 2;
|
||||
}
|
||||
|
||||
inline constexpr void map_c_8lanes(const int tid, int &row, int &col) {
|
||||
const int tg = tid / 4;
|
||||
|
||||
// C
|
||||
col = 0;
|
||||
row = tg * 4;
|
||||
|
||||
// Figure 7(b), right
|
||||
row += (tid % 4) % 2;
|
||||
col += ((tid % 4) / 2) * 2;
|
||||
}
|
||||
|
||||
inline void vx_wmma() {
|
||||
asm volatile (".insn r %0, 0, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3));
|
||||
}
|
||||
|
||||
void vx_wmma_load(volatile float *smem_A, volatile float *smem_B, int warp_x,
|
||||
int warp_y, int thread_in_warp) {
|
||||
int tid = thread_in_warp;
|
||||
int tg = tid / 4;
|
||||
|
||||
int row = 0;
|
||||
int col = 0;
|
||||
map_operand_32lanes(tid, row, col);
|
||||
|
||||
int smem_A_m = 32;
|
||||
int smem_A_n = 8;
|
||||
@@ -30,101 +97,83 @@ void vx_wmma_load(volatile float *smem_A, volatile float *smem_B, int warp_x, in
|
||||
|
||||
int A_offset = (row + BM * warp_y) * smem_A_n;
|
||||
|
||||
asm volatile ("flw f0, %0" :: "m"(smem_A[A_offset + 0]));
|
||||
asm volatile ("flw f1, %0" :: "m"(smem_A[A_offset + 1]));
|
||||
asm volatile ("flw f2, %0" :: "m"(smem_A[A_offset + 2]));
|
||||
asm volatile ("flw f3, %0" :: "m"(smem_A[A_offset + 3]));
|
||||
asm volatile ("flw f4, %0" :: "m"(smem_A[A_offset + 4]));
|
||||
asm volatile ("flw f5, %0" :: "m"(smem_A[A_offset + 5]));
|
||||
asm volatile ("flw f6, %0" :: "m"(smem_A[A_offset + 6]));
|
||||
asm volatile ("flw f7, %0" :: "m"(smem_A[A_offset + 7]));
|
||||
asm volatile("flw f0, %0" ::"m"(smem_A[A_offset + 0]));
|
||||
asm volatile("flw f1, %0" ::"m"(smem_A[A_offset + 1]));
|
||||
asm volatile("flw f2, %0" ::"m"(smem_A[A_offset + 2]));
|
||||
asm volatile("flw f3, %0" ::"m"(smem_A[A_offset + 3]));
|
||||
asm volatile("flw f4, %0" ::"m"(smem_A[A_offset + 4]));
|
||||
asm volatile("flw f5, %0" ::"m"(smem_A[A_offset + 5]));
|
||||
asm volatile("flw f6, %0" ::"m"(smem_A[A_offset + 6]));
|
||||
asm volatile("flw f7, %0" ::"m"(smem_A[A_offset + 7]));
|
||||
|
||||
// load B
|
||||
int col = tid % 4;
|
||||
col += ((tg % 4) / 2) * 8;
|
||||
col += (tg / 4) * 4;
|
||||
|
||||
asm volatile ("flw f8 , %0" :: "m"(smem_B[(0 * smem_B_n) + warp_x * BN + col]));
|
||||
asm volatile ("flw f9 , %0" :: "m"(smem_B[(1 * smem_B_n) + warp_x * BN + col]));
|
||||
asm volatile ("flw f10, %0" :: "m"(smem_B[(2 * smem_B_n) + warp_x * BN + col]));
|
||||
asm volatile ("flw f11, %0" :: "m"(smem_B[(3 * smem_B_n) + warp_x * BN + col]));
|
||||
asm volatile ("flw f12, %0" :: "m"(smem_B[(4 * smem_B_n) + warp_x * BN + col]));
|
||||
asm volatile ("flw f13, %0" :: "m"(smem_B[(5 * smem_B_n) + warp_x * BN + col]));
|
||||
asm volatile ("flw f14, %0" :: "m"(smem_B[(6 * smem_B_n) + warp_x * BN + col]));
|
||||
asm volatile ("flw f15, %0" :: "m"(smem_B[(7 * smem_B_n) + warp_x * BN + col]));
|
||||
asm volatile("flw f8 , %0" ::"m"(smem_B[(0 * smem_B_n) + warp_x * BN + col]));
|
||||
asm volatile("flw f9 , %0" ::"m"(smem_B[(1 * smem_B_n) + warp_x * BN + col]));
|
||||
asm volatile("flw f10, %0" ::"m"(smem_B[(2 * smem_B_n) + warp_x * BN + col]));
|
||||
asm volatile("flw f11, %0" ::"m"(smem_B[(3 * smem_B_n) + warp_x * BN + col]));
|
||||
asm volatile("flw f12, %0" ::"m"(smem_B[(4 * smem_B_n) + warp_x * BN + col]));
|
||||
asm volatile("flw f13, %0" ::"m"(smem_B[(5 * smem_B_n) + warp_x * BN + col]));
|
||||
asm volatile("flw f14, %0" ::"m"(smem_B[(6 * smem_B_n) + warp_x * BN + col]));
|
||||
asm volatile("flw f15, %0" ::"m"(smem_B[(7 * smem_B_n) + warp_x * BN + col]));
|
||||
}
|
||||
|
||||
inline void initialize_C() {
|
||||
// initialize C to zeros
|
||||
asm volatile ("fmv.w.x f16, x0");
|
||||
asm volatile ("fmv.w.x f17, x0");
|
||||
asm volatile ("fmv.w.x f18, x0");
|
||||
asm volatile ("fmv.w.x f19, x0");
|
||||
asm volatile ("fmv.w.x f20, x0");
|
||||
asm volatile ("fmv.w.x f21, x0");
|
||||
asm volatile ("fmv.w.x f22, x0");
|
||||
asm volatile ("fmv.w.x f23, x0");
|
||||
asm volatile("fmv.w.x f16, x0");
|
||||
asm volatile("fmv.w.x f17, x0");
|
||||
asm volatile("fmv.w.x f18, x0");
|
||||
asm volatile("fmv.w.x f19, x0");
|
||||
asm volatile("fmv.w.x f20, x0");
|
||||
asm volatile("fmv.w.x f21, x0");
|
||||
asm volatile("fmv.w.x f22, x0");
|
||||
asm volatile("fmv.w.x f23, x0");
|
||||
}
|
||||
|
||||
inline void write_results(
|
||||
volatile float *local_warp_results,
|
||||
int thread_in_warp,
|
||||
int warp_x,
|
||||
int warp_y,
|
||||
int dim_m,
|
||||
int dim_n,
|
||||
float *C,
|
||||
int threadblock_id_x,
|
||||
int threadblock_id_y
|
||||
) {
|
||||
inline void write_results(volatile float *local_warp_results,
|
||||
int thread_in_warp, int warp_x, int warp_y, int dim_m,
|
||||
int dim_n, float *C, int threadblock_id_x,
|
||||
int threadblock_id_y) {
|
||||
int tid = thread_in_warp;
|
||||
int tg = tid / 4;
|
||||
int tg = tid / 4;
|
||||
|
||||
asm volatile ("fsw f16, %0" :: "m"(local_warp_results[tid*8+0]));
|
||||
asm volatile ("fsw f17, %0" :: "m"(local_warp_results[tid*8+1]));
|
||||
asm volatile ("fsw f18, %0" :: "m"(local_warp_results[tid*8+2]));
|
||||
asm volatile ("fsw f19, %0" :: "m"(local_warp_results[tid*8+3]));
|
||||
asm volatile ("fsw f20, %0" :: "m"(local_warp_results[tid*8+4]));
|
||||
asm volatile ("fsw f21, %0" :: "m"(local_warp_results[tid*8+5]));
|
||||
asm volatile ("fsw f22, %0" :: "m"(local_warp_results[tid*8+6]));
|
||||
asm volatile ("fsw f23, %0" :: "m"(local_warp_results[tid*8+7]));
|
||||
asm volatile("fsw f16, %0" ::"m"(local_warp_results[tid * 8 + 0]));
|
||||
asm volatile("fsw f17, %0" ::"m"(local_warp_results[tid * 8 + 1]));
|
||||
asm volatile("fsw f18, %0" ::"m"(local_warp_results[tid * 8 + 2]));
|
||||
asm volatile("fsw f19, %0" ::"m"(local_warp_results[tid * 8 + 3]));
|
||||
asm volatile("fsw f20, %0" ::"m"(local_warp_results[tid * 8 + 4]));
|
||||
asm volatile("fsw f21, %0" ::"m"(local_warp_results[tid * 8 + 5]));
|
||||
asm volatile("fsw f22, %0" ::"m"(local_warp_results[tid * 8 + 6]));
|
||||
asm volatile("fsw f23, %0" ::"m"(local_warp_results[tid * 8 + 7]));
|
||||
|
||||
/*
|
||||
col = ((threadgroup % 4) // 2) * 8
|
||||
row = (threadgroup * 8) % 16
|
||||
row += (threadgroup // 4) * 4
|
||||
offsets = [(0, 0), (0, 1), (2, 0), (2, 1), (0, 4), (0, 5), (2, 4), (2, 5)]
|
||||
offset = offsets[register-16]
|
||||
row += offset[0]
|
||||
col += offset[1]
|
||||
thread_offsets = [(0, 0), (1, 0), (0, 2), (1, 2)]
|
||||
thread_offset = thread_offsets[thread % 4]
|
||||
row += thread_offset[0]
|
||||
col += thread_offset[1]
|
||||
return (row, col)
|
||||
*/
|
||||
col = ((threadgroup % 4) // 2) * 8
|
||||
row = (threadgroup * 8) % 16
|
||||
row += (threadgroup // 4) * 4
|
||||
offsets = [(0, 0), (0, 1), (2, 0), (2, 1), (0, 4), (0, 5), (2, 4), (2, 5)]
|
||||
offset = offsets[register-16]
|
||||
row += offset[0]
|
||||
col += offset[1]
|
||||
thread_offsets = [(0, 0), (1, 0), (0, 2), (1, 2)]
|
||||
thread_offset = thread_offsets[thread % 4]
|
||||
row += thread_offset[0]
|
||||
col += thread_offset[1]
|
||||
return (row, col)
|
||||
*/
|
||||
|
||||
int local_col = ((tg % 4) / 2) * 8;
|
||||
int local_row = (tg * 8) % 16;
|
||||
local_row += (tg / 4) * 4;
|
||||
int local_row = 0;
|
||||
int local_col = 0;
|
||||
map_c_32lanes(tid, local_row, local_col);
|
||||
|
||||
// int row_offsets[8] = {0, 0, 2, 2, 0, 0, 2, 2};
|
||||
// int col_offsets[8] = {0, 1, 0, 1, 4, 5, 4, 5};
|
||||
|
||||
// int thread_row_offsets[4] = {0, 1, 0, 1};
|
||||
// int thread_col_offsets[4] = {0, 0, 2, 2};
|
||||
int thread_row_offset = (tid % 4) % 2;
|
||||
int thread_col_offset = ((tid % 4) / 2) * 2;
|
||||
|
||||
float *global_offset_C = C + (threadblock_id_y * BM * 2 + warp_y * BM) * dim_n + threadblock_id_x * BN * 2 + warp_x * BM;
|
||||
float *global_offset_C = C +
|
||||
(threadblock_id_y * BM * 2 + warp_y * BM) * dim_n +
|
||||
threadblock_id_x * BN * 2 + warp_x * BM;
|
||||
for (int i = 0; i < 8; i += 1) {
|
||||
int row_offset = ((i / 2) % 2) * 2;
|
||||
int col_offset = (i / 4) * 4 + i % 2;
|
||||
|
||||
int adjusted_local_row = local_row + thread_row_offset + row_offset;
|
||||
int adjusted_local_col = local_col + thread_col_offset + col_offset;
|
||||
int adjusted_local_row = local_row + row_offset;
|
||||
int adjusted_local_col = local_col + col_offset;
|
||||
|
||||
float v = local_warp_results[tid*8+i];
|
||||
float v = local_warp_results[tid * 8 + i];
|
||||
global_offset_C[adjusted_local_row * dim_n + adjusted_local_col] = v;
|
||||
}
|
||||
}
|
||||
@@ -174,7 +223,6 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
||||
const uint32_t global_b_col = threadblock_dim_x * threadblock_id_x;
|
||||
const uint32_t local_b_row = warp_in_threadblock;
|
||||
const uint32_t local_b_col = tid_in_warp;
|
||||
|
||||
|
||||
volatile float *local_a = sharedmem_per_threadblock;
|
||||
const size_t local_a_elems = (threadblock_dim_y * BK);
|
||||
|
||||
Reference in New Issue
Block a user