sgemm_tcore: Fix invocation with compile time threadblock size

This commit is contained in:
Hansung Kim
2024-09-02 17:03:46 -07:00
parent 70273fd00d
commit 9d71fa44a7
2 changed files with 29 additions and 26 deletions

View File

@@ -736,29 +736,27 @@ __attribute__((always_inline)) inline void thread_block_gemm_single_tile(
}
}
template <typename T, bool write_to_gmem = true,
// by default, A/B tiles are placed at the start of the smem
uint32_t smem_a_offset = 0, // byte offset of A tile in shared
// memory
uint32_t smem_a_dbuf_offset = 0, // byte offset of A
// double-buffer tile in shared
// memory
uint32_t smem_b_offset = sizeof(float) * BM *
BK, // byte offset of B tile
// in shared memory
uint32_t smem_b_dbuf_offset = sizeof(float) * BM *
BK // byte offset of B double-buffer
// tile in shared memory
>
template <
typename T, uint32_t threads_per_threadblock, bool write_to_gmem = true,
// by default, A/B tiles are placed at the start of the smem
uint32_t smem_a_offset = 0, // byte offset of A tile in shared
// memory
uint32_t smem_a_dbuf_offset = 0, // byte offset of A
// double-buffer tile in shared
// memory
uint32_t smem_b_offset = sizeof(float) * BM * BK, // byte offset of B tile
// in shared memory
uint32_t smem_b_dbuf_offset = sizeof(float) * BM *
BK // byte offset of B double-buffer
// tile in shared memory
>
inline void thread_block_gemm(const T *A, const T *B, float *C,
const uint32_t dim_m, const uint32_t dim_n,
const uint32_t dim_k,
const uint32_t tid_in_threadblock,
const uint32_t threads_per_threadblock,
const uint32_t threadblocks_per_cluster,
const uint32_t threadblock_id_in_cluster,
uint8_t *sharedmem_per_threadblock) {
// no double-buffering
const uint32_t threads_per_warpgroup = threads_per_threadblock;
const uint32_t warp_id_in_warpgroup = tid_in_threadblock / NUM_THREADS;
const uint32_t warp_row = warp_id_in_warpgroup / (BN / WN);