Commit Graph

16 Commits

Author SHA1 Message Date
Hansung Kim
f6cc61241b flash: 2nd-order taylor approx of exponential for P 2024-08-19 20:12:37 -07:00
Hansung Kim
64e48de8af flash: Do accumulation of PV into O using the single_tile API 2024-08-19 18:08:58 -07:00
Hansung Kim
a98da9e3ca flash: Add missing accum reg init and fix barrier count 2024-08-19 16:15:46 -07:00
Hansung Kim
1e042af571 flash: Write and verify O = O + PV step 2024-08-19 13:18:27 -07:00
Hansung Kim
b44b202a21 sgemm_impl: Rename to wmma 2024-08-18 16:21:22 -07:00
Hansung Kim
90f6effa97 flash: Pass smem_P arg to softmax func 2024-08-18 15:21:05 -07:00
Hansung Kim
d3de1b674a flash: Compute exponents using prev/next/this rowmax values
maybe there is a better way than storing all three in sharedmem?
2024-08-15 22:10:02 -07:00
Hansung Kim
be08204e65 flash: Do proper allocation and init of QK/V/O tile 2024-08-15 21:26:14 -07:00
Hansung Kim
e0daf226ef flash: Change kernel arg to contain qkv; strip stimulus gen from host code
test data is now generated by the python script instead of the host
binary.
2024-08-15 21:03:02 -07:00
Hansung Kim
ac44633b39 flash: Compile time flag for skipping GEMM 2024-08-15 17:40:32 -07:00
Hansung Kim
f844d96eea flash: Initialize rowmax/rowsum cache in sharedmem 2024-08-15 17:28:36 -07:00
Hansung Kim
745aa098ed flash: Optimize spad use, fix rowsum 2024-08-15 16:54:56 -07:00
Hansung Kim
e809d25305 flash: Fix rowsum and write fake exp
GEMM part is disabled for faster debugging, the kernel reads the result
of A*B directly from input binary.
2024-08-15 16:32:21 -07:00
Hansung Kim
53dfc690b9 flash: Allocate smem properly for rowsum and scratch 2024-08-14 21:50:20 -07:00
Hansung Kim
9cabe3413b Fix overlapping smem in rowmax 2024-08-14 21:09:53 -07:00
Hansung Kim
692d028afd Add flash attention kernel skeleton 2024-08-14 20:46:09 -07:00