refactoring fixes
This commit is contained in:
220
miscs/rvvector/benchmark_temp/vx_vec_sgemm_float.s
Normal file
220
miscs/rvvector/benchmark_temp/vx_vec_sgemm_float.s
Normal file
@@ -0,0 +1,220 @@
|
||||
.type vx_vec_sgemm_nn_float, @function
|
||||
.global vx_vec_sgemm_nn_float
|
||||
# RV64IDV system
|
||||
#
|
||||
# void
|
||||
# sgemm_nn(size_t n,
|
||||
# size_t m,
|
||||
# size_t k,
|
||||
# const float*a, // m * k matrix
|
||||
# size_t lda,
|
||||
# const float*b, // k * n matrix
|
||||
# size_t ldb,
|
||||
# float*c, // m * n matrix
|
||||
# size_t ldc)
|
||||
#
|
||||
# c += a*b (alpha=1, no transpose on input matrices)
|
||||
# matrices stored in C row-major order
|
||||
|
||||
#define n a0
|
||||
#define m a1
|
||||
#define k a2
|
||||
#define ap a3
|
||||
#define astride a4
|
||||
#define bp a5
|
||||
#define bstride a6
|
||||
#define cp a7
|
||||
#define cstride t0
|
||||
#define kt t1
|
||||
#define nt t2
|
||||
#define bnp t3
|
||||
#define cnp t4
|
||||
#define akp t5
|
||||
#define bkp s0
|
||||
#define nvl s1
|
||||
#define ccp s2
|
||||
#define amp s3
|
||||
|
||||
# Use args as additional temporaries
|
||||
#define ft12 fa0
|
||||
#define ft13 fa1
|
||||
#define ft14 fa2
|
||||
#define ft15 fa3
|
||||
|
||||
# This version holds a 16*VLMAX block of C matrix in vector registers
|
||||
# in inner loop, but otherwise does not cache or TLB tiling.
|
||||
vx_vec_sgemm_nn_float:
|
||||
#sgemm_nn:
|
||||
addi sp, sp, -FRAMESIZE
|
||||
sd s0, OFFSET(sp)
|
||||
sd s1, OFFSET(sp)
|
||||
sd s2, OFFSET(sp)
|
||||
|
||||
# Check for zero size matrices
|
||||
beqz n, exit
|
||||
beqz m, exit
|
||||
beqz k, exit
|
||||
|
||||
# Convert elements strides to byte strides.
|
||||
ld cstride, OFFSET(sp) # Get arg from stack frame
|
||||
slli astride, astride, 2
|
||||
slli bstride, bstride, 2
|
||||
slli cstride, cstride, 2
|
||||
|
||||
slti t6, m, 16
|
||||
bnez t6, end_rows
|
||||
|
||||
c_row_loop: # Loop across rows of C blocks
|
||||
|
||||
mv nt, n # Initialize n counter for next row of C blocks
|
||||
|
||||
mv bnp, bp # Initialize B n-loop pointer to start
|
||||
mv cnp, cp # Initialize C n-loop pointer
|
||||
|
||||
c_col_loop: # Loop across one row of C blocks
|
||||
vsetvli nvl, nt, e32 # 32-bit vectors, LMUL=1
|
||||
|
||||
mv akp, ap # reset pointer into A to beginning
|
||||
mv bkp, bnp # step to next column in B matrix
|
||||
|
||||
# Initalize current C submatrix block from memory.
|
||||
vlw.v v0, (cnp); add ccp, cnp, cstride;
|
||||
vlw.v v1, (ccp); add ccp, ccp, cstride;
|
||||
vlw.v v2, (ccp); add ccp, ccp, cstride;
|
||||
vlw.v v3, (ccp); add ccp, ccp, cstride;
|
||||
vlw.v v4, (ccp); add ccp, ccp, cstride;
|
||||
vlw.v v5, (ccp); add ccp, ccp, cstride;
|
||||
vlw.v v6, (ccp); add ccp, ccp, cstride;
|
||||
vlw.v v7, (ccp); add ccp, ccp, cstride;
|
||||
vlw.v v8, (ccp); add ccp, ccp, cstride;
|
||||
vlw.v v9, (ccp); add ccp, ccp, cstride;
|
||||
vlw.v v10, (ccp); add ccp, ccp, cstride;
|
||||
vlw.v v11, (ccp); add ccp, ccp, cstride;
|
||||
vlw.v v12, (ccp); add ccp, ccp, cstride;
|
||||
vlw.v v13, (ccp); add ccp, ccp, cstride;
|
||||
vlw.v v14, (ccp); add ccp, ccp, cstride;
|
||||
vlw.v v15, (ccp)
|
||||
|
||||
|
||||
mv kt, k # Initialize inner loop counter
|
||||
|
||||
# Inner loop scheduled assuming 4-clock occupancy of vfmacc instruction and single-issue pipeline
|
||||
# Software pipeline loads
|
||||
flw ft0, (akp); add amp, akp, astride;
|
||||
flw ft1, (amp); add amp, amp, astride;
|
||||
flw ft2, (amp); add amp, amp, astride;
|
||||
flw ft3, (amp); add amp, amp, astride;
|
||||
# Get vector from B matrix
|
||||
vlw.v v16, (bkp)
|
||||
|
||||
# Loop on inner dimension for current C block
|
||||
k_loop:
|
||||
vfmacc.vf v0, ft0, v16
|
||||
add bkp, bkp, bstride
|
||||
flw ft4, (amp)
|
||||
add amp, amp, astride
|
||||
vfmacc.vf v1, ft1, v16
|
||||
addi kt, kt, -1 # Decrement k counter
|
||||
flw ft5, (amp)
|
||||
add amp, amp, astride
|
||||
vfmacc.vf v2, ft2, v16
|
||||
flw ft6, (amp)
|
||||
add amp, amp, astride
|
||||
flw ft7, (amp)
|
||||
vfmacc.vf v3, ft3, v16
|
||||
add amp, amp, astride
|
||||
flw ft8, (amp)
|
||||
add amp, amp, astride
|
||||
vfmacc.vf v4, ft4, v16
|
||||
flw ft9, (amp)
|
||||
add amp, amp, astride
|
||||
vfmacc.vf v5, ft5, v16
|
||||
flw ft10, (amp)
|
||||
add amp, amp, astride
|
||||
vfmacc.vf v6, ft6, v16
|
||||
flw ft11, (amp)
|
||||
add amp, amp, astride
|
||||
vfmacc.vf v7, ft7, v16
|
||||
flw ft12, (amp)
|
||||
add amp, amp, astride
|
||||
vfmacc.vf v8, ft8, v16
|
||||
flw ft13, (amp)
|
||||
add amp, amp, astride
|
||||
vfmacc.vf v9, ft9, v16
|
||||
flw ft14, (amp)
|
||||
add amp, amp, astride
|
||||
vfmacc.vf v10, ft10, v16
|
||||
flw ft15, (amp)
|
||||
add amp, amp, astride
|
||||
addi akp, akp, 4 # Move to next column of a
|
||||
vfmacc.vf v11, ft11, v16
|
||||
beqz kt, 1f # Don't load past end of matrix
|
||||
flw ft0, (akp)
|
||||
add amp, akp, astride
|
||||
1: vfmacc.vf v12, ft12, v16
|
||||
beqz kt, 1f
|
||||
flw ft1, (amp)
|
||||
add amp, amp, astride
|
||||
1: vfmacc.vf v13, ft13, v16
|
||||
beqz kt, 1f
|
||||
flw ft2, (amp)
|
||||
add amp, amp, astride
|
||||
1: vfmacc.vf v14, ft14, v16
|
||||
beqz kt, 1f # Exit out of loop
|
||||
flw ft3, (amp)
|
||||
add amp, amp, astride
|
||||
vfmacc.vf v15, ft15, v16
|
||||
vlw.v v16, (bkp) # Get next vector from B matrix, overlap loads with jump stalls
|
||||
j k_loop
|
||||
|
||||
1: vfmacc.vf v15, ft15, v16
|
||||
|
||||
# Save C matrix block back to memory
|
||||
vsw.v v0, (cnp); add ccp, cnp, cstride;
|
||||
vsw.v v1, (ccp); add ccp, ccp, cstride;
|
||||
vsw.v v2, (ccp); add ccp, ccp, cstride;
|
||||
vsw.v v3, (ccp); add ccp, ccp, cstride;
|
||||
vsw.v v4, (ccp); add ccp, ccp, cstride;
|
||||
vsw.v v5, (ccp); add ccp, ccp, cstride;
|
||||
vsw.v v6, (ccp); add ccp, ccp, cstride;
|
||||
vsw.v v7, (ccp); add ccp, ccp, cstride;
|
||||
vsw.v v8, (ccp); add ccp, ccp, cstride;
|
||||
vsw.v v9, (ccp); add ccp, ccp, cstride;
|
||||
vsw.v v10, (ccp); add ccp, ccp, cstride;
|
||||
vsw.v v11, (ccp); add ccp, ccp, cstride;
|
||||
vsw.v v12, (ccp); add ccp, ccp, cstride;
|
||||
vsw.v v13, (ccp); add ccp, ccp, cstride;
|
||||
vsw.v v14, (ccp); add ccp, ccp, cstride;
|
||||
vsw.v v15, (ccp)
|
||||
|
||||
# Following tail instructions should be scheduled earlier in free slots during C block save.
|
||||
# Leaving here for clarity.
|
||||
|
||||
# Bump pointers for loop across blocks in one row
|
||||
slli t6, nvl, 2
|
||||
add cnp, cnp, t6 # Move C block pointer over
|
||||
add bnp, bnp, t6 # Move B block pointer over
|
||||
sub nt, nt, nvl # Decrement element count in n dimension
|
||||
bnez nt, c_col_loop # Any more to do?
|
||||
|
||||
# Move to next set of rows
|
||||
addi m, m, -16 # Did 16 rows above
|
||||
slli t6, astride, 4 # Multiply astride by 16
|
||||
add ap, ap, t6 # Move A matrix pointer down 16 rows
|
||||
slli t6, cstride, 4 # Multiply cstride by 16
|
||||
add cp, cp, t6 # Move C matrix pointer down 16 rows
|
||||
|
||||
slti t6, m, 16
|
||||
beqz t6, c_row_loop
|
||||
|
||||
# Handle end of matrix with fewer than 16 rows.
|
||||
# Can use smaller versions of above decreasing in powers-of-2 depending on code-size concerns.
|
||||
end_rows:
|
||||
# Not done.
|
||||
|
||||
exit:
|
||||
ld s0, OFFSET(sp)
|
||||
ld s1, OFFSET(sp)
|
||||
ld s2, OFFSET(sp)
|
||||
addi sp, sp, FRAMESIZE
|
||||
ret
|
||||
Reference in New Issue
Block a user