diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp index 06e3a579..5c141e01 100644 --- a/tests/regression/sgemm_tcore/kernel.cpp +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -280,11 +280,11 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, volatile float *local_a = sharedmem_per_threadblock; constexpr size_t local_a_elems = (BM * BK); - volatile float *local_b = sharedmem_per_threadblock + local_a_elems; - constexpr size_t local_b_elems = (BK * BN); + volatile float *local_a_buf = local_a + local_a_elems; - volatile float *local_a_buf = local_b + local_b_elems; - volatile float *local_b_buf = local_a_buf + local_a_elems; + volatile float *local_b = local_a_buf + local_a_elems; + constexpr size_t local_b_elems = (BK * BN); + volatile float *local_b_buf = local_a_buf + local_b_elems; constexpr uint32_t skips = loop_matmul_skips(/*skip_lda=*/0, /*skip_ldb=*/0, /*skip_ldd=*/1, @@ -453,8 +453,8 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, // local_b_consume = reinterpret_cast( // (mask_odd & reinterpret_cast(local_b_buf)) | // (mask_even & reinterpret_cast(local_b))); - local_a_consume = local_a + (block_k & 1) * (local_a_elems + local_b_elems); - local_b_consume = local_b + (block_k & 1) * (local_a_elems + local_b_elems); + local_a_consume = local_a + (block_k & 1) * (local_a_elems); + local_b_consume = local_b + (block_k & 1) * (local_b_elems); } else { local_a_consume = local_a; local_b_consume = local_b;