flash.py: Write V to file
This commit is contained in:
@@ -98,7 +98,7 @@ if __name__ == "__main__":
|
|||||||
full_S_T = full_S.transpose([1, 0])
|
full_S_T = full_S.transpose([1, 0])
|
||||||
full_S.astype('float32').tofile("full_S.bin")
|
full_S.astype('float32').tofile("full_S.bin")
|
||||||
|
|
||||||
col_to_save = 128
|
col_to_save = 0
|
||||||
|
|
||||||
for col in range(0, seqlen, Bc):
|
for col in range(0, seqlen, Bc):
|
||||||
print(f"tile iteration {col}~{col + Bc} ======================================")
|
print(f"tile iteration {col}~{col + Bc} ======================================")
|
||||||
@@ -148,6 +148,8 @@ if __name__ == "__main__":
|
|||||||
O.astype('float32').tofile("O_before_PV.bin")
|
O.astype('float32').tofile("O_before_PV.bin")
|
||||||
|
|
||||||
V = C_array[col:col+Bc, :]
|
V = C_array[col:col+Bc, :]
|
||||||
|
if col == col_to_save:
|
||||||
|
V.astype('float32').tofile("V_expected.bin")
|
||||||
# O = P.transpose([1, 0]) @ V
|
# O = P.transpose([1, 0]) @ V
|
||||||
O = O + P @ V
|
O = O + P @ V
|
||||||
if col == col_to_save:
|
if col == col_to_save:
|
||||||
|
|||||||
Reference in New Issue
Block a user