From 3bd9ffab661aad0506c840abc7e5195e6c5a0fe4 Mon Sep 17 00:00:00 2001 From: eunaeuna Date: Mon, 25 Nov 2019 15:42:48 -0500 Subject: [PATCH] sgemm fixed --- benchmarks/vector/sgemm_nn/vx_vec_sgemm_nn.c | 66 ++++++++++---------- benchmarks/vector/sgemm_nn/vx_vec_sgemm_nn.h | 2 +- benchmarks/vector/sgemm_nn/vx_vec_sgemm_nn.s | 50 +++++++-------- 3 files changed, 57 insertions(+), 61 deletions(-) diff --git a/benchmarks/vector/sgemm_nn/vx_vec_sgemm_nn.c b/benchmarks/vector/sgemm_nn/vx_vec_sgemm_nn.c index 4df5042f..467b29a6 100644 --- a/benchmarks/vector/sgemm_nn/vx_vec_sgemm_nn.c +++ b/benchmarks/vector/sgemm_nn/vx_vec_sgemm_nn.c @@ -16,30 +16,30 @@ int main() { vx_tmc(1); - int m = 4; - int k = 4; - int n = 4; + int w = 4; + int h = 4; + int d = 4; - int* a1 = (int*)malloc(sizeof(int) * m * k); - int* b1 = (int*)malloc(sizeof(int) * k * n); - int* c1 = (int*)malloc(sizeof(int) * m * n); - int* d1 = (int*)malloc(sizeof(int) * m * n); //verfication + int* a1 = (int*)malloc(sizeof(int) * w * h); + int* b1 = (int*)malloc(sizeof(int) * h * d); + int* c1 = (int*)malloc(sizeof(int) * w * d); + int* d1 = (int*)malloc(sizeof(int) * w * d); //verfication - for (int i = 0; i < (m * k); ++i) a1[i] = i; - for (int i = 0; i < (k * n); ++i) b1[i] = 1; - for (int i = 0; i < (m * n); ++i) c1[i] = 0; - for (int i = 0; i < (m * n); ++i) d1[i] = 0; + for (int i = 0; i < (w * h); ++i) a1[i] = i; + for (int i = 0; i < (h * d); ++i) b1[i] = 1; + for (int i = 0; i < (w * d); ++i) c1[i] = 0; + for (int i = 0; i < (w * d); ++i) d1[i] = 0; -#if 1 - printf("sgemm_nn\na[%d]:", m*k); - for (int i = 0; i < m*k; ++i) { - if(!(i % k)) printf("\n"); +#if 0 + printf("sgemm_nn\na[%d]:", w*h); + for (int i = 0; i < w*h; ++i) { + if(!(i % h)) printf("\n"); printf("%d ", a1[i]); } - printf("\n\nb[%d]:", k*n); - for (int i = 0; i < k*n; ++i) { - if (!(i % n)) printf("\n"); + printf("\n\nb[%d]:", h*d); + for (int i = 0; i < h*d; ++i) { + if (!(i % d)) printf("\n"); printf("%d ", b1[i]); } #endif @@ -49,31 +49,29 @@ int main() int ldc = 4; //64; int vsize = 4; - for (int r = 0; r < m; r++) { - for (int c = 0; c < n; c++) { - for (int i = 0; i < k;) { -// d1[r*k+i] += a1[r*k+c]*b1[i*n+c]; + for (int n = 0; n < h; n++) { + for (int i = 0; i < w; i=+4) { + for (int m = 0; m < d; m++) { + vx_vec_sgemm_nn(i, m, n, a1, b1, c1, ldc, vsize); + //d1[i+n*ldc] += a1[m+n*ldc]*b1[m*ldc+i]; vx_vec_sgemm_nn(i, r, c, a1, b1, c1, ldc, vsize); i = i + vsize; } } } -// vx_vec_sgemm_nn(n, a1, b1, c1); -#if 1 - printf("\n\nc[%d]:", m*n); - for (int i = 0; i < m*n; ++i) { - if (!(i % n)) printf("\n"); +#if 1 + printf("\n\nc[%d]:", d*h); + for (int i = 0; i < d*h; ++i) { + if (!(i % h)) printf("\n"); printf("%d ", c1[i]); } #endif - for (int r = 0; r < m; r++) { - for (int c = 0; c < n; c++) { - for (int i = 0; i < k; i++) { - d1[c*ldc+i] += a1[c*ldc+r]*b1[i + (r*ldc)]; - //printf("d[%d] += a[%d]*b[%d]\n", c*ldc+i, c*ldc+r , i + (r*ldc)); - //printf("%d %d %d\n", d1[c*ldc+i] , a1[c*ldc+r] , b1[i + (r*ldc)]); + for (int r = 0; r < h; r++) { + for (int c = 0; c < w; c++) { + for (int i = 0; i < d; i++) { + d1[r*h+i] += a1[r*h+c]*b1[i*d+c]; } } } @@ -89,7 +87,7 @@ int main() #endif - for(int i = 0; i < m*n; ++i) + for(int i = 0; i < w*d; ++i) { if(c1[i] != d1[i]) { diff --git a/benchmarks/vector/sgemm_nn/vx_vec_sgemm_nn.h b/benchmarks/vector/sgemm_nn/vx_vec_sgemm_nn.h index 34ce8a9f..7c2873e2 100644 --- a/benchmarks/vector/sgemm_nn/vx_vec_sgemm_nn.h +++ b/benchmarks/vector/sgemm_nn/vx_vec_sgemm_nn.h @@ -6,7 +6,7 @@ extern "C" { #endif //void vx_vec_sgemm_nn(int n, int m, int k, int* a1, int lda, int* b1, int ldb, int* c1, int ldc); -void vx_vec_sgemm_nn(int n, int m, int k, int* a1, int* b1, int* c1, int ldc, int vsize); +void vx_vec_sgemm_nn(int i, int m, int n, int* a1, int* b1, int* c1, int ldc, int vsize); //void vx_vec_sgemm_nn(int n, int* a1, int* b1, int* c1); #ifdef __cplusplus } diff --git a/benchmarks/vector/sgemm_nn/vx_vec_sgemm_nn.s b/benchmarks/vector/sgemm_nn/vx_vec_sgemm_nn.s index 11f299c2..639676ba 100644 --- a/benchmarks/vector/sgemm_nn/vx_vec_sgemm_nn.s +++ b/benchmarks/vector/sgemm_nn/vx_vec_sgemm_nn.s @@ -10,33 +10,31 @@ # } # } # } -# a3 = a, a4 = b, a5 = c -# a0 = i, a1 = m, a2 = n -# a6 = ldc +# a0 = i, a1 = m, a2 = n, a3 = a, a4 = b, a5 = c, a6 = ldc, a7 = vsize +# vx_vec_sgemm_nn: - vsetvli t0, a7, e32 - mul t1, a6, a2 # n*ldc - add t2, t1, a1 # i + (n*ldc) - slli t2, t2, 2 - add a3, t2, a3 # a[i+ n*ldc] - lw t3, (a3) + vsetvli t0, a7, e32 # <--- vsize + mul x11, a6, a2 # n*ldc + add x12, x11, a1 # i + (n*ldc) + add a3, x12, a3 # a[i+ n*ldc] + lw x13, (a3) - mul t4, a1, a6 # m*ldc - add t5, a0, t4 # i + m*ldc - slli t5, t5, 2 - add a4, t5, a4 # b[i + m*ldc] - # lw x6, (a4) - - vlw.v v0, (a4) - vmul.vx v1, v0, t3 - - mul t6, a2, a6 # n*ldc - add t0, a0, t6 # i + n*ldc - slli t0, t0, 2 - add a5, t0, a5 # c[i + m*ldc] - - vlw.v v2, (a5) #c - vadd.vv v2, v2, v1 - vsw.v v2, (a5) + mul x14, a1, a6 # m*ldc + add x15, a0, x14 # i + m*ldc + add a4, x15, a4 # b[i + m*ldc] + vlw.v v0, (a4) + vmul.vx v2, v1, x13 +## lw x6, (a4) +# lw x10, (a4) # b +# mul x11, x3, x10 + mul x6, a2, a6 # n*ldc + add x7, a0, x6 # i + n*ldc + add a5, x7, a5 # c[i + m*ldc] + vlw.v v3, (a5) # c + vadd.vv v3, v3, v2 + vsw.v v3, (a5) +# lw x12, (a5) +# add x12, x12, x11 +# sw x12, (a5) ret