diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp index d26bae36..4838e9d8 100644 --- a/tests/regression/sgemm_tcore/kernel.cpp +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -339,9 +339,6 @@ global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, const uint32_t k, volatile float *local_b, const uint32_t tid_in_threadblock, const uint32_t threadblock_id_x, const uint32_t threadblock_id_y) { - constexpr uint32_t BM_d = BM; - constexpr uint32_t BN_d = BN; - const uint32_t local_a_row = tid_in_threadblock / BK; const uint32_t local_a_col = tid_in_threadblock % BK; const uint32_t local_as_row = tid_in_threadblock / BM; @@ -359,14 +356,16 @@ global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, const uint32_t k, // // TODO: Sharedmem swizzling is important here if constexpr (!TRANSPOSE_AS) { - const uint32_t global_a_row = BM_d * threadblock_id_y + local_a_row; + // FIXME: !TRANSPOSE_AS code is old + + const uint32_t global_a_row = BM * threadblock_id_y + local_a_row; // number of rows a full TB can read at a time constexpr uint32_t row_stride_a = threads_in_warpgroup / BK; const float *global_a = A + dim_k * global_a_row + (k + local_a_col); volatile float *local_a_tmp = local_a + BK * local_a_row + local_a_col; #pragma GCC unroll 2 - for (uint32_t local_row_offset = 0; local_row_offset < BM_d; + for (uint32_t local_row_offset = 0; local_row_offset < BM; local_row_offset += row_stride_a) { // const uint32_t global_a_offset = // dim_k * (global_a_row + local_row_offset) + (k + local_a_col); @@ -379,13 +378,13 @@ global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, const uint32_t k, } } else { if constexpr (!GMEM_COALESCED_A) { - constexpr uint32_t row_stride_as = threads_in_warpgroup / BM_d; - const uint32_t global_a_row = BM_d * threadblock_id_y + local_as_col; + constexpr uint32_t row_stride_as = threads_in_warpgroup / BM; + const uint32_t global_a_row = BM * threadblock_id_y + local_as_col; const float *global_a = A + dim_k * global_a_row + (k + local_as_row); // FIXME experimenting with global coalescing - // const uint32_t global_a_row = BM_d * threadblock_id_y + local_as_row; + // const uint32_t global_a_row = BM * threadblock_id_y + local_as_row; // const float *global_a = A + dim_k * global_a_row + (k + local_as_col); - volatile float *local_a_tmp = local_a + BM_d * local_as_row + local_as_col; + volatile float *local_a_tmp = local_a + BM * local_as_row + local_as_col; static_assert( row_stride_as * 8 <= BK, @@ -403,7 +402,7 @@ global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, const uint32_t k, // FIXME experimenting with global coalescing // const uint32_t global_a_offset = // dim_k * (global_a_row + local_row_offset) + (k + local_as_col); - // local_a[BM_d * (local_as_row + local_row_offset) + local_as_col] = + // local_a[BM * (local_as_row + local_row_offset) + local_as_col] = // A[global_a_offset]; // *local_a_tmp = *global_a; @@ -436,25 +435,25 @@ global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, const uint32_t k, } } else { constexpr uint32_t row_stride_a = threads_in_warpgroup / BK; - const uint32_t global_a_row = BM_d * threadblock_id_y + local_a_row; + const uint32_t global_a_row = BM * threadblock_id_y + local_a_row; const float *global_a = A + dim_k * global_a_row + (k + local_a_col); // NOTE that SMEM writes are transposed - volatile float *local_a_tmp = local_a + BM_d * local_a_col + local_a_row; + volatile float *local_a_tmp = local_a + BM * local_a_col + local_a_row; static_assert( - row_stride_a * 8 <= BM_d, + row_stride_a * 8 <= BM, "manual loop unrolling condition not met; consider increasing BM"); static_assert( - (BM_d % (row_stride_a * 8)) == 0, + (BM % (row_stride_a * 8)) == 0, "manual loop unrolling condition not met; BM should be power-of-two"); #pragma GCC unroll 4 - for (uint32_t local_row_offset = 0; local_row_offset < BM_d; + for (uint32_t local_row_offset = 0; local_row_offset < BM; local_row_offset += row_stride_a * 8) { // const uint32_t global_a_offset = // dim_k * (global_a_row + local_row_offset) + (k + local_a_col); // NOTE that SMEM writes are transposed - // local_a[BM_d * (local_a_col) + local_a_row + local_row_offset] = + // local_a[BM * (local_a_col) + local_a_row + local_row_offset] = // A[global_a_offset]; asm volatile ("flw ft0, (%0)" :: "r"(global_a)); @@ -488,10 +487,10 @@ global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, const uint32_t k, } } - constexpr uint32_t row_stride_b = threads_in_warpgroup / BN_d; - const uint32_t global_b_col = BN_d * threadblock_id_x + local_b_col; + constexpr uint32_t row_stride_b = threads_in_warpgroup / BN; + const uint32_t global_b_col = BN * threadblock_id_x + local_b_col; const float *global_b = B + dim_n * (k + local_b_row) + global_b_col; - volatile float *local_b_tmp = local_b + BN_d * local_b_row + local_b_col; + volatile float *local_b_tmp = local_b + BN * local_b_row + local_b_col; static_assert( row_stride_b * 8 <= BK, @@ -505,13 +504,13 @@ global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, const uint32_t k, load_offset += row_stride_b * 8) { // const uint32_t global_b_offset = // dim_n * (k + local_b_row + load_offset) + global_b_col; - // local_b[BN_d * (local_b_row + load_offset) + local_b_col] = + // local_b[BN * (local_b_row + load_offset) + local_b_col] = // B[global_b_offset]; // *local_b_tmp = *global_b; // global_b += dim_n * row_stride_b; - // local_b_tmp += BN_d * row_stride_b; + // local_b_tmp += BN * row_stride_b; asm volatile ("flw ft0, (%0)" :: "r"(global_b)); global_b += dim_n * row_stride_b; @@ -530,15 +529,15 @@ global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, const uint32_t k, asm volatile ("flw ft7, (%0)" :: "r"(global_b)); global_b += dim_n * row_stride_b; - asm volatile ("fsw ft0, %0(%1)" :: "i"(BN_d * row_stride_b * 0 * sizeof(float)), "r"(local_b_tmp)); - asm volatile ("fsw ft1, %0(%1)" :: "i"(BN_d * row_stride_b * 1 * sizeof(float)), "r"(local_b_tmp)); - asm volatile ("fsw ft2, %0(%1)" :: "i"(BN_d * row_stride_b * 2 * sizeof(float)), "r"(local_b_tmp)); - asm volatile ("fsw ft3, %0(%1)" :: "i"(BN_d * row_stride_b * 3 * sizeof(float)), "r"(local_b_tmp)); - asm volatile ("fsw ft4, %0(%1)" :: "i"(BN_d * row_stride_b * 4 * sizeof(float)), "r"(local_b_tmp)); - asm volatile ("fsw ft5, %0(%1)" :: "i"(BN_d * row_stride_b * 5 * sizeof(float)), "r"(local_b_tmp)); - asm volatile ("fsw ft6, %0(%1)" :: "i"(BN_d * row_stride_b * 6 * sizeof(float)), "r"(local_b_tmp)); - asm volatile ("fsw ft7, %0(%1)" :: "i"(BN_d * row_stride_b * 7 * sizeof(float)), "r"(local_b_tmp)); - local_b_tmp += BN_d * row_stride_b * 8; + asm volatile ("fsw ft0, %0(%1)" :: "i"(BN * row_stride_b * 0 * sizeof(float)), "r"(local_b_tmp)); + asm volatile ("fsw ft1, %0(%1)" :: "i"(BN * row_stride_b * 1 * sizeof(float)), "r"(local_b_tmp)); + asm volatile ("fsw ft2, %0(%1)" :: "i"(BN * row_stride_b * 2 * sizeof(float)), "r"(local_b_tmp)); + asm volatile ("fsw ft3, %0(%1)" :: "i"(BN * row_stride_b * 3 * sizeof(float)), "r"(local_b_tmp)); + asm volatile ("fsw ft4, %0(%1)" :: "i"(BN * row_stride_b * 4 * sizeof(float)), "r"(local_b_tmp)); + asm volatile ("fsw ft5, %0(%1)" :: "i"(BN * row_stride_b * 5 * sizeof(float)), "r"(local_b_tmp)); + asm volatile ("fsw ft6, %0(%1)" :: "i"(BN * row_stride_b * 6 * sizeof(float)), "r"(local_b_tmp)); + asm volatile ("fsw ft7, %0(%1)" :: "i"(BN * row_stride_b * 7 * sizeof(float)), "r"(local_b_tmp)); + local_b_tmp += BN * row_stride_b * 8; } }