From dc7bd6b2480e02bcb9f08aec48c8691b308240e8 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Mon, 10 Jun 2024 19:52:37 -0700 Subject: [PATCH] sgemm_tcore: Fix warp_row/col calculation bug --- tests/regression/sgemm_tcore/kernel.cpp | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp index cb10c14d..a655f20c 100644 --- a/tests/regression/sgemm_tcore/kernel.cpp +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -281,10 +281,8 @@ inline void write_results(const int thread_in_warp, const int warp_col, int tid_col = 0; map_c(tid, tid_row, tid_col); - // int local_row = (WM * warp_row + TCM * wm_iter) + tid_row; - // int local_col = (WN * warp_col + TCN * wn_iter) + tid_col; - int local_row = (WM * warp_row); - int local_col = (WN * warp_col); + int local_row = (WM * warp_row + TCM * wm_iter) + tid_row; + int local_col = (WN * warp_col + TCN * wn_iter) + tid_col; float *global_offset_C = C + (BM * threadblock_id_y) * dim_n + @@ -563,10 +561,9 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, // no double-buffering const uint32_t threads_per_warpgroup = threads_per_threadblock; - const uint32_t warp_in_warpgroup = threads_per_warpgroup / NUM_LANES; - - const uint32_t warp_row = warp_in_warpgroup / (BN / WN); - const uint32_t warp_col = warp_in_warpgroup % (BN / WN); + const uint32_t warp_id_in_warpgroup = tid_in_threadblock / NUM_LANES; + const uint32_t warp_row = warp_id_in_warpgroup / (BN / WN); + const uint32_t warp_col = warp_id_in_warpgroup % (BN / WN); const uint32_t tid_in_warp = tid_in_threadblock % NUM_LANES; volatile float *local_a = sharedmem_per_threadblock;