tensor: Attempt row-major mapping for C store (WIP)
Doesn't work because 1x2 jagged mapping is required to achieve throughput for storing the bigger C matrix (2x4, vs. 2x2 in A).
This commit is contained in:
@@ -93,6 +93,23 @@ inline constexpr void map_c_8lanes(const int tid, int &row, int &col) {
|
|||||||
col += ((tid % 4) / 2) * 2;
|
col += ((tid % 4) / 2) * 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline constexpr void map_c_rowmajor_8lanes(const int tid, int &row, int &col) {
|
||||||
|
const int tg = tid / 4;
|
||||||
|
|
||||||
|
// A (row major)
|
||||||
|
// row 0~ 3: threadgroup 0
|
||||||
|
// row 4~ 7: threadgroup 1
|
||||||
|
row = tid % 4;
|
||||||
|
row += tg * 4;
|
||||||
|
|
||||||
|
// B (column major)
|
||||||
|
// col 0~ 3: threadgroup 0
|
||||||
|
// col 4~ 7: threadgroup 1
|
||||||
|
col = tid % 4;
|
||||||
|
col += tg * 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
void vx_wmma_load() {
|
void vx_wmma_load() {
|
||||||
int tid = vx_thread_id();
|
int tid = vx_thread_id();
|
||||||
int tg = tid / 4;
|
int tg = tid / 4;
|
||||||
@@ -174,11 +191,31 @@ void store_wmma_result() {
|
|||||||
int row = 0;
|
int row = 0;
|
||||||
int col = 0;
|
int col = 0;
|
||||||
|
|
||||||
map_c_8lanes(tid, row, col);
|
// map_c_8lanes(tid, row, col);
|
||||||
|
map_c_rowmajor_8lanes(tid, row, col);
|
||||||
|
|
||||||
// store C
|
// store C
|
||||||
float *const results_wid = results + (DIM_M * DIM_N * wid);
|
float *const results_wid = results + (DIM_M * DIM_N * wid);
|
||||||
// uncomment to have two accum buffers in rf
|
|
||||||
|
// asm volatile("fsw f16, %0" ::"m"(results_wid[DIM_N * 0 + col]));
|
||||||
|
// asm volatile("fsw f17, %0" ::"m"(results_wid[DIM_N * 1 + col]));
|
||||||
|
// asm volatile("fsw f18, %0" ::"m"(results_wid[DIM_N * 2 + col]));
|
||||||
|
// asm volatile("fsw f19, %0" ::"m"(results_wid[DIM_N * 3 + col]));
|
||||||
|
// asm volatile("fsw f20, %0" ::"m"(results_wid[DIM_N * 4 + col]));
|
||||||
|
// asm volatile("fsw f21, %0" ::"m"(results_wid[DIM_N * 5 + col]));
|
||||||
|
// asm volatile("fsw f22, %0" ::"m"(results_wid[DIM_N * 6 + col]));
|
||||||
|
// asm volatile("fsw f23, %0" ::"m"(results_wid[DIM_N * 7 + col]));
|
||||||
|
asm volatile("fsw f24, %0" ::"m"(results_wid[DIM_N * 0 + col]));
|
||||||
|
asm volatile("fsw f25, %0" ::"m"(results_wid[DIM_N * 1 + col]));
|
||||||
|
asm volatile("fsw f26, %0" ::"m"(results_wid[DIM_N * 2 + col]));
|
||||||
|
asm volatile("fsw f27, %0" ::"m"(results_wid[DIM_N * 3 + col]));
|
||||||
|
asm volatile("fsw f28, %0" ::"m"(results_wid[DIM_N * 4 + col]));
|
||||||
|
asm volatile("fsw f29, %0" ::"m"(results_wid[DIM_N * 5 + col]));
|
||||||
|
asm volatile("fsw f30, %0" ::"m"(results_wid[DIM_N * 6 + col]));
|
||||||
|
asm volatile("fsw f31, %0" ::"m"(results_wid[DIM_N * 7 + col]));
|
||||||
|
|
||||||
|
|
||||||
|
// 1x2 jagged mapping
|
||||||
// asm volatile("fsw f16, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 0)]));
|
// asm volatile("fsw f16, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 0)]));
|
||||||
// asm volatile("fsw f17, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 1)]));
|
// asm volatile("fsw f17, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 1)]));
|
||||||
// asm volatile("fsw f18, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 0)]));
|
// asm volatile("fsw f18, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 0)]));
|
||||||
@@ -187,14 +224,14 @@ void store_wmma_result() {
|
|||||||
// asm volatile("fsw f21, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 5)]));
|
// asm volatile("fsw f21, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 5)]));
|
||||||
// asm volatile("fsw f22, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 4)]));
|
// asm volatile("fsw f22, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 4)]));
|
||||||
// asm volatile("fsw f23, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 5)]));
|
// asm volatile("fsw f23, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 5)]));
|
||||||
asm volatile("fsw f24, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 0)]));
|
// asm volatile("fsw f24, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 0)]));
|
||||||
asm volatile("fsw f25, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 1)]));
|
// asm volatile("fsw f25, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 1)]));
|
||||||
asm volatile("fsw f26, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 0)]));
|
// asm volatile("fsw f26, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 0)]));
|
||||||
asm volatile("fsw f27, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 1)]));
|
// asm volatile("fsw f27, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 1)]));
|
||||||
asm volatile("fsw f28, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 4)]));
|
// asm volatile("fsw f28, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 4)]));
|
||||||
asm volatile("fsw f29, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 5)]));
|
// asm volatile("fsw f29, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 5)]));
|
||||||
asm volatile("fsw f30, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 4)]));
|
// asm volatile("fsw f30, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 4)]));
|
||||||
asm volatile("fsw f31, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 5)]));
|
// asm volatile("fsw f31, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 5)]));
|
||||||
}
|
}
|
||||||
|
|
||||||
void print_wmma_result() {
|
void print_wmma_result() {
|
||||||
|
|||||||
Reference in New Issue
Block a user