sgemm_tcore: Replace hardcoded NUM_LANES with NUM_THREADS
This commit is contained in:
@@ -9,7 +9,7 @@
|
||||
|
||||
#define DOUBLE_BUFFER 1
|
||||
#undef ELEM_PER_THREAD
|
||||
#define ELEM_PER_THREAD (WMITER * WNITER * ((TCM * TCN) / NUM_LANES) / (DOUBLE_BUFFER ? 2 : 1))
|
||||
#define ELEM_PER_THREAD (WMITER * WNITER * ((TCM * TCN) / NUM_THREADS) / (DOUBLE_BUFFER ? 2 : 1))
|
||||
|
||||
// FIXME: NUM_THREADS and NUM_WARPS hardcoded
|
||||
#if ((BM * BN / ELEM_PER_THREAD) > (CORES_PER_CLUSTER * 8 * 8))
|
||||
@@ -291,11 +291,11 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
||||
const uint32_t threads_per_warpgroup = threads_per_threadblock / (DOUBLE_BUFFER ? 2 : 1);
|
||||
const uint32_t warpgroup_id = tid_in_threadblock / threads_per_warpgroup;
|
||||
const uint32_t tid_in_warpgroup = tid_in_threadblock % threads_per_warpgroup; // FIXME
|
||||
const uint32_t warp_in_warpgroup = tid_in_warpgroup / NUM_LANES;
|
||||
const uint32_t warp_in_warpgroup = tid_in_warpgroup / NUM_THREADS;
|
||||
// FIXME: warp_row / BN should be warp-specialized?
|
||||
const uint32_t warp_row = warp_in_warpgroup / (BN / WN);
|
||||
const uint32_t warp_col = warp_in_warpgroup % (BN / WN);
|
||||
const uint32_t tid_in_warp = tid_in_threadblock % NUM_LANES;
|
||||
const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS;
|
||||
|
||||
volatile float *local_a = sharedmem_per_threadblock;
|
||||
// const size_t local_a_elems = threadblock_dim_x * threadblock_dim_y;
|
||||
|
||||
Reference in New Issue
Block a user