sgemm_gemmini_dma: Update activation to match tcore
This commit is contained in:
@@ -145,47 +145,48 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg,
|
||||
const uint32_t row_in_warptile = 0;
|
||||
const uint32_t C_row = (tile_i * TILE_M) + (warp_row * WM) + row_in_warptile;
|
||||
const uint32_t C_col = (tile_j * TILE_N) + (warp_col * WN) + col_in_warptile;
|
||||
const float *global_C = C + dim_n * C_row + C_col;
|
||||
const float *const global_C = C + dim_n * C_row + C_col;
|
||||
|
||||
// read in elements from GMEM to RF
|
||||
asm volatile("flw f0, (%0)" :: "r"(global_C));
|
||||
global_C += dim_n;
|
||||
asm volatile("flw f1, (%0)" :: "r"(global_C));
|
||||
global_C += dim_n;
|
||||
asm volatile("flw f2, (%0)" :: "r"(global_C));
|
||||
global_C += dim_n;
|
||||
asm volatile("flw f3, (%0)" :: "r"(global_C));
|
||||
global_C += dim_n;
|
||||
asm volatile("flw f4, (%0)" :: "r"(global_C));
|
||||
global_C += dim_n;
|
||||
asm volatile("flw f5, (%0)" :: "r"(global_C));
|
||||
global_C += dim_n;
|
||||
asm volatile("flw f6, (%0)" :: "r"(global_C));
|
||||
global_C += dim_n;
|
||||
asm volatile("flw f7, (%0)" :: "r"(global_C));
|
||||
global_C += dim_n;
|
||||
asm volatile("flw f8, (%0)" :: "r"(global_C));
|
||||
global_C += dim_n;
|
||||
asm volatile("flw f9, (%0)" :: "r"(global_C));
|
||||
global_C += dim_n;
|
||||
asm volatile("flw f10, (%0)" :: "r"(global_C));
|
||||
global_C += dim_n;
|
||||
asm volatile("flw f11, (%0)" :: "r"(global_C));
|
||||
global_C += dim_n;
|
||||
asm volatile("flw f12, (%0)" :: "r"(global_C));
|
||||
global_C += dim_n;
|
||||
asm volatile("flw f13, (%0)" :: "r"(global_C));
|
||||
global_C += dim_n;
|
||||
asm volatile("flw f14, (%0)" :: "r"(global_C));
|
||||
global_C += dim_n;
|
||||
asm volatile("flw f15, (%0)" :: "r"(global_C));
|
||||
global_C += dim_n;
|
||||
asm volatile("mv t6, %0" :: "r"(global_C));
|
||||
asm volatile("flw f0, (t6)");
|
||||
asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float)));
|
||||
asm volatile("flw f1, (t6)");
|
||||
asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float)));
|
||||
asm volatile("flw f2, (t6)");
|
||||
asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float)));
|
||||
asm volatile("flw f3, (t6)");
|
||||
asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float)));
|
||||
asm volatile("flw f4, (t6)");
|
||||
asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float)));
|
||||
asm volatile("flw f5, (t6)");
|
||||
asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float)));
|
||||
asm volatile("flw f6, (t6)");
|
||||
asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float)));
|
||||
asm volatile("flw f7, (t6)");
|
||||
asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float)));
|
||||
asm volatile("flw f8, (t6)");
|
||||
asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float)));
|
||||
asm volatile("flw f9, (t6)");
|
||||
asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float)));
|
||||
asm volatile("flw f10, (t6)");
|
||||
asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float)));
|
||||
asm volatile("flw f11, (t6)");
|
||||
asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float)));
|
||||
asm volatile("flw f12, (t6)");
|
||||
asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float)));
|
||||
asm volatile("flw f13, (t6)");
|
||||
asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float)));
|
||||
asm volatile("flw f14, (t6)");
|
||||
asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float)));
|
||||
asm volatile("flw f15, (t6)");
|
||||
asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float)));
|
||||
|
||||
asm volatile("fcvt.s.w f16, %0, rtz" :: "r"(2));
|
||||
|
||||
// do elem-wise compute in RF
|
||||
#pragma GCC unroll
|
||||
for (uint32_t count = 0; count < 8; count++) {
|
||||
#pragma GCC unroll 4
|
||||
for (uint32_t count = 0; count < 128; count++) {
|
||||
asm volatile("fmul.s f0, f0, f16");
|
||||
asm volatile("fmul.s f1, f1, f16");
|
||||
asm volatile("fmul.s f2, f2, f16");
|
||||
@@ -205,40 +206,39 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg,
|
||||
}
|
||||
|
||||
// move back from RF to gmem
|
||||
global_C = C + dim_n * C_row + C_col;
|
||||
asm volatile("fsw f0, (%0)" :: "r"(global_C));
|
||||
global_C += dim_n;
|
||||
asm volatile("fsw f1, (%0)" :: "r"(global_C));
|
||||
global_C += dim_n;
|
||||
asm volatile("fsw f2, (%0)" :: "r"(global_C));
|
||||
global_C += dim_n;
|
||||
asm volatile("fsw f3, (%0)" :: "r"(global_C));
|
||||
global_C += dim_n;
|
||||
asm volatile("fsw f4, (%0)" :: "r"(global_C));
|
||||
global_C += dim_n;
|
||||
asm volatile("fsw f5, (%0)" :: "r"(global_C));
|
||||
global_C += dim_n;
|
||||
asm volatile("fsw f6, (%0)" :: "r"(global_C));
|
||||
global_C += dim_n;
|
||||
asm volatile("fsw f7, (%0)" :: "r"(global_C));
|
||||
global_C += dim_n;
|
||||
asm volatile("fsw f8, (%0)" :: "r"(global_C));
|
||||
global_C += dim_n;
|
||||
asm volatile("fsw f9, (%0)" :: "r"(global_C));
|
||||
global_C += dim_n;
|
||||
asm volatile("fsw f10, (%0)" :: "r"(global_C));
|
||||
global_C += dim_n;
|
||||
asm volatile("fsw f11, (%0)" :: "r"(global_C));
|
||||
global_C += dim_n;
|
||||
asm volatile("fsw f12, (%0)" :: "r"(global_C));
|
||||
global_C += dim_n;
|
||||
asm volatile("fsw f13, (%0)" :: "r"(global_C));
|
||||
global_C += dim_n;
|
||||
asm volatile("fsw f14, (%0)" :: "r"(global_C));
|
||||
global_C += dim_n;
|
||||
asm volatile("fsw f15, (%0)" :: "r"(global_C));
|
||||
global_C += dim_n;
|
||||
|
||||
asm volatile("mv t6, %0" :: "r"(global_C));
|
||||
asm volatile("fsw f0, (t6)");
|
||||
asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float)));
|
||||
asm volatile("fsw f1, (t6)");
|
||||
asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float)));
|
||||
asm volatile("fsw f2, (t6)");
|
||||
asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float)));
|
||||
asm volatile("fsw f3, (t6)");
|
||||
asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float)));
|
||||
asm volatile("fsw f4, (t6)");
|
||||
asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float)));
|
||||
asm volatile("fsw f5, (t6)");
|
||||
asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float)));
|
||||
asm volatile("fsw f6, (t6)");
|
||||
asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float)));
|
||||
asm volatile("fsw f7, (t6)");
|
||||
asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float)));
|
||||
asm volatile("fsw f8, (t6)");
|
||||
asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float)));
|
||||
asm volatile("fsw f9, (t6)");
|
||||
asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float)));
|
||||
asm volatile("fsw f10, (t6)");
|
||||
asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float)));
|
||||
asm volatile("fsw f11, (t6)");
|
||||
asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float)));
|
||||
asm volatile("fsw f12, (t6)");
|
||||
asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float)));
|
||||
asm volatile("fsw f13, (t6)");
|
||||
asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float)));
|
||||
asm volatile("fsw f14, (t6)");
|
||||
asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float)));
|
||||
asm volatile("fsw f15, (t6)");
|
||||
asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float)));
|
||||
}
|
||||
|
||||
if (HW_TID() == 0) {
|
||||
|
||||
Reference in New Issue
Block a user