sgemm_tcore: Fix warp_row/col calculation bug

This commit is contained in:
Hansung Kim
2024-06-10 19:52:37 -07:00
parent 3b2f5a31de
commit dc7bd6b248

View File

@@ -281,10 +281,8 @@ inline void write_results(const int thread_in_warp, const int warp_col,
int tid_col = 0; int tid_col = 0;
map_c(tid, tid_row, tid_col); map_c(tid, tid_row, tid_col);
// int local_row = (WM * warp_row + TCM * wm_iter) + tid_row; int local_row = (WM * warp_row + TCM * wm_iter) + tid_row;
// int local_col = (WN * warp_col + TCN * wn_iter) + tid_col; int local_col = (WN * warp_col + TCN * wn_iter) + tid_col;
int local_row = (WM * warp_row);
int local_col = (WN * warp_col);
float *global_offset_C = C + float *global_offset_C = C +
(BM * threadblock_id_y) * dim_n + (BM * threadblock_id_y) * dim_n +
@@ -563,10 +561,9 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
// no double-buffering // no double-buffering
const uint32_t threads_per_warpgroup = threads_per_threadblock; const uint32_t threads_per_warpgroup = threads_per_threadblock;
const uint32_t warp_in_warpgroup = threads_per_warpgroup / NUM_LANES; const uint32_t warp_id_in_warpgroup = tid_in_threadblock / NUM_LANES;
const uint32_t warp_row = warp_id_in_warpgroup / (BN / WN);
const uint32_t warp_row = warp_in_warpgroup / (BN / WN); const uint32_t warp_col = warp_id_in_warpgroup % (BN / WN);
const uint32_t warp_col = warp_in_warpgroup % (BN / WN);
const uint32_t tid_in_warp = tid_in_threadblock % NUM_LANES; const uint32_t tid_in_warp = tid_in_threadblock % NUM_LANES;
volatile float *local_a = sharedmem_per_threadblock; volatile float *local_a = sharedmem_per_threadblock;