sgemm_tcore: Add compile-time write_to_gmem param to thread_block_gemm
This commit is contained in:
@@ -627,7 +627,7 @@ inline void global_dmem_load(const uint32_t dim_m, const uint32_t dim_n, const u
|
|||||||
asm volatile ("global_dmem_load_finish_%=:" :: );
|
asm volatile ("global_dmem_load_finish_%=:" :: );
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T, bool write_to_gmem = true>
|
||||||
inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
||||||
const uint32_t tid_in_threadblock,
|
const uint32_t tid_in_threadblock,
|
||||||
const uint32_t threads_per_threadblock,
|
const uint32_t threads_per_threadblock,
|
||||||
@@ -875,12 +875,14 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
|||||||
threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y);
|
threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if constexpr (write_to_gmem) {
|
||||||
#pragma GCC unroll 2
|
#pragma GCC unroll 2
|
||||||
for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) {
|
for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) {
|
||||||
#pragma GCC unroll 2
|
#pragma GCC unroll 2
|
||||||
for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) {
|
for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) {
|
||||||
write_results(tid_in_warp, warp_col, warp_row, wn_iter, wm_iter,
|
write_results(tid_in_warp, warp_col, warp_row, wn_iter, wm_iter,
|
||||||
dim_n, C, block_n, block_m);
|
dim_n, C, block_n, block_m);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user