sgemm_tcore: Move C from regF->GMEM directly

This commit is contained in:
Hansung Kim
2024-05-13 14:00:50 -07:00
parent 9e60b1834c
commit d848e88f72

View File

@@ -151,54 +151,25 @@ inline void write_results(volatile float *local_warp_results,
int tid = thread_in_warp;
int tg = tid / 4;
asm volatile("fsw f16, %0" ::"m"(local_warp_results[tid * 8 + 0]));
asm volatile("fsw f17, %0" ::"m"(local_warp_results[tid * 8 + 1]));
asm volatile("fsw f18, %0" ::"m"(local_warp_results[tid * 8 + 2]));
asm volatile("fsw f19, %0" ::"m"(local_warp_results[tid * 8 + 3]));
asm volatile("fsw f20, %0" ::"m"(local_warp_results[tid * 8 + 4]));
asm volatile("fsw f21, %0" ::"m"(local_warp_results[tid * 8 + 5]));
asm volatile("fsw f22, %0" ::"m"(local_warp_results[tid * 8 + 6]));
asm volatile("fsw f23, %0" ::"m"(local_warp_results[tid * 8 + 7]));
/*
col = ((threadgroup % 4) // 2) * 8
row = (threadgroup * 8) % 16
row += (threadgroup // 4) * 4
offsets = [(0, 0), (0, 1), (2, 0), (2, 1), (0, 4), (0, 5), (2, 4), (2, 5)]
offset = offsets[register-16]
row += offset[0]
col += offset[1]
thread_offsets = [(0, 0), (1, 0), (0, 2), (1, 2)]
thread_offset = thread_offsets[thread % 4]
row += thread_offset[0]
col += thread_offset[1]
return (row, col)
*/
// these are [0, TCM/TCN)
int local_row = 0;
int local_col = 0;
map_c_32lanes(tid, local_row, local_col);
// C[dim_n * (BM * threadblock_id_y + TM * local_c_row + res_idx_m) +
// (BN * threadblock_id_x + TN * local_c_col + res_idx_n)] =
// reg_c[TN * res_idx_m + res_idx_n];
// float *global_offset_C = C +
// (threadblock_id_y * TCM * 2 + warp_y * TCM) * dim_n +
// threadblock_id_x * TCN * 2 + warp_x * TCN;
float *global_offset_C = C +
(BM * threadblock_id_y /* 1 warp */) * dim_n +
BN * threadblock_id_x /* 1 warp */;
for (int i = 0; i < 8; i += 1) {
int row_offset = ((i / 2) % 2) * 2;
int col_offset = (i / 4) * 4 + i % 2;
(BM * threadblock_id_y) * dim_n +
BN * threadblock_id_x;
int adjusted_local_row = local_row + row_offset;
int adjusted_local_col = local_col + col_offset;
// FIXME: do we need to store to SMEM at all?
float v = local_warp_results[tid * 8 + i];
global_offset_C[adjusted_local_row * dim_n + adjusted_local_col] = v;
}
// @perf: this likely causes a lot of gmem bank conflicts
asm volatile ("fsw f16, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 0)]));
asm volatile ("fsw f17, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 1)]));
asm volatile ("fsw f18, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 0)]));
asm volatile ("fsw f19, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 1)]));
asm volatile ("fsw f20, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 4)]));
asm volatile ("fsw f21, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 5)]));
asm volatile ("fsw f22, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 4)]));
asm volatile ("fsw f23, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 5)]));
}
void threadblock_barrier(unsigned int tid_in_threadblock, unsigned int barrier_id, unsigned int count) {