sgemm_tcore: Add compile-time write_to_gmem param to thread_block_gemm
This commit is contained in:
@@ -627,7 +627,7 @@ inline void global_dmem_load(const uint32_t dim_m, const uint32_t dim_n, const u
|
||||
asm volatile ("global_dmem_load_finish_%=:" :: );
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
template <typename T, bool write_to_gmem = true>
|
||||
inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
||||
const uint32_t tid_in_threadblock,
|
||||
const uint32_t threads_per_threadblock,
|
||||
@@ -875,12 +875,14 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
||||
threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y);
|
||||
}
|
||||
|
||||
if constexpr (write_to_gmem) {
|
||||
#pragma GCC unroll 2
|
||||
for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) {
|
||||
for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) {
|
||||
#pragma GCC unroll 2
|
||||
for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) {
|
||||
write_results(tid_in_warp, warp_col, warp_row, wn_iter, wm_iter,
|
||||
dim_n, C, block_n, block_m);
|
||||
for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) {
|
||||
write_results(tid_in_warp, warp_col, warp_row, wn_iter, wm_iter,
|
||||
dim_n, C, block_n, block_m);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user