Fix range for
i'm a python noob
This commit is contained in:
@@ -64,23 +64,23 @@ if __name__ == "__main__":
|
|||||||
AT_packed = A_packed.transpose([1, 0, 2])
|
AT_packed = A_packed.transpose([1, 0, 2])
|
||||||
AT_array = AT_packed.reshape([-1, seqlen * 2])
|
AT_array = AT_packed.reshape([-1, seqlen * 2])
|
||||||
AT_array.astype('float16').tofile("input.a.col.bin")
|
AT_array.astype('float16').tofile("input.a.col.bin")
|
||||||
print('AT:')
|
# print('AT:')
|
||||||
print(AT_array)
|
# print(AT_array)
|
||||||
B_packed = pack_fp16_by_column(B_array)
|
B_packed = pack_fp16_by_column(B_array)
|
||||||
B_array = B_packed.reshape([-1, headdim * 2])
|
B_array = B_packed.reshape([-1, headdim * 2])
|
||||||
B_array.astype('float16').tofile("input.b.row.bin")
|
B_array.astype('float16').tofile("input.b.row.bin")
|
||||||
print('B:')
|
# print('B:')
|
||||||
print(B_array)
|
# print(B_array)
|
||||||
else:
|
else:
|
||||||
A_array.astype('float32').tofile("input.a.row.bin")
|
A_array.astype('float32').tofile("input.a.row.bin")
|
||||||
AT_array = A_array.transpose([1, 0])
|
AT_array = A_array.transpose([1, 0])
|
||||||
AT_array.astype('float32').tofile("input.a.col.bin")
|
AT_array.astype('float32').tofile("input.a.col.bin")
|
||||||
B_array.astype('float32').tofile("input.b.bin")
|
B_array.astype('float32').tofile("input.b.bin")
|
||||||
C_array.astype('float32').tofile("input.c.bin")
|
C_array.astype('float32').tofile("input.c.bin")
|
||||||
print('AT:')
|
# print('AT:')
|
||||||
print(AT_array)
|
# print(AT_array)
|
||||||
print('B:')
|
# print('B:')
|
||||||
print(B_array)
|
# print(B_array)
|
||||||
|
|
||||||
assert((seqlen % 64) == 0)
|
assert((seqlen % 64) == 0)
|
||||||
|
|
||||||
@@ -94,7 +94,7 @@ if __name__ == "__main__":
|
|||||||
def exp2(x):
|
def exp2(x):
|
||||||
return (x**2) / 2.0 + x + 1.0
|
return (x**2) / 2.0 + x + 1.0
|
||||||
|
|
||||||
for col in range(0, Bc, seqlen):
|
for col in range(0, seqlen, Bc):
|
||||||
print(f"tile iteration {col}~{col + Bc} ======================================")
|
print(f"tile iteration {col}~{col + Bc} ======================================")
|
||||||
|
|
||||||
# FIXME: only work with the first 64 rows of Q for now
|
# FIXME: only work with the first 64 rows of Q for now
|
||||||
|
|||||||
Reference in New Issue
Block a user