sgemm fixed
This commit is contained in:
@@ -16,30 +16,30 @@ int main()
|
|||||||
{
|
{
|
||||||
vx_tmc(1);
|
vx_tmc(1);
|
||||||
|
|
||||||
int m = 4;
|
int w = 4;
|
||||||
int k = 4;
|
int h = 4;
|
||||||
int n = 4;
|
int d = 4;
|
||||||
|
|
||||||
int* a1 = (int*)malloc(sizeof(int) * m * k);
|
int* a1 = (int*)malloc(sizeof(int) * w * h);
|
||||||
int* b1 = (int*)malloc(sizeof(int) * k * n);
|
int* b1 = (int*)malloc(sizeof(int) * h * d);
|
||||||
int* c1 = (int*)malloc(sizeof(int) * m * n);
|
int* c1 = (int*)malloc(sizeof(int) * w * d);
|
||||||
int* d1 = (int*)malloc(sizeof(int) * m * n); //verfication
|
int* d1 = (int*)malloc(sizeof(int) * w * d); //verfication
|
||||||
|
|
||||||
for (int i = 0; i < (m * k); ++i) a1[i] = i;
|
for (int i = 0; i < (w * h); ++i) a1[i] = i;
|
||||||
for (int i = 0; i < (k * n); ++i) b1[i] = 1;
|
for (int i = 0; i < (h * d); ++i) b1[i] = 1;
|
||||||
for (int i = 0; i < (m * n); ++i) c1[i] = 0;
|
for (int i = 0; i < (w * d); ++i) c1[i] = 0;
|
||||||
for (int i = 0; i < (m * n); ++i) d1[i] = 0;
|
for (int i = 0; i < (w * d); ++i) d1[i] = 0;
|
||||||
|
|
||||||
|
|
||||||
#if 1
|
#if 0
|
||||||
printf("sgemm_nn\na[%d]:", m*k);
|
printf("sgemm_nn\na[%d]:", w*h);
|
||||||
for (int i = 0; i < m*k; ++i) {
|
for (int i = 0; i < w*h; ++i) {
|
||||||
if(!(i % k)) printf("\n");
|
if(!(i % h)) printf("\n");
|
||||||
printf("%d ", a1[i]);
|
printf("%d ", a1[i]);
|
||||||
}
|
}
|
||||||
printf("\n\nb[%d]:", k*n);
|
printf("\n\nb[%d]:", h*d);
|
||||||
for (int i = 0; i < k*n; ++i) {
|
for (int i = 0; i < h*d; ++i) {
|
||||||
if (!(i % n)) printf("\n");
|
if (!(i % d)) printf("\n");
|
||||||
printf("%d ", b1[i]);
|
printf("%d ", b1[i]);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
@@ -49,31 +49,29 @@ int main()
|
|||||||
int ldc = 4; //64;
|
int ldc = 4; //64;
|
||||||
int vsize = 4;
|
int vsize = 4;
|
||||||
|
|
||||||
for (int r = 0; r < m; r++) {
|
for (int n = 0; n < h; n++) {
|
||||||
for (int c = 0; c < n; c++) {
|
for (int i = 0; i < w; i=+4) {
|
||||||
for (int i = 0; i < k;) {
|
for (int m = 0; m < d; m++) {
|
||||||
// d1[r*k+i] += a1[r*k+c]*b1[i*n+c];
|
vx_vec_sgemm_nn(i, m, n, a1, b1, c1, ldc, vsize);
|
||||||
|
//d1[i+n*ldc] += a1[m+n*ldc]*b1[m*ldc+i];
|
||||||
vx_vec_sgemm_nn(i, r, c, a1, b1, c1, ldc, vsize);
|
vx_vec_sgemm_nn(i, r, c, a1, b1, c1, ldc, vsize);
|
||||||
i = i + vsize;
|
i = i + vsize;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// vx_vec_sgemm_nn(n, a1, b1, c1);
|
|
||||||
|
|
||||||
#if 1
|
#if 1
|
||||||
printf("\n\nc[%d]:", m*n);
|
printf("\n\nc[%d]:", d*h);
|
||||||
for (int i = 0; i < m*n; ++i) {
|
for (int i = 0; i < d*h; ++i) {
|
||||||
if (!(i % n)) printf("\n");
|
if (!(i % h)) printf("\n");
|
||||||
printf("%d ", c1[i]);
|
printf("%d ", c1[i]);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
for (int r = 0; r < m; r++) {
|
for (int r = 0; r < h; r++) {
|
||||||
for (int c = 0; c < n; c++) {
|
for (int c = 0; c < w; c++) {
|
||||||
for (int i = 0; i < k; i++) {
|
for (int i = 0; i < d; i++) {
|
||||||
d1[c*ldc+i] += a1[c*ldc+r]*b1[i + (r*ldc)];
|
d1[r*h+i] += a1[r*h+c]*b1[i*d+c];
|
||||||
//printf("d[%d] += a[%d]*b[%d]\n", c*ldc+i, c*ldc+r , i + (r*ldc));
|
|
||||||
//printf("%d %d %d\n", d1[c*ldc+i] , a1[c*ldc+r] , b1[i + (r*ldc)]);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -89,7 +87,7 @@ int main()
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
||||||
for(int i = 0; i < m*n; ++i)
|
for(int i = 0; i < w*d; ++i)
|
||||||
{
|
{
|
||||||
if(c1[i] != d1[i])
|
if(c1[i] != d1[i])
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ extern "C" {
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
//void vx_vec_sgemm_nn(int n, int m, int k, int* a1, int lda, int* b1, int ldb, int* c1, int ldc);
|
//void vx_vec_sgemm_nn(int n, int m, int k, int* a1, int lda, int* b1, int ldb, int* c1, int ldc);
|
||||||
void vx_vec_sgemm_nn(int n, int m, int k, int* a1, int* b1, int* c1, int ldc, int vsize);
|
void vx_vec_sgemm_nn(int i, int m, int n, int* a1, int* b1, int* c1, int ldc, int vsize);
|
||||||
//void vx_vec_sgemm_nn(int n, int* a1, int* b1, int* c1);
|
//void vx_vec_sgemm_nn(int n, int* a1, int* b1, int* c1);
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,33 +10,31 @@
|
|||||||
# }
|
# }
|
||||||
# }
|
# }
|
||||||
# }
|
# }
|
||||||
# a3 = a, a4 = b, a5 = c
|
# a0 = i, a1 = m, a2 = n, a3 = a, a4 = b, a5 = c, a6 = ldc, a7 = vsize
|
||||||
# a0 = i, a1 = m, a2 = n
|
#
|
||||||
# a6 = ldc
|
|
||||||
vx_vec_sgemm_nn:
|
vx_vec_sgemm_nn:
|
||||||
vsetvli t0, a7, e32
|
vsetvli t0, a7, e32 # <--- vsize
|
||||||
mul t1, a6, a2 # n*ldc
|
mul x11, a6, a2 # n*ldc
|
||||||
add t2, t1, a1 # i + (n*ldc)
|
add x12, x11, a1 # i + (n*ldc)
|
||||||
slli t2, t2, 2
|
add a3, x12, a3 # a[i+ n*ldc]
|
||||||
add a3, t2, a3 # a[i+ n*ldc]
|
lw x13, (a3)
|
||||||
lw t3, (a3)
|
|
||||||
|
|
||||||
mul t4, a1, a6 # m*ldc
|
mul x14, a1, a6 # m*ldc
|
||||||
add t5, a0, t4 # i + m*ldc
|
add x15, a0, x14 # i + m*ldc
|
||||||
slli t5, t5, 2
|
add a4, x15, a4 # b[i + m*ldc]
|
||||||
add a4, t5, a4 # b[i + m*ldc]
|
vlw.v v0, (a4)
|
||||||
# lw x6, (a4)
|
vmul.vx v2, v1, x13
|
||||||
|
## lw x6, (a4)
|
||||||
vlw.v v0, (a4)
|
# lw x10, (a4) # b
|
||||||
vmul.vx v1, v0, t3
|
# mul x11, x3, x10
|
||||||
|
|
||||||
mul t6, a2, a6 # n*ldc
|
|
||||||
add t0, a0, t6 # i + n*ldc
|
|
||||||
slli t0, t0, 2
|
|
||||||
add a5, t0, a5 # c[i + m*ldc]
|
|
||||||
|
|
||||||
vlw.v v2, (a5) #c
|
|
||||||
vadd.vv v2, v2, v1
|
|
||||||
vsw.v v2, (a5)
|
|
||||||
|
|
||||||
|
mul x6, a2, a6 # n*ldc
|
||||||
|
add x7, a0, x6 # i + n*ldc
|
||||||
|
add a5, x7, a5 # c[i + m*ldc]
|
||||||
|
vlw.v v3, (a5) # c
|
||||||
|
vadd.vv v3, v3, v2
|
||||||
|
vsw.v v3, (a5)
|
||||||
|
# lw x12, (a5)
|
||||||
|
# add x12, x12, x11
|
||||||
|
# sw x12, (a5)
|
||||||
ret
|
ret
|
||||||
|
|||||||
Reference in New Issue
Block a user