sgemm_impl: Add missing reconvergence barrier after mmio

This commit is contained in:
Hansung Kim
2024-09-10 18:05:01 -07:00
parent ccddd0bcc9
commit 2152c80ffd

View File

@@ -1022,6 +1022,10 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips) /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips)
#endif #endif
} }
// reconverge after mmio divergence
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
#else #else
// move A // move A
if constexpr (!TRANSPOSE_AT_PRODUCE) { if constexpr (!TRANSPOSE_AT_PRODUCE) {
@@ -1038,9 +1042,6 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
load_tile_to_smem<T, MemLayout::MN_major, MemLayout::MN_major, BN, BK, load_tile_to_smem<T, MemLayout::MN_major, MemLayout::MN_major, BN, BK,
threads_per_threadblock>(dim_n, block_n, block_k, B, threads_per_threadblock>(dim_n, block_n, block_k, B,
local_b, tid_in_threadblock); local_b, tid_in_threadblock);
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
#endif #endif
// consumer code: SMEM->RF and compute // consumer code: SMEM->RF and compute