diff --git a/tests/regression/flash_attention/common.h b/tests/regression/flash_attention/common.h index 5c84f3b7..9c09726f 100644 --- a/tests/regression/flash_attention/common.h +++ b/tests/regression/flash_attention/common.h @@ -7,12 +7,12 @@ #define DEV_SMEM_START_ADDR 0xff000000 typedef struct { - uint32_t dim_m; - uint32_t dim_n; - uint32_t dim_k; - uint64_t addr_a; - uint64_t addr_b; - uint64_t addr_c; + uint32_t dim_seqlen; + uint32_t dim_headdim; + uint64_t addr_q; + uint64_t addr_k; + uint64_t addr_v; + uint64_t addr_o; } kernel_arg_t; #endif diff --git a/tests/regression/flash_attention/flash_attention b/tests/regression/flash_attention/flash_attention deleted file mode 100644 index 993f22bf..00000000 Binary files a/tests/regression/flash_attention/flash_attention and /dev/null differ diff --git a/tests/regression/flash_attention/kernel.cpp b/tests/regression/flash_attention/kernel.cpp index 888db94a..ffd3e495 100644 --- a/tests/regression/flash_attention/kernel.cpp +++ b/tests/regression/flash_attention/kernel.cpp @@ -8,9 +8,6 @@ #include "include/gemmini.h" #include "gemmini_mmio.h" -// using float_type = float; -using float_type = float16_t; - #define B_ROW BM #define B_COL BN @@ -90,8 +87,8 @@ inline void thread_block_flashattn(float *S, const uint32_t tid_in_threadblock, } #else - static_assert((B_ROW % NUM_THREADS) == 0, - "B_ROW must be a multiple of NUM_THREADS"); + static_assert((B_COL % NUM_THREADS) == 0, + "B_COL must be a multiple of NUM_THREADS"); constexpr uint32_t per_row_iter = B_COL / NUM_THREADS; uint32_t thread_offset = first_thread_offset + tid_in_warp; float per_thread_max = FLT_MIN; @@ -122,7 +119,7 @@ inline void thread_block_flashattn(float *S, const uint32_t tid_in_threadblock, : "f"(rowmax), "f"(other)); } - // update previous rowsum + // update previous rowmax // i.e. mi_new = max(mi, mij) float prev_rowmax = sharedmem_rowmax[row]; asm volatile("fmax.s %0, %1, %2" @@ -147,17 +144,32 @@ inline void thread_block_flashattn(float *S, const uint32_t tid_in_threadblock, // broadcast rowmax to all threads in the warp const float row_max = sharedmem_rowmax[row]; - thread_offset = first_thread_offset + tid_in_warp; + // each thread computes two fp32 elements, downconverts it to fp16, then + // packs them into one fp32 + constexpr uint32_t elem_per_thread = 1; + static_assert((B_COL % (elem_per_thread * NUM_THREADS)) == 0, + "B_COL condition not met for P compute"); + + thread_offset = first_thread_offset + (elem_per_thread * tid_in_warp); + constexpr uint32_t exp_per_row_iter = + B_COL / (elem_per_thread * NUM_THREADS); #pragma GCC unroll - for (int i = 0; i < per_row_iter; i++) { - float val = S[thread_offset]; + for (int i = 0; i < exp_per_row_iter; i++) { + float f0 = S[thread_offset]; + // float f1 = S[thread_offset + 1]; // FIXME: placeholder for proper exp - val -= row_max; + f0 -= row_max; + // f1 -= row_max; + // float16_t h0 = NN_float_to_half(f0); + // float16_t h1 = NN_float_to_half(f1); - // update S in-place to P - S[thread_offset] = val; - gmem_tmp1[thread_offset] = val; + // Store S transposed to the shared memory + + // update S in-place into P + S[thread_offset] = f0; + // S[thread_offset + 1] = f1; + gmem_tmp1[thread_offset] = f0; thread_offset += NUM_THREADS; } @@ -230,13 +242,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { threadblock_id % threadblocks_per_cluster; const int tid_in_threadblock = task_id % threads_per_threadblock; - const uint32_t dim_m = arg->dim_m; - const uint32_t dim_n = arg->dim_n; - const uint32_t dim_n_in_blocks = dim_n / BN; - const int threadblock_id_x = threadblock_id % dim_n_in_blocks; - const int threadblock_id_y = threadblock_id / dim_n_in_blocks; - const uint32_t problem_size = (dim_m * dim_n) / (ELEM_PER_THREAD); - const uint32_t num_threadblocks = problem_size / threads_per_threadblock; + const uint32_t dim_seqlen = arg->dim_seqlen; + const uint32_t dim_headdim = arg->dim_headdim; // "static" shared memory allocation. This would determine threadblock // occupancy of a single cluster @@ -272,7 +279,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { #define SKIP_GEMM #ifndef SKIP_GEMM thread_block_gemm( - (const float_type *)arg->addr_a, (const float_type *)arg->addr_b, + (const float_type *)arg->addr_q, (const float_type *)arg->addr_k, (float *)smem_S /*write result to SMEM */, arg->dim_m, arg->dim_n, arg->dim_k, tid_in_threadblock, threads_per_threadblock, threadblocks_per_cluster, threadblock_id_in_cluster, @@ -284,7 +291,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { float *tile_S = (float *)smem_S; #else - float *tile_S = (float *)arg->addr_a; + float *tile_S = (float *)arg->addr_q; #endif thread_block_flashattn(tile_S, tid_in_threadblock, @@ -296,7 +303,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { int main() { kernel_arg_t *arg = (kernel_arg_t *)KERNEL_ARG_DEV_MEM_ADDR; - const uint32_t problem_size = (arg->dim_m * arg->dim_n) / (ELEM_PER_THREAD); + // FIXME:: use actuall seqlen/headdim + const uint32_t problem_size = (B_ROW * B_COL) / (ELEM_PER_THREAD); const uint32_t hw_threads_per_cluster = CORES_PER_CLUSTER * vx_num_threads() * vx_num_warps(); // prevent launching more threads than the necessary problem size diff --git a/tests/regression/flash_attention/main.cpp b/tests/regression/flash_attention/main.cpp index 7399ce6d..3747bbb6 100644 --- a/tests/regression/flash_attention/main.cpp +++ b/tests/regression/flash_attention/main.cpp @@ -26,8 +26,6 @@ using half_float::half_cast; const char* kernel_file = "kernel.bin"; uint32_t count = 0; -template std::vector src_a_data; -template std::vector src_b_data; std::vector ref_data; vx_device_h device = nullptr; @@ -70,54 +68,8 @@ void cleanup() { } } -template -void generate_source_matrix(uint32_t dim_m, uint32_t dim_n, uint32_t dim_k) { - static_assert(std::is_same_v || std::is_same_v, - "unsupported floating point datatype"); - - src_a_data.resize(dim_m * dim_k); - src_b_data.resize(dim_k * dim_n); - - for (uint32_t i = 0; i < src_a_data.size(); ++i) { - if constexpr (std::is_same_v) { - src_a_data[i] = half_cast(static_cast(i)); - } else if (std::is_same_v) { - src_a_data[i] = static_cast(i); - } - std::cout << "A: " << i << ": value=" << src_a_data[i] << std::endl; - } - for (uint32_t i = 0; i < src_b_data.size(); ++i) { - if constexpr (std::is_same_v) { - src_b_data[i] = half_cast(static_cast(i)); - } else if (std::is_same_v) { - src_b_data[i] = static_cast(i); - } - std::cout << "B: " << i << ": value=" << src_b_data[i] << std::endl; - } -} - -template -void generate_reference_matmul(uint32_t dim_m, uint32_t dim_n, uint32_t dim_k) { - static_assert(std::is_same_v || std::is_same_v, - "unsupported floating point datatype"); - - ref_data.resize(dim_m * dim_n); - - for (uint32_t i = 0; i < dim_m; ++i) { - for (uint32_t j = 0; j < dim_n; ++j) { - float ref = 0.0f; - for (uint32_t k = 0; k < dim_k; ++k) { - ref += static_cast(src_a_data[dim_k * i + k]) * - static_cast(src_b_data[dim_n * k + j]); - } - ref_data.at(dim_n * i + j) = ref; - } - } -} - int run_test(const kernel_arg_t& kernel_arg, - uint32_t buf_size, - uint32_t dim_m, uint32_t dim_n) { + uint32_t buf_size) { // start device std::cout << "start device" << std::endl; RT_CHECK(vx_start(device)); @@ -128,28 +80,7 @@ int run_test(const kernel_arg_t& kernel_arg, // download destination buffer std::cout << "download destination buffer" << std::endl; - RT_CHECK(vx_copy_from_dev(device, staging_buf.data(), kernel_arg.addr_c, buf_size)); - - // verify result - std::cout << "verify result" << std::endl; - { - int errors = 0; - auto buf_ptr = (float*)staging_buf.data(); - for (uint32_t i = 0; i < dim_m * dim_n; ++i) { - float ref = ref_data.at(i); - float cur = buf_ptr[i]; - if (std::abs((cur - ref) / ref) > 1e-6) { - std::cout << "error at result #" << std::dec << i - << std::hex << ": actual=" << cur << ", expected=" << ref << std::endl; - ++errors; - } - } - if (errors != 0) { - std::cout << "Found " << std::dec << errors << " errors!" << std::endl; - std::cout << "FAILED!" << std::endl; - return 1; - } - } + RT_CHECK(vx_copy_from_dev(device, staging_buf.data(), kernel_arg.addr_o, buf_size)); return 0; } @@ -168,30 +99,13 @@ int main(int argc, char *argv[]) { std::cout << "open device connection" << std::endl; RT_CHECK(vx_dev_open(&device)); - // FIXME: hardcoded - uint32_t dim_m = 128; - uint32_t dim_n = 128; - uint32_t dim_k = 128; + uint32_t dim_seqlen = 64; + uint32_t dim_headdim = 64; using float_type = half; - generate_source_matrix(dim_m, dim_n, dim_k); - generate_reference_matmul(dim_m, dim_n, dim_k); - - std::cout << "write reference output" << std::endl; - std::ofstream ref_file("reference.c.bin", std::ios::binary | std::ios::out); - if (!ref_file) { - std::cerr << "error: failed to open reference.c.bin for writing\n"; - exit(EXIT_FAILURE); - } - ref_file.write(reinterpret_cast(ref_data.data()), ref_data.size() * sizeof(ref_data[0])); - ref_file.close(); - - uint32_t src_a_buf_size = src_a_data.size() * sizeof(src_a_data[0]); - uint32_t src_b_buf_size = src_b_data.size() * sizeof(src_b_data[0]); - uint32_t dst_buf_size = ref_data.size() * sizeof(src_a_data[0]); - - std::cout << "buffer size: " << dst_buf_size << " bytes" << std::endl; + uint32_t dst_buf_size = + dim_seqlen * dim_headdim * sizeof(ref_data[0]); // upload program std::cout << "upload program" << std::endl; @@ -199,29 +113,23 @@ int main(int argc, char *argv[]) { // allocate device memory std::cout << "allocate device memory" << std::endl; - // RT_CHECK(vx_mem_alloc(device, src_a_buf_size, VX_MEM_TYPE_GLOBAL, &kernel_arg.addr_a)); - // RT_CHECK(vx_mem_alloc(device, src_b_buf_size, VX_MEM_TYPE_GLOBAL, &kernel_arg.addr_b)); - // RT_CHECK(vx_mem_alloc(device, dst_buf_size, VX_MEM_TYPE_GLOBAL, &kernel_arg.addr_c)); - kernel_arg.addr_a = 0xa0000000; - kernel_arg.addr_b = 0xa1000000; - kernel_arg.addr_c = 0xc0000000; + kernel_arg.addr_q = 0xa0000000; + kernel_arg.addr_k = 0xa1000000; + kernel_arg.addr_v = 0xa2000000; + kernel_arg.addr_o = 0xc0000000; - kernel_arg.dim_m = dim_m; - kernel_arg.dim_n = dim_n; - kernel_arg.dim_k = dim_k; + kernel_arg.dim_seqlen = dim_seqlen; + kernel_arg.dim_headdim = dim_headdim; - std::cout << "dev_addr_a=0x" << std::hex << kernel_arg.addr_a << std::endl; - std::cout << "dev_addr_b=0x" << std::hex << kernel_arg.addr_b << std::endl; - std::cout << "dev_addr_c=0x" << std::hex << kernel_arg.addr_c << std::endl; + std::cout << "dev_addr_q=0x" << std::hex << kernel_arg.addr_q << std::endl; + std::cout << "dev_addr_k=0x" << std::hex << kernel_arg.addr_k << std::endl; + std::cout << "dev_addr_v=0x" << std::hex << kernel_arg.addr_v << std::endl; + std::cout << "dev_addr_o=0x" << std::hex << kernel_arg.addr_o << std::endl; // allocate staging buffer { std::cout << "allocate staging buffer" << std::endl; - uint32_t staging_buf_size = std::max( - src_a_buf_size, - std::max( - src_b_buf_size, - std::max(dst_buf_size, sizeof(kernel_arg_t)))); + uint32_t staging_buf_size = sizeof(kernel_arg_t); staging_buf.resize(staging_buf_size); } @@ -245,59 +153,9 @@ int main(int argc, char *argv[]) { file.close(); } - // upload source buffer - { - { - auto buf_ptr = staging_buf.data(); - memcpy(buf_ptr, src_a_data.data(), - src_a_data.size() * sizeof(float_type)); - RT_CHECK(vx_copy_to_dev(device, kernel_arg.addr_a, staging_buf.data(), - src_a_buf_size)); - - std::cout << "uploading source A matrix to device, device mem address=" - << std::hex << kernel_arg.addr_a << ", size=" << std::dec - << src_a_buf_size << " bytes\n"; - std::ofstream file("input.a.bin", std::ios::binary | std::ios::out); - if (!file) { - std::cerr << "error: failed to open args.bin for writing\n"; - exit(EXIT_FAILURE); - } - file.write(reinterpret_cast(buf_ptr), src_a_buf_size); - file.close(); - } - { - auto buf_ptr = staging_buf.data(); - memcpy(buf_ptr, src_b_data.data(), - src_b_data.size() * sizeof(float_type)); - RT_CHECK(vx_copy_to_dev(device, kernel_arg.addr_b, staging_buf.data(), - src_b_buf_size)); - - std::cout << "uploading source B matrix to device, device mem address=" - << std::hex << kernel_arg.addr_b << ", size=" << std::dec - << src_b_buf_size << " bytes\n"; - std::ofstream file("input.b.bin", std::ios::binary | std::ios::out); - if (!file) { - std::cerr << "error: failed to open args.bin for writing\n"; - exit(EXIT_FAILURE); - } - file.write(reinterpret_cast(buf_ptr), src_b_buf_size); - file.close(); - } - } - - // clear destination buffer - { - std::cout << "clear destination buffer" << std::endl; - auto buf_ptr = (int32_t*)staging_buf.data(); - for (uint32_t i = 0; i < ref_data.size(); ++i) { - buf_ptr[i] = 0xdeadbeef; - } - RT_CHECK(vx_copy_to_dev(device, kernel_arg.addr_c, staging_buf.data(), dst_buf_size)); - } - // run tests std::cout << "run tests" << std::endl; - RT_CHECK(run_test(kernel_arg, dst_buf_size, kernel_arg.dim_m, kernel_arg.dim_n)); + RT_CHECK(run_test(kernel_arg, dst_buf_size)); std::cout << "PASSED!" << std::endl; // cleanup