Generate S matrix, pull out FA stuff from basic script

This commit is contained in:
Hansung Kim
2024-08-28 16:13:38 -07:00
parent 3f20dd59c0
commit 4260bf7d6e
2 changed files with 12 additions and 35 deletions

View File

@@ -94,7 +94,11 @@ if __name__ == "__main__":
def exp2(x):
return (x**2) / 2.0 + x + 1.0
col_to_save = 64
full_S = A_array @ B_array
full_S_T = full_S.transpose([1, 0])
full_S.astype('float32').tofile("full_S.bin")
col_to_save = 128
for col in range(0, seqlen, Bc):
print(f"tile iteration {col}~{col + Bc} ======================================")
@@ -137,8 +141,6 @@ if __name__ == "__main__":
rowsum.astype('float32').tofile("rowsum.bin")
x = rowmax_prev - rowmax
print("haha")
print(exp2(x))
O = O / (exp2(x)[:, np.newaxis])
if col == col_to_save:
print('O_before_PV:')