sgemm_tcore: Add compile-time write_to_gmem param to thread_block_gemm

This commit is contained in:
Hansung Kim
2024-08-14 17:48:31 -07:00
parent ee6339a35f
commit 1b1264207b

View File

@@ -627,7 +627,7 @@ inline void global_dmem_load(const uint32_t dim_m, const uint32_t dim_n, const u
asm volatile ("global_dmem_load_finish_%=:" :: );
}
template <typename T>
template <typename T, bool write_to_gmem = true>
inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
const uint32_t tid_in_threadblock,
const uint32_t threads_per_threadblock,
@@ -875,12 +875,14 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y);
}
if constexpr (write_to_gmem) {
#pragma GCC unroll 2
for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) {
for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) {
#pragma GCC unroll 2
for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) {
write_results(tid_in_warp, warp_col, warp_row, wn_iter, wm_iter,
dim_n, C, block_n, block_m);
for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) {
write_results(tid_in_warp, warp_col, warp_row, wn_iter, wm_iter,
dim_n, C, block_n, block_m);
}
}
}
}