diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp index 5e048fc5..4ded4758 100644 --- a/tests/regression/sgemm_tcore/kernel.cpp +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -239,9 +239,69 @@ inline void write_results(volatile float *local_warp_results, asm volatile ("fsw f23, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 5)])); } -void threadblock_barrier(unsigned int tid_in_threadblock, unsigned int barrier_id, unsigned int count) { - vx_fence(); - vx_barrier(barrier_id, count); +inline void threadblock_barrier(unsigned int tid_in_threadblock, + unsigned int barrier_id, unsigned int count) { + vx_fence(); + vx_barrier(barrier_id, count); +} + +inline void +global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, const uint32_t k, + const float *A, const float *B, volatile float *local_a, + volatile float *local_b, const uint32_t threadblock_id_x, + const uint32_t threadblock_id_y, const uint32_t local_a_row, + const uint32_t local_a_col, const uint32_t local_as_row, + const uint32_t local_as_col, const uint32_t local_b_row, + const uint32_t local_b_col) { + + // Data move from GMEM to SMEM + // + // Make sure global offset values for A and B are contiguous between + // neighboring threads to ensure GMEM coalescing. + // + // TODO: Sharedmem swizzling is important here + if constexpr (!TRANSPOSE_AS) { + const uint32_t global_a_row = BM * threadblock_id_y + local_a_row; + // number of rows a full TB can read at a time + constexpr uint32_t row_stride_a = (BM * BN) / ELEM_PER_THREAD / BK; +#pragma GCC unroll 1 + for (uint32_t local_row_offset = 0; local_row_offset < BM; + local_row_offset += row_stride_a) { + const uint32_t global_a_offset = + dim_k * (global_a_row + local_row_offset) + (k + local_a_col); + // NOTE: all threads in TB will do this load; make sure this is not + // out-of-bounds of BM*BK + local_a[BK * (local_a_row + local_row_offset) + local_a_col] = + A[global_a_offset]; + } + } else { + const uint32_t global_a_row = BM * threadblock_id_y + local_as_col; + // const uint32_t global_a_row = BM * threadblock_id_y + local_as_row; + constexpr uint32_t row_stride_as = (BM * BN) / ELEM_PER_THREAD / BM; +#pragma GCC unroll 1 + for (uint32_t local_row_offset = 0; local_row_offset < BK; + local_row_offset += row_stride_as) { + // @perf: bank conflicts here + const uint32_t global_a_offset = + dim_k * (global_a_row) + (k + local_as_row + local_row_offset); + // FIXME experimenting with global coalescing + // const uint32_t global_a_offset = + // dim_k * (global_a_row + local_row_offset) + (k + local_as_col); + local_a[BM * (local_as_row + local_row_offset) + local_as_col] = + A[global_a_offset]; + } + } + + constexpr uint32_t row_stride_b = (BM * BN) / ELEM_PER_THREAD / BN; + const uint32_t global_b_col = BN * threadblock_id_x + local_b_col; +#pragma GCC unroll 1 + for (uint32_t load_offset = 0; load_offset < BK; + load_offset += row_stride_b) { + const uint32_t global_b_offset = + dim_n * (k + local_b_row + load_offset) + global_b_col; + local_b[BN * (local_b_row + load_offset) + local_b_col] = + B[global_b_offset]; + } } void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, @@ -293,49 +353,10 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, #pragma GCC unroll 1 for (uint32_t k = 0; k < dim_k; k += BK) { - // Data move from GMEM to SMEM - // - // Make sure global offset values for A and B are contiguous between - // neighboring threads to ensure GMEM coalescing. - // - // TODO: Sharedmem swizzling is important here - if constexpr (!TRANSPOSE_AS) { - const uint32_t global_a_row = BM * threadblock_id_y + local_a_row; - // number of rows a full TB can read at a time - constexpr uint32_t row_stride_a = (BM * BN) / ELEM_PER_THREAD / BK; -#pragma GCC unroll 1 - for (uint32_t local_row_offset = 0; local_row_offset < BM; - local_row_offset += row_stride_a) { - const uint32_t global_a_offset = - dim_k * (global_a_row + local_row_offset) + (k + local_a_col); - // NOTE: all threads in TB will do this load; make sure this is not - // out-of-bounds of BM*BK - local_a[BK * (local_a_row + local_row_offset) + local_a_col] = - A[global_a_offset]; - } - } else { - const uint32_t global_a_row = BM * threadblock_id_y + local_as_col; - constexpr uint32_t row_stride_as = (BM * BN) / ELEM_PER_THREAD / BM; -#pragma GCC unroll 1 - for (uint32_t local_row_offset = 0; local_row_offset < BK; - local_row_offset += row_stride_as) { - // @perf: bank conflicts here - const uint32_t global_a_offset = - dim_k * (global_a_row) + (k + local_as_row + local_row_offset); - local_a[BM * (local_as_row + local_row_offset) + local_as_col] = - A[global_a_offset]; - } - } - - constexpr uint32_t row_stride_b = (BM * BN) / ELEM_PER_THREAD / BN; - const uint32_t global_b_col = BN * threadblock_id_x + local_b_col; -#pragma GCC unroll 1 - for (uint32_t load_offset = 0; load_offset < BK; load_offset += row_stride_b) { - const uint32_t global_b_offset = - dim_n * (k + local_b_row + load_offset) + global_b_col; - local_b[BN * (local_b_row + load_offset) + local_b_col] = - B[global_b_offset]; - } + global_dmem_load(dim_n, dim_k, k, A, B, local_a, local_b, + threadblock_id_x, threadblock_id_y, local_a_row, + local_a_col, local_as_row, local_as_col, local_b_row, + local_b_col); threadblock_barrier(tid_in_threadblock, threadblock_id_in_cluster, threadblock_dim_y); @@ -370,8 +391,8 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, // asm volatile("addi a0, a0, 0"); // } // SMEM -> RF - vx_wmma_load(local_a, local_b, local_k, warp_col, warp_row, wn_iter, - wm_iter, tid_in_warp); + vx_wmma_load(local_a, local_b, local_k, warp_col, warp_row, + wn_iter, wm_iter, tid_in_warp); // compute vx_wmma(); #if TC_SINGLE_WARP @@ -382,6 +403,9 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, } } + threadblock_barrier(tid_in_threadblock, threadblock_id_in_cluster, + threadblock_dim_y); + #else // Compute single tile*tile matmul @@ -413,10 +437,10 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, } } } -#endif threadblock_barrier(tid_in_threadblock, threadblock_id_in_cluster, threadblock_dim_y); +#endif } #if USE_TENSOR_CORE