diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp index 99621ec2..8913d95a 100644 --- a/tests/regression/sgemm_tcore/kernel.cpp +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -151,54 +151,25 @@ inline void write_results(volatile float *local_warp_results, int tid = thread_in_warp; int tg = tid / 4; - asm volatile("fsw f16, %0" ::"m"(local_warp_results[tid * 8 + 0])); - asm volatile("fsw f17, %0" ::"m"(local_warp_results[tid * 8 + 1])); - asm volatile("fsw f18, %0" ::"m"(local_warp_results[tid * 8 + 2])); - asm volatile("fsw f19, %0" ::"m"(local_warp_results[tid * 8 + 3])); - asm volatile("fsw f20, %0" ::"m"(local_warp_results[tid * 8 + 4])); - asm volatile("fsw f21, %0" ::"m"(local_warp_results[tid * 8 + 5])); - asm volatile("fsw f22, %0" ::"m"(local_warp_results[tid * 8 + 6])); - asm volatile("fsw f23, %0" ::"m"(local_warp_results[tid * 8 + 7])); - - /* - col = ((threadgroup % 4) // 2) * 8 - row = (threadgroup * 8) % 16 - row += (threadgroup // 4) * 4 - offsets = [(0, 0), (0, 1), (2, 0), (2, 1), (0, 4), (0, 5), (2, 4), (2, 5)] - offset = offsets[register-16] - row += offset[0] - col += offset[1] - thread_offsets = [(0, 0), (1, 0), (0, 2), (1, 2)] - thread_offset = thread_offsets[thread % 4] - row += thread_offset[0] - col += thread_offset[1] - return (row, col) - */ - + // these are [0, TCM/TCN) int local_row = 0; int local_col = 0; + map_c_32lanes(tid, local_row, local_col); - // C[dim_n * (BM * threadblock_id_y + TM * local_c_row + res_idx_m) + - // (BN * threadblock_id_x + TN * local_c_col + res_idx_n)] = - // reg_c[TN * res_idx_m + res_idx_n]; - // float *global_offset_C = C + - // (threadblock_id_y * TCM * 2 + warp_y * TCM) * dim_n + - // threadblock_id_x * TCN * 2 + warp_x * TCN; float *global_offset_C = C + - (BM * threadblock_id_y /* 1 warp */) * dim_n + - BN * threadblock_id_x /* 1 warp */; - for (int i = 0; i < 8; i += 1) { - int row_offset = ((i / 2) % 2) * 2; - int col_offset = (i / 4) * 4 + i % 2; + (BM * threadblock_id_y) * dim_n + + BN * threadblock_id_x; - int adjusted_local_row = local_row + row_offset; - int adjusted_local_col = local_col + col_offset; - - // FIXME: do we need to store to SMEM at all? - float v = local_warp_results[tid * 8 + i]; - global_offset_C[adjusted_local_row * dim_n + adjusted_local_col] = v; - } + // @perf: this likely causes a lot of gmem bank conflicts + asm volatile ("fsw f16, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 0)])); + asm volatile ("fsw f17, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 1)])); + asm volatile ("fsw f18, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 0)])); + asm volatile ("fsw f19, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 1)])); + asm volatile ("fsw f20, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 4)])); + asm volatile ("fsw f21, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 5)])); + asm volatile ("fsw f22, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 4)])); + asm volatile ("fsw f23, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 5)])); } void threadblock_barrier(unsigned int tid_in_threadblock, unsigned int barrier_id, unsigned int count) {