flash: Write and verify O = O + PV step
This commit is contained in:
@@ -105,7 +105,6 @@ inline void thread_block_online_softmax(
|
|||||||
: "f"(max), "f"(S[first_thread_offset + i]));
|
: "f"(max), "f"(S[first_thread_offset + i]));
|
||||||
}
|
}
|
||||||
smem_rowmax[row] = max;
|
smem_rowmax[row] = max;
|
||||||
gmem_tmp0[row] = max;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#else
|
#else
|
||||||
@@ -188,7 +187,7 @@ inline void thread_block_online_softmax(
|
|||||||
|
|
||||||
// Store S transposed to the shared memory
|
// Store S transposed to the shared memory
|
||||||
|
|
||||||
smem_S[thread_offset] = f0;
|
smem_P[thread_offset] = f0;
|
||||||
// S[thread_offset + 1] = f1;
|
// S[thread_offset + 1] = f1;
|
||||||
gmem_tmp1[thread_offset] = f0;
|
gmem_tmp1[thread_offset] = f0;
|
||||||
|
|
||||||
@@ -206,7 +205,7 @@ inline void thread_block_online_softmax(
|
|||||||
float per_thread_sum = 0.0f;
|
float per_thread_sum = 0.0f;
|
||||||
#pragma GCC unroll
|
#pragma GCC unroll
|
||||||
for (int i = 0; i < per_row_iter; i++) {
|
for (int i = 0; i < per_row_iter; i++) {
|
||||||
per_thread_sum += smem_S[thread_offset];
|
per_thread_sum += smem_P[thread_offset];
|
||||||
thread_offset += NUM_THREADS;
|
thread_offset += NUM_THREADS;
|
||||||
}
|
}
|
||||||
// stage per-thread sum value in smem
|
// stage per-thread sum value in smem
|
||||||
@@ -355,6 +354,9 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
float *tile_S = (float *)arg->addr_q;
|
float *tile_S = (float *)arg->addr_q;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
// FIXME: V is stored in d0000000 for debugging purpose
|
||||||
|
const float *gmem_V = reinterpret_cast<float *>(arg->addr_k);
|
||||||
|
|
||||||
thread_block_online_softmax(
|
thread_block_online_softmax(
|
||||||
tile_S, smem_O, smem_P, tid_in_threadblock, threads_per_threadblock,
|
tile_S, smem_O, smem_P, tid_in_threadblock, threads_per_threadblock,
|
||||||
threadblock_id_in_cluster, smem_scratchpad, smem_rowmax, smem_rowsum);
|
threadblock_id_in_cluster, smem_scratchpad, smem_rowmax, smem_rowsum);
|
||||||
@@ -365,9 +367,20 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
|
|
||||||
float *gmem_tmp2 = reinterpret_cast<float *>(0xf0000000UL);
|
float *gmem_tmp2 = reinterpret_cast<float *>(0xf0000000UL);
|
||||||
|
|
||||||
thread_block_gemm_single_tile<float, /*write_to_smem=*/true>(
|
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, BN, BK>(
|
||||||
|
B_COL, 0 /*FIXME*/, 0 /*FIXME*/, gmem_V, smem_V, tid_in_threadblock);
|
||||||
|
|
||||||
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
|
warps_per_threadblock_per_core);
|
||||||
|
|
||||||
|
// FIXME: support MN_major for A for ideal performance
|
||||||
|
thread_block_gemm_single_tile<float, MemLayout::K_major, MemLayout::MN_major,
|
||||||
|
/*write_to_smem=*/true>(
|
||||||
smem_P, smem_V, gmem_tmp2 /*smem_O*/, tid_in_threadblock,
|
smem_P, smem_V, gmem_tmp2 /*smem_O*/, tid_in_threadblock,
|
||||||
threads_per_threadblock);
|
threads_per_threadblock);
|
||||||
|
|
||||||
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
|
warps_per_threadblock_per_core);
|
||||||
}
|
}
|
||||||
|
|
||||||
int main() {
|
int main() {
|
||||||
|
|||||||
Reference in New Issue
Block a user