py: Write P_expected, don't rewrite vars
This commit is contained in:
@@ -133,6 +133,7 @@ if __name__ == "__main__":
|
|||||||
print('P_expected:')
|
print('P_expected:')
|
||||||
print(P)
|
print(P)
|
||||||
P.astype('float32').tofile("P_expected.bin")
|
P.astype('float32').tofile("P_expected.bin")
|
||||||
|
P.transpose([1, 0]).astype('float32').tofile("P_expected.col.bin")
|
||||||
|
|
||||||
rowsum_this = np.sum(P, axis=1)
|
rowsum_this = np.sum(P, axis=1)
|
||||||
x = rowmax_prev - rowmax_this
|
x = rowmax_prev - rowmax_this
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ def pack_fp16_by_row(array):
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
M, N, K = parse_mnk()
|
M, N, K = parse_mnk()
|
||||||
|
|
||||||
rand = False
|
rand = True
|
||||||
if not rand:
|
if not rand:
|
||||||
A_array = np.arange(M * K).reshape([M, K])
|
A_array = np.arange(M * K).reshape([M, K])
|
||||||
B_array = np.arange(K * N).reshape([K, N])
|
B_array = np.arange(K * N).reshape([K, N])
|
||||||
@@ -77,19 +77,19 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
np.savez("abc", A_array=A_array, B_array=B_array, C_array=C_array)
|
np.savez("abc", A_array=A_array, B_array=B_array, C_array=C_array)
|
||||||
|
|
||||||
fp16 = False
|
fp16 = True
|
||||||
if fp16:
|
if fp16:
|
||||||
A_packed = pack_fp16_by_row(A_array)
|
A_packed = pack_fp16_by_row(A_array)
|
||||||
AT_packed = A_packed.transpose([1, 0, 2])
|
AT_packed = A_packed.transpose([1, 0, 2])
|
||||||
AT_array = AT_packed.reshape([-1, M * 2])
|
AT_swizzled = AT_packed.reshape([-1, M * 2])
|
||||||
AT_array.astype('float16').tofile("input.a.col.bin")
|
AT_swizzled.astype('float16').tofile("input.a.col.bin")
|
||||||
print('AT:')
|
print('AT:')
|
||||||
print(AT_array)
|
print(AT_swizzled)
|
||||||
B_packed = pack_fp16_by_column(B_array)
|
B_packed = pack_fp16_by_column(B_array)
|
||||||
B_array = B_packed.reshape([-1, N * 2])
|
B_swizzled = B_packed.reshape([-1, N * 2])
|
||||||
B_array.astype('float16').tofile("input.b.row.bin")
|
B_swizzled.astype('float16').tofile("input.b.row.bin")
|
||||||
print('B:')
|
print('B:')
|
||||||
print(B_array)
|
print(B_swizzled)
|
||||||
else:
|
else:
|
||||||
A_array.astype('float32').tofile("input.a.row.bin")
|
A_array.astype('float32').tofile("input.a.row.bin")
|
||||||
AT_array = A_array.transpose([1, 0])
|
AT_array = A_array.transpose([1, 0])
|
||||||
|
|||||||
Reference in New Issue
Block a user