sgemm_tg: Use reg mapping functions
This commit is contained in:
1
tests/regression/sgemm_tcore/.gitignore
vendored
Normal file
1
tests/regression/sgemm_tcore/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
sgemm_tcore
|
||||||
@@ -10,18 +10,85 @@
|
|||||||
#define BN 16
|
#define BN 16
|
||||||
#define BK 8
|
#define BK 8
|
||||||
|
|
||||||
inline void vx_wmma() {
|
inline constexpr void map_operand_32lanes(const int tid, int &row, int &col) {
|
||||||
asm volatile (".insn r %0, 0, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3));
|
const int tg = tid / 4;
|
||||||
|
|
||||||
|
// A (row major)
|
||||||
|
// Figure 7(a) in paper
|
||||||
|
// row 0~ 3: threadgroups 0 and 2
|
||||||
|
// row 4~ 7: threadgroups 4 and 6
|
||||||
|
// row 8~11: threadgroups 1 and 3
|
||||||
|
// row 12~15: threadgroups 5 and 7
|
||||||
|
row = tid % 4;
|
||||||
|
row += (tg * 8) % 16;
|
||||||
|
row += (tg / 4) * 4;
|
||||||
|
|
||||||
|
// B (column major)
|
||||||
|
// NOTE: Matrix B mapping in Figure 7(a) is incorrect; below is the
|
||||||
|
// corrected mapping:
|
||||||
|
// col 0~ 3: threadgroups 0 and 1
|
||||||
|
// col 4~ 7: threadgroups 4 and 5
|
||||||
|
// col 8~11: threadgroups 2 and 3
|
||||||
|
// col 12~15: threadgroups 6 and 7
|
||||||
|
col = tid % 4;
|
||||||
|
col += ((tg % 4) / 2) * 8;
|
||||||
|
col += (tg / 4) * 4;
|
||||||
}
|
}
|
||||||
|
|
||||||
void vx_wmma_load(volatile float *smem_A, volatile float *smem_B, int warp_x, int warp_y, int thread_in_warp) {
|
inline constexpr void map_operand_8lanes(const int tid, int &row, int &col) {
|
||||||
int tid = thread_in_warp;
|
const int tg = tid / 4;
|
||||||
int tg = tid / 4;
|
|
||||||
|
|
||||||
// load A
|
// A (row major)
|
||||||
int row = tid % 4;
|
// row 0~ 3: threadgroup 0
|
||||||
row += (tg * 8) % 16;
|
// row 4~ 7: threadgroup 1
|
||||||
row += (tg / 4) * 4;
|
row = tid % 4;
|
||||||
|
row += tg * 4;
|
||||||
|
|
||||||
|
// B (column major)
|
||||||
|
// col 0~ 3: threadgroup 0
|
||||||
|
// col 4~ 7: threadgroup 1
|
||||||
|
col = tid % 4;
|
||||||
|
col += tg * 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline constexpr void map_c_32lanes(const int tid, int &row, int &col) {
|
||||||
|
const int tg = tid / 4;
|
||||||
|
|
||||||
|
// C
|
||||||
|
// Figure 7(b), left
|
||||||
|
col = ((tg % 4) / 2) * 8;
|
||||||
|
row = (tg * 8) % 16;
|
||||||
|
row += (tg / 4) * 4;
|
||||||
|
|
||||||
|
// Figure 7(b), right
|
||||||
|
row += (tid % 4) % 2;
|
||||||
|
col += ((tid % 4) / 2) * 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline constexpr void map_c_8lanes(const int tid, int &row, int &col) {
|
||||||
|
const int tg = tid / 4;
|
||||||
|
|
||||||
|
// C
|
||||||
|
col = 0;
|
||||||
|
row = tg * 4;
|
||||||
|
|
||||||
|
// Figure 7(b), right
|
||||||
|
row += (tid % 4) % 2;
|
||||||
|
col += ((tid % 4) / 2) * 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void vx_wmma() {
|
||||||
|
asm volatile (".insn r %0, 0, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3));
|
||||||
|
}
|
||||||
|
|
||||||
|
void vx_wmma_load(volatile float *smem_A, volatile float *smem_B, int warp_x,
|
||||||
|
int warp_y, int thread_in_warp) {
|
||||||
|
int tid = thread_in_warp;
|
||||||
|
int tg = tid / 4;
|
||||||
|
|
||||||
|
int row = 0;
|
||||||
|
int col = 0;
|
||||||
|
map_operand_32lanes(tid, row, col);
|
||||||
|
|
||||||
int smem_A_m = 32;
|
int smem_A_m = 32;
|
||||||
int smem_A_n = 8;
|
int smem_A_n = 8;
|
||||||
@@ -30,101 +97,83 @@ void vx_wmma_load(volatile float *smem_A, volatile float *smem_B, int warp_x, in
|
|||||||
|
|
||||||
int A_offset = (row + BM * warp_y) * smem_A_n;
|
int A_offset = (row + BM * warp_y) * smem_A_n;
|
||||||
|
|
||||||
asm volatile ("flw f0, %0" :: "m"(smem_A[A_offset + 0]));
|
asm volatile("flw f0, %0" ::"m"(smem_A[A_offset + 0]));
|
||||||
asm volatile ("flw f1, %0" :: "m"(smem_A[A_offset + 1]));
|
asm volatile("flw f1, %0" ::"m"(smem_A[A_offset + 1]));
|
||||||
asm volatile ("flw f2, %0" :: "m"(smem_A[A_offset + 2]));
|
asm volatile("flw f2, %0" ::"m"(smem_A[A_offset + 2]));
|
||||||
asm volatile ("flw f3, %0" :: "m"(smem_A[A_offset + 3]));
|
asm volatile("flw f3, %0" ::"m"(smem_A[A_offset + 3]));
|
||||||
asm volatile ("flw f4, %0" :: "m"(smem_A[A_offset + 4]));
|
asm volatile("flw f4, %0" ::"m"(smem_A[A_offset + 4]));
|
||||||
asm volatile ("flw f5, %0" :: "m"(smem_A[A_offset + 5]));
|
asm volatile("flw f5, %0" ::"m"(smem_A[A_offset + 5]));
|
||||||
asm volatile ("flw f6, %0" :: "m"(smem_A[A_offset + 6]));
|
asm volatile("flw f6, %0" ::"m"(smem_A[A_offset + 6]));
|
||||||
asm volatile ("flw f7, %0" :: "m"(smem_A[A_offset + 7]));
|
asm volatile("flw f7, %0" ::"m"(smem_A[A_offset + 7]));
|
||||||
|
|
||||||
// load B
|
asm volatile("flw f8 , %0" ::"m"(smem_B[(0 * smem_B_n) + warp_x * BN + col]));
|
||||||
int col = tid % 4;
|
asm volatile("flw f9 , %0" ::"m"(smem_B[(1 * smem_B_n) + warp_x * BN + col]));
|
||||||
col += ((tg % 4) / 2) * 8;
|
asm volatile("flw f10, %0" ::"m"(smem_B[(2 * smem_B_n) + warp_x * BN + col]));
|
||||||
col += (tg / 4) * 4;
|
asm volatile("flw f11, %0" ::"m"(smem_B[(3 * smem_B_n) + warp_x * BN + col]));
|
||||||
|
asm volatile("flw f12, %0" ::"m"(smem_B[(4 * smem_B_n) + warp_x * BN + col]));
|
||||||
asm volatile ("flw f8 , %0" :: "m"(smem_B[(0 * smem_B_n) + warp_x * BN + col]));
|
asm volatile("flw f13, %0" ::"m"(smem_B[(5 * smem_B_n) + warp_x * BN + col]));
|
||||||
asm volatile ("flw f9 , %0" :: "m"(smem_B[(1 * smem_B_n) + warp_x * BN + col]));
|
asm volatile("flw f14, %0" ::"m"(smem_B[(6 * smem_B_n) + warp_x * BN + col]));
|
||||||
asm volatile ("flw f10, %0" :: "m"(smem_B[(2 * smem_B_n) + warp_x * BN + col]));
|
asm volatile("flw f15, %0" ::"m"(smem_B[(7 * smem_B_n) + warp_x * BN + col]));
|
||||||
asm volatile ("flw f11, %0" :: "m"(smem_B[(3 * smem_B_n) + warp_x * BN + col]));
|
|
||||||
asm volatile ("flw f12, %0" :: "m"(smem_B[(4 * smem_B_n) + warp_x * BN + col]));
|
|
||||||
asm volatile ("flw f13, %0" :: "m"(smem_B[(5 * smem_B_n) + warp_x * BN + col]));
|
|
||||||
asm volatile ("flw f14, %0" :: "m"(smem_B[(6 * smem_B_n) + warp_x * BN + col]));
|
|
||||||
asm volatile ("flw f15, %0" :: "m"(smem_B[(7 * smem_B_n) + warp_x * BN + col]));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void initialize_C() {
|
inline void initialize_C() {
|
||||||
// initialize C to zeros
|
// initialize C to zeros
|
||||||
asm volatile ("fmv.w.x f16, x0");
|
asm volatile("fmv.w.x f16, x0");
|
||||||
asm volatile ("fmv.w.x f17, x0");
|
asm volatile("fmv.w.x f17, x0");
|
||||||
asm volatile ("fmv.w.x f18, x0");
|
asm volatile("fmv.w.x f18, x0");
|
||||||
asm volatile ("fmv.w.x f19, x0");
|
asm volatile("fmv.w.x f19, x0");
|
||||||
asm volatile ("fmv.w.x f20, x0");
|
asm volatile("fmv.w.x f20, x0");
|
||||||
asm volatile ("fmv.w.x f21, x0");
|
asm volatile("fmv.w.x f21, x0");
|
||||||
asm volatile ("fmv.w.x f22, x0");
|
asm volatile("fmv.w.x f22, x0");
|
||||||
asm volatile ("fmv.w.x f23, x0");
|
asm volatile("fmv.w.x f23, x0");
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void write_results(
|
inline void write_results(volatile float *local_warp_results,
|
||||||
volatile float *local_warp_results,
|
int thread_in_warp, int warp_x, int warp_y, int dim_m,
|
||||||
int thread_in_warp,
|
int dim_n, float *C, int threadblock_id_x,
|
||||||
int warp_x,
|
int threadblock_id_y) {
|
||||||
int warp_y,
|
|
||||||
int dim_m,
|
|
||||||
int dim_n,
|
|
||||||
float *C,
|
|
||||||
int threadblock_id_x,
|
|
||||||
int threadblock_id_y
|
|
||||||
) {
|
|
||||||
int tid = thread_in_warp;
|
int tid = thread_in_warp;
|
||||||
int tg = tid / 4;
|
int tg = tid / 4;
|
||||||
|
|
||||||
asm volatile ("fsw f16, %0" :: "m"(local_warp_results[tid*8+0]));
|
asm volatile("fsw f16, %0" ::"m"(local_warp_results[tid * 8 + 0]));
|
||||||
asm volatile ("fsw f17, %0" :: "m"(local_warp_results[tid*8+1]));
|
asm volatile("fsw f17, %0" ::"m"(local_warp_results[tid * 8 + 1]));
|
||||||
asm volatile ("fsw f18, %0" :: "m"(local_warp_results[tid*8+2]));
|
asm volatile("fsw f18, %0" ::"m"(local_warp_results[tid * 8 + 2]));
|
||||||
asm volatile ("fsw f19, %0" :: "m"(local_warp_results[tid*8+3]));
|
asm volatile("fsw f19, %0" ::"m"(local_warp_results[tid * 8 + 3]));
|
||||||
asm volatile ("fsw f20, %0" :: "m"(local_warp_results[tid*8+4]));
|
asm volatile("fsw f20, %0" ::"m"(local_warp_results[tid * 8 + 4]));
|
||||||
asm volatile ("fsw f21, %0" :: "m"(local_warp_results[tid*8+5]));
|
asm volatile("fsw f21, %0" ::"m"(local_warp_results[tid * 8 + 5]));
|
||||||
asm volatile ("fsw f22, %0" :: "m"(local_warp_results[tid*8+6]));
|
asm volatile("fsw f22, %0" ::"m"(local_warp_results[tid * 8 + 6]));
|
||||||
asm volatile ("fsw f23, %0" :: "m"(local_warp_results[tid*8+7]));
|
asm volatile("fsw f23, %0" ::"m"(local_warp_results[tid * 8 + 7]));
|
||||||
|
|
||||||
/*
|
/*
|
||||||
col = ((threadgroup % 4) // 2) * 8
|
col = ((threadgroup % 4) // 2) * 8
|
||||||
row = (threadgroup * 8) % 16
|
row = (threadgroup * 8) % 16
|
||||||
row += (threadgroup // 4) * 4
|
row += (threadgroup // 4) * 4
|
||||||
offsets = [(0, 0), (0, 1), (2, 0), (2, 1), (0, 4), (0, 5), (2, 4), (2, 5)]
|
offsets = [(0, 0), (0, 1), (2, 0), (2, 1), (0, 4), (0, 5), (2, 4), (2, 5)]
|
||||||
offset = offsets[register-16]
|
offset = offsets[register-16]
|
||||||
row += offset[0]
|
row += offset[0]
|
||||||
col += offset[1]
|
col += offset[1]
|
||||||
thread_offsets = [(0, 0), (1, 0), (0, 2), (1, 2)]
|
thread_offsets = [(0, 0), (1, 0), (0, 2), (1, 2)]
|
||||||
thread_offset = thread_offsets[thread % 4]
|
thread_offset = thread_offsets[thread % 4]
|
||||||
row += thread_offset[0]
|
row += thread_offset[0]
|
||||||
col += thread_offset[1]
|
col += thread_offset[1]
|
||||||
return (row, col)
|
return (row, col)
|
||||||
*/
|
*/
|
||||||
|
|
||||||
int local_col = ((tg % 4) / 2) * 8;
|
int local_row = 0;
|
||||||
int local_row = (tg * 8) % 16;
|
int local_col = 0;
|
||||||
local_row += (tg / 4) * 4;
|
map_c_32lanes(tid, local_row, local_col);
|
||||||
|
|
||||||
// int row_offsets[8] = {0, 0, 2, 2, 0, 0, 2, 2};
|
float *global_offset_C = C +
|
||||||
// int col_offsets[8] = {0, 1, 0, 1, 4, 5, 4, 5};
|
(threadblock_id_y * BM * 2 + warp_y * BM) * dim_n +
|
||||||
|
threadblock_id_x * BN * 2 + warp_x * BM;
|
||||||
// int thread_row_offsets[4] = {0, 1, 0, 1};
|
|
||||||
// int thread_col_offsets[4] = {0, 0, 2, 2};
|
|
||||||
int thread_row_offset = (tid % 4) % 2;
|
|
||||||
int thread_col_offset = ((tid % 4) / 2) * 2;
|
|
||||||
|
|
||||||
float *global_offset_C = C + (threadblock_id_y * BM * 2 + warp_y * BM) * dim_n + threadblock_id_x * BN * 2 + warp_x * BM;
|
|
||||||
for (int i = 0; i < 8; i += 1) {
|
for (int i = 0; i < 8; i += 1) {
|
||||||
int row_offset = ((i / 2) % 2) * 2;
|
int row_offset = ((i / 2) % 2) * 2;
|
||||||
int col_offset = (i / 4) * 4 + i % 2;
|
int col_offset = (i / 4) * 4 + i % 2;
|
||||||
|
|
||||||
int adjusted_local_row = local_row + thread_row_offset + row_offset;
|
int adjusted_local_row = local_row + row_offset;
|
||||||
int adjusted_local_col = local_col + thread_col_offset + col_offset;
|
int adjusted_local_col = local_col + col_offset;
|
||||||
|
|
||||||
float v = local_warp_results[tid*8+i];
|
float v = local_warp_results[tid * 8 + i];
|
||||||
global_offset_C[adjusted_local_row * dim_n + adjusted_local_col] = v;
|
global_offset_C[adjusted_local_row * dim_n + adjusted_local_col] = v;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -174,7 +223,6 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
|||||||
const uint32_t global_b_col = threadblock_dim_x * threadblock_id_x;
|
const uint32_t global_b_col = threadblock_dim_x * threadblock_id_x;
|
||||||
const uint32_t local_b_row = warp_in_threadblock;
|
const uint32_t local_b_row = warp_in_threadblock;
|
||||||
const uint32_t local_b_col = tid_in_warp;
|
const uint32_t local_b_col = tid_in_warp;
|
||||||
|
|
||||||
|
|
||||||
volatile float *local_a = sharedmem_per_threadblock;
|
volatile float *local_a = sharedmem_per_threadblock;
|
||||||
const size_t local_a_elems = (threadblock_dim_y * BK);
|
const size_t local_a_elems = (threadblock_dim_y * BK);
|
||||||
|
|||||||
Reference in New Issue
Block a user