sgemm_impl: Fix wrong barrier count; add barrier for write_to_smem
This commit is contained in:
@@ -568,7 +568,9 @@ template <typename T,
|
|||||||
__attribute__((always_inline)) inline void
|
__attribute__((always_inline)) inline void
|
||||||
thread_block_gemm_single_tile(const T *local_a, const T *local_b, T *local_c,
|
thread_block_gemm_single_tile(const T *local_a, const T *local_b, T *local_c,
|
||||||
const uint32_t tid_in_threadblock,
|
const uint32_t tid_in_threadblock,
|
||||||
const uint32_t threads_per_threadblock) {
|
const uint32_t threads_per_threadblock,
|
||||||
|
const uint32_t threadblocks_per_cluster,
|
||||||
|
const uint32_t threadblock_id_in_cluster) {
|
||||||
// no double-buffering
|
// no double-buffering
|
||||||
// FIXME: duplicated from thread_block_gemm
|
// FIXME: duplicated from thread_block_gemm
|
||||||
const uint32_t threads_per_warpgroup = threads_per_threadblock;
|
const uint32_t threads_per_warpgroup = threads_per_threadblock;
|
||||||
@@ -576,6 +578,8 @@ thread_block_gemm_single_tile(const T *local_a, const T *local_b, T *local_c,
|
|||||||
const uint32_t warp_row = warp_id_in_warpgroup / (BN / WN);
|
const uint32_t warp_row = warp_id_in_warpgroup / (BN / WN);
|
||||||
const uint32_t warp_col = warp_id_in_warpgroup % (BN / WN);
|
const uint32_t warp_col = warp_id_in_warpgroup % (BN / WN);
|
||||||
const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS;
|
const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS;
|
||||||
|
const uint32_t warps_per_threadblock_per_core =
|
||||||
|
NUM_WARPS / threadblocks_per_cluster;
|
||||||
|
|
||||||
#pragma GCC unroll 1
|
#pragma GCC unroll 1
|
||||||
for (int i = 0; i < BK_LOOP; i++) {
|
for (int i = 0; i < BK_LOOP; i++) {
|
||||||
@@ -608,6 +612,11 @@ thread_block_gemm_single_tile(const T *local_a, const T *local_b, T *local_c,
|
|||||||
}
|
}
|
||||||
|
|
||||||
if constexpr (write_to_smem) {
|
if constexpr (write_to_smem) {
|
||||||
|
// need to protect smem reads in the earlier step from writes in below,
|
||||||
|
// especially when the destination smem address overlaps with the input
|
||||||
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
|
warps_per_threadblock_per_core);
|
||||||
|
|
||||||
#pragma GCC unroll
|
#pragma GCC unroll
|
||||||
for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) {
|
for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) {
|
||||||
#pragma GCC unroll
|
#pragma GCC unroll
|
||||||
@@ -655,7 +664,7 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
|
|||||||
const uint32_t warp_col = warp_id_in_warpgroup % (BN / WN);
|
const uint32_t warp_col = warp_id_in_warpgroup % (BN / WN);
|
||||||
const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS;
|
const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS;
|
||||||
const uint32_t warps_per_threadblock_per_core =
|
const uint32_t warps_per_threadblock_per_core =
|
||||||
NUM_WARPS / threads_per_threadblock;
|
NUM_WARPS / threadblocks_per_cluster;
|
||||||
|
|
||||||
T *local_a = reinterpret_cast<T *>(sharedmem_per_threadblock + smem_a_offset);
|
T *local_a = reinterpret_cast<T *>(sharedmem_per_threadblock + smem_a_offset);
|
||||||
T *local_a_buf =
|
T *local_a_buf =
|
||||||
@@ -858,7 +867,8 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
|
|||||||
/*write_to_smem=*/false>(
|
/*write_to_smem=*/false>(
|
||||||
local_a_consume, local_b_consume,
|
local_a_consume, local_b_consume,
|
||||||
static_cast<T *>(nullptr) /*ignore*/, tid_in_threadblock,
|
static_cast<T *>(nullptr) /*ignore*/, tid_in_threadblock,
|
||||||
threads_per_threadblock);
|
threads_per_threadblock, threadblocks_per_cluster,
|
||||||
|
threadblock_id_in_cluster);
|
||||||
|
|
||||||
if constexpr (GEMMINI_DMA) {
|
if constexpr (GEMMINI_DMA) {
|
||||||
// Call gemmini fence at the end of the loop to overlap dma & wmma.
|
// Call gemmini fence at the end of the loop to overlap dma & wmma.
|
||||||
|
|||||||
Reference in New Issue
Block a user