diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp index 6c677326..760c8467 100644 --- a/tests/regression/sgemm_tcore/kernel.cpp +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -370,43 +370,64 @@ global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, const uint32_t k, const uint32_t global_a_row = BM_d * threadblock_id_y + local_a_row; // number of rows a full TB can read at a time constexpr uint32_t row_stride_a = threads_in_warpgroup / BK; -#pragma GCC unroll 1 + const float *global_a = A + dim_k * global_a_row + (k + local_a_col); + volatile float *local_a_tmp = local_a + BK * local_a_row + local_a_col; + +#pragma GCC unroll 2 for (uint32_t local_row_offset = 0; local_row_offset < BM_d; local_row_offset += row_stride_a) { - const uint32_t global_a_offset = - dim_k * (global_a_row + local_row_offset) + (k + local_a_col); - // NOTE: all threads in TB will do this load; make sure this is not - // out-of-bounds of BM_d*BK - local_a[BK * (local_a_row + local_row_offset) + local_a_col] = - A[global_a_offset]; + // const uint32_t global_a_offset = + // dim_k * (global_a_row + local_row_offset) + (k + local_a_col); + // local_a[BK * (local_a_row + local_row_offset) + local_a_col] = + // A[global_a_offset]; + *local_a_tmp = *global_a; + + global_a += dim_k * row_stride_a; + local_a_tmp += BK * row_stride_a; } } else { const uint32_t global_a_row = BM_d * threadblock_id_y + local_as_col; // const uint32_t global_a_row = BM_d * threadblock_id_y + local_as_row; constexpr uint32_t row_stride_as = threads_in_warpgroup / BM_d; -#pragma GCC unroll 4 + const float *global_a = A + dim_k * global_a_row + (k + local_as_row); + volatile float *local_a_tmp = local_a + BM_d * local_as_row + local_as_col; + +#pragma GCC ivdep for (uint32_t local_row_offset = 0; local_row_offset < BK; local_row_offset += row_stride_as) { // @perf: bank conflicts here - const uint32_t global_a_offset = - dim_k * (global_a_row) + (k + local_as_row + local_row_offset); + // const uint32_t global_a_offset = + // dim_k * (global_a_row) + (k + local_as_row + local_row_offset); // FIXME experimenting with global coalescing // const uint32_t global_a_offset = // dim_k * (global_a_row + local_row_offset) + (k + local_as_col); - local_a[BM_d * (local_as_row + local_row_offset) + local_as_col] = - A[global_a_offset]; + // local_a[BM_d * (local_as_row + local_row_offset) + local_as_col] = + // A[global_a_offset]; + + *local_a_tmp = *global_a; + + global_a += row_stride_as; + local_a_tmp += BM * row_stride_as; } } constexpr uint32_t row_stride_b = threads_in_warpgroup / BN_d; const uint32_t global_b_col = BN_d * threadblock_id_x + local_b_col; -#pragma GCC unroll 2 + const float *global_b = B + dim_n * (k + local_b_row) + global_b_col; + volatile float *local_b_tmp = local_b + BN_d * local_b_row + local_b_col; + +#pragma GCC ivdep for (uint32_t load_offset = 0; load_offset < BK; load_offset += row_stride_b) { - const uint32_t global_b_offset = - dim_n * (k + local_b_row + load_offset) + global_b_col; - local_b[BN_d * (local_b_row + load_offset) + local_b_col] = - B[global_b_offset]; + // const uint32_t global_b_offset = + // dim_n * (k + local_b_row + load_offset) + global_b_col; + // local_b[BN_d * (local_b_row + load_offset) + local_b_col] = + // B[global_b_offset]; + + *local_b_tmp = *global_b; + + global_b += dim_n * row_stride_b; + local_b_tmp += BN_d * row_stride_b; } } @@ -480,6 +501,10 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, #pragma GCC unroll 1 for (uint32_t k = 0; k < dim_k; k += BK) { + // register volatile float *local_a_produce asm("t0"); + // register volatile float *local_b_produce asm("t1"); + // register volatile float *local_a_consume asm("t2"); + // register volatile float *local_b_consume asm("t3"); volatile float *local_a_produce; volatile float *local_b_produce; volatile float *local_a_consume;