flash.py: Fix row-wise scaling of O, col_to_save
This commit is contained in:
@@ -50,7 +50,7 @@ if __name__ == "__main__":
|
|||||||
if not rand:
|
if not rand:
|
||||||
A_array = np.arange(seqlen * headdim).reshape([seqlen, headdim])
|
A_array = np.arange(seqlen * headdim).reshape([seqlen, headdim])
|
||||||
B_array = np.arange(headdim * seqlen).reshape([headdim, seqlen])
|
B_array = np.arange(headdim * seqlen).reshape([headdim, seqlen])
|
||||||
C_array = np.arange(seqlen * seqlen).reshape([seqlen, seqlen])
|
C_array = np.arange(seqlen * seqlen).reshape([seqlen, headdim])
|
||||||
else:
|
else:
|
||||||
np.random.seed(0)
|
np.random.seed(0)
|
||||||
A_array = np.random.rand(seqlen, headdim) - 0.5
|
A_array = np.random.rand(seqlen, headdim) - 0.5
|
||||||
@@ -94,6 +94,8 @@ if __name__ == "__main__":
|
|||||||
def exp2(x):
|
def exp2(x):
|
||||||
return (x**2) / 2.0 + x + 1.0
|
return (x**2) / 2.0 + x + 1.0
|
||||||
|
|
||||||
|
col_to_save = 64
|
||||||
|
|
||||||
for col in range(0, seqlen, Bc):
|
for col in range(0, seqlen, Bc):
|
||||||
print(f"tile iteration {col}~{col + Bc} ======================================")
|
print(f"tile iteration {col}~{col + Bc} ======================================")
|
||||||
|
|
||||||
@@ -102,15 +104,17 @@ if __name__ == "__main__":
|
|||||||
K_tile = B_array[:, col:col+Bc]
|
K_tile = B_array[:, col:col+Bc]
|
||||||
|
|
||||||
S = Q_tile @ K_tile
|
S = Q_tile @ K_tile
|
||||||
S.astype('float32').tofile("S_expected.bin")
|
if col == col_to_save:
|
||||||
print('S_expected:')
|
print('S_expected:')
|
||||||
print(S)
|
print(S)
|
||||||
|
S.astype('float32').tofile("S_expected.bin")
|
||||||
|
|
||||||
# generate rowmax result in online softmax
|
# generate rowmax result in online softmax
|
||||||
rowmax_this = np.max(S, axis=1)
|
rowmax_this = np.max(S, axis=1)
|
||||||
rowmax_this.astype('float32').tofile("rowmax.bin")
|
|
||||||
rowmax_prev = rowmax.copy()
|
rowmax_prev = rowmax.copy()
|
||||||
rowmax = np.maximum(rowmax, rowmax_this)
|
rowmax = np.maximum(rowmax, rowmax_this)
|
||||||
|
if col == col_to_save:
|
||||||
|
rowmax.astype('float32').tofile("rowmax.bin")
|
||||||
|
|
||||||
# subtrace rowmax from each row by broadcasting
|
# subtrace rowmax from each row by broadcasting
|
||||||
# (placeholder for exp)
|
# (placeholder for exp)
|
||||||
@@ -119,24 +123,32 @@ if __name__ == "__main__":
|
|||||||
# for i in range(3, 4):
|
# for i in range(3, 4):
|
||||||
# P += (x**i) / np.math.factorial(i)
|
# P += (x**i) / np.math.factorial(i)
|
||||||
# P = np.exp(exp)
|
# P = np.exp(exp)
|
||||||
P.astype('float32').tofile("P_expected.bin")
|
|
||||||
# print('P error:')
|
# print('P error:')
|
||||||
# print(P / np.exp(x))
|
# print(P / np.exp(x))
|
||||||
print('P_expected:')
|
if col == col_to_save:
|
||||||
print(P)
|
print('P_expected:')
|
||||||
|
print(P)
|
||||||
|
P.astype('float32').tofile("P_expected.bin")
|
||||||
|
|
||||||
rowsum_this = np.sum(P, axis=1)
|
rowsum_this = np.sum(P, axis=1)
|
||||||
x = rowmax_prev - rowmax_this
|
x = rowmax_prev - rowmax_this
|
||||||
rowsum = exp2(x) * rowsum + rowsum_this
|
rowsum = exp2(x) * rowsum + rowsum_this
|
||||||
rowsum.astype('float32').tofile("rowsum.bin")
|
if col == col_to_save:
|
||||||
|
rowsum.astype('float32').tofile("rowsum.bin")
|
||||||
|
|
||||||
x = rowmax_prev - rowmax
|
x = rowmax_prev - rowmax
|
||||||
O = O / exp2(x)
|
print("haha")
|
||||||
|
print(exp2(x))
|
||||||
|
O = O / (exp2(x)[:, np.newaxis])
|
||||||
|
if col == col_to_save:
|
||||||
|
print('O_before_PV:')
|
||||||
|
print(O)
|
||||||
|
O.astype('float32').tofile("O_before_PV.bin")
|
||||||
|
|
||||||
# FIXME
|
V = C_array[col:col+Bc, :]
|
||||||
V = C_array[0:64, :]
|
|
||||||
# O = P.transpose([1, 0]) @ V
|
# O = P.transpose([1, 0]) @ V
|
||||||
O = O + P @ V
|
O = O + P @ V
|
||||||
O.astype('float32').tofile("O_expected.bin")
|
if col == col_to_save:
|
||||||
print('O_expected:')
|
print('O_after_PV:')
|
||||||
print(O)
|
print(O)
|
||||||
|
O.astype('float32').tofile("O_after_PV.bin")
|
||||||
|
|||||||
Reference in New Issue
Block a user