sgemm test

This commit is contained in:
Euna Kim
2019-11-24 22:58:04 -05:00
parent ec068d2428
commit 4ce79b24b4
5 changed files with 62 additions and 11 deletions

View File

@@ -10,7 +10,7 @@
# a2 a
# a3 b
vx_vec_saxpy:
vsetvli a4, a0, e32, m8
vsetvli a4, a0, e32
loop:
vlw.v v0, (a2)
sub a0, a0, a4

View File

@@ -16,7 +16,7 @@ VX_IO = $(LIB_PATH)/io/vx_io.s $(LIB_PATH)/io/vx_io.c
VX_API = $(LIB_PATH)/vx_api/vx_api.c
VX_TEST = $(LIB_PATH)/tests/tests.c
VX_FIO = $(LIB_PATH)/fileio/fileio.s
VX_VEC = vx_vec_sgemm_nn.s #float --> int
VX_VEC = test_asm.s #vx_vec_sgemm_nn_backup.s #float --> int
LIBS = /nethome/ekim79/riscv-gnu-toolchain/drops/riscv32-unknown-elf/lib/libc.a /nethome/ekim79/riscv-gnu-toolchain/drops/riscv32-unknown-elf/lib/libstdc++.a -static-libgcc -lgcc
VX_MAIN = vx_vec_sgemm_nn

View File

@@ -0,0 +1,38 @@
.type vx_vec_sgemm_nn, @function
.global vx_vec_sgemm_nn
#
# for (int n = 0; n < k; n++) {
# for (int m = 0; m < m; m++) {
# for (int i = 0; i < n;) {
#// d1[n*k+i] += a1[n*k+m]*b1[i*n+m];
# vx_vec_sgemm_nn(i, c, r, a1, b1, c1, ldc);
# i = i + 4;
# }
# }
# }
# a3 = a, a4 = b, a5 = c
# a0 = i, a1 = m, a2 = n
# a6 = ldc
vx_vec_sgemm_nn:
vsetvli t0, a6, e32
mul x1, a6, a2 # n*ldc
add x2, x1, a1 # i + (n*ldc)
add a3, x2, a3 # a[i+ n*ldc]
lw x3, (a3)
mul x4, a1, a6 # m*ldc
add x5, a0, x4 # i + m*ldc
add a4, x5, a4 # b[i + m*ldc]
# lw x6, (a4)
vlw.v v0, (a4)
vmul.vx v2, v1, x3
mul x6, a2, a6 # n*ldc
add x7, a0, x6 # i + n*ldc
add a5, x7, a5 # c[i + m*ldc]
vlw.v v3, (a5) #c
vadd.vv v3, v3, v2
ret

View File

@@ -16,9 +16,9 @@ int main()
{
vx_tmc(1);
int m = 3;
int k = 3;
int n = 3;
int m = 4;
int k = 4;
int n = 4;
int* a1 = (int*)malloc(sizeof(int) * m * k);
int* b1 = (int*)malloc(sizeof(int) * k * n);
@@ -31,7 +31,7 @@ int main()
for (int i = 0; i < (m * n); ++i) d1[i] = 0;
#if 1
#if 0
printf("sgemm_nn\na[%d]:", m*k);
for (int i = 0; i < m*k; ++i) {
if(!(i % k)) printf("\n");
@@ -44,11 +44,24 @@ int main()
}
#endif
vx_vec_sgemm_nn(n, m, k, a1, b1, c1);
int lda = 4;
int ldb = 4;
int ldc = 4; //64;
int vsize = 4;
for (int r = 0; r < k; r++) {
for (int c = 0; c < m; c++) {
for (int i = 0; i < n;) {
// d1[r*k+i] += a1[r*k+c]*b1[i*n+c];
vx_vec_sgemm_nn(i, c, r, a1, b1, c1, ldc, vsize);
i = i + vsize;
}
}
}
// vx_vec_sgemm_nn(n, a1, b1, c1);
#if 1
printf("\n\nc[%d]:\n", m*n);
printf("\n\nc[%d]:", m*n);
for (int i = 0; i < m*n; ++i) {
if (!(i % n)) printf("\n");
printf("%d ", c1[i]);
@@ -63,11 +76,11 @@ int main()
}
}
#if 1
#if 0
printf("\n\nc[%d]:\n", m*n);
for(int i = 0; i < m; ++i) {
for(int j = 0; j < n; ++j) {
printf("%d ", c1[i*m+j]);
printf("%d ", d1[i*m+j]);
}
printf("\n");
}

View File

@@ -6,7 +6,7 @@ extern "C" {
#endif
//void vx_vec_sgemm_nn(int n, int m, int k, int* a1, int lda, int* b1, int ldb, int* c1, int ldc);
void vx_vec_sgemm_nn(int n, int m, int k, int* a1, int* b1, int* c1);
void vx_vec_sgemm_nn(int n, int m, int k, int* a1, int* b1, int* c1, int ldc, int vsize);
//void vx_vec_sgemm_nn(int n, int* a1, int* b1, int* c1);
#ifdef __cplusplus
}