py: Write P_expected, don't rewrite vars

This commit is contained in:
Hansung Kim
2024-09-04 23:35:52 -07:00
parent dcd69ea304
commit bde6f0ea2e
2 changed files with 9 additions and 8 deletions

View File

@@ -133,6 +133,7 @@ if __name__ == "__main__":
print('P_expected:') print('P_expected:')
print(P) print(P)
P.astype('float32').tofile("P_expected.bin") P.astype('float32').tofile("P_expected.bin")
P.transpose([1, 0]).astype('float32').tofile("P_expected.col.bin")
rowsum_this = np.sum(P, axis=1) rowsum_this = np.sum(P, axis=1)
x = rowmax_prev - rowmax_this x = rowmax_prev - rowmax_this

View File

@@ -46,7 +46,7 @@ def pack_fp16_by_row(array):
if __name__ == "__main__": if __name__ == "__main__":
M, N, K = parse_mnk() M, N, K = parse_mnk()
rand = False rand = True
if not rand: if not rand:
A_array = np.arange(M * K).reshape([M, K]) A_array = np.arange(M * K).reshape([M, K])
B_array = np.arange(K * N).reshape([K, N]) B_array = np.arange(K * N).reshape([K, N])
@@ -77,19 +77,19 @@ if __name__ == "__main__":
np.savez("abc", A_array=A_array, B_array=B_array, C_array=C_array) np.savez("abc", A_array=A_array, B_array=B_array, C_array=C_array)
fp16 = False fp16 = True
if fp16: if fp16:
A_packed = pack_fp16_by_row(A_array) A_packed = pack_fp16_by_row(A_array)
AT_packed = A_packed.transpose([1, 0, 2]) AT_packed = A_packed.transpose([1, 0, 2])
AT_array = AT_packed.reshape([-1, M * 2]) AT_swizzled = AT_packed.reshape([-1, M * 2])
AT_array.astype('float16').tofile("input.a.col.bin") AT_swizzled.astype('float16').tofile("input.a.col.bin")
print('AT:') print('AT:')
print(AT_array) print(AT_swizzled)
B_packed = pack_fp16_by_column(B_array) B_packed = pack_fp16_by_column(B_array)
B_array = B_packed.reshape([-1, N * 2]) B_swizzled = B_packed.reshape([-1, N * 2])
B_array.astype('float16').tofile("input.b.row.bin") B_swizzled.astype('float16').tofile("input.b.row.bin")
print('B:') print('B:')
print(B_array) print(B_swizzled)
else: else:
A_array.astype('float32').tofile("input.a.row.bin") A_array.astype('float32').tofile("input.a.row.bin")
AT_array = A_array.transpose([1, 0]) AT_array = A_array.transpose([1, 0])