bleh still not work

This commit is contained in:
joshua
2024-03-27 00:26:04 -07:00
parent b254281295
commit e16584ddd9
12 changed files with 485 additions and 64 deletions

View File

@@ -0,0 +1,8 @@
PROJECT = tensor
SRCS = main.cpp
DEPS = a_matrix.h
DEPS += b_matrix.h
DEPS += c_matrix.h
include ../common.mk

View File

@@ -0,0 +1,94 @@
import numpy as np
import struct
A_array = np.zeros((16, 8))
B_array = np.zeros((8, 16))
C_array = np.zeros((16, 16))
file = input("simulator output filename: ")
def hex2float(float_hex_str):
# print(float_hex_str.strip())
return struct.unpack(">f",struct.pack(">i",int(float_hex_str,16)))[0]
def C_index(threadgroup, thread, register):
"""
col = ((tg % 4) / 2) * 8;
row = (tg * 8) % 16;
row += (tg / 4) * 4;
asm volatile ("flw f16, %0" :: "m"(C[row+0][col+0]));
asm volatile ("flw f17, %0" :: "m"(C[row+0][col+1]));
asm volatile ("flw f18, %0" :: "m"(C[row+2][col+0]));
asm volatile ("flw f19, %0" :: "m"(C[row+2][col+1]));
asm volatile ("flw f20, %0" :: "m"(C[row+0][col+4]));
asm volatile ("flw f21, %0" :: "m"(C[row+0][col+5]));
asm volatile ("flw f22, %0" :: "m"(C[row+2][col+4]));
asm volatile ("flw f23, %0" :: "m"(C[row+2][col+5]));
"""
col = ((threadgroup % 4) // 2) * 8
row = (threadgroup * 8) % 16
row += (threadgroup // 4) * 4
offsets = [(0, 0), (0, 1), (2, 0), (2, 1), (0, 4), (0, 5), (2, 4), (2, 5)]
offset = offsets[register-16]
row += offset[0]
col += offset[1]
thread_offsets = [(0, 0), (1, 0), (0, 2), (1, 2)]
thread_offset = thread_offsets[thread % 4]
row += thread_offset[0]
col += thread_offset[1]
if C_array[row, col] != 0:
print("bad")
return (row, col)
with open(file) as f:
for line in f.readlines():
line = line.strip()
if "warp" in line:
a, b, c = line.split(',')
_, a = a.split(' ')
_, b = b.strip().split(' ')
c, d = c.strip().split(':')
_, c = c.split(' ')
warp = int(a)
thread = int(b)
register = int(c)
value = d.strip()
if warp != 0:
continue
if not (32 <= register < 32+24):
continue
register = register - 32
# threadgroups 0, 4, 1, 5 have all elements of A
threadgroup = thread // 4
if threadgroup in [0, 4, 1, 5]:
row = [0, 4, 1, 5].index(threadgroup) * 4 + thread % 4
if 0 <= register < 8:
A_array[row, register] = hex2float(value)
if threadgroup in [0, 4, 2, 6]:
col = [0, 4, 2, 6].index(threadgroup) * 4 + thread % 4
if 8 <= register < 16:
B_array[register-8, col] = hex2float(value)
if 16 <= register < 24:
# print(value)
C_array[C_index(threadgroup, thread, register)] = hex2float(value)
expected = np.load("abc.npz")
expected_A = expected['A_array']
expected_B = expected['B_array']
expected_C = expected['C_array']
expected_C = expected_C + expected_A @ expected_B
print(expected_C - C_array)
assert np.allclose(expected_A, A_array)
assert np.allclose(expected_B, B_array)
assert np.allclose(expected_C, C_array)

View File

@@ -0,0 +1,29 @@
import numpy as np
# A_array = np.random.rand(16, 8)
# B_array = np.random.rand(8, 16)
A_array = np.zeros((16, 8))
B_array = np.zeros((8, 16))
A_array[0,:] = 1.0
B_array[:,0] = 1.0
C_array = np.random.rand(16, 16)
with open('a_matrix.h', 'w') as f:
for i in range(A_array.shape[0]):
for j in range(A_array.shape[1]):
f.write(f'{A_array[i,j]}f, ')
f.write('\n')
with open('b_matrix.h', 'w') as f:
for i in range(B_array.shape[0]):
for j in range(B_array.shape[1]):
f.write(f'{B_array[i,j]}f, ')
f.write('\n')
with open('c_matrix.h', 'w') as f:
for i in range(C_array.shape[0]):
for j in range(C_array.shape[1]):
f.write(f'{C_array[i,j]}f, ')
f.write('\n')
np.savez("abc", A_array=A_array, B_array=B_array, C_array=C_array)

View File

@@ -0,0 +1,96 @@
#define RISCV_CUSTOM3 0x7B
#include <vx_intrinsics.h>
#include <stdio.h>
#include <vx_print.h>
inline void vx_wmma() {
asm volatile (".insn r %0, 0, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3));
}
#include "test_data.h"
void vx_wmma_load() {
int tid = vx_thread_id();
int tg = tid / 4;
// load A
int row = tid % 4;
row += (tg * 8) % 16;
row += (tg / 4) * 4;
asm volatile ("flw f0, %0" :: "m"(A[row][0]));
asm volatile ("flw f1, %0" :: "m"(A[row][1]));
asm volatile ("flw f2, %0" :: "m"(A[row][2]));
asm volatile ("flw f3, %0" :: "m"(A[row][3]));
asm volatile ("flw f4, %0" :: "m"(A[row][4]));
asm volatile ("flw f5, %0" :: "m"(A[row][5]));
asm volatile ("flw f6, %0" :: "m"(A[row][6]));
asm volatile ("flw f7, %0" :: "m"(A[row][7]));
// load B
int col = tid % 4;
col += ((tg % 4) / 2) * 8;
col += (tg / 4) * 4;
asm volatile ("flw f8 , %0" :: "m"(B[0][col]));
asm volatile ("flw f9 , %0" :: "m"(B[1][col]));
asm volatile ("flw f10, %0" :: "m"(B[2][col]));
asm volatile ("flw f11, %0" :: "m"(B[3][col]));
asm volatile ("flw f12, %0" :: "m"(B[4][col]));
asm volatile ("flw f13, %0" :: "m"(B[5][col]));
asm volatile ("flw f14, %0" :: "m"(B[6][col]));
asm volatile ("flw f15, %0" :: "m"(B[7][col]));
// load C
col = ((tg % 4) / 2) * 8;
row = (tg * 8) % 16;
row += (tg / 4) * 4;
row += (tid % 4) % 2;
col += ((tid % 4) / 2) * 2;
asm volatile ("flw f16, %0" :: "m"(C[row+0][col+0]));
asm volatile ("flw f17, %0" :: "m"(C[row+0][col+1]));
asm volatile ("flw f18, %0" :: "m"(C[row+2][col+0]));
asm volatile ("flw f19, %0" :: "m"(C[row+2][col+1]));
asm volatile ("flw f20, %0" :: "m"(C[row+0][col+4]));
asm volatile ("flw f21, %0" :: "m"(C[row+0][col+5]));
asm volatile ("flw f22, %0" :: "m"(C[row+2][col+4]));
asm volatile ("flw f23, %0" :: "m"(C[row+2][col+5]));
}
float results[32*8];
void store_wmma_result() {
int tid = vx_thread_id();
asm volatile ("fsw f16, %0" :: "m"(results[tid*8+0]));
asm volatile ("fsw f17, %0" :: "m"(results[tid*8+1]));
asm volatile ("fsw f18, %0" :: "m"(results[tid*8+2]));
asm volatile ("fsw f19, %0" :: "m"(results[tid*8+3]));
asm volatile ("fsw f20, %0" :: "m"(results[tid*8+4]));
asm volatile ("fsw f21, %0" :: "m"(results[tid*8+5]));
asm volatile ("fsw f22, %0" :: "m"(results[tid*8+6]));
asm volatile ("fsw f23, %0" :: "m"(results[tid*8+7]));
}
void print_wmma_result() {
for (int tid = 0; tid < 32; tid += 1) {
for (int reg = 0; reg < 8; reg += 1) {
vx_printf("thread %d, f%d: %x\n", tid, 16+reg, *((int*) &results[tid*8+reg]));
}
}
}
int main()
{
vx_tmc(-1);
vx_wmma_load();
vx_wmma();
store_wmma_result();
vx_tmc(1);
print_wmma_result();
return 0;
}

View File

@@ -0,0 +1,11 @@
float A[16][8] = {
#include "a_matrix.h"
};
float B[8][16] = {
#include "b_matrix.h"
};
float C[16][16] = {
#include "c_matrix.h"
};