From 60aec1de8db5858af2f3c0f1c5417bd2f8d4c8c6 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Tue, 20 Aug 2024 14:49:25 -0700 Subject: [PATCH] flash.py: Fix row-wise scaling of O, col_to_save --- tests/kernel/tensor/flash_attn.py | 42 ++++++++++++++++++++----------- 1 file changed, 27 insertions(+), 15 deletions(-) diff --git a/tests/kernel/tensor/flash_attn.py b/tests/kernel/tensor/flash_attn.py index 9934599d..e7fb8300 100644 --- a/tests/kernel/tensor/flash_attn.py +++ b/tests/kernel/tensor/flash_attn.py @@ -50,7 +50,7 @@ if __name__ == "__main__": if not rand: A_array = np.arange(seqlen * headdim).reshape([seqlen, headdim]) B_array = np.arange(headdim * seqlen).reshape([headdim, seqlen]) - C_array = np.arange(seqlen * seqlen).reshape([seqlen, seqlen]) + C_array = np.arange(seqlen * seqlen).reshape([seqlen, headdim]) else: np.random.seed(0) A_array = np.random.rand(seqlen, headdim) - 0.5 @@ -94,6 +94,8 @@ if __name__ == "__main__": def exp2(x): return (x**2) / 2.0 + x + 1.0 + col_to_save = 64 + for col in range(0, seqlen, Bc): print(f"tile iteration {col}~{col + Bc} ======================================") @@ -102,15 +104,17 @@ if __name__ == "__main__": K_tile = B_array[:, col:col+Bc] S = Q_tile @ K_tile - S.astype('float32').tofile("S_expected.bin") - print('S_expected:') - print(S) + if col == col_to_save: + print('S_expected:') + print(S) + S.astype('float32').tofile("S_expected.bin") # generate rowmax result in online softmax rowmax_this = np.max(S, axis=1) - rowmax_this.astype('float32').tofile("rowmax.bin") rowmax_prev = rowmax.copy() rowmax = np.maximum(rowmax, rowmax_this) + if col == col_to_save: + rowmax.astype('float32').tofile("rowmax.bin") # subtrace rowmax from each row by broadcasting # (placeholder for exp) @@ -119,24 +123,32 @@ if __name__ == "__main__": # for i in range(3, 4): # P += (x**i) / np.math.factorial(i) # P = np.exp(exp) - P.astype('float32').tofile("P_expected.bin") # print('P error:') # print(P / np.exp(x)) - print('P_expected:') - print(P) + if col == col_to_save: + print('P_expected:') + print(P) + P.astype('float32').tofile("P_expected.bin") rowsum_this = np.sum(P, axis=1) x = rowmax_prev - rowmax_this rowsum = exp2(x) * rowsum + rowsum_this - rowsum.astype('float32').tofile("rowsum.bin") + if col == col_to_save: + rowsum.astype('float32').tofile("rowsum.bin") x = rowmax_prev - rowmax - O = O / exp2(x) + print("haha") + print(exp2(x)) + O = O / (exp2(x)[:, np.newaxis]) + if col == col_to_save: + print('O_before_PV:') + print(O) + O.astype('float32').tofile("O_before_PV.bin") - # FIXME - V = C_array[0:64, :] + V = C_array[col:col+Bc, :] # O = P.transpose([1, 0]) @ V O = O + P @ V - O.astype('float32').tofile("O_expected.bin") - print('O_expected:') - print(O) + if col == col_to_save: + print('O_after_PV:') + print(O) + O.astype('float32').tofile("O_after_PV.bin")