25 Commits

Author SHA1 Message Date
Virgo-AE Eval
326141b11f Merge branch 'ae' into ae-volta 2025-02-07 14:52:25 -08:00
Virgo-AE Eval
71f713b9fc Disable git pull for archive
Only use local refs in the archive for reproducibility.
2025-02-07 14:51:25 -08:00
Richard Yan
d893780594 Merge branch 'ae' into ae-volta 2025-01-31 03:53:26 -08:00
Richard Yan
9847072eff fix hexadecile 2025-01-31 02:02:18 -08:00
Richard Yan
9b7c22a7e9 Merge branch 'ae' into ae-volta 2025-01-30 23:41:16 -08:00
Richard Yan
f8c51669c1 fix toolchain env sh 2025-01-30 21:17:12 -08:00
Richard Yan
b1ebabef26 Merge branch 'ae' into ae-volta 2025-01-30 15:35:22 -08:00
Richard Yan
17a9d31be5 fix dma invocation 2025-01-30 15:33:58 -08:00
Hansung Kim
9f524538a4 Merge branch 'ae' into ae-volta 2025-01-30 13:24:46 -08:00
Hansung Kim
238b942133 Add missing library remake 2025-01-30 13:24:23 -08:00
Hansung Kim
51ebe18ebb Merge remote-tracking branch 'origin/ae-volta' into ae-volta 2025-01-30 01:48:25 -08:00
Hansung Kim
7b0a95034b Merge branch 'ae' into ae-volta 2025-01-30 01:48:05 -08:00
Hansung Kim
2c1ac4e938 Do git pull to make sure up-to-date 2025-01-30 01:47:35 -08:00
Richard Yan
c240069147 Merge branch 'ae' into ae-volta 2025-01-30 01:35:35 -08:00
Richard Yan
9cdee597b6 Merge branch 'ae' of https://github.com/richardyrh/virgo-kernels into ae 2025-01-30 01:34:29 -08:00
Hansung Kim
6bdc6af607 Fix branch name and dims for flash script 2025-01-30 01:15:57 -08:00
Hansung Kim
d86c33acf3 Merge branch 'ae' into ae-volta 2025-01-30 01:05:27 -08:00
Hansung Kim
b73147cd06 Add compile and operand generate script for flash 2025-01-30 01:04:20 -08:00
Hansung Kim
471f89e371 Add arg binary for flash 2025-01-30 01:02:12 -08:00
Hansung Kim
b49e8a293c Merge branch 'ae' into ae-volta 2025-01-30 00:49:19 -08:00
Hansung Kim
7e1fc54c97 Fix typo in path 2025-01-30 00:41:42 -08:00
Hansung Kim
19731b8e2f Merge branch 'ae' into ae-volta 2025-01-30 00:35:00 -08:00
Hansung Kim
50c8f1c410 Add operand generate script for tcore 2025-01-29 23:33:09 -08:00
Richard Yan
afc69507a3 Merge branch 'ae' into ae-volta 2025-01-29 23:31:34 -08:00
Richard Yan
6e279c905f volta change 2025-01-29 22:16:39 -08:00
12 changed files with 361 additions and 18 deletions

View File

@@ -0,0 +1 @@
args.seq1024.headdim64.bin

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,45 @@
#!/bin/bash
archs=("ampere" "virgo")
if [ -z "$TOOLDIR" ]; then
echo "error: \$TOOLDIR not set. Did you run source ci/toolchain_env.sh?"
exit 1
fi
check_exists() {
if ! [ -f "$1" ]; then
echo "error: looked for file $1 that does not exist."
exit 1
fi
}
# generate operands
echo "generating flash_attn operands for seqlen 1024, headdim 64"
python3 flash_attn.py 1024 64 64
mv -v input.a.col.bin input.a.rand.fp32.seqlen1024headdim64.col.bin
mv -v input.a.row.bin input.a.rand.fp32.seqlen1024headdim64.row.bin
mv -v input.b.bin input.b.rand.fp32.seqlen1024headdim64.row.bin
mv -v input.c.bin input.c.rand.fp32.seqlen1024headdim64.row.bin
ln -sf input.a.rand.fp32.seqlen1024headdim64.row.bin input.a.bin
ln -sf input.b.rand.fp32.seqlen1024headdim64.row.bin input.b.bin
ln -sf input.c.rand.fp32.seqlen1024headdim64.row.bin input.c.bin
for arch in "${archs[@]}"; do
git checkout ae-flash-$arch
# git pull
# re-compile libvortexrt.a
pushd ../../lib
make
popd
echo "compiling flash_attn kernel for $arch with seqlen 1024, headdim 64"
# touch source file to force re-building, as the Makefile does not track
# binary changes
touch kernel.cpp
touch kernel.gemmini.cpp
make CONFIG=flash.$arch.seqlen1024.headdim64
done

View File

@@ -0,0 +1,159 @@
import sys
import numpy as np
def parse_mnk():
if len(sys.argv) != 4:
print(f"usage: {sys.argv[0]} dimM dimN dimK", file=sys.stderr)
sys.exit(1)
m = int(sys.argv[1])
n = int(sys.argv[2])
k = int(sys.argv[3])
return (m, n, k)
# Reorder array in a way that groups two adjacent elements along the column to
# be now adjacent along the row. This way, when the resulting fp16 array is
# read in column-major order with 32-bit granularity, the fp16 elements will be
# read in the same order as regular fp32 elements in column-major.
#
# For example:
# [[1 2]
# [3 4]
# [5 6]
# [7 8]]
# becomes
# [[1 3 2 4]
# [5 7 6 8]]
def pack_fp16_by_column(array):
rows = array.shape[0]
cols = array.shape[1]
T = array.transpose([1, 0])
T_packed = T.reshape([cols, -1, 2])
result = T_packed.transpose([1, 0, 2])
return result
# Do the same as pack_fp16_by_column, but for every two elements along the row.
def pack_fp16_by_row(array):
rows = array.shape[0]
cols = array.shape[1]
result = array.reshape([rows, -1, 2])
return result
if __name__ == "__main__":
seqlen, _, headdim = parse_mnk()
rand = True
if not rand:
A_array = np.arange(seqlen * headdim).reshape([seqlen, headdim])
B_array = np.arange(headdim * seqlen).reshape([headdim, seqlen])
C_array = np.arange(seqlen * seqlen).reshape([seqlen, headdim])
else:
np.random.seed(0)
A_array = np.random.rand(seqlen, headdim) - 0.5
B_array = np.random.rand(headdim, seqlen) - 0.5
C_array = np.random.rand(seqlen, headdim) - 0.5
# C_array = np.zeros([M, N])
fp16 = False
if fp16:
A_packed = pack_fp16_by_row(A_array)
AT_packed = A_packed.transpose([1, 0, 2])
AT_array = AT_packed.reshape([-1, seqlen * 2])
AT_array.astype('float16').tofile("input.a.col.bin")
# print('AT:')
# print(AT_array)
B_packed = pack_fp16_by_column(B_array)
B_array = B_packed.reshape([-1, headdim * 2])
B_array.astype('float16').tofile("input.b.row.bin")
# print('B:')
# print(B_array)
else:
A_array.astype('float32').tofile("input.a.row.bin")
AT_array = A_array.transpose([1, 0])
AT_array.astype('float32').tofile("input.a.col.bin")
B_array.astype('float32').tofile("input.b.bin")
C_array.astype('float32').tofile("input.c.bin")
# print('AT:')
# print(AT_array)
# print('B:')
# print(B_array)
assert((seqlen % 64) == 0)
Br = 64
Bc = Br
rowmax = np.zeros([Br])
rowsum = np.zeros([Br])
O = np.zeros([Br, headdim])
def exp2(x):
return (x**2) / 2.0 + x + 1.0
full_S = A_array @ B_array
full_S_T = full_S.transpose([1, 0])
full_S.astype('float32').tofile("full_S.bin")
col_to_save = 0
for col in range(0, seqlen, Bc):
print(f"tile iteration {col}~{col + Bc} ======================================")
# FIXME: only work with the first 64 rows of Q for now
Q_tile = A_array[0:64, :]
K_tile = B_array[:, col:col+Bc]
S = Q_tile @ K_tile
if col == col_to_save:
print('S_expected:')
print(S)
S.astype('float32').tofile("S_expected.bin")
# generate rowmax result in online softmax
rowmax_this = np.max(S, axis=1)
rowmax_prev = rowmax.copy()
rowmax = np.maximum(rowmax, rowmax_this)
if col == col_to_save:
rowmax.astype('float32').tofile("rowmax.bin")
# subtrace rowmax from each row by broadcasting
# (placeholder for exp)
x = S - rowmax[:, np.newaxis]
P = exp2(x)
# for i in range(3, 4):
# P += (x**i) / np.math.factorial(i)
# P = np.exp(exp)
# print('P error:')
# print(P / np.exp(x))
if col == col_to_save:
print('P_expected:')
print(P)
P.astype('float32').tofile("P_expected.bin")
P.transpose([1, 0]).astype('float32').tofile("P_expected.col.bin")
rowsum_this = np.sum(P, axis=1)
x = rowmax_prev - rowmax_this
rowsum = exp2(x) * rowsum + rowsum_this
if col == col_to_save:
rowsum.astype('float32').tofile("rowsum.bin")
x = rowmax_prev - rowmax
O = O / (exp2(x)[:, np.newaxis])
if col == col_to_save:
print('O_before_PV:')
print(O)
O.astype('float32').tofile("O_before_PV.bin")
V = C_array[col:col+Bc, :]
if col == col_to_save:
V.astype('float32').tofile("V_expected.bin")
# O = P.transpose([1, 0]) @ V
O = O + P @ V
if col == col_to_save:
print('O_after_PV:')
print(O)
O.astype('float32').tofile("O_after_PV.bin")

View File

@@ -1,5 +1,14 @@
#!/bin/sh
# hopper and virgo has the same SIMT configurations
git checkout ae-hopper
# git pull
# re-compile libvortexrt.a
pushd ../../lib
make
popd
if [ ! -f input.a.rand01.fp16.m256n256k256.row.bin ]; then
echo "input binaries not found, generating operands"
python3 generate_operands.py

View File

@@ -41,12 +41,22 @@ check_exists() {
fi
}
# generate operands
for dim in "${dims[@]}"; do
echo "generating operands for dim $dim"
python3 generate_operands.py $dim $dim $dim
mv -v input.a.col.bin input.a.rand01.fp16.m${dim}n${dim}k${dim}.col.swizzle_fp16.bin
mv -v input.a.row.bin input.a.rand01.fp16.m${dim}n${dim}k${dim}.row.swizzle_fp16.bin
mv -v input.b.row.bin input.b.rand01.fp16.m${dim}n${dim}k${dim}.row.bin
mv -v input.b.row.swizzled.bin input.b.rand01.fp16.m${dim}n${dim}k${dim}.row.swizzle_fp16.bin
done
for arch in "${archs[@]}"; do
git checkout ae-$arch
# git pull
# re-compile libvortexrt.a
# FIXME after restructure
pushd ../../libs
pushd ../../lib
make
popd

View File

@@ -0,0 +1,116 @@
import sys
import numpy as np
def parse_mnk():
if len(sys.argv) != 4:
print(f"usage: {sys.argv[0]} dimM dimN dimK", file=sys.stderr)
sys.exit(1)
m = int(sys.argv[1])
n = int(sys.argv[2])
k = int(sys.argv[3])
return (m, n, k)
# Reorder array in a way that groups two adjacent elements along the column to
# be now adjacent along the row. This way, when the resulting fp16 array is
# read in column-major order with 32-bit granularity, the fp16 elements will be
# read in the same order as regular fp32 elements in column-major.
#
# For example:
# [[1 2]
# [3 4]
# [5 6]
# [7 8]]
# becomes
# [[1 3 2 4]
# [5 7 6 8]]
def pack_fp16_by_column(array):
rows = array.shape[0]
cols = array.shape[1]
T = array.transpose([1, 0])
T_packed = T.reshape([cols, -1, 2])
result = T_packed.transpose([1, 0, 2])
return result
# Do the same as pack_fp16_by_column, but for every two elements along the row.
def pack_fp16_by_row(array):
rows = array.shape[0]
cols = array.shape[1]
result = array.reshape([rows, -1, 2])
return result
if __name__ == "__main__":
M, N, K = parse_mnk()
rand = True
if not rand:
A_array = np.arange(M * K).reshape([M, K])
B_array = np.arange(K * N).reshape([K, N])
# C_array = np.arange(M * N).reshape([M, N])
C_array = np.zeros([M, N])
else:
np.random.seed(0)
A_array = np.random.rand(M, K)
B_array = np.random.rand(K, N)
C_array = np.random.rand(N, K)
# C_array = np.zeros([M, N])
with open('a_matrix.h', 'w') as f:
for i in range(A_array.shape[0]):
for j in range(A_array.shape[1]):
f.write(f'{A_array[i,j]:f}f, ')
f.write('\n')
with open('b_matrix.h', 'w') as f:
for i in range(B_array.shape[0]):
for j in range(B_array.shape[1]):
f.write(f'{B_array[i,j]:f}f, ')
f.write('\n')
with open('c_matrix.h', 'w') as f:
for i in range(C_array.shape[0]):
for j in range(C_array.shape[1]):
f.write(f'{C_array[i,j]:f}f, ')
f.write('\n')
np.savez("abc", A_array=A_array, B_array=B_array, C_array=C_array)
fp16 = True
if fp16:
A_packed = pack_fp16_by_row(A_array)
A_swizzled = A_packed.reshape([-1, M * 2])
A_swizzled.astype('float16').tofile("input.a.row.bin")
AT_packed = A_packed.transpose([1, 0, 2])
AT_swizzled = AT_packed.reshape([-1, M * 2])
AT_swizzled.astype('float16').tofile("input.a.col.bin")
print('A:')
print(A_swizzled)
print('AT:')
print(AT_swizzled)
B_array.astype('float16').tofile("input.b.row.bin")
# B_packed_row = pack_fp16_by_row(B_array)
# B_packed_row = B_packed_row.reshape([-1, N * 2])
# B_packed_row.astype('float16').tofile("input.b.row.bin")
B_packed = pack_fp16_by_column(B_array)
B_swizzled = B_packed.reshape([-1, N * 2])
B_swizzled.astype('float16').tofile("input.b.row.swizzled.bin")
print('B:')
print(B_swizzled)
else:
A_array.astype('float32').tofile("input.a.row.bin")
AT_array = A_array.transpose([1, 0])
AT_array.astype('float32').tofile("input.a.col.bin")
B_array.astype('float32').tofile("input.b.bin")
C_array.astype('float32').tofile("input.c.bin")
print('AT:')
print(AT_array)
print('B:')
print(B_array)
D_expected = A_array @ B_array
D_expected.astype('float32').tofile("d_expected.bin")
print('D_expected:')
print(D_expected)

View File

@@ -110,7 +110,7 @@ static_assert(WMITER * WNITER * TCM * TCN * NUM_WARPS * CORES_PER_CLUSTER ==
// result matrix will be stored in a swizzled form in the global memory.
#define WMMA_STORE_FAST 0
#define GEMMINI_DMA 1
#define GEMMINI_DMA 0
#define GEMMINI_DMA_FAST 1
#if SMEM_SIZE == 0x4000
#define SMEM_ADDR_Q0 ((float * const) 0xff000000)
@@ -1190,10 +1190,10 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
(uint64_t)(B + /*block_k:*/ 0 * BK * dim_n + block_n * BN),
k_LOOP_WS_CONFIG_ADDRS_AB)
// GEMMINI_CISC(8) does k_LOOP_WS_CONFIG_STRIDES_AB
GEMMINI_CISC_CMD_R((dim_n << 20) | (dim_k << 8) | 8);
GEMMINI_CISC_CMD_R((dim_n << 20) | (dim_k << 8) | GEMMINI_CISC_SET_AB_STRIDE);
gemmini_fence();
GEMMINI_CISC_CMD_I(10);
GEMMINI_CISC_CMD_R((11 << 16) | (0 << 8) | GEMMINI_CISC_LOAD_TO_HEXADECILES);
gemmini_fence();
#if 0
@@ -1257,7 +1257,7 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
k_LOOP_WS_CONFIG_ADDRS_AB)
// GEMMINI_CISC(8) does k_LOOP_WS_CONFIG_STRIDES_AB
GEMMINI_CISC_CMD_R((dim_n << 20) | (dim_k << 8) | 8);
gemmini_fence();
// gemmini_fence();
// block_k is even: opcode 11 (write to local_a_buf)
// block_k is odd: opcode 10 (write to local_a)
@@ -1266,8 +1266,9 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
// the last iteration of the k-loop is prefetching for the first
// iteration of the n-loop. The ping-poing indexing has to match for
// the two loop end to connect.
const uint32_t opcode = 11 - (block_k & 1);
GEMMINI_CISC_CMD_I(opcode);
const uint32_t a_hexadecile = 4 - ((block_k & 1) * 4);
const uint32_t b_hexadecile = a_hexadecile + 11;
GEMMINI_CISC_CMD_R((b_hexadecile << 16) | (a_hexadecile << 8) | GEMMINI_CISC_LOAD_TO_HEXADECILES);
// // TODO: branch is probably slow
// if (block_k & 1) {
// GEMMINI_CISC_CMD_I(12);

View File

@@ -1,21 +1,23 @@
#!/bin/sh
#!/bin/bash
# Copyright 2023 blaise
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
TOOLDIR=${TOOLDIR:=$HOME/build/vortex-toolchain-prebuilt}
export TOOLDIR
ENV_SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
AE_TOOLCHAIN_DIR="$(realpath ${ENV_SCRIPT_DIR}/../../toolchain)"
export TOOLDIR=${AE_TOOLCHAIN_DIR}/vortex-toolchain-prebuilt
export VERILATOR_ROOT=$TOOLDIR/verilator
export PATH=$VERILATOR_ROOT/bin:$PATH
@@ -27,7 +29,7 @@ export YOSYS_PATH=$TOOLDIR/yosys
export PATH=$YOSYS_PATH/bin:$PATH
# LLVM_POCL seems to be only used in tests/opencl
export LLVM_POCL=/home/virgo-ae/build/llvm-vortex2
export LLVM_VORTEX=/home/virgo-ae/build/llvm-vortex2
export POCL_CC_PATH=/home/virgo-ae/build/pocl-vortex2/compiler
export POCL_RT_PATH=/home/virgo-ae/build/pocl-vortex2/runtime
export LLVM_POCL=${AE_TOOLCHAIN_DIR}/llvm-vortex2
export LLVM_VORTEX=${AE_TOOLCHAIN_DIR}/llvm-vortex2
export POCL_CC_PATH=${AE_TOOLCHAIN_DIR}/pocl-vortex2/compiler
export POCL_RT_PATH=${AE_TOOLCHAIN_DIR}/pocl-vortex2/runtime