sgemm_gemmini_dma: Update activation to match tcore

This commit is contained in:
Hansung Kim
2024-06-18 15:30:12 -07:00
parent 50b843d8c4
commit b586e0f881

View File

@@ -145,47 +145,48 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg,
const uint32_t row_in_warptile = 0;
const uint32_t C_row = (tile_i * TILE_M) + (warp_row * WM) + row_in_warptile;
const uint32_t C_col = (tile_j * TILE_N) + (warp_col * WN) + col_in_warptile;
const float *global_C = C + dim_n * C_row + C_col;
const float *const global_C = C + dim_n * C_row + C_col;
// read in elements from GMEM to RF
asm volatile("flw f0, (%0)" :: "r"(global_C));
global_C += dim_n;
asm volatile("flw f1, (%0)" :: "r"(global_C));
global_C += dim_n;
asm volatile("flw f2, (%0)" :: "r"(global_C));
global_C += dim_n;
asm volatile("flw f3, (%0)" :: "r"(global_C));
global_C += dim_n;
asm volatile("flw f4, (%0)" :: "r"(global_C));
global_C += dim_n;
asm volatile("flw f5, (%0)" :: "r"(global_C));
global_C += dim_n;
asm volatile("flw f6, (%0)" :: "r"(global_C));
global_C += dim_n;
asm volatile("flw f7, (%0)" :: "r"(global_C));
global_C += dim_n;
asm volatile("flw f8, (%0)" :: "r"(global_C));
global_C += dim_n;
asm volatile("flw f9, (%0)" :: "r"(global_C));
global_C += dim_n;
asm volatile("flw f10, (%0)" :: "r"(global_C));
global_C += dim_n;
asm volatile("flw f11, (%0)" :: "r"(global_C));
global_C += dim_n;
asm volatile("flw f12, (%0)" :: "r"(global_C));
global_C += dim_n;
asm volatile("flw f13, (%0)" :: "r"(global_C));
global_C += dim_n;
asm volatile("flw f14, (%0)" :: "r"(global_C));
global_C += dim_n;
asm volatile("flw f15, (%0)" :: "r"(global_C));
global_C += dim_n;
asm volatile("mv t6, %0" :: "r"(global_C));
asm volatile("flw f0, (t6)");
asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float)));
asm volatile("flw f1, (t6)");
asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float)));
asm volatile("flw f2, (t6)");
asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float)));
asm volatile("flw f3, (t6)");
asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float)));
asm volatile("flw f4, (t6)");
asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float)));
asm volatile("flw f5, (t6)");
asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float)));
asm volatile("flw f6, (t6)");
asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float)));
asm volatile("flw f7, (t6)");
asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float)));
asm volatile("flw f8, (t6)");
asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float)));
asm volatile("flw f9, (t6)");
asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float)));
asm volatile("flw f10, (t6)");
asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float)));
asm volatile("flw f11, (t6)");
asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float)));
asm volatile("flw f12, (t6)");
asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float)));
asm volatile("flw f13, (t6)");
asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float)));
asm volatile("flw f14, (t6)");
asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float)));
asm volatile("flw f15, (t6)");
asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float)));
asm volatile("fcvt.s.w f16, %0, rtz" :: "r"(2));
// do elem-wise compute in RF
#pragma GCC unroll
for (uint32_t count = 0; count < 8; count++) {
#pragma GCC unroll 4
for (uint32_t count = 0; count < 128; count++) {
asm volatile("fmul.s f0, f0, f16");
asm volatile("fmul.s f1, f1, f16");
asm volatile("fmul.s f2, f2, f16");
@@ -205,40 +206,39 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg,
}
// move back from RF to gmem
global_C = C + dim_n * C_row + C_col;
asm volatile("fsw f0, (%0)" :: "r"(global_C));
global_C += dim_n;
asm volatile("fsw f1, (%0)" :: "r"(global_C));
global_C += dim_n;
asm volatile("fsw f2, (%0)" :: "r"(global_C));
global_C += dim_n;
asm volatile("fsw f3, (%0)" :: "r"(global_C));
global_C += dim_n;
asm volatile("fsw f4, (%0)" :: "r"(global_C));
global_C += dim_n;
asm volatile("fsw f5, (%0)" :: "r"(global_C));
global_C += dim_n;
asm volatile("fsw f6, (%0)" :: "r"(global_C));
global_C += dim_n;
asm volatile("fsw f7, (%0)" :: "r"(global_C));
global_C += dim_n;
asm volatile("fsw f8, (%0)" :: "r"(global_C));
global_C += dim_n;
asm volatile("fsw f9, (%0)" :: "r"(global_C));
global_C += dim_n;
asm volatile("fsw f10, (%0)" :: "r"(global_C));
global_C += dim_n;
asm volatile("fsw f11, (%0)" :: "r"(global_C));
global_C += dim_n;
asm volatile("fsw f12, (%0)" :: "r"(global_C));
global_C += dim_n;
asm volatile("fsw f13, (%0)" :: "r"(global_C));
global_C += dim_n;
asm volatile("fsw f14, (%0)" :: "r"(global_C));
global_C += dim_n;
asm volatile("fsw f15, (%0)" :: "r"(global_C));
global_C += dim_n;
asm volatile("mv t6, %0" :: "r"(global_C));
asm volatile("fsw f0, (t6)");
asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float)));
asm volatile("fsw f1, (t6)");
asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float)));
asm volatile("fsw f2, (t6)");
asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float)));
asm volatile("fsw f3, (t6)");
asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float)));
asm volatile("fsw f4, (t6)");
asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float)));
asm volatile("fsw f5, (t6)");
asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float)));
asm volatile("fsw f6, (t6)");
asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float)));
asm volatile("fsw f7, (t6)");
asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float)));
asm volatile("fsw f8, (t6)");
asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float)));
asm volatile("fsw f9, (t6)");
asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float)));
asm volatile("fsw f10, (t6)");
asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float)));
asm volatile("fsw f11, (t6)");
asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float)));
asm volatile("fsw f12, (t6)");
asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float)));
asm volatile("fsw f13, (t6)");
asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float)));
asm volatile("fsw f14, (t6)");
asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float)));
asm volatile("fsw f15, (t6)");
asm volatile("add t6, t6, %0" :: "r"(dim_n * sizeof(float)));
}
if (HW_TID() == 0) {