Compare commits
25 Commits
dde3602046
...
ae-volta
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
326141b11f | ||
|
|
71f713b9fc | ||
|
|
d893780594 | ||
|
|
9847072eff | ||
|
|
9b7c22a7e9 | ||
|
|
f8c51669c1 | ||
|
|
b1ebabef26 | ||
|
|
17a9d31be5 | ||
|
|
9f524538a4 | ||
|
|
238b942133 | ||
|
|
51ebe18ebb | ||
|
|
7b0a95034b | ||
|
|
2c1ac4e938 | ||
|
|
c240069147 | ||
|
|
9cdee597b6 | ||
|
|
6bdc6af607 | ||
|
|
d86c33acf3 | ||
|
|
b73147cd06 | ||
|
|
471f89e371 | ||
|
|
b49e8a293c | ||
|
|
7e1fc54c97 | ||
|
|
19731b8e2f | ||
|
|
50c8f1c410 | ||
|
|
afc69507a3 | ||
|
|
6e279c905f |
1
kernels/flash_attention/args.bin
Symbolic link
1
kernels/flash_attention/args.bin
Symbolic link
@@ -0,0 +1 @@
|
||||
args.seq1024.headdim64.bin
|
||||
BIN
kernels/flash_attention/args.seq1024.headdim64.bin
Normal file
BIN
kernels/flash_attention/args.seq1024.headdim64.bin
Normal file
Binary file not shown.
BIN
kernels/flash_attention/args.seq128.headdim64.bin
Normal file
BIN
kernels/flash_attention/args.seq128.headdim64.bin
Normal file
Binary file not shown.
BIN
kernels/flash_attention/args.seq192.headdim64.bin
Normal file
BIN
kernels/flash_attention/args.seq192.headdim64.bin
Normal file
Binary file not shown.
BIN
kernels/flash_attention/args.seq64.headdim64.bin
Normal file
BIN
kernels/flash_attention/args.seq64.headdim64.bin
Normal file
Binary file not shown.
45
kernels/flash_attention/compile_flash.sh
Executable file
45
kernels/flash_attention/compile_flash.sh
Executable file
@@ -0,0 +1,45 @@
|
||||
#!/bin/bash
|
||||
|
||||
archs=("ampere" "virgo")
|
||||
|
||||
if [ -z "$TOOLDIR" ]; then
|
||||
echo "error: \$TOOLDIR not set. Did you run source ci/toolchain_env.sh?"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
check_exists() {
|
||||
if ! [ -f "$1" ]; then
|
||||
echo "error: looked for file $1 that does not exist."
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
# generate operands
|
||||
echo "generating flash_attn operands for seqlen 1024, headdim 64"
|
||||
python3 flash_attn.py 1024 64 64
|
||||
mv -v input.a.col.bin input.a.rand.fp32.seqlen1024headdim64.col.bin
|
||||
mv -v input.a.row.bin input.a.rand.fp32.seqlen1024headdim64.row.bin
|
||||
mv -v input.b.bin input.b.rand.fp32.seqlen1024headdim64.row.bin
|
||||
mv -v input.c.bin input.c.rand.fp32.seqlen1024headdim64.row.bin
|
||||
ln -sf input.a.rand.fp32.seqlen1024headdim64.row.bin input.a.bin
|
||||
ln -sf input.b.rand.fp32.seqlen1024headdim64.row.bin input.b.bin
|
||||
ln -sf input.c.rand.fp32.seqlen1024headdim64.row.bin input.c.bin
|
||||
|
||||
for arch in "${archs[@]}"; do
|
||||
git checkout ae-flash-$arch
|
||||
# git pull
|
||||
|
||||
# re-compile libvortexrt.a
|
||||
pushd ../../lib
|
||||
make
|
||||
popd
|
||||
|
||||
echo "compiling flash_attn kernel for $arch with seqlen 1024, headdim 64"
|
||||
|
||||
# touch source file to force re-building, as the Makefile does not track
|
||||
# binary changes
|
||||
touch kernel.cpp
|
||||
touch kernel.gemmini.cpp
|
||||
|
||||
make CONFIG=flash.$arch.seqlen1024.headdim64
|
||||
done
|
||||
159
kernels/flash_attention/flash_attn.py
Normal file
159
kernels/flash_attention/flash_attn.py
Normal file
@@ -0,0 +1,159 @@
|
||||
import sys
|
||||
import numpy as np
|
||||
|
||||
def parse_mnk():
|
||||
if len(sys.argv) != 4:
|
||||
print(f"usage: {sys.argv[0]} dimM dimN dimK", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
m = int(sys.argv[1])
|
||||
n = int(sys.argv[2])
|
||||
k = int(sys.argv[3])
|
||||
return (m, n, k)
|
||||
|
||||
|
||||
# Reorder array in a way that groups two adjacent elements along the column to
|
||||
# be now adjacent along the row. This way, when the resulting fp16 array is
|
||||
# read in column-major order with 32-bit granularity, the fp16 elements will be
|
||||
# read in the same order as regular fp32 elements in column-major.
|
||||
#
|
||||
# For example:
|
||||
# [[1 2]
|
||||
# [3 4]
|
||||
# [5 6]
|
||||
# [7 8]]
|
||||
# becomes
|
||||
# [[1 3 2 4]
|
||||
# [5 7 6 8]]
|
||||
def pack_fp16_by_column(array):
|
||||
rows = array.shape[0]
|
||||
cols = array.shape[1]
|
||||
|
||||
T = array.transpose([1, 0])
|
||||
T_packed = T.reshape([cols, -1, 2])
|
||||
result = T_packed.transpose([1, 0, 2])
|
||||
return result
|
||||
|
||||
|
||||
# Do the same as pack_fp16_by_column, but for every two elements along the row.
|
||||
def pack_fp16_by_row(array):
|
||||
rows = array.shape[0]
|
||||
cols = array.shape[1]
|
||||
|
||||
result = array.reshape([rows, -1, 2])
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
seqlen, _, headdim = parse_mnk()
|
||||
|
||||
rand = True
|
||||
if not rand:
|
||||
A_array = np.arange(seqlen * headdim).reshape([seqlen, headdim])
|
||||
B_array = np.arange(headdim * seqlen).reshape([headdim, seqlen])
|
||||
C_array = np.arange(seqlen * seqlen).reshape([seqlen, headdim])
|
||||
else:
|
||||
np.random.seed(0)
|
||||
A_array = np.random.rand(seqlen, headdim) - 0.5
|
||||
B_array = np.random.rand(headdim, seqlen) - 0.5
|
||||
C_array = np.random.rand(seqlen, headdim) - 0.5
|
||||
# C_array = np.zeros([M, N])
|
||||
|
||||
fp16 = False
|
||||
if fp16:
|
||||
A_packed = pack_fp16_by_row(A_array)
|
||||
AT_packed = A_packed.transpose([1, 0, 2])
|
||||
AT_array = AT_packed.reshape([-1, seqlen * 2])
|
||||
AT_array.astype('float16').tofile("input.a.col.bin")
|
||||
# print('AT:')
|
||||
# print(AT_array)
|
||||
B_packed = pack_fp16_by_column(B_array)
|
||||
B_array = B_packed.reshape([-1, headdim * 2])
|
||||
B_array.astype('float16').tofile("input.b.row.bin")
|
||||
# print('B:')
|
||||
# print(B_array)
|
||||
else:
|
||||
A_array.astype('float32').tofile("input.a.row.bin")
|
||||
AT_array = A_array.transpose([1, 0])
|
||||
AT_array.astype('float32').tofile("input.a.col.bin")
|
||||
B_array.astype('float32').tofile("input.b.bin")
|
||||
C_array.astype('float32').tofile("input.c.bin")
|
||||
# print('AT:')
|
||||
# print(AT_array)
|
||||
# print('B:')
|
||||
# print(B_array)
|
||||
|
||||
assert((seqlen % 64) == 0)
|
||||
|
||||
Br = 64
|
||||
Bc = Br
|
||||
|
||||
rowmax = np.zeros([Br])
|
||||
rowsum = np.zeros([Br])
|
||||
O = np.zeros([Br, headdim])
|
||||
|
||||
def exp2(x):
|
||||
return (x**2) / 2.0 + x + 1.0
|
||||
|
||||
full_S = A_array @ B_array
|
||||
full_S_T = full_S.transpose([1, 0])
|
||||
full_S.astype('float32').tofile("full_S.bin")
|
||||
|
||||
col_to_save = 0
|
||||
|
||||
for col in range(0, seqlen, Bc):
|
||||
print(f"tile iteration {col}~{col + Bc} ======================================")
|
||||
|
||||
# FIXME: only work with the first 64 rows of Q for now
|
||||
Q_tile = A_array[0:64, :]
|
||||
K_tile = B_array[:, col:col+Bc]
|
||||
|
||||
S = Q_tile @ K_tile
|
||||
if col == col_to_save:
|
||||
print('S_expected:')
|
||||
print(S)
|
||||
S.astype('float32').tofile("S_expected.bin")
|
||||
|
||||
# generate rowmax result in online softmax
|
||||
rowmax_this = np.max(S, axis=1)
|
||||
rowmax_prev = rowmax.copy()
|
||||
rowmax = np.maximum(rowmax, rowmax_this)
|
||||
if col == col_to_save:
|
||||
rowmax.astype('float32').tofile("rowmax.bin")
|
||||
|
||||
# subtrace rowmax from each row by broadcasting
|
||||
# (placeholder for exp)
|
||||
x = S - rowmax[:, np.newaxis]
|
||||
P = exp2(x)
|
||||
# for i in range(3, 4):
|
||||
# P += (x**i) / np.math.factorial(i)
|
||||
# P = np.exp(exp)
|
||||
# print('P error:')
|
||||
# print(P / np.exp(x))
|
||||
if col == col_to_save:
|
||||
print('P_expected:')
|
||||
print(P)
|
||||
P.astype('float32').tofile("P_expected.bin")
|
||||
P.transpose([1, 0]).astype('float32').tofile("P_expected.col.bin")
|
||||
|
||||
rowsum_this = np.sum(P, axis=1)
|
||||
x = rowmax_prev - rowmax_this
|
||||
rowsum = exp2(x) * rowsum + rowsum_this
|
||||
if col == col_to_save:
|
||||
rowsum.astype('float32').tofile("rowsum.bin")
|
||||
|
||||
x = rowmax_prev - rowmax
|
||||
O = O / (exp2(x)[:, np.newaxis])
|
||||
if col == col_to_save:
|
||||
print('O_before_PV:')
|
||||
print(O)
|
||||
O.astype('float32').tofile("O_before_PV.bin")
|
||||
|
||||
V = C_array[col:col+Bc, :]
|
||||
if col == col_to_save:
|
||||
V.astype('float32').tofile("V_expected.bin")
|
||||
# O = P.transpose([1, 0]) @ V
|
||||
O = O + P @ V
|
||||
if col == col_to_save:
|
||||
print('O_after_PV:')
|
||||
print(O)
|
||||
O.astype('float32').tofile("O_after_PV.bin")
|
||||
@@ -1,5 +1,14 @@
|
||||
#!/bin/sh
|
||||
|
||||
# hopper and virgo has the same SIMT configurations
|
||||
git checkout ae-hopper
|
||||
# git pull
|
||||
|
||||
# re-compile libvortexrt.a
|
||||
pushd ../../lib
|
||||
make
|
||||
popd
|
||||
|
||||
if [ ! -f input.a.rand01.fp16.m256n256k256.row.bin ]; then
|
||||
echo "input binaries not found, generating operands"
|
||||
python3 generate_operands.py
|
||||
|
||||
@@ -41,12 +41,22 @@ check_exists() {
|
||||
fi
|
||||
}
|
||||
|
||||
# generate operands
|
||||
for dim in "${dims[@]}"; do
|
||||
echo "generating operands for dim $dim"
|
||||
python3 generate_operands.py $dim $dim $dim
|
||||
mv -v input.a.col.bin input.a.rand01.fp16.m${dim}n${dim}k${dim}.col.swizzle_fp16.bin
|
||||
mv -v input.a.row.bin input.a.rand01.fp16.m${dim}n${dim}k${dim}.row.swizzle_fp16.bin
|
||||
mv -v input.b.row.bin input.b.rand01.fp16.m${dim}n${dim}k${dim}.row.bin
|
||||
mv -v input.b.row.swizzled.bin input.b.rand01.fp16.m${dim}n${dim}k${dim}.row.swizzle_fp16.bin
|
||||
done
|
||||
|
||||
for arch in "${archs[@]}"; do
|
||||
git checkout ae-$arch
|
||||
# git pull
|
||||
|
||||
# re-compile libvortexrt.a
|
||||
# FIXME after restructure
|
||||
pushd ../../libs
|
||||
pushd ../../lib
|
||||
make
|
||||
popd
|
||||
|
||||
|
||||
116
kernels/sgemm_tcore/generate_operands.py
Normal file
116
kernels/sgemm_tcore/generate_operands.py
Normal file
@@ -0,0 +1,116 @@
|
||||
import sys
|
||||
import numpy as np
|
||||
|
||||
def parse_mnk():
|
||||
if len(sys.argv) != 4:
|
||||
print(f"usage: {sys.argv[0]} dimM dimN dimK", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
m = int(sys.argv[1])
|
||||
n = int(sys.argv[2])
|
||||
k = int(sys.argv[3])
|
||||
return (m, n, k)
|
||||
|
||||
|
||||
# Reorder array in a way that groups two adjacent elements along the column to
|
||||
# be now adjacent along the row. This way, when the resulting fp16 array is
|
||||
# read in column-major order with 32-bit granularity, the fp16 elements will be
|
||||
# read in the same order as regular fp32 elements in column-major.
|
||||
#
|
||||
# For example:
|
||||
# [[1 2]
|
||||
# [3 4]
|
||||
# [5 6]
|
||||
# [7 8]]
|
||||
# becomes
|
||||
# [[1 3 2 4]
|
||||
# [5 7 6 8]]
|
||||
def pack_fp16_by_column(array):
|
||||
rows = array.shape[0]
|
||||
cols = array.shape[1]
|
||||
|
||||
T = array.transpose([1, 0])
|
||||
T_packed = T.reshape([cols, -1, 2])
|
||||
result = T_packed.transpose([1, 0, 2])
|
||||
return result
|
||||
|
||||
|
||||
# Do the same as pack_fp16_by_column, but for every two elements along the row.
|
||||
def pack_fp16_by_row(array):
|
||||
rows = array.shape[0]
|
||||
cols = array.shape[1]
|
||||
|
||||
result = array.reshape([rows, -1, 2])
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
M, N, K = parse_mnk()
|
||||
|
||||
rand = True
|
||||
if not rand:
|
||||
A_array = np.arange(M * K).reshape([M, K])
|
||||
B_array = np.arange(K * N).reshape([K, N])
|
||||
# C_array = np.arange(M * N).reshape([M, N])
|
||||
C_array = np.zeros([M, N])
|
||||
else:
|
||||
np.random.seed(0)
|
||||
A_array = np.random.rand(M, K)
|
||||
B_array = np.random.rand(K, N)
|
||||
C_array = np.random.rand(N, K)
|
||||
# C_array = np.zeros([M, N])
|
||||
|
||||
with open('a_matrix.h', 'w') as f:
|
||||
for i in range(A_array.shape[0]):
|
||||
for j in range(A_array.shape[1]):
|
||||
f.write(f'{A_array[i,j]:f}f, ')
|
||||
f.write('\n')
|
||||
with open('b_matrix.h', 'w') as f:
|
||||
for i in range(B_array.shape[0]):
|
||||
for j in range(B_array.shape[1]):
|
||||
f.write(f'{B_array[i,j]:f}f, ')
|
||||
f.write('\n')
|
||||
with open('c_matrix.h', 'w') as f:
|
||||
for i in range(C_array.shape[0]):
|
||||
for j in range(C_array.shape[1]):
|
||||
f.write(f'{C_array[i,j]:f}f, ')
|
||||
f.write('\n')
|
||||
|
||||
np.savez("abc", A_array=A_array, B_array=B_array, C_array=C_array)
|
||||
|
||||
fp16 = True
|
||||
if fp16:
|
||||
A_packed = pack_fp16_by_row(A_array)
|
||||
A_swizzled = A_packed.reshape([-1, M * 2])
|
||||
A_swizzled.astype('float16').tofile("input.a.row.bin")
|
||||
AT_packed = A_packed.transpose([1, 0, 2])
|
||||
AT_swizzled = AT_packed.reshape([-1, M * 2])
|
||||
AT_swizzled.astype('float16').tofile("input.a.col.bin")
|
||||
print('A:')
|
||||
print(A_swizzled)
|
||||
print('AT:')
|
||||
print(AT_swizzled)
|
||||
B_array.astype('float16').tofile("input.b.row.bin")
|
||||
# B_packed_row = pack_fp16_by_row(B_array)
|
||||
# B_packed_row = B_packed_row.reshape([-1, N * 2])
|
||||
# B_packed_row.astype('float16').tofile("input.b.row.bin")
|
||||
B_packed = pack_fp16_by_column(B_array)
|
||||
B_swizzled = B_packed.reshape([-1, N * 2])
|
||||
B_swizzled.astype('float16').tofile("input.b.row.swizzled.bin")
|
||||
print('B:')
|
||||
print(B_swizzled)
|
||||
else:
|
||||
A_array.astype('float32').tofile("input.a.row.bin")
|
||||
AT_array = A_array.transpose([1, 0])
|
||||
AT_array.astype('float32').tofile("input.a.col.bin")
|
||||
B_array.astype('float32').tofile("input.b.bin")
|
||||
C_array.astype('float32').tofile("input.c.bin")
|
||||
print('AT:')
|
||||
print(AT_array)
|
||||
print('B:')
|
||||
print(B_array)
|
||||
|
||||
D_expected = A_array @ B_array
|
||||
D_expected.astype('float32').tofile("d_expected.bin")
|
||||
print('D_expected:')
|
||||
print(D_expected)
|
||||
|
||||
@@ -110,7 +110,7 @@ static_assert(WMITER * WNITER * TCM * TCN * NUM_WARPS * CORES_PER_CLUSTER ==
|
||||
// result matrix will be stored in a swizzled form in the global memory.
|
||||
#define WMMA_STORE_FAST 0
|
||||
|
||||
#define GEMMINI_DMA 1
|
||||
#define GEMMINI_DMA 0
|
||||
#define GEMMINI_DMA_FAST 1
|
||||
#if SMEM_SIZE == 0x4000
|
||||
#define SMEM_ADDR_Q0 ((float * const) 0xff000000)
|
||||
@@ -1190,10 +1190,10 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
|
||||
(uint64_t)(B + /*block_k:*/ 0 * BK * dim_n + block_n * BN),
|
||||
k_LOOP_WS_CONFIG_ADDRS_AB)
|
||||
// GEMMINI_CISC(8) does k_LOOP_WS_CONFIG_STRIDES_AB
|
||||
GEMMINI_CISC_CMD_R((dim_n << 20) | (dim_k << 8) | 8);
|
||||
GEMMINI_CISC_CMD_R((dim_n << 20) | (dim_k << 8) | GEMMINI_CISC_SET_AB_STRIDE);
|
||||
gemmini_fence();
|
||||
|
||||
GEMMINI_CISC_CMD_I(10);
|
||||
GEMMINI_CISC_CMD_R((11 << 16) | (0 << 8) | GEMMINI_CISC_LOAD_TO_HEXADECILES);
|
||||
gemmini_fence();
|
||||
|
||||
#if 0
|
||||
@@ -1257,7 +1257,7 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
|
||||
k_LOOP_WS_CONFIG_ADDRS_AB)
|
||||
// GEMMINI_CISC(8) does k_LOOP_WS_CONFIG_STRIDES_AB
|
||||
GEMMINI_CISC_CMD_R((dim_n << 20) | (dim_k << 8) | 8);
|
||||
gemmini_fence();
|
||||
// gemmini_fence();
|
||||
|
||||
// block_k is even: opcode 11 (write to local_a_buf)
|
||||
// block_k is odd: opcode 10 (write to local_a)
|
||||
@@ -1266,8 +1266,9 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
|
||||
// the last iteration of the k-loop is prefetching for the first
|
||||
// iteration of the n-loop. The ping-poing indexing has to match for
|
||||
// the two loop end to connect.
|
||||
const uint32_t opcode = 11 - (block_k & 1);
|
||||
GEMMINI_CISC_CMD_I(opcode);
|
||||
const uint32_t a_hexadecile = 4 - ((block_k & 1) * 4);
|
||||
const uint32_t b_hexadecile = a_hexadecile + 11;
|
||||
GEMMINI_CISC_CMD_R((b_hexadecile << 16) | (a_hexadecile << 8) | GEMMINI_CISC_LOAD_TO_HEXADECILES);
|
||||
// // TODO: branch is probably slow
|
||||
// if (block_k & 1) {
|
||||
// GEMMINI_CISC_CMD_I(12);
|
||||
|
||||
@@ -1,21 +1,23 @@
|
||||
#!/bin/sh
|
||||
#!/bin/bash
|
||||
|
||||
# Copyright 2023 blaise
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
TOOLDIR=${TOOLDIR:=$HOME/build/vortex-toolchain-prebuilt}
|
||||
export TOOLDIR
|
||||
ENV_SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
AE_TOOLCHAIN_DIR="$(realpath ${ENV_SCRIPT_DIR}/../../toolchain)"
|
||||
|
||||
export TOOLDIR=${AE_TOOLCHAIN_DIR}/vortex-toolchain-prebuilt
|
||||
|
||||
export VERILATOR_ROOT=$TOOLDIR/verilator
|
||||
export PATH=$VERILATOR_ROOT/bin:$PATH
|
||||
@@ -27,7 +29,7 @@ export YOSYS_PATH=$TOOLDIR/yosys
|
||||
export PATH=$YOSYS_PATH/bin:$PATH
|
||||
|
||||
# LLVM_POCL seems to be only used in tests/opencl
|
||||
export LLVM_POCL=/home/virgo-ae/build/llvm-vortex2
|
||||
export LLVM_VORTEX=/home/virgo-ae/build/llvm-vortex2
|
||||
export POCL_CC_PATH=/home/virgo-ae/build/pocl-vortex2/compiler
|
||||
export POCL_RT_PATH=/home/virgo-ae/build/pocl-vortex2/runtime
|
||||
export LLVM_POCL=${AE_TOOLCHAIN_DIR}/llvm-vortex2
|
||||
export LLVM_VORTEX=${AE_TOOLCHAIN_DIR}/llvm-vortex2
|
||||
export POCL_CC_PATH=${AE_TOOLCHAIN_DIR}/pocl-vortex2/compiler
|
||||
export POCL_RT_PATH=${AE_TOOLCHAIN_DIR}/pocl-vortex2/runtime
|
||||
|
||||
Reference in New Issue
Block a user