diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp index 760c8467..e5a9cf33 100644 --- a/tests/regression/sgemm_tcore/kernel.cpp +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -386,15 +386,21 @@ global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, const uint32_t k, local_a_tmp += BK * row_stride_a; } } else { - const uint32_t global_a_row = BM_d * threadblock_id_y + local_as_col; - // const uint32_t global_a_row = BM_d * threadblock_id_y + local_as_row; constexpr uint32_t row_stride_as = threads_in_warpgroup / BM_d; + const uint32_t global_a_row = BM_d * threadblock_id_y + local_as_col; const float *global_a = A + dim_k * global_a_row + (k + local_as_row); + // FIXME experimenting with global coalescing + // const uint32_t global_a_row = BM_d * threadblock_id_y + local_as_row; + // const float *global_a = A + dim_k * global_a_row + (k + local_as_col); volatile float *local_a_tmp = local_a + BM_d * local_as_row + local_as_col; + static_assert( + row_stride_as * 8 <= BK, + "manual loop unrolling condition not met; consider increasing BK"); + #pragma GCC ivdep for (uint32_t local_row_offset = 0; local_row_offset < BK; - local_row_offset += row_stride_as) { + local_row_offset += row_stride_as * 8) { // @perf: bank conflicts here // const uint32_t global_a_offset = // dim_k * (global_a_row) + (k + local_as_row + local_row_offset); @@ -404,10 +410,33 @@ global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, const uint32_t k, // local_a[BM_d * (local_as_row + local_row_offset) + local_as_col] = // A[global_a_offset]; - *local_a_tmp = *global_a; - + // *local_a_tmp = *global_a; + asm volatile ("flw ft0, (%0)" :: "r"(global_a)); global_a += row_stride_as; - local_a_tmp += BM * row_stride_as; + asm volatile ("flw ft1, (%0)" :: "r"(global_a)); + global_a += row_stride_as; + asm volatile ("flw ft2, (%0)" :: "r"(global_a)); + global_a += row_stride_as; + asm volatile ("flw ft3, (%0)" :: "r"(global_a)); + global_a += row_stride_as; + asm volatile ("flw ft4, (%0)" :: "r"(global_a)); + global_a += row_stride_as; + asm volatile ("flw ft5, (%0)" :: "r"(global_a)); + global_a += row_stride_as; + asm volatile ("flw ft6, (%0)" :: "r"(global_a)); + global_a += row_stride_as; + asm volatile ("flw ft7, (%0)" :: "r"(global_a)); + global_a += row_stride_as; + + asm volatile ("fsw ft0, %0(%1)" :: "i"(BM * row_stride_as * 0 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft1, %0(%1)" :: "i"(BM * row_stride_as * 1 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft2, %0(%1)" :: "i"(BM * row_stride_as * 2 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft3, %0(%1)" :: "i"(BM * row_stride_as * 3 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft4, %0(%1)" :: "i"(BM * row_stride_as * 4 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft5, %0(%1)" :: "i"(BM * row_stride_as * 5 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft6, %0(%1)" :: "i"(BM * row_stride_as * 6 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft7, %0(%1)" :: "i"(BM * row_stride_as * 7 * sizeof(float)), "r"(local_a_tmp)); + local_a_tmp += BM * row_stride_as * 8; } } @@ -416,18 +445,49 @@ global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, const uint32_t k, const float *global_b = B + dim_n * (k + local_b_row) + global_b_col; volatile float *local_b_tmp = local_b + BN_d * local_b_row + local_b_col; + static_assert( + row_stride_b * 8 <= BK, + "manual loop unrolling condition not met; consider increasing BK"); + #pragma GCC ivdep for (uint32_t load_offset = 0; load_offset < BK; - load_offset += row_stride_b) { + load_offset += row_stride_b * 8) { // const uint32_t global_b_offset = // dim_n * (k + local_b_row + load_offset) + global_b_col; // local_b[BN_d * (local_b_row + load_offset) + local_b_col] = // B[global_b_offset]; - *local_b_tmp = *global_b; + // *local_b_tmp = *global_b; + // global_b += dim_n * row_stride_b; + // local_b_tmp += BN_d * row_stride_b; + + asm volatile ("flw ft0, (%0)" :: "r"(global_b)); global_b += dim_n * row_stride_b; - local_b_tmp += BN_d * row_stride_b; + asm volatile ("flw ft1, (%0)" :: "r"(global_b)); + global_b += dim_n * row_stride_b; + asm volatile ("flw ft2, (%0)" :: "r"(global_b)); + global_b += dim_n * row_stride_b; + asm volatile ("flw ft3, (%0)" :: "r"(global_b)); + global_b += dim_n * row_stride_b; + asm volatile ("flw ft4, (%0)" :: "r"(global_b)); + global_b += dim_n * row_stride_b; + asm volatile ("flw ft5, (%0)" :: "r"(global_b)); + global_b += dim_n * row_stride_b; + asm volatile ("flw ft6, (%0)" :: "r"(global_b)); + global_b += dim_n * row_stride_b; + asm volatile ("flw ft7, (%0)" :: "r"(global_b)); + global_b += dim_n * row_stride_b; + + asm volatile ("fsw ft0, %0(%1)" :: "i"(BN_d * row_stride_b * 0 * sizeof(float)), "r"(local_b_tmp)); + asm volatile ("fsw ft1, %0(%1)" :: "i"(BN_d * row_stride_b * 1 * sizeof(float)), "r"(local_b_tmp)); + asm volatile ("fsw ft2, %0(%1)" :: "i"(BN_d * row_stride_b * 2 * sizeof(float)), "r"(local_b_tmp)); + asm volatile ("fsw ft3, %0(%1)" :: "i"(BN_d * row_stride_b * 3 * sizeof(float)), "r"(local_b_tmp)); + asm volatile ("fsw ft4, %0(%1)" :: "i"(BN_d * row_stride_b * 4 * sizeof(float)), "r"(local_b_tmp)); + asm volatile ("fsw ft5, %0(%1)" :: "i"(BN_d * row_stride_b * 5 * sizeof(float)), "r"(local_b_tmp)); + asm volatile ("fsw ft6, %0(%1)" :: "i"(BN_d * row_stride_b * 6 * sizeof(float)), "r"(local_b_tmp)); + asm volatile ("fsw ft7, %0(%1)" :: "i"(BN_d * row_stride_b * 7 * sizeof(float)), "r"(local_b_tmp)); + local_b_tmp += BN_d * row_stride_b * 8; } } @@ -514,6 +574,8 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, local_b_produce = (k_index % 2) ? local_b : local_b_buf; local_a_consume = (k_index % 2) ? local_a_buf : local_a; local_b_consume = (k_index % 2) ? local_b_buf : local_b; + // local_a_consume = local_a_produce; + // local_b_consume = local_b_produce; } else { local_a_produce = local_a; local_b_produce = local_b;