Commit Graph

7 Commits

Author SHA1 Message Date
Hansung Kim
ac44633b39 flash: Compile time flag for skipping GEMM 2024-08-15 17:40:32 -07:00
Hansung Kim
f844d96eea flash: Initialize rowmax/rowsum cache in sharedmem 2024-08-15 17:28:36 -07:00
Hansung Kim
745aa098ed flash: Optimize spad use, fix rowsum 2024-08-15 16:54:56 -07:00
Hansung Kim
e809d25305 flash: Fix rowsum and write fake exp
GEMM part is disabled for faster debugging, the kernel reads the result
of A*B directly from input binary.
2024-08-15 16:32:21 -07:00
Hansung Kim
53dfc690b9 flash: Allocate smem properly for rowsum and scratch 2024-08-14 21:50:20 -07:00
Hansung Kim
9cabe3413b Fix overlapping smem in rowmax 2024-08-14 21:09:53 -07:00
Hansung Kim
692d028afd Add flash attention kernel skeleton 2024-08-14 20:46:09 -07:00