update kernels

This commit is contained in:
Richard Yan
2025-01-28 16:37:22 -08:00
parent c114a7a4ab
commit b1e6495630
9 changed files with 67 additions and 50 deletions

View File

@@ -15,20 +15,20 @@ def truncated_matrix_multiplication(matrix_a, matrix_b, size):
result = np.matmul(truncated_a, truncated_b)
return result.astype(np.float16)
# Generate the 512x512 matrices
size = 512
matrix_a = generate_fp16_matrix(size)
matrix_b = generate_fp16_matrix(size)
# Save the operand matrices to binary files
save_matrix_to_bin("input.a.bin", matrix_a)
save_matrix_to_bin("input.b.bin", matrix_b)
# Generate and save the reference matrices for 128x128, 256x256, and 512x512 sizes
sizes = [128, 256, 512]
sizes = [128, 256, 512, 1024]
for s in sizes:
np.random.seed(0)
matrix_a = generate_fp16_matrix(s)
matrix_b = generate_fp16_matrix(s)
# Save the operand matrices to binary files
save_matrix_to_bin("input.a.bin", matrix_a)
save_matrix_to_bin(f"input.a/{s}", matrix_a)
save_matrix_to_bin("input.b.bin", matrix_b)
save_matrix_to_bin(f"input.b/{s}", matrix_b)
ref_matrix = truncated_matrix_multiplication(matrix_a, matrix_b, s)
print(ref_matrix)
save_matrix_to_bin(f"ref{s}.bin", ref_matrix)
print("All files generated successfully.")