flash: Reduce smem use for rowmax; verify result
This commit is contained in:
@@ -13,6 +13,8 @@
|
|||||||
// FIXME
|
// FIXME
|
||||||
#define HEADDIM B_COL
|
#define HEADDIM B_COL
|
||||||
|
|
||||||
|
constexpr bool DEBUG = true;
|
||||||
|
|
||||||
inline void thread_block_init_sharedmem(const uint32_t tid_in_threadblock,
|
inline void thread_block_init_sharedmem(const uint32_t tid_in_threadblock,
|
||||||
const uint32_t threads_per_threadblock,
|
const uint32_t threads_per_threadblock,
|
||||||
float *smem_O,
|
float *smem_O,
|
||||||
@@ -53,6 +55,66 @@ inline void thread_block_init_sharedmem(const uint32_t tid_in_threadblock,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline void thread_block_copy_rowmax(const float *src, float *dest,
|
||||||
|
const uint32_t tid_in_threadblock,
|
||||||
|
const uint32_t threads_per_threadblock,
|
||||||
|
const uint32_t threadblocks_per_cluster,
|
||||||
|
const uint32_t threadblock_id_in_cluster) {
|
||||||
|
asm volatile("threadblock_copy_rowmax_start_%=:" ::);
|
||||||
|
|
||||||
|
const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS;
|
||||||
|
const uint32_t warp_id = tid_in_threadblock / NUM_THREADS;
|
||||||
|
const uint32_t warps_in_threadblock = threads_per_threadblock / NUM_THREADS;
|
||||||
|
const uint32_t warps_per_threadblock_per_core =
|
||||||
|
NUM_WARPS / threadblocks_per_cluster;
|
||||||
|
|
||||||
|
constexpr uint32_t num_warps = B_ROW / NUM_THREADS;
|
||||||
|
if (warp_id < num_warps) {
|
||||||
|
uint32_t offset = NUM_THREADS * warp_id + tid_in_warp;
|
||||||
|
dest[offset] = src[offset];
|
||||||
|
}
|
||||||
|
|
||||||
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
|
warps_per_threadblock_per_core);
|
||||||
|
|
||||||
|
asm volatile("threadblock_copy_rowmax_finish_%=:" ::);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void thread_block_copy_tile(const float *src, float *dest,
|
||||||
|
const uint32_t tid_in_threadblock,
|
||||||
|
const uint32_t threads_per_threadblock,
|
||||||
|
const uint32_t threadblocks_per_cluster,
|
||||||
|
const uint32_t threadblock_id_in_cluster) {
|
||||||
|
asm volatile("threadblock_copy_tile_start_%=:" ::);
|
||||||
|
|
||||||
|
const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS;
|
||||||
|
const uint32_t warp_id = tid_in_threadblock / NUM_THREADS;
|
||||||
|
const uint32_t warps_in_threadblock = threads_per_threadblock / NUM_THREADS;
|
||||||
|
const uint32_t warps_per_threadblock_per_core =
|
||||||
|
NUM_WARPS / threadblocks_per_cluster;
|
||||||
|
|
||||||
|
// FIXME: dedup this pattern
|
||||||
|
for (int warp_offset = 0; warp_offset < B_ROW;
|
||||||
|
warp_offset += warps_in_threadblock) {
|
||||||
|
const uint32_t row = warp_offset + warp_id;
|
||||||
|
const uint32_t first_thread_offset = B_COL * row;
|
||||||
|
|
||||||
|
constexpr uint32_t per_row_iter = B_COL / NUM_THREADS;
|
||||||
|
uint32_t thread_offset = first_thread_offset + tid_in_warp;
|
||||||
|
float per_thread_max = FLT_MIN;
|
||||||
|
#pragma GCC unroll
|
||||||
|
for (int i = 0; i < per_row_iter; i++) {
|
||||||
|
dest[thread_offset] = src[thread_offset];
|
||||||
|
thread_offset += NUM_THREADS;
|
||||||
|
}
|
||||||
|
|
||||||
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
|
warps_per_threadblock_per_core);
|
||||||
|
}
|
||||||
|
|
||||||
|
asm volatile("threadblock_copy_tile_finish_%=:" ::);
|
||||||
|
}
|
||||||
|
|
||||||
template <int order>
|
template <int order>
|
||||||
inline float exponential_taylor_term(const float x) {
|
inline float exponential_taylor_term(const float x) {
|
||||||
asm volatile("exponential_taylor_term_start_%=:" ::);
|
asm volatile("exponential_taylor_term_start_%=:" ::);
|
||||||
@@ -73,38 +135,7 @@ inline float exponential_taylor_term(const float x) {
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void thread_block_copy_data(const float *src, float *dest,
|
__attribute__((always_inline)) inline void thread_block_online_softmax(
|
||||||
const uint32_t tid_in_threadblock,
|
|
||||||
const uint32_t threads_per_threadblock,
|
|
||||||
const uint32_t threadblocks_per_cluster,
|
|
||||||
const uint32_t threadblock_id_in_cluster) {
|
|
||||||
const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS;
|
|
||||||
const uint32_t warp_id = tid_in_threadblock / NUM_THREADS;
|
|
||||||
const uint32_t warps_in_threadblock = threads_per_threadblock / NUM_THREADS;
|
|
||||||
const uint32_t warps_per_threadblock_per_core =
|
|
||||||
NUM_WARPS / threadblocks_per_cluster;
|
|
||||||
|
|
||||||
for (int warp_offset = 0; warp_offset < B_ROW;
|
|
||||||
warp_offset += warps_in_threadblock) {
|
|
||||||
const uint32_t row = warp_offset + warp_id;
|
|
||||||
const uint32_t first_thread_offset = B_COL * row;
|
|
||||||
|
|
||||||
constexpr uint32_t per_row_iter = B_COL / NUM_THREADS;
|
|
||||||
uint32_t thread_offset = first_thread_offset + tid_in_warp;
|
|
||||||
float per_thread_max = FLT_MIN;
|
|
||||||
#pragma GCC unroll
|
|
||||||
for (int i = 0; i < per_row_iter; i++) {
|
|
||||||
const float f = src[thread_offset];
|
|
||||||
dest[thread_offset] = f;
|
|
||||||
thread_offset += NUM_THREADS;
|
|
||||||
}
|
|
||||||
|
|
||||||
threadblock_barrier(threadblock_id_in_cluster,
|
|
||||||
warps_per_threadblock_per_core);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
inline void thread_block_online_softmax(
|
|
||||||
const float *smem_S, float *smem_O, float *smem_P,
|
const float *smem_S, float *smem_O, float *smem_P,
|
||||||
const uint32_t tid_in_threadblock, const uint32_t threads_per_threadblock,
|
const uint32_t tid_in_threadblock, const uint32_t threads_per_threadblock,
|
||||||
const uint32_t threadblocks_per_cluster,
|
const uint32_t threadblocks_per_cluster,
|
||||||
@@ -128,9 +159,7 @@ inline void thread_block_online_softmax(
|
|||||||
// asm volatile("fmv.s %0, f22" : "=f"(ft[6]));
|
// asm volatile("fmv.s %0, f22" : "=f"(ft[6]));
|
||||||
// asm volatile("fmv.s %0, f23" : "=f"(ft[7]));
|
// asm volatile("fmv.s %0, f23" : "=f"(ft[7]));
|
||||||
|
|
||||||
float *smem_rowmax_prev = smem_rowmax;
|
float *smem_rowmax_this = smem_rowmax + B_ROW;
|
||||||
float *smem_rowmax_new = smem_rowmax + B_ROW;
|
|
||||||
float *smem_rowmax_this = smem_rowmax + 2 * B_ROW;
|
|
||||||
|
|
||||||
for (int warp_offset = 0; warp_offset < B_ROW;
|
for (int warp_offset = 0; warp_offset < B_ROW;
|
||||||
warp_offset += warps_in_threadblock) {
|
warp_offset += warps_in_threadblock) {
|
||||||
@@ -192,26 +221,34 @@ inline void thread_block_online_softmax(
|
|||||||
|
|
||||||
// update previous rowmax
|
// update previous rowmax
|
||||||
// i.e. mi_new = max(mi, mij)
|
// i.e. mi_new = max(mi, mij)
|
||||||
float prev_rowmax = smem_rowmax_prev[row];
|
float prev_rowmax = smem_rowmax[row];
|
||||||
|
// stage prev rowmax in scratchpad for warp-wide broadcast
|
||||||
|
warp_smem[0] = prev_rowmax;
|
||||||
asm volatile("fmax.s %0, %1, %2"
|
asm volatile("fmax.s %0, %1, %2"
|
||||||
: "=f"(rowmax)
|
: "=f"(rowmax)
|
||||||
: "f"(rowmax), "f"(prev_rowmax));
|
: "f"(rowmax), "f"(prev_rowmax));
|
||||||
smem_rowmax_new[row] = rowmax;
|
smem_rowmax[row] = rowmax;
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// FIXME: unnecessary?
|
// FIXME: unnecessary?
|
||||||
threadblock_barrier(threadblock_id_in_cluster,
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
warps_per_threadblock_per_core);
|
warps_per_threadblock_per_core);
|
||||||
|
|
||||||
|
// broadcast prev rowmax to all threads in the warp
|
||||||
|
// NOTE: memory consistency is a little sketchy here
|
||||||
|
const float rowmax_prev = warp_smem[0];
|
||||||
|
const float rowmax_this = smem_rowmax_this[row];
|
||||||
|
|
||||||
// exponential
|
// exponential
|
||||||
//
|
//
|
||||||
// B_ROW / (B_ROW * B_COL / (exp_elem * threads_per_threadblock))
|
// B_ROW / (B_ROW * B_COL / (exp_elem * threads_per_threadblock))
|
||||||
// const uint32_t row_stride =
|
// const uint32_t row_stride =
|
||||||
// (exp_elem_per_thread * threads_per_threadblock) / B_COL;
|
// (exp_elem_per_thread * threads_per_threadblock) / B_COL;
|
||||||
|
|
||||||
// broadcast rowmax to all threads in the warp
|
// broadcast updated rowmax to all threads in the warp
|
||||||
const float rowmax_new = smem_rowmax_new[row];
|
const float rowmax_new = smem_rowmax[row];
|
||||||
|
|
||||||
// each thread computes two fp32 elements, downconverts it to fp16, then
|
// each thread computes two fp32 elements, downconverts it to fp16, then
|
||||||
// packs them into one fp32
|
// packs them into one fp32
|
||||||
@@ -279,8 +316,9 @@ inline void thread_block_online_softmax(
|
|||||||
rowsum += other;
|
rowsum += other;
|
||||||
}
|
}
|
||||||
|
|
||||||
const float mi_prev = smem_rowmax_prev[row];
|
const float mi_prev = rowmax_prev;
|
||||||
const float mi_this = smem_rowmax_this[row];
|
// TODO: replace this with a register?
|
||||||
|
const float mi_this = rowmax_this;
|
||||||
|
|
||||||
const float x = mi_prev - mi_this;
|
const float x = mi_prev - mi_this;
|
||||||
// 2nd-order Taylor approximation
|
// 2nd-order Taylor approximation
|
||||||
@@ -309,8 +347,8 @@ inline void thread_block_online_softmax(
|
|||||||
for (int i = 0; i < per_row_iter; i++) {
|
for (int i = 0; i < per_row_iter; i++) {
|
||||||
float o = smem_O[thread_offset];
|
float o = smem_O[thread_offset];
|
||||||
|
|
||||||
const float mi_prev = smem_rowmax_prev[row];
|
const float mi_prev = rowmax_prev;
|
||||||
const float mi_new = smem_rowmax_new[row];
|
const float mi_new = rowmax_new;
|
||||||
|
|
||||||
const float x = mi_prev - mi_new;
|
const float x = mi_prev - mi_new;
|
||||||
// 2nd-order Taylor approximation
|
// 2nd-order Taylor approximation
|
||||||
@@ -398,9 +436,10 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
// sharedmem "scratchpad" area to put temporary data, e.g. for tree reduction
|
// sharedmem "scratchpad" area to put temporary data, e.g. for tree reduction
|
||||||
// in rowsum
|
// in rowsum
|
||||||
// NOTE: out-of bounds is not checked
|
// NOTE: out-of bounds is not checked
|
||||||
|
// TODO: reduce this from B_ROW to NUM_WARPS
|
||||||
constexpr uint32_t smem_scratchpad_size =
|
constexpr uint32_t smem_scratchpad_size =
|
||||||
B_ROW * NUM_THREADS * 2 /*arbitrary slack*/;
|
B_ROW * NUM_THREADS * 2 /*arbitrary slack*/;
|
||||||
float *smem_scratchpad = smem_rowmax - smem_scratchpad_size;
|
float *smem_scratchpad = smem_rowsum - smem_scratchpad_size;
|
||||||
|
|
||||||
const uint32_t warps_per_threadblock_per_core =
|
const uint32_t warps_per_threadblock_per_core =
|
||||||
NUM_WARPS / threadblocks_per_cluster;
|
NUM_WARPS / threadblocks_per_cluster;
|
||||||
@@ -414,6 +453,17 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
const float *gmem_V = reinterpret_cast<float *>(arg->addr_v);
|
const float *gmem_V = reinterpret_cast<float *>(arg->addr_v);
|
||||||
float *gmem_O = reinterpret_cast<float *>(arg->addr_o);
|
float *gmem_O = reinterpret_cast<float *>(arg->addr_o);
|
||||||
|
|
||||||
|
float *gmem_tmp_d0 = reinterpret_cast<float *>(0xd0000000UL);
|
||||||
|
float *gmem_tmp_d1 = reinterpret_cast<float *>(0xd1000000UL);
|
||||||
|
float *gmem_tmp_d2 = reinterpret_cast<float *>(0xd2000000UL);
|
||||||
|
float *gmem_tmp_d3 = reinterpret_cast<float *>(0xd3000000UL);
|
||||||
|
float *gmem_tmp_d4 = reinterpret_cast<float *>(0xd4000000UL);
|
||||||
|
float *gmem_tmp_d5 = reinterpret_cast<float *>(0xd5000000UL);
|
||||||
|
float *gmem_tmp_e0 = reinterpret_cast<float *>(0xe0000000UL);
|
||||||
|
float *gmem_tmp_e1 = reinterpret_cast<float *>(0xe1000000UL);
|
||||||
|
float *gmem_tmp_e2 = reinterpret_cast<float *>(0xe2000000UL);
|
||||||
|
float *gmem_tmp_e3 = reinterpret_cast<float *>(0xe3000000UL);
|
||||||
|
|
||||||
// "inner loop" along the columns of K^T
|
// "inner loop" along the columns of K^T
|
||||||
for (uint32_t tile_k = 0; tile_k < (dim_seqlen / B_COL); tile_k++) {
|
for (uint32_t tile_k = 0; tile_k < (dim_seqlen / B_COL); tile_k++) {
|
||||||
|
|
||||||
@@ -469,6 +519,43 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
threadblock_barrier(threadblock_id_in_cluster,
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
warps_per_threadblock_per_core);
|
warps_per_threadblock_per_core);
|
||||||
|
|
||||||
|
if constexpr (DEBUG) {
|
||||||
|
if (tile_k == 0) {
|
||||||
|
thread_block_copy_tile(
|
||||||
|
smem_P, gmem_tmp_d0, tid_in_threadblock, threads_per_threadblock,
|
||||||
|
threadblocks_per_cluster, threadblock_id_in_cluster);
|
||||||
|
thread_block_copy_tile(
|
||||||
|
smem_O, gmem_tmp_d2, tid_in_threadblock, threads_per_threadblock,
|
||||||
|
threadblocks_per_cluster, threadblock_id_in_cluster);
|
||||||
|
thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e0, tid_in_threadblock,
|
||||||
|
threads_per_threadblock,
|
||||||
|
threadblocks_per_cluster,
|
||||||
|
threadblock_id_in_cluster);
|
||||||
|
thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e2, tid_in_threadblock,
|
||||||
|
threads_per_threadblock,
|
||||||
|
threadblocks_per_cluster,
|
||||||
|
threadblock_id_in_cluster);
|
||||||
|
} else if (tile_k == 1) {
|
||||||
|
thread_block_copy_tile(
|
||||||
|
smem_P, gmem_tmp_d1, tid_in_threadblock, threads_per_threadblock,
|
||||||
|
threadblocks_per_cluster, threadblock_id_in_cluster);
|
||||||
|
thread_block_copy_tile(
|
||||||
|
smem_O, gmem_tmp_d3, tid_in_threadblock, threads_per_threadblock,
|
||||||
|
threadblocks_per_cluster, threadblock_id_in_cluster);
|
||||||
|
thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e1, tid_in_threadblock,
|
||||||
|
threads_per_threadblock,
|
||||||
|
threadblocks_per_cluster,
|
||||||
|
threadblock_id_in_cluster);
|
||||||
|
thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e3, tid_in_threadblock,
|
||||||
|
threads_per_threadblock,
|
||||||
|
threadblocks_per_cluster,
|
||||||
|
threadblock_id_in_cluster);
|
||||||
|
}
|
||||||
|
|
||||||
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
|
warps_per_threadblock_per_core);
|
||||||
|
}
|
||||||
|
|
||||||
// GEMM II: O = O + P*V
|
// GEMM II: O = O + P*V
|
||||||
|
|
||||||
// clear out accumulators
|
// clear out accumulators
|
||||||
@@ -495,18 +582,22 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
|
|
||||||
threadblock_barrier(threadblock_id_in_cluster,
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
warps_per_threadblock_per_core);
|
warps_per_threadblock_per_core);
|
||||||
|
|
||||||
|
if constexpr (DEBUG) {
|
||||||
|
if (tile_k == 0) {
|
||||||
|
thread_block_copy_tile(
|
||||||
|
smem_O, gmem_tmp_d4, tid_in_threadblock, threads_per_threadblock,
|
||||||
|
threadblocks_per_cluster, threadblock_id_in_cluster);
|
||||||
|
} else if (tile_k == 1) {
|
||||||
|
thread_block_copy_tile(
|
||||||
|
smem_O, gmem_tmp_d5, tid_in_threadblock, threads_per_threadblock,
|
||||||
|
threadblocks_per_cluster, threadblock_id_in_cluster);
|
||||||
|
}
|
||||||
|
|
||||||
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
|
warps_per_threadblock_per_core);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
float *gmem_tmp0 = reinterpret_cast<float *>(0xd0000000UL);
|
|
||||||
float *gmem_tmp1 = reinterpret_cast<float *>(0xe0000000UL);
|
|
||||||
|
|
||||||
// copy out tile data to GMEM for debugging
|
|
||||||
thread_block_copy_data(smem_P, gmem_tmp0, tid_in_threadblock,
|
|
||||||
threads_per_threadblock, threadblocks_per_cluster,
|
|
||||||
threadblock_id_in_cluster);
|
|
||||||
thread_block_copy_data(smem_O, gmem_tmp1, tid_in_threadblock,
|
|
||||||
threads_per_threadblock, threadblocks_per_cluster,
|
|
||||||
threadblock_id_in_cluster);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int main() {
|
int main() {
|
||||||
|
|||||||
Reference in New Issue
Block a user