diff --git a/tests/kernel/tensor/flash_attn.py b/tests/kernel/tensor/flash_attn.py index e7fb8300..9c934adb 100644 --- a/tests/kernel/tensor/flash_attn.py +++ b/tests/kernel/tensor/flash_attn.py @@ -94,7 +94,11 @@ if __name__ == "__main__": def exp2(x): return (x**2) / 2.0 + x + 1.0 - col_to_save = 64 + full_S = A_array @ B_array + full_S_T = full_S.transpose([1, 0]) + full_S.astype('float32').tofile("full_S.bin") + + col_to_save = 128 for col in range(0, seqlen, Bc): print(f"tile iteration {col}~{col + Bc} ======================================") @@ -137,8 +141,6 @@ if __name__ == "__main__": rowsum.astype('float32').tofile("rowsum.bin") x = rowmax_prev - rowmax - print("haha") - print(exp2(x)) O = O / (exp2(x)[:, np.newaxis]) if col == col_to_save: print('O_before_PV:') diff --git a/tests/kernel/tensor/generate_matrix.py b/tests/kernel/tensor/generate_matrix.py index 7d49e193..8626ba43 100644 --- a/tests/kernel/tensor/generate_matrix.py +++ b/tests/kernel/tensor/generate_matrix.py @@ -46,11 +46,12 @@ def pack_fp16_by_row(array): if __name__ == "__main__": M, N, K = parse_mnk() - rand = True + rand = False if not rand: A_array = np.arange(M * K).reshape([M, K]) B_array = np.arange(K * N).reshape([K, N]) - C_array = np.arange(M * N).reshape([M, N]) + # C_array = np.arange(M * N).reshape([M, N]) + C_array = np.zeros([M, N]) else: np.random.seed(0) A_array = np.random.rand(M, K) - 0.5 @@ -76,11 +77,6 @@ if __name__ == "__main__": np.savez("abc", A_array=A_array, B_array=B_array, C_array=C_array) - D_expected = A_array @ B_array - D_expected.astype('float32').tofile("d_expected.bin") - print('D_expected:') - print(D_expected) - fp16 = False if fp16: A_packed = pack_fp16_by_row(A_array) @@ -105,29 +101,8 @@ if __name__ == "__main__": print('B:') print(B_array) - # generate rowmax result in online softmax - row_max = np.max(D_expected, axis=1) - row_max.astype('float32').tofile("rowmax.bin") + D_expected = A_array @ B_array + D_expected.astype('float32').tofile("d_expected.bin") + print('D_expected:') + print(D_expected) - # subtrace row_max from each row by broadcasting - # (placeholder for exp) - x = D_expected - row_max[:, np.newaxis] - P = (x**2) / 2.0 + x + 1.0 - # for i in range(3, 4): - # P += (x**i) / np.math.factorial(i) - # P = np.exp(exp) - P.astype('float32').tofile("P_expected.bin") - # print('P error:') - # print(P / np.exp(x)) - print('P_expected:') - print(P) - - row_sum = np.sum(P, axis=1) - row_sum.astype('float32').tofile("rowsum.bin") - - V = C_array - # O = P.transpose([1, 0]) @ V - O = P @ V - O.astype('float32').tofile("O_expected.bin") - print('O_expected:') - print(O)