diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp index fbb34f8b..b0a7e9e2 100644 --- a/tests/regression/sgemm_tcore/kernel.cpp +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -37,6 +37,9 @@ #error "threadblock size too big for cluster" #endif +// using float_type = float; +using float_type = float16_t; + template inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, const uint32_t k, const T *A, const T *B, @@ -391,7 +394,7 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, } #pragma GCC unroll 1 - for (uint32_t block_k = 0; (block_k * BK) < (dim_k); block_k++) { + for (uint32_t block_k = 0; (block_k * BK) < dim_k; block_k++) { // producer code: GMEM->SMEM memory movement // --------------------------------------------------------------------- @@ -572,8 +575,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { const uint32_t problem_size = (dim_m * dim_n) / (ELEM_PER_THREAD); const uint32_t num_threadblocks = problem_size / threads_per_threadblock; - using float_type = float16_t; - // "static" shared memory allocation. This would determine threadblock // occupancy of a single cluster uint8_t *sharedmem_per_threadblock = reinterpret_cast( diff --git a/tests/regression/sgemm_tcore/util.hpp b/tests/regression/sgemm_tcore/util.hpp index 1e4c1cf6..e1db7bb8 100644 --- a/tests/regression/sgemm_tcore/util.hpp +++ b/tests/regression/sgemm_tcore/util.hpp @@ -25,7 +25,7 @@ #define WN 8 #define TCM 8 #define TCN 8 -#define TCK 8 +#define TCK 16 #define WMITER (WM / TCM) #define WNITER (WN / TCN) #define ELEM_PER_THREAD (WMITER * WNITER * (TCM * TCN) / NUM_THREADS) @@ -40,9 +40,9 @@ // // For correctness, only one of either should be 1. E.g., PRODUCE 1 CONSUME 0 // generates the NN kernel where both A and B are stored row-major in GMEM. -// To model the case where the A matrix is already stored transposed in GMEM -// ("TN" kernel), set both to 0. -#define TRANSPOSE_AT_PRODUCE 1 +// To model the case where the A matrix is already stored column-major in GMEM, +// set both to 0. +#define TRANSPOSE_AT_PRODUCE 0 #define TRANSPOSE_AT_CONSUME 0 // GMEM_COALESCED: When TRANSPOSE_AT_PRODUCE == 1 (i.e. transpose at // GMEM->SMEM), determines whether we do bank-conflict-free accesses for @@ -156,7 +156,8 @@ inline void vx_wmma(const int dest_reg) { // `local_k` is assumed to be multiple of TCK template inline void vx_wmma_load_a(volatile const T *smem_A, const int local_k, - const int warp_row, const int wm_iter, const int thread_in_warp) { + const int warp_row, const int wm_iter, + const int thread_in_warp) { const int tid = thread_in_warp; const int tg = tid / 4; @@ -171,13 +172,17 @@ inline void vx_wmma_load_a(volatile const T *smem_A, const int local_k, // neighboring columns; therefore, it essentially becomes equivalent to // moving a fp32 matrix whose column dimensions (dim_k/BK/k) are compressed // by a factor of two. - constexpr uint32_t packed_factor = (std::is_same_v ? 2 : 1); - constexpr uint32_t BK_adjusted = BK / packed_factor; + constexpr int packed_factor = (std::is_same_v ? 2 : 1); + constexpr int BK_adjusted = BK / packed_factor; + constexpr int BM_adjusted = BM / packed_factor; + const int local_k_adjusted = local_k / packed_factor; constexpr int smem_A_rows = BM; constexpr int smem_A_cols = BK_adjusted; constexpr int smem_AS_rows = BK_adjusted; constexpr int smem_AS_cols = BM; + // constexpr int smem_AS_rows = BK; + // constexpr int smem_AS_cols = BM_adjusted; if constexpr (TRANSPOSE_AT_CONSUME) { // int A_offset = (WM * warp_row + TCM * wm_iter + row) * smem_A_cols; @@ -188,7 +193,7 @@ inline void vx_wmma_load_a(volatile const T *smem_A, const int local_k, smem_addr = reinterpret_cast( &reinterpret_cast( smem_A)[(WM * warp_row + TCM * wm_iter + row) * smem_A_cols + - local_k]); + local_k /* FIXME: adjust for fp16? */]); // step to the next column // threads read from different rows; bank conflicts asm volatile("flw f0, %0(%1)" ::"i"(0 * sizeof(float)), "r"(smem_addr)); @@ -206,7 +211,7 @@ inline void vx_wmma_load_a(volatile const T *smem_A, const int local_k, const volatile uint8_t *smem_addr; smem_addr = reinterpret_cast( &reinterpret_cast( - smem_A)[((local_k + 0) * smem_AS_cols) + + smem_A)[((local_k_adjusted + 0) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row]); // step to the next row // threads read from different columns; no bank conflicts @@ -234,17 +239,21 @@ inline void vx_wmma_load_b(const volatile T *smem_B, const int local_k, map_operand(tid, row, col); // see comment in vx_wmma_load_a - constexpr uint32_t packed_factor = (std::is_same_v ? 2 : 1); - constexpr uint32_t BN_adjusted = BN / packed_factor; + constexpr int packed_factor = (std::is_same_v ? 2 : 1); + constexpr int BK_adjusted = BN / packed_factor; + constexpr int BN_adjusted = BN / packed_factor; + const int local_k_adjusted = local_k / packed_factor; - constexpr int smem_B_rows = BK; - constexpr int smem_B_cols = BN_adjusted; + // constexpr int smem_B_rows = BK; + // constexpr int smem_B_cols = BN_adjusted; + constexpr int smem_B_rows = BK_adjusted; + constexpr int smem_B_cols = BN; // f8-f15 stores a single column of B const volatile uint8_t *smem_addr; smem_addr = reinterpret_cast( &reinterpret_cast( - smem_B)[((local_k + 0) * smem_B_cols) + + smem_B)[((local_k_adjusted + 0) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col]); // step to the next row // threads read from different columns; no bank conflicts