Generate S matrix, pull out FA stuff from basic script
This commit is contained in:
@@ -94,7 +94,11 @@ if __name__ == "__main__":
|
||||
def exp2(x):
|
||||
return (x**2) / 2.0 + x + 1.0
|
||||
|
||||
col_to_save = 64
|
||||
full_S = A_array @ B_array
|
||||
full_S_T = full_S.transpose([1, 0])
|
||||
full_S.astype('float32').tofile("full_S.bin")
|
||||
|
||||
col_to_save = 128
|
||||
|
||||
for col in range(0, seqlen, Bc):
|
||||
print(f"tile iteration {col}~{col + Bc} ======================================")
|
||||
@@ -137,8 +141,6 @@ if __name__ == "__main__":
|
||||
rowsum.astype('float32').tofile("rowsum.bin")
|
||||
|
||||
x = rowmax_prev - rowmax
|
||||
print("haha")
|
||||
print(exp2(x))
|
||||
O = O / (exp2(x)[:, np.newaxis])
|
||||
if col == col_to_save:
|
||||
print('O_before_PV:')
|
||||
|
||||
Reference in New Issue
Block a user