From 301f1ca26097f932bb3b4a37b3fbc93b67d5452a Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Fri, 16 Feb 2024 16:20:45 -0800 Subject: [PATCH] sgemm_wg: Implement blocking over k-dimension --- tests/regression/sgemm_wg/common.h | 5 +- tests/regression/sgemm_wg/kernel.cpp | 51 ++++++++---- tests/regression/sgemm_wg/main.cpp | 120 +++++++++++++++++---------- 3 files changed, 113 insertions(+), 63 deletions(-) diff --git a/tests/regression/sgemm_wg/common.h b/tests/regression/sgemm_wg/common.h index ef1e85a8..74941562 100644 --- a/tests/regression/sgemm_wg/common.h +++ b/tests/regression/sgemm_wg/common.h @@ -7,8 +7,11 @@ #define DEV_SMEM_START_ADDR 0xff000000 typedef struct { - uint32_t matrix_dim; + uint32_t dim_m; + uint32_t dim_n; + uint32_t dim_k; uint64_t addr_a; + uint64_t addr_b; uint64_t addr_c; } kernel_arg_t; diff --git a/tests/regression/sgemm_wg/kernel.cpp b/tests/regression/sgemm_wg/kernel.cpp index ee09cf99..368d9270 100644 --- a/tests/regression/sgemm_wg/kernel.cpp +++ b/tests/regression/sgemm_wg/kernel.cpp @@ -5,34 +5,49 @@ void kernel_body(int task_id, kernel_arg_t* __UNIFORM__ arg) { const float *global_a = (const float *)arg->addr_a; + const float *global_b = (const float *)arg->addr_b; float *global_c = (float *)arg->addr_c; // assumes NT == NW == matrix_dim - const uint32_t dim = arg->matrix_dim; - const uint32_t row = vx_warp_id(); - const uint32_t col = vx_thread_id(); + const uint32_t dim_m = arg->dim_m; + const uint32_t dim_n = arg->dim_n; + const uint32_t dim_k = arg->dim_k; + const uint32_t block_dim = vx_num_warps(); + const uint32_t local_row = vx_warp_id(); + const uint32_t local_col = vx_thread_id(); - float *local_c = (float *)DEV_SMEM_START_ADDR; - float *local_a = (float *)DEV_SMEM_START_ADDR + (dim * dim); - float *local_b = (float *)DEV_SMEM_START_ADDR + 2 * (dim * dim); + // each thread generates one output element + float reg_c = 0.0f; - local_a[dim * row + col] = global_a[dim * row + col]; - local_c[dim * row + col] = 0.0f; + for (uint32_t k = 0; k < dim_k; k += block_dim) { + float *local_a = (float *)DEV_SMEM_START_ADDR; + float *local_b = (float *)DEV_SMEM_START_ADDR + (block_dim * block_dim); - vx_barrier(0, vx_num_warps()); + // FIXME: assumes local block size is square shape + // TODO: "local_row" should be global_row + uint32_t offset_global_a = dim_k * local_row + (k + local_col); + uint32_t offset_global_b = dim_n * (local_row + k) + local_col; + local_a[block_dim * local_row + local_col] = global_a[offset_global_a]; + local_b[block_dim * local_row + local_col] = global_b[offset_global_b]; - for (uint32_t k = 0; k < dim; k++) { - local_c[dim * row + col] += local_a[dim * row + k] * local_a[dim * k + col]; + vx_barrier(0, vx_num_warps()); + vx_fence(); + + for (uint32_t local_k = 0; local_k < block_dim; local_k++) { + reg_c += local_a[block_dim * local_row + local_k] * + local_b[block_dim * local_k + local_col]; + } + + vx_barrier(0, vx_num_warps()); + vx_fence(); } - vx_barrier(0, vx_num_warps()); - - global_c[dim * row + col] = local_c[dim * row + col]; + global_c[dim_n * local_row + local_col] = reg_c; } int main() { - kernel_arg_t* arg = (kernel_arg_t*)KERNEL_ARG_DEV_MEM_ADDR; - int threads_per_core = vx_num_warps() * vx_num_threads(); - vx_spawn_tasks(threads_per_core, (vx_spawn_tasks_cb)kernel_body, arg); - return 0; + kernel_arg_t *arg = (kernel_arg_t *)KERNEL_ARG_DEV_MEM_ADDR; + int threads_per_core = vx_num_warps() * vx_num_threads(); + vx_spawn_tasks(threads_per_core, (vx_spawn_tasks_cb)kernel_body, arg); + return 0; } diff --git a/tests/regression/sgemm_wg/main.cpp b/tests/regression/sgemm_wg/main.cpp index d12216e4..a6babcb0 100644 --- a/tests/regression/sgemm_wg/main.cpp +++ b/tests/regression/sgemm_wg/main.cpp @@ -21,7 +21,8 @@ const char* kernel_file = "kernel.bin"; uint32_t count = 0; -std::vector src_data; +std::vector src_a_data; +std::vector src_b_data; std::vector ref_data; vx_device_h device = nullptr; @@ -58,37 +59,43 @@ static void parse_args(int argc, char **argv) { void cleanup() { if (device) { vx_mem_free(device, kernel_arg.addr_a); + vx_mem_free(device, kernel_arg.addr_b); vx_mem_free(device, kernel_arg.addr_c); vx_dev_close(device); } } -void generate_source_matrix(uint32_t dim) { - src_data.resize(dim * dim); +void generate_source_matrix(uint32_t dim_m, uint32_t dim_n, uint32_t dim_k) { + src_a_data.resize(dim_m * dim_k); + src_b_data.resize(dim_k * dim_n); - for (uint32_t i = 0; i < dim * dim; ++i) { - src_data[i] = static_cast(i); - std::cout << i << ": value=" << src_data[i] << std::endl; + for (uint32_t i = 0; i < src_a_data.size(); ++i) { + src_a_data[i] = static_cast(i); + std::cout << "A: " << i << ": value=" << src_a_data[i] << std::endl; + } + for (uint32_t i = 0; i < src_b_data.size(); ++i) { + src_b_data[i] = static_cast(i); + std::cout << "B: " << i << ": value=" << src_b_data[i] << std::endl; } } -void generate_reference_matmul(uint32_t dim) { - ref_data.resize(dim * dim); +void generate_reference_matmul(uint32_t dim_m, uint32_t dim_n, uint32_t dim_k) { + ref_data.resize(dim_m * dim_n); - for (uint32_t i = 0; i < dim; ++i) { - for (uint32_t j = 0; j < dim; ++j) { + for (uint32_t i = 0; i < dim_m; ++i) { + for (uint32_t j = 0; j < dim_n; ++j) { float ref = 0.0f; - for (uint32_t k = 0; k < dim; ++k) { - ref += src_data[dim * i + k] * src_data[dim * k + j]; + for (uint32_t k = 0; k < dim_k; ++k) { + ref += src_a_data[dim_k * i + k] * src_b_data[dim_n * k + j]; } - ref_data.at(dim * i + j) = ref; + ref_data.at(dim_n * i + j) = ref; } } } int run_test(const kernel_arg_t& kernel_arg, uint32_t buf_size, - uint32_t dim) { + uint32_t dim_m, uint32_t dim_n) { // start device std::cout << "start device" << std::endl; RT_CHECK(vx_start(device)); @@ -106,7 +113,7 @@ int run_test(const kernel_arg_t& kernel_arg, { int errors = 0; auto buf_ptr = (float*)staging_buf.data(); - for (uint32_t i = 0; i < dim * dim; ++i) { + for (uint32_t i = 0; i < dim_m * dim_n; ++i) { float ref = ref_data.at(i); float cur = buf_ptr[i]; if (cur != ref) { @@ -139,16 +146,17 @@ int main(int argc, char *argv[]) { std::cout << "open device connection" << std::endl; RT_CHECK(vx_dev_open(&device)); - uint32_t matrix_size = count; - uint32_t matrix_dim = 4; // FIXME: hardcoded + uint32_t dim_m = 4; // FIXME: hardcoded + uint32_t dim_n = 4; // FIXME: hardcoded + uint32_t dim_k = 128; // FIXME: hardcoded - generate_source_matrix(matrix_dim); - generate_reference_matmul(matrix_dim); + generate_source_matrix(dim_m, dim_n, dim_k); + generate_reference_matmul(dim_m, dim_n, dim_k); - uint32_t src_buf_size = src_data.size() * sizeof(src_data[0]); - uint32_t dst_buf_size = ref_data.size() * sizeof(src_data[0]); + uint32_t src_a_buf_size = src_a_data.size() * sizeof(src_a_data[0]); + uint32_t src_b_buf_size = src_b_data.size() * sizeof(src_b_data[0]); + uint32_t dst_buf_size = ref_data.size() * sizeof(src_a_data[0]); - std::cout << "number of elements: " << matrix_size << std::endl; std::cout << "buffer size: " << dst_buf_size << " bytes" << std::endl; // upload program @@ -157,20 +165,26 @@ int main(int argc, char *argv[]) { // allocate device memory std::cout << "allocate device memory" << std::endl; - RT_CHECK(vx_mem_alloc(device, src_buf_size, VX_MEM_TYPE_GLOBAL, &kernel_arg.addr_a)); + RT_CHECK(vx_mem_alloc(device, src_a_buf_size, VX_MEM_TYPE_GLOBAL, &kernel_arg.addr_a)); + RT_CHECK(vx_mem_alloc(device, src_b_buf_size, VX_MEM_TYPE_GLOBAL, &kernel_arg.addr_b)); RT_CHECK(vx_mem_alloc(device, dst_buf_size, VX_MEM_TYPE_GLOBAL, &kernel_arg.addr_c)); - kernel_arg.matrix_dim = matrix_dim; + kernel_arg.dim_m = dim_m; + kernel_arg.dim_n = dim_n; + kernel_arg.dim_k = dim_k; - std::cout << "dev_src=0x" << std::hex << kernel_arg.addr_a << std::endl; - std::cout << "dev_dst=0x" << std::hex << kernel_arg.addr_c << std::endl; + std::cout << "dev_addr_a=0x" << std::hex << kernel_arg.addr_a << std::endl; + std::cout << "dev_addr_b=0x" << std::hex << kernel_arg.addr_b << std::endl; + std::cout << "dev_addr_c=0x" << std::hex << kernel_arg.addr_c << std::endl; // allocate staging buffer { std::cout << "allocate staging buffer" << std::endl; - uint32_t staging_buf_size = std::max(src_buf_size, - std::max(dst_buf_size, - sizeof(kernel_arg_t))); + uint32_t staging_buf_size = std::max( + src_a_buf_size, + std::max( + src_b_buf_size, + std::max(dst_buf_size, sizeof(kernel_arg_t)))); staging_buf.resize(staging_buf_size); } @@ -196,28 +210,47 @@ int main(int argc, char *argv[]) { // upload source buffer { - std::cout << "upload source buffer" << std::endl; - auto buf_ptr = staging_buf.data(); - memcpy(buf_ptr, src_data.data(), matrix_size * sizeof(float)); - RT_CHECK(vx_copy_to_dev(device, kernel_arg.addr_a, staging_buf.data(), src_buf_size)); + { + auto buf_ptr = staging_buf.data(); + memcpy(buf_ptr, src_a_data.data(), src_a_data.size() * sizeof(float)); + RT_CHECK(vx_copy_to_dev(device, kernel_arg.addr_a, staging_buf.data(), + src_a_buf_size)); - std::cout << "uploading source buffer to device, device mem address=" - << std::hex << kernel_arg.addr_a << ", size=" << std::dec - << src_buf_size << " bytes\n"; - std::ofstream file("input.bin", std::ios::binary | std::ios::out); - if (!file) { + std::cout << "uploading source A matrix to device, device mem address=" + << std::hex << kernel_arg.addr_a << ", size=" << std::dec + << src_a_buf_size << " bytes\n"; + std::ofstream file("input.a.bin", std::ios::binary | std::ios::out); + if (!file) { std::cerr << "error: failed to open args.bin for writing\n"; exit(EXIT_FAILURE); + } + file.write(reinterpret_cast(buf_ptr), src_a_buf_size); + file.close(); + } + { + auto buf_ptr = staging_buf.data(); + memcpy(buf_ptr, src_b_data.data(), src_b_data.size() * sizeof(float)); + RT_CHECK(vx_copy_to_dev(device, kernel_arg.addr_b, staging_buf.data(), + src_b_buf_size)); + + std::cout << "uploading source B matrix to device, device mem address=" + << std::hex << kernel_arg.addr_b << ", size=" << std::dec + << src_b_buf_size << " bytes\n"; + std::ofstream file("input.b.bin", std::ios::binary | std::ios::out); + if (!file) { + std::cerr << "error: failed to open args.bin for writing\n"; + exit(EXIT_FAILURE); + } + file.write(reinterpret_cast(buf_ptr), src_b_buf_size); + file.close(); } - file.write(reinterpret_cast(buf_ptr), src_buf_size); - file.close(); } // clear destination buffer { std::cout << "clear destination buffer" << std::endl; auto buf_ptr = (int32_t*)staging_buf.data(); - for (uint32_t i = 0; i < matrix_size; ++i) { + for (uint32_t i = 0; i < ref_data.size(); ++i) { buf_ptr[i] = 0xdeadbeef; } RT_CHECK(vx_copy_to_dev(device, kernel_arg.addr_c, staging_buf.data(), dst_buf_size)); @@ -225,13 +258,12 @@ int main(int argc, char *argv[]) { // run tests std::cout << "run tests" << std::endl; - RT_CHECK(run_test(kernel_arg, dst_buf_size, kernel_arg.matrix_dim)); + RT_CHECK(run_test(kernel_arg, dst_buf_size, kernel_arg.dim_m, kernel_arg.dim_n)); + std::cout << "PASSED!" << std::endl; // cleanup std::cout << "cleanup" << std::endl; cleanup(); - std::cout << "PASSED!" << std::endl; - return 0; }