sgemm_impl: 128x64 tile; fix unrolled asm, comment out actual gemm
This commit is contained in:
@@ -29,7 +29,7 @@ using float_type = float16_t;
|
|||||||
// (BM*BN) / (TM*TN) == threadblock size >= NT * CORES_PER_CLUSTER
|
// (BM*BN) / (TM*TN) == threadblock size >= NT * CORES_PER_CLUSTER
|
||||||
// * Combining BM * BK >= (BM*BN) / (TM*TN) == threadblock yields
|
// * Combining BM * BK >= (BM*BN) / (TM*TN) == threadblock yields
|
||||||
// BM <= BK*TM*TN
|
// BM <= BK*TM*TN
|
||||||
#define BM 64
|
#define BM 128
|
||||||
#define BN 64
|
#define BN 64
|
||||||
#if (FP_SIZE == 32)
|
#if (FP_SIZE == 32)
|
||||||
#define BK 64
|
#define BK 64
|
||||||
@@ -72,7 +72,7 @@ static_assert(WMITER * WNITER * TCM * TCN * NUM_WARPS * CORES_PER_CLUSTER ==
|
|||||||
#define TRANSPOSE_AT_PRODUCE 0
|
#define TRANSPOSE_AT_PRODUCE 0
|
||||||
#define TRANSPOSE_AT_CONSUME 0
|
#define TRANSPOSE_AT_CONSUME 0
|
||||||
|
|
||||||
#define GEMMINI_DMA 1
|
#define GEMMINI_DMA 0
|
||||||
#define GEMMINI_DMA_MN_MAJOR 1
|
#define GEMMINI_DMA_MN_MAJOR 1
|
||||||
#if SMEM_SIZE == 0x4000
|
#if SMEM_SIZE == 0x4000
|
||||||
#define SMEM_ADDR_Q0 ((float * const) 0xff000000)
|
#define SMEM_ADDR_Q0 ((float * const) 0xff000000)
|
||||||
@@ -299,14 +299,23 @@ inline void wmma_load_a(volatile const T *smem_A, const int local_k,
|
|||||||
(WM * warp_row + TCM * wm_iter) + row]);
|
(WM * warp_row + TCM * wm_iter) + row]);
|
||||||
// f8-f15 stores a single row of A
|
// f8-f15 stores a single row of A
|
||||||
// threads read from different columns; no bank conflicts
|
// threads read from different columns; no bank conflicts
|
||||||
|
// asm volatile("flw f0, %0(%1)" :: "i"(smem_AS_cols * 0 * sizeof(float)), "r"(smem_addr));
|
||||||
|
// asm volatile("flw f1, %0(%1)" :: "i"(smem_AS_cols * 1 * sizeof(float)), "r"(smem_addr));
|
||||||
|
// asm volatile("flw f2, %0(%1)" :: "i"(smem_AS_cols * 2 * sizeof(float)), "r"(smem_addr));
|
||||||
|
// asm volatile("flw f3, %0(%1)" :: "i"(smem_AS_cols * 3 * sizeof(float)), "r"(smem_addr));
|
||||||
|
// asm volatile("flw f4, %0(%1)" :: "i"(smem_AS_cols * 4 * sizeof(float)), "r"(smem_addr));
|
||||||
|
// asm volatile("flw f5, %0(%1)" :: "i"(smem_AS_cols * 5 * sizeof(float)), "r"(smem_addr));
|
||||||
|
// asm volatile("flw f6, %0(%1)" :: "i"(smem_AS_cols * 6 * sizeof(float)), "r"(smem_addr));
|
||||||
|
// asm volatile("flw f7, %0(%1)" :: "i"(smem_AS_cols * 7 * sizeof(float)), "r"(smem_addr));
|
||||||
asm volatile("flw f0, %0(%1)" :: "i"(smem_AS_cols * 0 * sizeof(float)), "r"(smem_addr));
|
asm volatile("flw f0, %0(%1)" :: "i"(smem_AS_cols * 0 * sizeof(float)), "r"(smem_addr));
|
||||||
asm volatile("flw f1, %0(%1)" :: "i"(smem_AS_cols * 1 * sizeof(float)), "r"(smem_addr));
|
asm volatile("flw f1, %0(%1)" :: "i"(smem_AS_cols * 1 * sizeof(float)), "r"(smem_addr));
|
||||||
asm volatile("flw f2, %0(%1)" :: "i"(smem_AS_cols * 2 * sizeof(float)), "r"(smem_addr));
|
asm volatile("flw f2, %0(%1)" :: "i"(smem_AS_cols * 2 * sizeof(float)), "r"(smem_addr));
|
||||||
asm volatile("flw f3, %0(%1)" :: "i"(smem_AS_cols * 3 * sizeof(float)), "r"(smem_addr));
|
asm volatile("flw f3, %0(%1)" :: "i"(smem_AS_cols * 3 * sizeof(float)), "r"(smem_addr));
|
||||||
asm volatile("flw f4, %0(%1)" :: "i"(smem_AS_cols * 4 * sizeof(float)), "r"(smem_addr));
|
smem_addr += smem_AS_cols * 4 * sizeof(float);
|
||||||
asm volatile("flw f5, %0(%1)" :: "i"(smem_AS_cols * 5 * sizeof(float)), "r"(smem_addr));
|
asm volatile("flw f4, %0(%1)" :: "i"(smem_AS_cols * 0 * sizeof(float)), "r"(smem_addr));
|
||||||
asm volatile("flw f6, %0(%1)" :: "i"(smem_AS_cols * 6 * sizeof(float)), "r"(smem_addr));
|
asm volatile("flw f5, %0(%1)" :: "i"(smem_AS_cols * 1 * sizeof(float)), "r"(smem_addr));
|
||||||
asm volatile("flw f7, %0(%1)" :: "i"(smem_AS_cols * 7 * sizeof(float)), "r"(smem_addr));
|
asm volatile("flw f6, %0(%1)" :: "i"(smem_AS_cols * 2 * sizeof(float)), "r"(smem_addr));
|
||||||
|
asm volatile("flw f7, %0(%1)" :: "i"(smem_AS_cols * 3 * sizeof(float)), "r"(smem_addr));
|
||||||
} else {
|
} else {
|
||||||
static_assert(layout ==
|
static_assert(layout ==
|
||||||
MemLayout::K_major /* fake cond that is always false */,
|
MemLayout::K_major /* fake cond that is always false */,
|
||||||
@@ -638,34 +647,67 @@ load_tile_to_smem(const uint32_t dim_major, const uint32_t mn_index,
|
|||||||
// need to branch because address offset constant in the inline assembly
|
// need to branch because address offset constant in the inline assembly
|
||||||
// cannot be larger than a certain limit
|
// cannot be larger than a certain limit
|
||||||
if constexpr (!transposed_write) {
|
if constexpr (!transposed_write) {
|
||||||
|
// asm volatile("fsw ft0, %0(%1)" ::"i"(smem_dim_col * row_stride * 0 *
|
||||||
|
// sizeof(float)),
|
||||||
|
// "r"(local));
|
||||||
|
// asm volatile("fsw ft1, %0(%1)" ::"i"(smem_dim_col * row_stride * 1 *
|
||||||
|
// sizeof(float)),
|
||||||
|
// "r"(local));
|
||||||
|
// local += smem_dim_col * row_stride * 2;
|
||||||
|
// asm volatile("fsw ft2, %0(%1)" ::"i"(smem_dim_col * row_stride * 0 *
|
||||||
|
// sizeof(float)),
|
||||||
|
// "r"(local));
|
||||||
|
// asm volatile("fsw ft3, %0(%1)" ::"i"(smem_dim_col * row_stride * 1 *
|
||||||
|
// sizeof(float)),
|
||||||
|
// "r"(local));
|
||||||
|
// local += smem_dim_col * row_stride * 2;
|
||||||
|
// asm volatile("fsw ft4, %0(%1)" ::"i"(smem_dim_col * row_stride * 0 *
|
||||||
|
// sizeof(float)),
|
||||||
|
// "r"(local));
|
||||||
|
// asm volatile("fsw ft5, %0(%1)" ::"i"(smem_dim_col * row_stride * 1 *
|
||||||
|
// sizeof(float)),
|
||||||
|
// "r"(local));
|
||||||
|
// local += smem_dim_col * row_stride * 2;
|
||||||
|
// asm volatile("fsw ft6, %0(%1)" ::"i"(smem_dim_col * row_stride * 0 *
|
||||||
|
// sizeof(float)),
|
||||||
|
// "r"(local));
|
||||||
|
// asm volatile("fsw ft7, %0(%1)" ::"i"(smem_dim_col * row_stride * 1 *
|
||||||
|
// sizeof(float)),
|
||||||
|
// "r"(local));
|
||||||
|
// local += smem_dim_col * row_stride * 2;
|
||||||
|
|
||||||
asm volatile("fsw ft0, %0(%1)" ::"i"(smem_dim_col * row_stride * 0 *
|
asm volatile("fsw ft0, %0(%1)" ::"i"(smem_dim_col * row_stride * 0 *
|
||||||
sizeof(float)),
|
sizeof(float)),
|
||||||
"r"(local));
|
"r"(local));
|
||||||
asm volatile("fsw ft1, %0(%1)" ::"i"(smem_dim_col * row_stride * 1 *
|
local += smem_dim_col * row_stride;
|
||||||
|
asm volatile("fsw ft1, %0(%1)" ::"i"(smem_dim_col * row_stride * 0 *
|
||||||
sizeof(float)),
|
sizeof(float)),
|
||||||
"r"(local));
|
"r"(local));
|
||||||
local += smem_dim_col * row_stride * 2;
|
local += smem_dim_col * row_stride;
|
||||||
asm volatile("fsw ft2, %0(%1)" ::"i"(smem_dim_col * row_stride * 0 *
|
asm volatile("fsw ft2, %0(%1)" ::"i"(smem_dim_col * row_stride * 0 *
|
||||||
sizeof(float)),
|
sizeof(float)),
|
||||||
"r"(local));
|
"r"(local));
|
||||||
asm volatile("fsw ft3, %0(%1)" ::"i"(smem_dim_col * row_stride * 1 *
|
local += smem_dim_col * row_stride;
|
||||||
|
asm volatile("fsw ft3, %0(%1)" ::"i"(smem_dim_col * row_stride * 0 *
|
||||||
sizeof(float)),
|
sizeof(float)),
|
||||||
"r"(local));
|
"r"(local));
|
||||||
local += smem_dim_col * row_stride * 2;
|
local += smem_dim_col * row_stride;
|
||||||
asm volatile("fsw ft4, %0(%1)" ::"i"(smem_dim_col * row_stride * 0 *
|
asm volatile("fsw ft4, %0(%1)" ::"i"(smem_dim_col * row_stride * 0 *
|
||||||
sizeof(float)),
|
sizeof(float)),
|
||||||
"r"(local));
|
"r"(local));
|
||||||
asm volatile("fsw ft5, %0(%1)" ::"i"(smem_dim_col * row_stride * 1 *
|
local += smem_dim_col * row_stride;
|
||||||
|
asm volatile("fsw ft5, %0(%1)" ::"i"(smem_dim_col * row_stride * 0 *
|
||||||
sizeof(float)),
|
sizeof(float)),
|
||||||
"r"(local));
|
"r"(local));
|
||||||
local += smem_dim_col * row_stride * 2;
|
local += smem_dim_col * row_stride;
|
||||||
asm volatile("fsw ft6, %0(%1)" ::"i"(smem_dim_col * row_stride * 0 *
|
asm volatile("fsw ft6, %0(%1)" ::"i"(smem_dim_col * row_stride * 0 *
|
||||||
sizeof(float)),
|
sizeof(float)),
|
||||||
"r"(local));
|
"r"(local));
|
||||||
asm volatile("fsw ft7, %0(%1)" ::"i"(smem_dim_col * row_stride * 1 *
|
local += smem_dim_col * row_stride;
|
||||||
|
asm volatile("fsw ft7, %0(%1)" ::"i"(smem_dim_col * row_stride * 0 *
|
||||||
sizeof(float)),
|
sizeof(float)),
|
||||||
"r"(local));
|
"r"(local));
|
||||||
local += smem_dim_col * row_stride * 2;
|
local += smem_dim_col * row_stride;
|
||||||
} else {
|
} else {
|
||||||
// currently, tensor core hardware only supports MN-major SMEM tile
|
// currently, tensor core hardware only supports MN-major SMEM tile
|
||||||
// layout for correct results
|
// layout for correct results
|
||||||
@@ -996,6 +1038,7 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
|
|||||||
warps_per_threadblock_per_core);
|
warps_per_threadblock_per_core);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if 0
|
||||||
// consumer code: SMEM->RF and compute
|
// consumer code: SMEM->RF and compute
|
||||||
// ----------------------------------------------------------------------
|
// ----------------------------------------------------------------------
|
||||||
// @perf: this loop spills to stack a lot because of all the flws in
|
// @perf: this loop spills to stack a lot because of all the flws in
|
||||||
@@ -1044,6 +1087,7 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
|
|||||||
|
|
||||||
threadblock_barrier(threadblock_id_in_cluster,
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
warps_per_threadblock_per_core);
|
warps_per_threadblock_per_core);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
if constexpr (write_to_gmem) {
|
if constexpr (write_to_gmem) {
|
||||||
|
|||||||
Reference in New Issue
Block a user