sgemm_tcore: Deconstruct smem addr calc to reduce reg alloc
This commit is contained in:
@@ -180,18 +180,32 @@ inline void vx_wmma_load_a(volatile float *smem_A, const int local_k,
|
|||||||
} else {
|
} else {
|
||||||
// transposed A
|
// transposed A
|
||||||
// f8-f15 stores a single row of A
|
// f8-f15 stores a single row of A
|
||||||
asm volatile("flw f0, %0" ::"m"(smem_A[((local_k + 0) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row]));
|
register volatile float *smem_addr asm("t5");
|
||||||
asm volatile("flw f1, %0" ::"m"(smem_A[((local_k + 1) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row]));
|
smem_addr = &smem_A[((local_k + 0) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row];
|
||||||
asm volatile("flw f2, %0" ::"m"(smem_A[((local_k + 2) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row]));
|
asm volatile("flw f0, %0" ::"m"(*smem_addr));
|
||||||
asm volatile("flw f3, %0" ::"m"(smem_A[((local_k + 3) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row]));
|
smem_addr += smem_AS_cols;
|
||||||
asm volatile("flw f4, %0" ::"m"(smem_A[((local_k + 4) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row]));
|
asm volatile("flw f1, %0" ::"m"(*smem_addr));
|
||||||
asm volatile("flw f5, %0" ::"m"(smem_A[((local_k + 5) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row]));
|
smem_addr += smem_AS_cols;
|
||||||
asm volatile("flw f6, %0" ::"m"(smem_A[((local_k + 6) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row]));
|
asm volatile("flw f2, %0" ::"m"(*smem_addr));
|
||||||
asm volatile("flw f7, %0" ::"m"(smem_A[((local_k + 7) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row]));
|
smem_addr += smem_AS_cols;
|
||||||
// #pragma GCC unroll 8
|
asm volatile("flw f3, %0" ::"m"(*smem_addr));
|
||||||
// for (int i = 0; i < 8; i++) {
|
smem_addr += smem_AS_cols;
|
||||||
// asm volatile("flw f0, %0" ::"m"(smem_A[((local_k + i) * smem_A_rows) + (WM * warp_row + TCM * wm_iter) + row]));
|
asm volatile("flw f4, %0" ::"m"(*smem_addr));
|
||||||
// }
|
smem_addr += smem_AS_cols;
|
||||||
|
asm volatile("flw f5, %0" ::"m"(*smem_addr));
|
||||||
|
smem_addr += smem_AS_cols;
|
||||||
|
asm volatile("flw f6, %0" ::"m"(*smem_addr));
|
||||||
|
smem_addr += smem_AS_cols;
|
||||||
|
asm volatile("flw f7, %0" ::"m"(*smem_addr));
|
||||||
|
|
||||||
|
// asm volatile("flw f0, %0" ::"m"(smem_A[((local_k + 0) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row]));
|
||||||
|
// asm volatile("flw f1, %0" ::"m"(smem_A[((local_k + 1) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row]));
|
||||||
|
// asm volatile("flw f2, %0" ::"m"(smem_A[((local_k + 2) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row]));
|
||||||
|
// asm volatile("flw f3, %0" ::"m"(smem_A[((local_k + 3) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row]));
|
||||||
|
// asm volatile("flw f4, %0" ::"m"(smem_A[((local_k + 4) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row]));
|
||||||
|
// asm volatile("flw f5, %0" ::"m"(smem_A[((local_k + 5) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row]));
|
||||||
|
// asm volatile("flw f6, %0" ::"m"(smem_A[((local_k + 6) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row]));
|
||||||
|
// asm volatile("flw f7, %0" ::"m"(smem_A[((local_k + 7) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -210,14 +224,31 @@ inline void vx_wmma_load_b(volatile float *smem_B, const int local_k,
|
|||||||
constexpr int smem_B_cols = BN;
|
constexpr int smem_B_cols = BN;
|
||||||
|
|
||||||
// f8-f15 stores a single column of B
|
// f8-f15 stores a single column of B
|
||||||
asm volatile("flw f8, %0" ::"m"(smem_B[((local_k + 0) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col]));
|
register volatile float *smem_addr asm("t5");
|
||||||
asm volatile("flw f9, %0" ::"m"(smem_B[((local_k + 1) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col]));
|
smem_addr = &smem_B[((local_k + 0) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col];
|
||||||
asm volatile("flw f10, %0" ::"m"(smem_B[((local_k + 2) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col]));
|
asm volatile("flw f8, %0" ::"m"(*smem_addr));
|
||||||
asm volatile("flw f11, %0" ::"m"(smem_B[((local_k + 3) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col]));
|
smem_addr += smem_B_cols;
|
||||||
asm volatile("flw f12, %0" ::"m"(smem_B[((local_k + 4) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col]));
|
asm volatile("flw f9, %0" ::"m"(*smem_addr));
|
||||||
asm volatile("flw f13, %0" ::"m"(smem_B[((local_k + 5) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col]));
|
smem_addr += smem_B_cols;
|
||||||
asm volatile("flw f14, %0" ::"m"(smem_B[((local_k + 6) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col]));
|
asm volatile("flw f10, %0" ::"m"(*smem_addr));
|
||||||
asm volatile("flw f15, %0" ::"m"(smem_B[((local_k + 7) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col]));
|
smem_addr += smem_B_cols;
|
||||||
|
asm volatile("flw f11, %0" ::"m"(*smem_addr));
|
||||||
|
smem_addr += smem_B_cols;
|
||||||
|
asm volatile("flw f12, %0" ::"m"(*smem_addr));
|
||||||
|
smem_addr += smem_B_cols;
|
||||||
|
asm volatile("flw f13, %0" ::"m"(*smem_addr));
|
||||||
|
smem_addr += smem_B_cols;
|
||||||
|
asm volatile("flw f14, %0" ::"m"(*smem_addr));
|
||||||
|
smem_addr += smem_B_cols;
|
||||||
|
asm volatile("flw f15, %0" ::"m"(*smem_addr));
|
||||||
|
// asm volatile("flw f8, %0" ::"m"(smem_B[((local_k + 0) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col]));
|
||||||
|
// asm volatile("flw f9, %0" ::"m"(smem_B[((local_k + 1) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col]));
|
||||||
|
// asm volatile("flw f10, %0" ::"m"(smem_B[((local_k + 2) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col]));
|
||||||
|
// asm volatile("flw f11, %0" ::"m"(smem_B[((local_k + 3) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col]));
|
||||||
|
// asm volatile("flw f12, %0" ::"m"(smem_B[((local_k + 4) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col]));
|
||||||
|
// asm volatile("flw f13, %0" ::"m"(smem_B[((local_k + 5) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col]));
|
||||||
|
// asm volatile("flw f14, %0" ::"m"(smem_B[((local_k + 6) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col]));
|
||||||
|
// asm volatile("flw f15, %0" ::"m"(smem_B[((local_k + 7) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col]));
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void initialize_C(const int dest_reg) {
|
inline void initialize_C(const int dest_reg) {
|
||||||
@@ -243,8 +274,7 @@ inline void initialize_C(const int dest_reg) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void write_results(volatile float *local_warp_results,
|
inline void write_results(const int thread_in_warp, const int warp_col,
|
||||||
const int thread_in_warp, const int warp_col,
|
|
||||||
const int warp_row, const int wn_iter,
|
const int warp_row, const int wn_iter,
|
||||||
const int wm_iter, const int dim_m, const int dim_n,
|
const int wm_iter, const int dim_m, const int dim_n,
|
||||||
float *C, const int threadblock_id_x,
|
float *C, const int threadblock_id_x,
|
||||||
@@ -266,28 +296,47 @@ inline void write_results(volatile float *local_warp_results,
|
|||||||
|
|
||||||
// @perf: this likely causes a lot of gmem bank conflicts
|
// @perf: this likely causes a lot of gmem bank conflicts
|
||||||
if (wm_iter == 0) {
|
if (wm_iter == 0) {
|
||||||
asm volatile ("fsw f16, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 0)]));
|
register volatile float *gmem_addr asm("t5");
|
||||||
asm volatile ("fsw f17, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 1)]));
|
register volatile float *gmem_addr_tmp asm("t6");
|
||||||
asm volatile ("fsw f18, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 0)]));
|
gmem_addr = &global_offset_C[dim_n * (local_row + 0) + (local_col + 0)];
|
||||||
asm volatile ("fsw f19, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 1)]));
|
asm volatile ("fsw f16, %0" :: "m"(*(gmem_addr + 0)));
|
||||||
asm volatile ("fsw f20, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 4)]));
|
asm volatile ("fsw f17, %0" :: "m"(*(gmem_addr + 1)));
|
||||||
asm volatile ("fsw f21, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 5)]));
|
gmem_addr_tmp = gmem_addr + (2 * dim_n);
|
||||||
asm volatile ("fsw f22, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 4)]));
|
asm volatile ("fsw f18, %0" :: "m"(*(gmem_addr_tmp + 0)));
|
||||||
asm volatile ("fsw f23, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 5)]));
|
asm volatile ("fsw f19, %0" :: "m"(*(gmem_addr_tmp + 1)));
|
||||||
|
gmem_addr += 4;
|
||||||
|
asm volatile ("fsw f20, %0" :: "m"(*(gmem_addr + 0)));
|
||||||
|
asm volatile ("fsw f21, %0" :: "m"(*(gmem_addr + 1)));
|
||||||
|
gmem_addr_tmp = gmem_addr + (2 * dim_n);
|
||||||
|
asm volatile ("fsw f22, %0" :: "m"(*(gmem_addr_tmp + 0)));
|
||||||
|
asm volatile ("fsw f23, %0" :: "m"(*(gmem_addr_tmp + 1)));
|
||||||
|
// asm volatile ("fsw f16, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 0)]));
|
||||||
|
// asm volatile ("fsw f17, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 1)]));
|
||||||
|
// asm volatile ("fsw f18, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 0)]));
|
||||||
|
// asm volatile ("fsw f19, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 1)]));
|
||||||
|
// asm volatile ("fsw f20, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 4)]));
|
||||||
|
// asm volatile ("fsw f21, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 5)]));
|
||||||
|
// asm volatile ("fsw f22, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 4)]));
|
||||||
|
// asm volatile ("fsw f23, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 5)]));
|
||||||
} else {
|
} else {
|
||||||
asm volatile ("fsw f24, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 0)]));
|
register volatile float *gmem_addr asm("t5");
|
||||||
asm volatile ("fsw f25, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 1)]));
|
register volatile float *gmem_addr_tmp asm("t6");
|
||||||
asm volatile ("fsw f26, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 0)]));
|
gmem_addr = &global_offset_C[dim_n * (local_row + 0) + (local_col + 0)];
|
||||||
asm volatile ("fsw f27, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 1)]));
|
gmem_addr_tmp = gmem_addr + (2 * dim_n);
|
||||||
asm volatile ("fsw f28, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 4)]));
|
asm volatile ("fsw f24, %0" :: "m"(*(gmem_addr + 0)));
|
||||||
asm volatile ("fsw f29, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 5)]));
|
asm volatile ("fsw f25, %0" :: "m"(*(gmem_addr + 1)));
|
||||||
asm volatile ("fsw f30, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 4)]));
|
asm volatile ("fsw f26, %0" :: "m"(*(gmem_addr_tmp + 0)));
|
||||||
asm volatile ("fsw f31, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 5)]));
|
asm volatile ("fsw f27, %0" :: "m"(*(gmem_addr_tmp + 1)));
|
||||||
|
gmem_addr += 4;
|
||||||
|
gmem_addr_tmp = gmem_addr + (2 * dim_n);
|
||||||
|
asm volatile ("fsw f28, %0" :: "m"(*(gmem_addr + 0)));
|
||||||
|
asm volatile ("fsw f29, %0" :: "m"(*(gmem_addr + 1)));
|
||||||
|
asm volatile ("fsw f30, %0" :: "m"(*(gmem_addr_tmp + 0)));
|
||||||
|
asm volatile ("fsw f31, %0" :: "m"(*(gmem_addr_tmp + 1)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void threadblock_barrier(unsigned int tid_in_threadblock,
|
inline void threadblock_barrier(const uint32_t barrier_id, const uint32_t count) {
|
||||||
unsigned int barrier_id, unsigned int count) {
|
|
||||||
vx_fence();
|
vx_fence();
|
||||||
vx_barrier(barrier_id, count);
|
vx_barrier(barrier_id, count);
|
||||||
}
|
}
|
||||||
@@ -406,16 +455,13 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
|||||||
|
|
||||||
volatile float *local_a = sharedmem_per_threadblock;
|
volatile float *local_a = sharedmem_per_threadblock;
|
||||||
// const size_t local_a_elems = threadblock_dim_x * threadblock_dim_y;
|
// const size_t local_a_elems = threadblock_dim_x * threadblock_dim_y;
|
||||||
const size_t local_a_elems = (BM * BK);
|
constexpr size_t local_a_elems = (BM * BK);
|
||||||
volatile float *local_b = sharedmem_per_threadblock + local_a_elems;
|
volatile float *local_b = sharedmem_per_threadblock + local_a_elems;
|
||||||
const size_t local_b_elems = (BK * BN);
|
constexpr size_t local_b_elems = (BK * BN);
|
||||||
|
|
||||||
volatile float *local_a_buf = local_b + local_b_elems;
|
volatile float *local_a_buf = local_b + local_b_elems;
|
||||||
volatile float *local_b_buf = local_a_buf + local_a_elems;
|
volatile float *local_b_buf = local_a_buf + local_a_elems;
|
||||||
|
|
||||||
volatile float *local_warp_results =
|
|
||||||
local_b_buf + local_b_elems + (warp_in_warpgroup * TCM * TCN);
|
|
||||||
|
|
||||||
// clear out C
|
// clear out C
|
||||||
initialize_C(0);
|
initialize_C(0);
|
||||||
initialize_C(1);
|
initialize_C(1);
|
||||||
@@ -427,8 +473,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
|||||||
tid_in_warpgroup, threadblock_id_x, threadblock_id_y);
|
tid_in_warpgroup, threadblock_id_x, threadblock_id_y);
|
||||||
}
|
}
|
||||||
|
|
||||||
threadblock_barrier(tid_in_threadblock, threadblock_id_in_cluster,
|
threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y);
|
||||||
threadblock_dim_y);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t k_index = 0;
|
uint32_t k_index = 0;
|
||||||
@@ -459,8 +504,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
|||||||
threadblock_id_x, threadblock_id_y);
|
threadblock_id_x, threadblock_id_y);
|
||||||
}
|
}
|
||||||
|
|
||||||
threadblock_barrier(tid_in_threadblock, threadblock_id_in_cluster,
|
threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y);
|
||||||
threadblock_dim_y);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
else {
|
else {
|
||||||
@@ -509,8 +553,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
threadblock_barrier(tid_in_threadblock, threadblock_id_in_cluster,
|
threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y);
|
||||||
threadblock_dim_y);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#else
|
#else
|
||||||
@@ -559,9 +602,8 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
|||||||
if (warp_in_warpgroup == 0) {
|
if (warp_in_warpgroup == 0) {
|
||||||
#endif
|
#endif
|
||||||
if (warpgroup_id == 1) {
|
if (warpgroup_id == 1) {
|
||||||
write_results(local_warp_results, tid_in_warp, warp_col, warp_row,
|
write_results(tid_in_warp, warp_col, warp_row, wn_iter, wm_iter,
|
||||||
wn_iter, wm_iter, dim_m, dim_n, C, threadblock_id_x,
|
dim_m, dim_n, C, threadblock_id_x, threadblock_id_y);
|
||||||
threadblock_id_y);
|
|
||||||
}
|
}
|
||||||
#if TC_SINGLE_WARP
|
#if TC_SINGLE_WARP
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user