sgemm_tg: 1-octet 8-lane kernel
This commit is contained in:
@@ -18,14 +18,16 @@
|
||||
// (BM*BN) / (TM*TN) == threadblock size >= NT * CORES_PER_CLUSTER
|
||||
// * Combining BM * BK >= (BM*BN) / (TM*TN) == threadblock yields
|
||||
// BM <= BK*TM*TN
|
||||
#define BM 16
|
||||
#define BM 8
|
||||
#define BN BM
|
||||
#define BK 8
|
||||
#define TCM 16
|
||||
#define TCN 16
|
||||
#define TCM 8
|
||||
#define TCN 8
|
||||
#define TM 1
|
||||
#define TN 1
|
||||
|
||||
#define NUM_LANES 8
|
||||
|
||||
inline constexpr void map_operand_32lanes(const int tid, int &row, int &col) {
|
||||
const int tg = tid / 4;
|
||||
|
||||
@@ -67,6 +69,16 @@ inline constexpr void map_operand_8lanes(const int tid, int &row, int &col) {
|
||||
col += tg * 4;
|
||||
}
|
||||
|
||||
inline constexpr void map_operand(const int tid, int &row, int &col) {
|
||||
if constexpr (NUM_LANES == 32) {
|
||||
map_operand_32lanes(tid, row, col);
|
||||
} else if constexpr (NUM_LANES == 8) {
|
||||
map_operand_8lanes(tid, row, col);
|
||||
} else {
|
||||
// FIXME: not allowed
|
||||
}
|
||||
}
|
||||
|
||||
inline constexpr void map_c_32lanes(const int tid, int &row, int &col) {
|
||||
const int tg = tid / 4;
|
||||
|
||||
@@ -93,6 +105,16 @@ inline constexpr void map_c_8lanes(const int tid, int &row, int &col) {
|
||||
col += ((tid % 4) / 2) * 2;
|
||||
}
|
||||
|
||||
inline constexpr void map_c(const int tid, int &row, int &col) {
|
||||
if constexpr (NUM_LANES == 32) {
|
||||
map_c_32lanes(tid, row, col);
|
||||
} else if constexpr (NUM_LANES == 8) {
|
||||
map_c_8lanes(tid, row, col);
|
||||
} else {
|
||||
// FIXME: not allowed
|
||||
}
|
||||
}
|
||||
|
||||
inline void vx_wmma() {
|
||||
asm volatile (".insn r %0, 0, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3));
|
||||
}
|
||||
@@ -104,7 +126,7 @@ void vx_wmma_load(volatile float *smem_A, volatile float *smem_B, int warp_x,
|
||||
|
||||
int row = 0;
|
||||
int col = 0;
|
||||
map_operand_32lanes(tid, row, col);
|
||||
map_operand(tid, row, col);
|
||||
|
||||
int smem_A_rows = BM;
|
||||
int smem_A_cols = BK;
|
||||
@@ -154,8 +176,7 @@ inline void write_results(volatile float *local_warp_results,
|
||||
// these are [0, TCM/TCN)
|
||||
int local_row = 0;
|
||||
int local_col = 0;
|
||||
|
||||
map_c_32lanes(tid, local_row, local_col);
|
||||
map_c(tid, local_row, local_col);
|
||||
|
||||
float *global_offset_C = C +
|
||||
(BM * threadblock_id_y) * dim_n +
|
||||
@@ -189,19 +210,10 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
||||
const float *B = (const float *)arg->addr_b;
|
||||
float *C = (float *)arg->addr_c;
|
||||
|
||||
// assumes NT == NW == matrix_dim
|
||||
const uint32_t dim_m = arg->dim_m;
|
||||
const uint32_t dim_n = arg->dim_n;
|
||||
const uint32_t dim_k = arg->dim_k;
|
||||
|
||||
// FIXME: Output block size is assumed to be square, i.e. BM == BN
|
||||
// const uint32_t BM = threadblock_dim_y;
|
||||
// const uint32_t BN = threadblock_dim_y;
|
||||
// const uint32_t BK = threadblock_dim_x;
|
||||
// constexpr uint32_t BM = 8;
|
||||
// constexpr uint32_t BN = 8;
|
||||
// constexpr uint32_t BK = 2;
|
||||
|
||||
const uint32_t local_a_row = tid_in_threadblock / BK;
|
||||
const uint32_t local_a_col = tid_in_threadblock % BK;
|
||||
const uint32_t local_b_row = tid_in_threadblock / BN;
|
||||
@@ -217,8 +229,8 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
||||
float reg_a[TM] = { 0.0f };
|
||||
float reg_b[TN] = { 0.0f };
|
||||
|
||||
const uint32_t warp_in_threadblock = tid_in_threadblock / 32;
|
||||
const uint32_t tid_in_warp = tid_in_threadblock % 32;
|
||||
const uint32_t warp_in_threadblock = tid_in_threadblock / NUM_LANES;
|
||||
const uint32_t tid_in_warp = tid_in_threadblock % NUM_LANES;
|
||||
const uint32_t warp_x = warp_in_threadblock % 2;
|
||||
const uint32_t warp_y = warp_in_threadblock / 2;
|
||||
|
||||
@@ -272,38 +284,6 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
||||
vx_wmma();
|
||||
}
|
||||
|
||||
#if 0
|
||||
// Compute single tile*tile matmul
|
||||
#pragma GCC unroll 4
|
||||
for (uint32_t local_k = 0; local_k < BK; local_k++) {
|
||||
// First, pump data from SMEM->RF
|
||||
#pragma GCC unroll TM
|
||||
for (uint32_t res_idx_m = 0; res_idx_m < TM; res_idx_m++) {
|
||||
reg_a[res_idx_m] =
|
||||
local_a[BK * (TM * local_c_row + res_idx_m) + local_k];
|
||||
}
|
||||
#pragma GCC unroll TN
|
||||
for (uint32_t res_idx_n = 0; res_idx_n < TN; res_idx_n++) {
|
||||
reg_b[res_idx_n] =
|
||||
local_b[BN * local_k + (TN * local_c_col + res_idx_n)];
|
||||
}
|
||||
|
||||
// Next, compute multiple result elements (TM*TN) by reusing data in RF
|
||||
#pragma GCC unroll TM
|
||||
for (uint32_t res_idx_m = 0; res_idx_m < TM; res_idx_m++) {
|
||||
#pragma GCC unroll TN
|
||||
for (uint32_t res_idx_n = 0; res_idx_n < TN; res_idx_n++) {
|
||||
// NOTE use of local_b_row
|
||||
reg_c[TN * res_idx_m + res_idx_n] +=
|
||||
reg_a[res_idx_m] * reg_b[res_idx_n];
|
||||
// reg_c[TN * res_idx_m + res_idx_n] +=
|
||||
// local_a[BK * (TM * local_c_row + res_idx_m) + local_k] *
|
||||
// local_b[BN * local_k + (TN * local_c_col + res_idx_n)];
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
threadblock_barrier(tid_in_threadblock, threadblock_id_in_cluster,
|
||||
threadblock_dim_y);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user