sgemm_tcore: Fix fp16 addr gen in vx_wmma_load
This commit is contained in:
@@ -37,6 +37,9 @@
|
||||
#error "threadblock size too big for cluster"
|
||||
#endif
|
||||
|
||||
// using float_type = float;
|
||||
using float_type = float16_t;
|
||||
|
||||
template <typename T>
|
||||
inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k,
|
||||
const uint32_t k, const T *A, const T *B,
|
||||
@@ -391,7 +394,7 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
||||
}
|
||||
|
||||
#pragma GCC unroll 1
|
||||
for (uint32_t block_k = 0; (block_k * BK) < (dim_k); block_k++) {
|
||||
for (uint32_t block_k = 0; (block_k * BK) < dim_k; block_k++) {
|
||||
|
||||
// producer code: GMEM->SMEM memory movement
|
||||
// ---------------------------------------------------------------------
|
||||
@@ -572,8 +575,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
const uint32_t problem_size = (dim_m * dim_n) / (ELEM_PER_THREAD);
|
||||
const uint32_t num_threadblocks = problem_size / threads_per_threadblock;
|
||||
|
||||
using float_type = float16_t;
|
||||
|
||||
// "static" shared memory allocation. This would determine threadblock
|
||||
// occupancy of a single cluster
|
||||
uint8_t *sharedmem_per_threadblock = reinterpret_cast<uint8_t *>(
|
||||
|
||||
Reference in New Issue
Block a user