sgemm_tcore: Deconstruct addr calc for GMEM->SMEM

This commit is contained in:
Hansung Kim
2024-06-05 15:11:01 -07:00
parent ff6e5bf6dc
commit e44173c65e

View File

@@ -370,43 +370,64 @@ global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, const uint32_t k,
const uint32_t global_a_row = BM_d * threadblock_id_y + local_a_row;
// number of rows a full TB can read at a time
constexpr uint32_t row_stride_a = threads_in_warpgroup / BK;
#pragma GCC unroll 1
const float *global_a = A + dim_k * global_a_row + (k + local_a_col);
volatile float *local_a_tmp = local_a + BK * local_a_row + local_a_col;
#pragma GCC unroll 2
for (uint32_t local_row_offset = 0; local_row_offset < BM_d;
local_row_offset += row_stride_a) {
const uint32_t global_a_offset =
dim_k * (global_a_row + local_row_offset) + (k + local_a_col);
// NOTE: all threads in TB will do this load; make sure this is not
// out-of-bounds of BM_d*BK
local_a[BK * (local_a_row + local_row_offset) + local_a_col] =
A[global_a_offset];
// const uint32_t global_a_offset =
// dim_k * (global_a_row + local_row_offset) + (k + local_a_col);
// local_a[BK * (local_a_row + local_row_offset) + local_a_col] =
// A[global_a_offset];
*local_a_tmp = *global_a;
global_a += dim_k * row_stride_a;
local_a_tmp += BK * row_stride_a;
}
} else {
const uint32_t global_a_row = BM_d * threadblock_id_y + local_as_col;
// const uint32_t global_a_row = BM_d * threadblock_id_y + local_as_row;
constexpr uint32_t row_stride_as = threads_in_warpgroup / BM_d;
#pragma GCC unroll 4
const float *global_a = A + dim_k * global_a_row + (k + local_as_row);
volatile float *local_a_tmp = local_a + BM_d * local_as_row + local_as_col;
#pragma GCC ivdep
for (uint32_t local_row_offset = 0; local_row_offset < BK;
local_row_offset += row_stride_as) {
// @perf: bank conflicts here
const uint32_t global_a_offset =
dim_k * (global_a_row) + (k + local_as_row + local_row_offset);
// const uint32_t global_a_offset =
// dim_k * (global_a_row) + (k + local_as_row + local_row_offset);
// FIXME experimenting with global coalescing
// const uint32_t global_a_offset =
// dim_k * (global_a_row + local_row_offset) + (k + local_as_col);
local_a[BM_d * (local_as_row + local_row_offset) + local_as_col] =
A[global_a_offset];
// local_a[BM_d * (local_as_row + local_row_offset) + local_as_col] =
// A[global_a_offset];
*local_a_tmp = *global_a;
global_a += row_stride_as;
local_a_tmp += BM * row_stride_as;
}
}
constexpr uint32_t row_stride_b = threads_in_warpgroup / BN_d;
const uint32_t global_b_col = BN_d * threadblock_id_x + local_b_col;
#pragma GCC unroll 2
const float *global_b = B + dim_n * (k + local_b_row) + global_b_col;
volatile float *local_b_tmp = local_b + BN_d * local_b_row + local_b_col;
#pragma GCC ivdep
for (uint32_t load_offset = 0; load_offset < BK;
load_offset += row_stride_b) {
const uint32_t global_b_offset =
dim_n * (k + local_b_row + load_offset) + global_b_col;
local_b[BN_d * (local_b_row + load_offset) + local_b_col] =
B[global_b_offset];
// const uint32_t global_b_offset =
// dim_n * (k + local_b_row + load_offset) + global_b_col;
// local_b[BN_d * (local_b_row + load_offset) + local_b_col] =
// B[global_b_offset];
*local_b_tmp = *global_b;
global_b += dim_n * row_stride_b;
local_b_tmp += BN_d * row_stride_b;
}
}
@@ -480,6 +501,10 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
#pragma GCC unroll 1
for (uint32_t k = 0; k < dim_k; k += BK) {
// register volatile float *local_a_produce asm("t0");
// register volatile float *local_b_produce asm("t1");
// register volatile float *local_a_consume asm("t2");
// register volatile float *local_b_consume asm("t3");
volatile float *local_a_produce;
volatile float *local_b_produce;
volatile float *local_a_consume;