sgemm_tcore: Move C from regF->GMEM directly
This commit is contained in:
@@ -151,54 +151,25 @@ inline void write_results(volatile float *local_warp_results,
|
||||
int tid = thread_in_warp;
|
||||
int tg = tid / 4;
|
||||
|
||||
asm volatile("fsw f16, %0" ::"m"(local_warp_results[tid * 8 + 0]));
|
||||
asm volatile("fsw f17, %0" ::"m"(local_warp_results[tid * 8 + 1]));
|
||||
asm volatile("fsw f18, %0" ::"m"(local_warp_results[tid * 8 + 2]));
|
||||
asm volatile("fsw f19, %0" ::"m"(local_warp_results[tid * 8 + 3]));
|
||||
asm volatile("fsw f20, %0" ::"m"(local_warp_results[tid * 8 + 4]));
|
||||
asm volatile("fsw f21, %0" ::"m"(local_warp_results[tid * 8 + 5]));
|
||||
asm volatile("fsw f22, %0" ::"m"(local_warp_results[tid * 8 + 6]));
|
||||
asm volatile("fsw f23, %0" ::"m"(local_warp_results[tid * 8 + 7]));
|
||||
|
||||
/*
|
||||
col = ((threadgroup % 4) // 2) * 8
|
||||
row = (threadgroup * 8) % 16
|
||||
row += (threadgroup // 4) * 4
|
||||
offsets = [(0, 0), (0, 1), (2, 0), (2, 1), (0, 4), (0, 5), (2, 4), (2, 5)]
|
||||
offset = offsets[register-16]
|
||||
row += offset[0]
|
||||
col += offset[1]
|
||||
thread_offsets = [(0, 0), (1, 0), (0, 2), (1, 2)]
|
||||
thread_offset = thread_offsets[thread % 4]
|
||||
row += thread_offset[0]
|
||||
col += thread_offset[1]
|
||||
return (row, col)
|
||||
*/
|
||||
|
||||
// these are [0, TCM/TCN)
|
||||
int local_row = 0;
|
||||
int local_col = 0;
|
||||
|
||||
map_c_32lanes(tid, local_row, local_col);
|
||||
|
||||
// C[dim_n * (BM * threadblock_id_y + TM * local_c_row + res_idx_m) +
|
||||
// (BN * threadblock_id_x + TN * local_c_col + res_idx_n)] =
|
||||
// reg_c[TN * res_idx_m + res_idx_n];
|
||||
// float *global_offset_C = C +
|
||||
// (threadblock_id_y * TCM * 2 + warp_y * TCM) * dim_n +
|
||||
// threadblock_id_x * TCN * 2 + warp_x * TCN;
|
||||
float *global_offset_C = C +
|
||||
(BM * threadblock_id_y /* 1 warp */) * dim_n +
|
||||
BN * threadblock_id_x /* 1 warp */;
|
||||
for (int i = 0; i < 8; i += 1) {
|
||||
int row_offset = ((i / 2) % 2) * 2;
|
||||
int col_offset = (i / 4) * 4 + i % 2;
|
||||
(BM * threadblock_id_y) * dim_n +
|
||||
BN * threadblock_id_x;
|
||||
|
||||
int adjusted_local_row = local_row + row_offset;
|
||||
int adjusted_local_col = local_col + col_offset;
|
||||
|
||||
// FIXME: do we need to store to SMEM at all?
|
||||
float v = local_warp_results[tid * 8 + i];
|
||||
global_offset_C[adjusted_local_row * dim_n + adjusted_local_col] = v;
|
||||
}
|
||||
// @perf: this likely causes a lot of gmem bank conflicts
|
||||
asm volatile ("fsw f16, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 0)]));
|
||||
asm volatile ("fsw f17, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 1)]));
|
||||
asm volatile ("fsw f18, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 0)]));
|
||||
asm volatile ("fsw f19, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 1)]));
|
||||
asm volatile ("fsw f20, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 4)]));
|
||||
asm volatile ("fsw f21, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 5)]));
|
||||
asm volatile ("fsw f22, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 4)]));
|
||||
asm volatile ("fsw f23, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 5)]));
|
||||
}
|
||||
|
||||
void threadblock_barrier(unsigned int tid_in_threadblock, unsigned int barrier_id, unsigned int count) {
|
||||
|
||||
Reference in New Issue
Block a user