diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp index c269bc9c..8f658d67 100644 --- a/tests/regression/sgemm_tcore/kernel.cpp +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -37,10 +37,6 @@ #error "threadblock size too big for cluster" #endif -// "fake" fp16 type that only has the correct word size. Proper conversion to -// fp32 need to be done in a custom function. -using float16_t = uint16_t; - template inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, const uint32_t k, const T *A, const T *B, @@ -48,13 +44,27 @@ inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, const uint32_t tid_in_threadblock, const uint32_t threadblock_id_x, const uint32_t threadblock_id_y) { - const uint32_t local_a_row = tid_in_threadblock / BK; - const uint32_t local_a_col = tid_in_threadblock % BK; + // In fp16 mode, bit-pack two fp16 elements into each fp32 element, and do + // data movement at the fp32 granularity. Assuming that the matrix is stored + // row-major in GMEM, the packed fp16 pairs belong to the same row, + // neighboring columns; therefore, it essentially becomes equivalent to + // moving a fp32 matrix whose column dimensions (dim_k/BK/k) are compressed + // by a factor of two. + constexpr uint32_t packed_factor = (std::is_same_v ? 2 : 1); + constexpr uint32_t BK_adjusted = BK / packed_factor; + const uint32_t dim_k_adjusted = dim_k / packed_factor; + constexpr uint32_t BN_adjusted = BN / packed_factor; + const uint32_t dim_n_adjusted = dim_n / packed_factor; + const uint32_t k_adjusted = k / packed_factor; + + const uint32_t local_a_row = tid_in_threadblock / BK_adjusted; + const uint32_t local_a_col = tid_in_threadblock % BK_adjusted; const uint32_t local_as_row = tid_in_threadblock / BM; const uint32_t local_as_col = tid_in_threadblock % BM; - const uint32_t local_b_row = tid_in_threadblock / BN; - const uint32_t local_b_col = tid_in_threadblock % BN; + const uint32_t local_b_row = tid_in_threadblock / BN_adjusted; + const uint32_t local_b_col = tid_in_threadblock % BN_adjusted; + // FIXME: need fix for fp16? constexpr uint32_t threads_in_threadblock = (BM * BN) / ELEM_PER_THREAD; // Data move from GMEM to SMEM @@ -63,53 +73,59 @@ inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, // neighboring threads to ensure GMEM coalescing. // // TODO: Sharedmem swizzling is important here + + // move A if constexpr (!TRANSPOSE_AT_PRODUCE) { + // No transpose at GMEM->SMEM movement // FIXME: !TRANSPOSE_AS code is old const uint32_t global_a_row = BM * threadblock_id_y + local_a_row; // number of rows a full TB can read at a time - constexpr uint32_t row_stride_a = threads_in_threadblock / BK; - const T *global_a = A + dim_k * global_a_row + (k + local_a_col); - volatile T *local_a_tmp = local_a + BK * local_a_row + local_a_col; + // this is equivalent to threadblock_dim_y (assuming threadblock_dim_x == + // BK) + constexpr uint32_t row_stride_a = threads_in_threadblock / BK_adjusted; + const float *global_a = reinterpret_cast(A) + + dim_k_adjusted * global_a_row + + (k_adjusted + local_a_col); + volatile float *local_a_tmp = reinterpret_cast(local_a) + + BK_adjusted * local_a_row + local_a_col; #pragma GCC unroll 1 for (uint32_t local_row_offset = 0; local_row_offset < BM; local_row_offset += row_stride_a) { - // const uint32_t global_a_offset = - // dim_k * (global_a_row + local_row_offset) + (k + local_a_col); - // local_a[BK * (local_a_row + local_row_offset) + local_a_col] = - // A[global_a_offset]; *local_a_tmp = *global_a; - global_a += dim_k * row_stride_a; - local_a_tmp += BK * row_stride_a; + // move to the next "row-chunk", when threadblock is smaller than BM*BK + global_a += dim_k_adjusted * row_stride_a; + local_a_tmp += BK_adjusted * row_stride_a; } } else { if constexpr (!GMEM_COALESCED_A) { + // !GMEM_COALESCED_A: threads do uncoalesced read from neighboring row in + // GMEM, writes to neighboring cols in SMEM constexpr uint32_t row_stride_as = threads_in_threadblock / BM; const uint32_t global_a_row = BM * threadblock_id_y + local_as_col; - const T *global_a = A + dim_k * global_a_row + (k + local_as_row); - // FIXME experimenting with global coalescing - // const uint32_t global_a_row = BM * threadblock_id_y + local_as_row; - // const T *global_a = A + dim_k * global_a_row + (k + local_as_col); - volatile T *local_a_tmp = local_a + BM * local_as_row + local_as_col; + const float *global_a = + reinterpret_cast(A) + dim_k_adjusted * global_a_row + (k_adjusted + local_as_row); + volatile float *local_a_tmp = + reinterpret_cast(local_a) + BM * local_as_row + local_as_col; static_assert( - row_stride_as * 8 <= BK, + row_stride_as * 8 <= BK_adjusted, "manual loop unrolling condition not met; consider increasing BK"); static_assert( - (BK % (row_stride_as * 8)) == 0, + (BK_adjusted % (row_stride_as * 8)) == 0, "manual loop unrolling condition not met; BK should be power-of-two"); #pragma GCC unroll 1 - for (uint32_t local_row_offset = 0; local_row_offset < BK; + for (uint32_t local_row_offset = 0; local_row_offset < BK_adjusted; local_row_offset += row_stride_as * 8) { // @perf: bank conflicts here // const uint32_t global_a_offset = - // dim_k * (global_a_row) + (k + local_as_row + local_row_offset); + // dim_k_adjusted * (global_a_row) + (k + local_as_row + local_row_offset); // FIXME experimenting with global coalescing // const uint32_t global_a_offset = - // dim_k * (global_a_row + local_row_offset) + (k + local_as_col); + // dim_k_adjusted * (global_a_row + local_row_offset) + (k + local_as_col); // local_a[BM * (local_as_row + local_row_offset) + local_as_col] = // A[global_a_offset]; @@ -146,11 +162,15 @@ inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, local_a_tmp += BM * row_stride_as * 8; } } else { - constexpr uint32_t row_stride_a = threads_in_threadblock / BK; + constexpr uint32_t row_stride_a = threads_in_threadblock / BK_adjusted; const uint32_t global_a_row = BM * threadblock_id_y + local_a_row; - const T *global_a = A + dim_k * global_a_row + (k + local_a_col); + const float *global_a = reinterpret_cast(A) + + dim_k_adjusted * global_a_row + + (k_adjusted + local_a_col); // NOTE that SMEM writes are transposed - volatile T *local_a_tmp = local_a + BM * local_a_col + local_a_row; + volatile float *local_a_tmp = + reinterpret_cast(local_a) + BM * local_a_col + + local_a_row; static_assert( row_stride_a * 8 <= BM, @@ -163,27 +183,27 @@ inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, for (uint32_t local_row_offset = 0; local_row_offset < BM; local_row_offset += row_stride_a * 8) { // const uint32_t global_a_offset = - // dim_k * (global_a_row + local_row_offset) + (k + local_a_col); + // dim_k_adjusted * (global_a_row + local_row_offset) + (k + local_a_col); // NOTE that SMEM writes are transposed // local_a[BM * (local_a_col) + local_a_row + local_row_offset] = // A[global_a_offset]; asm volatile ("flw ft0, (%0)" :: "r"(global_a)); - global_a += dim_k * row_stride_a; + global_a += dim_k_adjusted * row_stride_a; asm volatile ("flw ft1, (%0)" :: "r"(global_a)); - global_a += dim_k * row_stride_a; + global_a += dim_k_adjusted * row_stride_a; asm volatile ("flw ft2, (%0)" :: "r"(global_a)); - global_a += dim_k * row_stride_a; + global_a += dim_k_adjusted * row_stride_a; asm volatile ("flw ft3, (%0)" :: "r"(global_a)); - global_a += dim_k * row_stride_a; + global_a += dim_k_adjusted * row_stride_a; asm volatile ("flw ft4, (%0)" :: "r"(global_a)); - global_a += dim_k * row_stride_a; + global_a += dim_k_adjusted * row_stride_a; asm volatile ("flw ft5, (%0)" :: "r"(global_a)); - global_a += dim_k * row_stride_a; + global_a += dim_k_adjusted * row_stride_a; asm volatile ("flw ft6, (%0)" :: "r"(global_a)); - global_a += dim_k * row_stride_a; + global_a += dim_k_adjusted * row_stride_a; asm volatile ("flw ft7, (%0)" :: "r"(global_a)); - global_a += dim_k * row_stride_a; + global_a += dim_k_adjusted * row_stride_a; // stride along columns asm volatile ("fsw ft0, %0(%1)" :: "i"(row_stride_a * 0 * sizeof(float)), "r"(local_a_tmp)); @@ -197,62 +217,63 @@ inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, local_a_tmp += row_stride_a * 8; } } - } + } // end move A - constexpr uint32_t row_stride_b = threads_in_threadblock / BN; - const uint32_t global_b_col = BN * threadblock_id_x + local_b_col; - const T *global_b = B + dim_n * (k + local_b_row) + global_b_col; - volatile T *local_b_tmp = local_b + BN * local_b_row + local_b_col; + // move B + constexpr uint32_t row_stride_b = threads_in_threadblock / BN_adjusted; + const uint32_t global_b_col = BN_adjusted * threadblock_id_x + local_b_col; + // NOTE: not k_adjusted here; k is along the row dimension which is not + // compressed for fp16 + const float *global_b = reinterpret_cast(B) + + dim_n_adjusted * (k + local_b_row) + global_b_col; + volatile float *local_b_tmp = reinterpret_cast(local_b) + + BN_adjusted * local_b_row + local_b_col; static_assert( - row_stride_b * 8 <= BK, + row_stride_b * 8 <= BK_adjusted, "manual loop unrolling condition not met; consider increasing BK"); static_assert( - (BK % (row_stride_b * 8)) == 0, + (BK_adjusted % (row_stride_b * 8)) == 0, "manual loop unrolling condition not met; BK should be power-of-two"); #pragma GCC unroll 1 for (uint32_t load_offset = 0; load_offset < BK; load_offset += row_stride_b * 8) { - // const uint32_t global_b_offset = - // dim_n * (k + local_b_row + load_offset) + global_b_col; - // local_b[BN * (local_b_row + load_offset) + local_b_col] = - // B[global_b_offset]; - + // equivalent code: + // // *local_b_tmp = *global_b; - // global_b += dim_n * row_stride_b; // local_b_tmp += BN * row_stride_b; asm volatile ("flw ft0, (%0)" :: "r"(global_b)); - global_b += dim_n * row_stride_b; + global_b += dim_n_adjusted * row_stride_b; asm volatile ("flw ft1, (%0)" :: "r"(global_b)); - global_b += dim_n * row_stride_b; + global_b += dim_n_adjusted * row_stride_b; asm volatile ("flw ft2, (%0)" :: "r"(global_b)); - global_b += dim_n * row_stride_b; + global_b += dim_n_adjusted * row_stride_b; asm volatile ("flw ft3, (%0)" :: "r"(global_b)); - global_b += dim_n * row_stride_b; + global_b += dim_n_adjusted * row_stride_b; asm volatile ("flw ft4, (%0)" :: "r"(global_b)); - global_b += dim_n * row_stride_b; + global_b += dim_n_adjusted * row_stride_b; asm volatile ("flw ft5, (%0)" :: "r"(global_b)); - global_b += dim_n * row_stride_b; + global_b += dim_n_adjusted * row_stride_b; asm volatile ("flw ft6, (%0)" :: "r"(global_b)); - global_b += dim_n * row_stride_b; + global_b += dim_n_adjusted * row_stride_b; asm volatile ("flw ft7, (%0)" :: "r"(global_b)); - global_b += dim_n * row_stride_b; + global_b += dim_n_adjusted * row_stride_b; - asm volatile ("fsw ft0, %0(%1)" :: "i"(BN * row_stride_b * 0 * sizeof(float)), "r"(local_b_tmp)); - asm volatile ("fsw ft1, %0(%1)" :: "i"(BN * row_stride_b * 1 * sizeof(float)), "r"(local_b_tmp)); - local_b_tmp += BN * row_stride_b * 2; - asm volatile ("fsw ft2, %0(%1)" :: "i"(BN * row_stride_b * 0 * sizeof(float)), "r"(local_b_tmp)); - asm volatile ("fsw ft3, %0(%1)" :: "i"(BN * row_stride_b * 1 * sizeof(float)), "r"(local_b_tmp)); - local_b_tmp += BN * row_stride_b * 2; - asm volatile ("fsw ft4, %0(%1)" :: "i"(BN * row_stride_b * 0 * sizeof(float)), "r"(local_b_tmp)); - asm volatile ("fsw ft5, %0(%1)" :: "i"(BN * row_stride_b * 1 * sizeof(float)), "r"(local_b_tmp)); - local_b_tmp += BN * row_stride_b * 2; - asm volatile ("fsw ft6, %0(%1)" :: "i"(BN * row_stride_b * 0 * sizeof(float)), "r"(local_b_tmp)); - asm volatile ("fsw ft7, %0(%1)" :: "i"(BN * row_stride_b * 1 * sizeof(float)), "r"(local_b_tmp)); - local_b_tmp += BN * row_stride_b * 2; + asm volatile ("fsw ft0, %0(%1)" :: "i"(BN_adjusted * row_stride_b * 0 * sizeof(float)), "r"(local_b_tmp)); + asm volatile ("fsw ft1, %0(%1)" :: "i"(BN_adjusted * row_stride_b * 1 * sizeof(float)), "r"(local_b_tmp)); + local_b_tmp += BN_adjusted * row_stride_b * 2; + asm volatile ("fsw ft2, %0(%1)" :: "i"(BN_adjusted * row_stride_b * 0 * sizeof(float)), "r"(local_b_tmp)); + asm volatile ("fsw ft3, %0(%1)" :: "i"(BN_adjusted * row_stride_b * 1 * sizeof(float)), "r"(local_b_tmp)); + local_b_tmp += BN_adjusted * row_stride_b * 2; + asm volatile ("fsw ft4, %0(%1)" :: "i"(BN_adjusted * row_stride_b * 0 * sizeof(float)), "r"(local_b_tmp)); + asm volatile ("fsw ft5, %0(%1)" :: "i"(BN_adjusted * row_stride_b * 1 * sizeof(float)), "r"(local_b_tmp)); + local_b_tmp += BN_adjusted * row_stride_b * 2; + asm volatile ("fsw ft6, %0(%1)" :: "i"(BN_adjusted * row_stride_b * 0 * sizeof(float)), "r"(local_b_tmp)); + asm volatile ("fsw ft7, %0(%1)" :: "i"(BN_adjusted * row_stride_b * 1 * sizeof(float)), "r"(local_b_tmp)); + local_b_tmp += BN_adjusted * row_stride_b * 2; } } @@ -440,8 +461,8 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, #endif } #else - global_dmem_load(dim_n, dim_k, block_k * BK, A, B, local_a, local_b, - tid_in_threadblock, block_n, block_m); + global_dmem_load(dim_n, dim_k, block_k * BK, A, B, local_a, local_b, + tid_in_threadblock, block_n, block_m); threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y); #endif @@ -466,6 +487,7 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, local_a_consume = local_a + (block_k & 1) * (local_a_elems); local_b_consume = local_b + (block_k & 1) * (local_b_elems); } else { + // no double-buffering without DMA local_a_consume = local_a; local_b_consume = local_b; } @@ -477,12 +499,13 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, #pragma GCC unroll 2 for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) { // SMEM -> RF - vx_wmma_load_b(local_b_consume, local_k, warp_col, wn_iter, tid_in_warp); + vx_wmma_load_b(local_b_consume, local_k, warp_col, wn_iter, + tid_in_warp); #pragma GCC unroll 2 for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) { // SMEM -> RF - vx_wmma_load_a(local_a_consume, local_k, warp_row, wm_iter, - tid_in_warp); + vx_wmma_load_a(local_a_consume, local_k, warp_row, wm_iter, + tid_in_warp); // perform mma vx_wmma(wm_iter); } @@ -506,8 +529,8 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) { #pragma GCC unroll 2 for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) { - write_results(tid_in_warp, warp_col, warp_row, wn_iter, - wm_iter, dim_n, C, block_n, block_m); + write_results(tid_in_warp, warp_col, warp_row, wn_iter, wm_iter, + dim_n, C, block_n, block_m); } } } diff --git a/tests/regression/sgemm_tcore/util.hpp b/tests/regression/sgemm_tcore/util.hpp index 4f52fa4a..7950bddc 100644 --- a/tests/regression/sgemm_tcore/util.hpp +++ b/tests/regression/sgemm_tcore/util.hpp @@ -35,20 +35,27 @@ #define BK_LOOP 1 // Whether to transpose smem A tile at GMEM->SMEM (produce), or SMEM->RF // (consume). This is because the tensor core expects the A tile to be stored -// in column-major order in SMEM, whereas it is stored row-major in GMEM. +// in column-major order in SMEM, whereas it will be ultimately stored in +// row-major in the RF. // -// For correctness, only one of either should be 1. To model the case where -// the A matrix is already stored transposed in GMEM ("TN" kernel), set -// both to 0. -// -// For reference, PRODUCE 1 CONSUME 0 generates the performant NN kernel. +// For correctness, only one of either should be 1. E.g., PRODUCE 1 CONSUME 0 +// generates the NN kernel where both A and B are stored row-major in GMEM. +// To model the case where the A matrix is already stored transposed in GMEM +// ("TN" kernel), set both to 0. #define TRANSPOSE_AT_PRODUCE 1 #define TRANSPOSE_AT_CONSUME 0 -// GMEM_COALESCED sets bank conflict-free accesses for -// 1: GMEM loads of A matrix -// 0: SMEM stores of A matrix +// GMEM_COALESCED: When TRANSPOSE_AT_PRODUCE == 1 (i.e. transpose at +// GMEM->SMEM), determines whether we do bank-conflict-free accesses for +// 1: GMEM loads of A matrix, or +// 0: SMEM stores of A matrix. +// +// Usually, GMEM_COALESCED==1 yields better performance since the memory +// behavior of GMEM is more sensitive to bank conflicts. #define GMEM_COALESCED_A 1 +// "fake" fp16 type that only has the correct data width. +using float16_t = uint16_t; + inline constexpr void map_operand_32lanes(const int tid, int &row, int &col) { const int tg = tid / 4; @@ -153,14 +160,23 @@ inline void vx_wmma_load_a(volatile const T *smem_A, const int local_k, const int tid = thread_in_warp; const int tg = tid / 4; - // TODO: this is duplicately computed between vx_wmma_load_a and vx_wmma_load_b + // @perf: this is duplicately computed in vx_wmma_load_a and vx_wmma_load_b int row = 0; int col = 0; map_operand(tid, row, col); + // In fp16 mode, bit-pack two fp16 elements into each fp32 element, and do + // data movement at the fp32 granularity. Assuming that the matrix is stored + // row-major in GMEM, the packed fp16 pairs belong to the same row, + // neighboring columns; therefore, it essentially becomes equivalent to + // moving a fp32 matrix whose column dimensions (dim_k/BK/k) are compressed + // by a factor of two. + constexpr uint32_t packed_factor = (std::is_same_v ? 2 : 1); + constexpr uint32_t BK_adjusted = BK / packed_factor; + constexpr int smem_A_rows = BM; - constexpr int smem_A_cols = BK; - constexpr int smem_AS_rows = BK; + constexpr int smem_A_cols = BK_adjusted; + constexpr int smem_AS_rows = BK_adjusted; constexpr int smem_AS_cols = BM; if constexpr (TRANSPOSE_AT_CONSUME) { @@ -170,11 +186,11 @@ inline void vx_wmma_load_a(volatile const T *smem_A, const int local_k, // f8-f15 stores a single row of A const volatile uint8_t *smem_addr; smem_addr = reinterpret_cast( - &smem_A[(WM * warp_row + TCM * wm_iter + row) * smem_A_cols + local_k]); - // NOTE: stride is fixed to word size , i.e. sizeof(float) = 4, - // regardless of fp16 or fp32. Since Vortex core does not support fp16, - // load things at word granularity and reinterpret bits inside the tensor - // core. + &reinterpret_cast( + smem_A)[(WM * warp_row + TCM * wm_iter + row) * smem_A_cols + + local_k]); + // step to the next column + // threads read from different rows; bank conflicts asm volatile("flw f0, %0(%1)" ::"i"(0 * sizeof(float)), "r"(smem_addr)); asm volatile("flw f1, %0(%1)" ::"i"(1 * sizeof(float)), "r"(smem_addr)); asm volatile("flw f2, %0(%1)" ::"i"(2 * sizeof(float)), "r"(smem_addr)); @@ -183,21 +199,17 @@ inline void vx_wmma_load_a(volatile const T *smem_A, const int local_k, asm volatile("flw f5, %0(%1)" ::"i"(5 * sizeof(float)), "r"(smem_addr)); asm volatile("flw f6, %0(%1)" ::"i"(6 * sizeof(float)), "r"(smem_addr)); asm volatile("flw f7, %0(%1)" ::"i"(7 * sizeof(float)), "r"(smem_addr)); - // asm volatile("flw f0, %0" ::"m"(smem_A[A_offset + (local_k + 0)])); - // asm volatile("flw f1, %0" ::"m"(smem_A[A_offset + (local_k + 1)])); - // asm volatile("flw f2, %0" ::"m"(smem_A[A_offset + (local_k + 2)])); - // asm volatile("flw f3, %0" ::"m"(smem_A[A_offset + (local_k + 3)])); - // asm volatile("flw f4, %0" ::"m"(smem_A[A_offset + (local_k + 4)])); - // asm volatile("flw f5, %0" ::"m"(smem_A[A_offset + (local_k + 5)])); - // asm volatile("flw f6, %0" ::"m"(smem_A[A_offset + (local_k + 6)])); - // asm volatile("flw f7, %0" ::"m"(smem_A[A_offset + (local_k + 7)])); } else { // read smem A tile as-is; bank-conflict-free AS load + // smem A tile is stored column-major // f8-f15 stores a single row of A const volatile uint8_t *smem_addr; smem_addr = reinterpret_cast( - &smem_A[((local_k + 0) * smem_AS_cols) + - (WM * warp_row + TCM * wm_iter) + row]); + &reinterpret_cast( + smem_A)[((local_k + 0) * smem_AS_cols) + + (WM * warp_row + TCM * wm_iter) + row]); + // step to the next row + // threads read from different columns; no bank conflicts asm volatile("flw f0, %0(%1)" :: "i"(smem_AS_cols * 0 * sizeof(float)), "r"(smem_addr)); asm volatile("flw f1, %0(%1)" :: "i"(smem_AS_cols * 1 * sizeof(float)), "r"(smem_addr)); asm volatile("flw f2, %0(%1)" :: "i"(smem_AS_cols * 2 * sizeof(float)), "r"(smem_addr)); @@ -206,15 +218,6 @@ inline void vx_wmma_load_a(volatile const T *smem_A, const int local_k, asm volatile("flw f5, %0(%1)" :: "i"(smem_AS_cols * 5 * sizeof(float)), "r"(smem_addr)); asm volatile("flw f6, %0(%1)" :: "i"(smem_AS_cols * 6 * sizeof(float)), "r"(smem_addr)); asm volatile("flw f7, %0(%1)" :: "i"(smem_AS_cols * 7 * sizeof(float)), "r"(smem_addr)); - - // asm volatile("flw f0, %0" ::"m"(smem_A[((local_k + 0) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row])); - // asm volatile("flw f1, %0" ::"m"(smem_A[((local_k + 1) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row])); - // asm volatile("flw f2, %0" ::"m"(smem_A[((local_k + 2) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row])); - // asm volatile("flw f3, %0" ::"m"(smem_A[((local_k + 3) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row])); - // asm volatile("flw f4, %0" ::"m"(smem_A[((local_k + 4) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row])); - // asm volatile("flw f5, %0" ::"m"(smem_A[((local_k + 5) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row])); - // asm volatile("flw f6, %0" ::"m"(smem_A[((local_k + 6) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row])); - // asm volatile("flw f7, %0" ::"m"(smem_A[((local_k + 7) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row])); } } @@ -230,14 +233,21 @@ inline void vx_wmma_load_b(const volatile T *smem_B, const int local_k, int col = 0; map_operand(tid, row, col); + // see comment in vx_wmma_load_a + constexpr uint32_t packed_factor = (std::is_same_v ? 2 : 1); + constexpr uint32_t BN_adjusted = BN / packed_factor; + constexpr int smem_B_rows = BK; - constexpr int smem_B_cols = BN; + constexpr int smem_B_cols = BN_adjusted; // f8-f15 stores a single column of B const volatile uint8_t *smem_addr; smem_addr = reinterpret_cast( - &smem_B[((local_k + 0) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + - col]); + &reinterpret_cast( + smem_B)[((local_k + 0) * smem_B_cols) + + (WN * warp_col + TCN * wn_iter) + col]); + // step to the next row + // threads read from different columns; no bank conflicts asm volatile("flw f8, %0(%1)" :: "i"(smem_B_cols * 0 * sizeof(float)), "r"(smem_addr)); asm volatile("flw f9, %0(%1)" :: "i"(smem_B_cols * 1 * sizeof(float)), "r"(smem_addr)); asm volatile("flw f10, %0(%1)" :: "i"(smem_B_cols * 2 * sizeof(float)), "r"(smem_addr)); @@ -246,15 +256,6 @@ inline void vx_wmma_load_b(const volatile T *smem_B, const int local_k, asm volatile("flw f13, %0(%1)" :: "i"(smem_B_cols * 5 * sizeof(float)), "r"(smem_addr)); asm volatile("flw f14, %0(%1)" :: "i"(smem_B_cols * 6 * sizeof(float)), "r"(smem_addr)); asm volatile("flw f15, %0(%1)" :: "i"(smem_B_cols * 7 * sizeof(float)), "r"(smem_addr)); - - // asm volatile("flw f8, %0" ::"m"(smem_B[((local_k + 0) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); - // asm volatile("flw f9, %0" ::"m"(smem_B[((local_k + 1) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); - // asm volatile("flw f10, %0" ::"m"(smem_B[((local_k + 2) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); - // asm volatile("flw f11, %0" ::"m"(smem_B[((local_k + 3) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); - // asm volatile("flw f12, %0" ::"m"(smem_B[((local_k + 4) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); - // asm volatile("flw f13, %0" ::"m"(smem_B[((local_k + 5) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); - // asm volatile("flw f14, %0" ::"m"(smem_B[((local_k + 6) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); - // asm volatile("flw f15, %0" ::"m"(smem_B[((local_k + 7) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); } inline void initialize_C(const int dest_reg) { @@ -280,11 +281,10 @@ inline void initialize_C(const int dest_reg) { } } -template inline void write_results(const int thread_in_warp, const int warp_col, const int warp_row, const int wn_iter, const int wm_iter, const int dim_n, - T *C, const int threadblock_id_x, + float *C, const int threadblock_id_x, const int threadblock_id_y) { int tid = thread_in_warp; @@ -296,14 +296,14 @@ inline void write_results(const int thread_in_warp, const int warp_col, int local_row = (WM * warp_row + TCM * wm_iter) + tid_row; int local_col = (WN * warp_col + TCN * wn_iter) + tid_col; - T *global_offset_C = + float *global_offset_C = C + (BM * threadblock_id_y) * dim_n + BN * threadblock_id_x; // @perf: this likely causes a lot of gmem bank conflicts if (wm_iter == 0) { volatile uint8_t *gmem_addr = reinterpret_cast( &global_offset_C[dim_n * (local_row + 0) + (local_col + 0)]); - volatile uint8_t *gmem_addr_tmp = gmem_addr + (2 * dim_n) * sizeof(T); + volatile uint8_t *gmem_addr_tmp = gmem_addr + (2 * dim_n) * sizeof(float); asm volatile ("fsw f16, %0(%1)" :: "i"(0 * sizeof(float)), "r"(gmem_addr)); asm volatile ("fsw f17, %0(%1)" :: "i"(1 * sizeof(float)), "r"(gmem_addr)); asm volatile ("fsw f18, %0(%1)" :: "i"(0 * sizeof(float)), "r"(gmem_addr_tmp)); @@ -323,7 +323,7 @@ inline void write_results(const int thread_in_warp, const int warp_col, } else { volatile uint8_t *gmem_addr = reinterpret_cast( &global_offset_C[dim_n * (local_row + 0) + (local_col + 0)]); - volatile uint8_t *gmem_addr_tmp = gmem_addr + (2 * dim_n) * sizeof(T); + volatile uint8_t *gmem_addr_tmp = gmem_addr + (2 * dim_n) * sizeof(float); asm volatile ("fsw f24, %0(%1)" :: "i"(0 * sizeof(float)), "r"(gmem_addr)); asm volatile ("fsw f25, %0(%1)" :: "i"(1 * sizeof(float)), "r"(gmem_addr)); asm volatile ("fsw f26, %0(%1)" :: "i"(0 * sizeof(float)), "r"(gmem_addr_tmp));