diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp index 3118a368..d748a909 100644 --- a/tests/regression/sgemm_tcore/kernel.cpp +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -28,7 +28,7 @@ inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, // neighboring threads to ensure GMEM coalescing. // // TODO: Sharedmem swizzling is important here - if constexpr (!TRANSPOSE_AS) { + if constexpr (!TRANSPOSE_AT_PRODUCE) { // FIXME: !TRANSPOSE_AS code is old const uint32_t global_a_row = BM * threadblock_id_y + local_a_row; diff --git a/tests/regression/sgemm_tcore/kernel.warpspecial.cpp b/tests/regression/sgemm_tcore/kernel.warpspecial.cpp index 94c98569..d8764bb1 100644 --- a/tests/regression/sgemm_tcore/kernel.warpspecial.cpp +++ b/tests/regression/sgemm_tcore/kernel.warpspecial.cpp @@ -33,8 +33,9 @@ inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, // neighboring threads to ensure GMEM coalescing. // // TODO: Sharedmem swizzling is important here - if constexpr (!TRANSPOSE_AS) { - // FIXME: !TRANSPOSE_AS code is old + if constexpr (!TRANSPOSE_AT_PRODUCE) { + // if !TRANSPOSE_AT_PRODUCE, we only support coalesced GMEM loads + static_assert(TRANSPOSE_AT_PRODUCE || GMEM_COALESCED_A); const uint32_t global_a_row = BM * threadblock_id_y + local_a_row; // number of rows a full TB can read at a time @@ -42,26 +43,60 @@ inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, const float *global_a = A + dim_k * global_a_row + (k + local_a_col); volatile float *local_a_tmp = local_a + BK * local_a_row + local_a_col; + static_assert( + row_stride_a * 8 <= BM, + "manual loop unrolling condition not met; consider increasing BM"); + static_assert( + (BM % (row_stride_a * 8)) == 0, + "manual loop unrolling condition not met; BM should be power-of-two"); + #pragma GCC unroll 1 for (uint32_t local_row_offset = 0; local_row_offset < BM; - local_row_offset += row_stride_a) { + local_row_offset += row_stride_a * 8) { // const uint32_t global_a_offset = // dim_k * (global_a_row + local_row_offset) + (k + local_a_col); // local_a[BK * (local_a_row + local_row_offset) + local_a_col] = // A[global_a_offset]; - *local_a_tmp = *global_a; + // + // *local_a_tmp = *global_a; + // global_a += dim_k * row_stride_a; + // local_a_tmp += BK * row_stride_a; + asm volatile ("flw ft0, (%0)" :: "r"(global_a)); global_a += dim_k * row_stride_a; - local_a_tmp += BK * row_stride_a; + asm volatile ("flw ft1, (%0)" :: "r"(global_a)); + global_a += dim_k * row_stride_a; + asm volatile ("flw ft2, (%0)" :: "r"(global_a)); + global_a += dim_k * row_stride_a; + asm volatile ("flw ft3, (%0)" :: "r"(global_a)); + global_a += dim_k * row_stride_a; + asm volatile ("flw ft4, (%0)" :: "r"(global_a)); + global_a += dim_k * row_stride_a; + asm volatile ("flw ft5, (%0)" :: "r"(global_a)); + global_a += dim_k * row_stride_a; + asm volatile ("flw ft6, (%0)" :: "r"(global_a)); + global_a += dim_k * row_stride_a; + asm volatile ("flw ft7, (%0)" :: "r"(global_a)); + global_a += dim_k * row_stride_a; + + // stride along columns + // bank conflicts + asm volatile ("fsw ft0, %0(%1)" :: "i"(BK * row_stride_a * 0 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft1, %0(%1)" :: "i"(BK * row_stride_a * 1 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft2, %0(%1)" :: "i"(BK * row_stride_a * 2 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft3, %0(%1)" :: "i"(BK * row_stride_a * 3 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft4, %0(%1)" :: "i"(BK * row_stride_a * 4 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft5, %0(%1)" :: "i"(BK * row_stride_a * 5 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft6, %0(%1)" :: "i"(BK * row_stride_a * 6 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft7, %0(%1)" :: "i"(BK * row_stride_a * 7 * sizeof(float)), "r"(local_a_tmp)); + local_a_tmp += BK * row_stride_a * 8; } } else { if constexpr (!GMEM_COALESCED_A) { constexpr uint32_t row_stride_as = threads_in_warpgroup / BM; const uint32_t global_a_row = BM * threadblock_id_y + local_as_col; + // NOTE that GMEM reads are transposed const float *global_a = A + dim_k * global_a_row + (k + local_as_row); - // FIXME experimenting with global coalescing - // const uint32_t global_a_row = BM * threadblock_id_y + local_as_row; - // const float *global_a = A + dim_k * global_a_row + (k + local_as_col); volatile float *local_a_tmp = local_a + BM * local_as_row + local_as_col; static_assert( @@ -152,6 +187,7 @@ inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, global_a += dim_k * row_stride_a; // stride along columns + // bank conflicts asm volatile ("fsw ft0, %0(%1)" :: "i"(row_stride_a * 0 * sizeof(float)), "r"(local_a_tmp)); asm volatile ("fsw ft1, %0(%1)" :: "i"(row_stride_a * 1 * sizeof(float)), "r"(local_a_tmp)); asm volatile ("fsw ft2, %0(%1)" :: "i"(row_stride_a * 2 * sizeof(float)), "r"(local_a_tmp)); diff --git a/tests/regression/sgemm_tcore/util.hpp b/tests/regression/sgemm_tcore/util.hpp index 34c9b168..a601d22c 100644 --- a/tests/regression/sgemm_tcore/util.hpp +++ b/tests/regression/sgemm_tcore/util.hpp @@ -35,7 +35,15 @@ // number of loop around the inner 0..TCK..BK loop to simulate perfect-DRAM // scenario #define BK_LOOP 1 -#define TRANSPOSE_AS 1 +// whether to transpose smem A tile at GMEM->SMEM (produce), or SMEM->RF +// (consume). This is because the tensor core expects the A tile to be stored +// in column-major order in SMEM. +// +// For correctness, only one of either should be 1. To model the case where +// the entire A matrix is already stored transposed in GMEM ("TN" kernel), set +// both to 0. +#define TRANSPOSE_AT_PRODUCE 0 +#define TRANSPOSE_AT_CONSUME 0 // GMEM_COALESCED sets bank conflict-free accesses for // 1: GMEM loads of A matrix // 0: SMEM stores of A matrix @@ -171,7 +179,7 @@ inline void vx_wmma_load_a(volatile float *smem_A, const int local_k, constexpr int smem_AS_rows = BK; constexpr int smem_AS_cols = BM; - if constexpr (!TRANSPOSE_AS) { + if constexpr (TRANSPOSE_AT_CONSUME) { // int A_offset = (WM * warp_row + TCM * wm_iter + row) * smem_A_cols; // @perf: bank conflicts @@ -195,7 +203,7 @@ inline void vx_wmma_load_a(volatile float *smem_A, const int local_k, // asm volatile("flw f6, %0" ::"m"(smem_A[A_offset + (local_k + 6)])); // asm volatile("flw f7, %0" ::"m"(smem_A[A_offset + (local_k + 7)])); } else { - // transposed A + // read smem A tile as-is; bank-conflict-free AS load // f8-f15 stores a single row of A volatile float *smem_addr; smem_addr = &smem_A[((local_k + 0) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row];