sgemm_tg: 1-octet 8-lane kernel

This commit is contained in:
Hansung Kim
2024-05-13 14:52:33 -07:00
parent d848e88f72
commit 09b23ffe87

View File

@@ -18,14 +18,16 @@
// (BM*BN) / (TM*TN) == threadblock size >= NT * CORES_PER_CLUSTER
// * Combining BM * BK >= (BM*BN) / (TM*TN) == threadblock yields
// BM <= BK*TM*TN
#define BM 16
#define BM 8
#define BN BM
#define BK 8
#define TCM 16
#define TCN 16
#define TCM 8
#define TCN 8
#define TM 1
#define TN 1
#define NUM_LANES 8
inline constexpr void map_operand_32lanes(const int tid, int &row, int &col) {
const int tg = tid / 4;
@@ -67,6 +69,16 @@ inline constexpr void map_operand_8lanes(const int tid, int &row, int &col) {
col += tg * 4;
}
inline constexpr void map_operand(const int tid, int &row, int &col) {
if constexpr (NUM_LANES == 32) {
map_operand_32lanes(tid, row, col);
} else if constexpr (NUM_LANES == 8) {
map_operand_8lanes(tid, row, col);
} else {
// FIXME: not allowed
}
}
inline constexpr void map_c_32lanes(const int tid, int &row, int &col) {
const int tg = tid / 4;
@@ -93,6 +105,16 @@ inline constexpr void map_c_8lanes(const int tid, int &row, int &col) {
col += ((tid % 4) / 2) * 2;
}
inline constexpr void map_c(const int tid, int &row, int &col) {
if constexpr (NUM_LANES == 32) {
map_c_32lanes(tid, row, col);
} else if constexpr (NUM_LANES == 8) {
map_c_8lanes(tid, row, col);
} else {
// FIXME: not allowed
}
}
inline void vx_wmma() {
asm volatile (".insn r %0, 0, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3));
}
@@ -104,7 +126,7 @@ void vx_wmma_load(volatile float *smem_A, volatile float *smem_B, int warp_x,
int row = 0;
int col = 0;
map_operand_32lanes(tid, row, col);
map_operand(tid, row, col);
int smem_A_rows = BM;
int smem_A_cols = BK;
@@ -154,8 +176,7 @@ inline void write_results(volatile float *local_warp_results,
// these are [0, TCM/TCN)
int local_row = 0;
int local_col = 0;
map_c_32lanes(tid, local_row, local_col);
map_c(tid, local_row, local_col);
float *global_offset_C = C +
(BM * threadblock_id_y) * dim_n +
@@ -189,19 +210,10 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
const float *B = (const float *)arg->addr_b;
float *C = (float *)arg->addr_c;
// assumes NT == NW == matrix_dim
const uint32_t dim_m = arg->dim_m;
const uint32_t dim_n = arg->dim_n;
const uint32_t dim_k = arg->dim_k;
// FIXME: Output block size is assumed to be square, i.e. BM == BN
// const uint32_t BM = threadblock_dim_y;
// const uint32_t BN = threadblock_dim_y;
// const uint32_t BK = threadblock_dim_x;
// constexpr uint32_t BM = 8;
// constexpr uint32_t BN = 8;
// constexpr uint32_t BK = 2;
const uint32_t local_a_row = tid_in_threadblock / BK;
const uint32_t local_a_col = tid_in_threadblock % BK;
const uint32_t local_b_row = tid_in_threadblock / BN;
@@ -217,8 +229,8 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
float reg_a[TM] = { 0.0f };
float reg_b[TN] = { 0.0f };
const uint32_t warp_in_threadblock = tid_in_threadblock / 32;
const uint32_t tid_in_warp = tid_in_threadblock % 32;
const uint32_t warp_in_threadblock = tid_in_threadblock / NUM_LANES;
const uint32_t tid_in_warp = tid_in_threadblock % NUM_LANES;
const uint32_t warp_x = warp_in_threadblock % 2;
const uint32_t warp_y = warp_in_threadblock / 2;
@@ -272,38 +284,6 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
vx_wmma();
}
#if 0
// Compute single tile*tile matmul
#pragma GCC unroll 4
for (uint32_t local_k = 0; local_k < BK; local_k++) {
// First, pump data from SMEM->RF
#pragma GCC unroll TM
for (uint32_t res_idx_m = 0; res_idx_m < TM; res_idx_m++) {
reg_a[res_idx_m] =
local_a[BK * (TM * local_c_row + res_idx_m) + local_k];
}
#pragma GCC unroll TN
for (uint32_t res_idx_n = 0; res_idx_n < TN; res_idx_n++) {
reg_b[res_idx_n] =
local_b[BN * local_k + (TN * local_c_col + res_idx_n)];
}
// Next, compute multiple result elements (TM*TN) by reusing data in RF
#pragma GCC unroll TM
for (uint32_t res_idx_m = 0; res_idx_m < TM; res_idx_m++) {
#pragma GCC unroll TN
for (uint32_t res_idx_n = 0; res_idx_n < TN; res_idx_n++) {
// NOTE use of local_b_row
reg_c[TN * res_idx_m + res_idx_n] +=
reg_a[res_idx_m] * reg_b[res_idx_n];
// reg_c[TN * res_idx_m + res_idx_n] +=
// local_a[BK * (TM * local_c_row + res_idx_m) + local_k] *
// local_b[BN * local_k + (TN * local_c_col + res_idx_n)];
}
}
}
#endif
threadblock_barrier(tid_in_threadblock, threadblock_id_in_cluster,
threadblock_dim_y);
}