sgemm_tcore: Move C from regF->GMEM directly
This commit is contained in:
@@ -151,54 +151,25 @@ inline void write_results(volatile float *local_warp_results,
|
|||||||
int tid = thread_in_warp;
|
int tid = thread_in_warp;
|
||||||
int tg = tid / 4;
|
int tg = tid / 4;
|
||||||
|
|
||||||
asm volatile("fsw f16, %0" ::"m"(local_warp_results[tid * 8 + 0]));
|
// these are [0, TCM/TCN)
|
||||||
asm volatile("fsw f17, %0" ::"m"(local_warp_results[tid * 8 + 1]));
|
|
||||||
asm volatile("fsw f18, %0" ::"m"(local_warp_results[tid * 8 + 2]));
|
|
||||||
asm volatile("fsw f19, %0" ::"m"(local_warp_results[tid * 8 + 3]));
|
|
||||||
asm volatile("fsw f20, %0" ::"m"(local_warp_results[tid * 8 + 4]));
|
|
||||||
asm volatile("fsw f21, %0" ::"m"(local_warp_results[tid * 8 + 5]));
|
|
||||||
asm volatile("fsw f22, %0" ::"m"(local_warp_results[tid * 8 + 6]));
|
|
||||||
asm volatile("fsw f23, %0" ::"m"(local_warp_results[tid * 8 + 7]));
|
|
||||||
|
|
||||||
/*
|
|
||||||
col = ((threadgroup % 4) // 2) * 8
|
|
||||||
row = (threadgroup * 8) % 16
|
|
||||||
row += (threadgroup // 4) * 4
|
|
||||||
offsets = [(0, 0), (0, 1), (2, 0), (2, 1), (0, 4), (0, 5), (2, 4), (2, 5)]
|
|
||||||
offset = offsets[register-16]
|
|
||||||
row += offset[0]
|
|
||||||
col += offset[1]
|
|
||||||
thread_offsets = [(0, 0), (1, 0), (0, 2), (1, 2)]
|
|
||||||
thread_offset = thread_offsets[thread % 4]
|
|
||||||
row += thread_offset[0]
|
|
||||||
col += thread_offset[1]
|
|
||||||
return (row, col)
|
|
||||||
*/
|
|
||||||
|
|
||||||
int local_row = 0;
|
int local_row = 0;
|
||||||
int local_col = 0;
|
int local_col = 0;
|
||||||
|
|
||||||
map_c_32lanes(tid, local_row, local_col);
|
map_c_32lanes(tid, local_row, local_col);
|
||||||
|
|
||||||
// C[dim_n * (BM * threadblock_id_y + TM * local_c_row + res_idx_m) +
|
|
||||||
// (BN * threadblock_id_x + TN * local_c_col + res_idx_n)] =
|
|
||||||
// reg_c[TN * res_idx_m + res_idx_n];
|
|
||||||
// float *global_offset_C = C +
|
|
||||||
// (threadblock_id_y * TCM * 2 + warp_y * TCM) * dim_n +
|
|
||||||
// threadblock_id_x * TCN * 2 + warp_x * TCN;
|
|
||||||
float *global_offset_C = C +
|
float *global_offset_C = C +
|
||||||
(BM * threadblock_id_y /* 1 warp */) * dim_n +
|
(BM * threadblock_id_y) * dim_n +
|
||||||
BN * threadblock_id_x /* 1 warp */;
|
BN * threadblock_id_x;
|
||||||
for (int i = 0; i < 8; i += 1) {
|
|
||||||
int row_offset = ((i / 2) % 2) * 2;
|
|
||||||
int col_offset = (i / 4) * 4 + i % 2;
|
|
||||||
|
|
||||||
int adjusted_local_row = local_row + row_offset;
|
// @perf: this likely causes a lot of gmem bank conflicts
|
||||||
int adjusted_local_col = local_col + col_offset;
|
asm volatile ("fsw f16, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 0)]));
|
||||||
|
asm volatile ("fsw f17, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 1)]));
|
||||||
// FIXME: do we need to store to SMEM at all?
|
asm volatile ("fsw f18, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 0)]));
|
||||||
float v = local_warp_results[tid * 8 + i];
|
asm volatile ("fsw f19, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 1)]));
|
||||||
global_offset_C[adjusted_local_row * dim_n + adjusted_local_col] = v;
|
asm volatile ("fsw f20, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 4)]));
|
||||||
}
|
asm volatile ("fsw f21, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 5)]));
|
||||||
|
asm volatile ("fsw f22, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 4)]));
|
||||||
|
asm volatile ("fsw f23, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 5)]));
|
||||||
}
|
}
|
||||||
|
|
||||||
void threadblock_barrier(unsigned int tid_in_threadblock, unsigned int barrier_id, unsigned int count) {
|
void threadblock_barrier(unsigned int tid_in_threadblock, unsigned int barrier_id, unsigned int count) {
|
||||||
|
|||||||
Reference in New Issue
Block a user