sgemm_wg: Implement software barrier for inter-core synchronization

This commit is contained in:
Hansung Kim
2024-03-12 15:34:42 -07:00
parent fbe872c831
commit 510a834db5
2 changed files with 41 additions and 10 deletions

View File

@@ -1,5 +1,6 @@
#include <stdint.h>
#include <vx_intrinsics.h>
#include <vx_print.h>
#include <vx_spawn.h>
#include "common.h"
@@ -8,8 +9,35 @@
#define BK 2
// #define TM (BM/BK)
// #define TN (BN/BK)
#define TM 4
#define TN 4
#define TM 2
#define TN 2
#define DEV_BARRIER_MMIO_BASE_ADDR 0xff003f00UL
#define CORES_PER_CLUSTER 4
void threadblock_barrier(unsigned int barrier_id, unsigned int count) {
vx_barrier(barrier_id, count);
vx_fence();
#if CORES_PER_CLUSTER != 1
if (vx_thread_id() == 0) {
volatile uint32_t *mmio = (volatile uint32_t *)(DEV_BARRIER_MMIO_BASE_ADDR);
int core_id = vx_core_id();
const uint32_t barrier_stride = CORES_PER_CLUSTER;
const uint32_t barrier_offset = barrier_stride * barrier_id;
// 1 : 0x00 is reserved for mmio read reg
mmio[barrier_offset + 1 + core_id] = 1;
vx_printf("========== barrier written! barrier_id=%u, count=%u\n", barrier_id, count);
// wait for other cores in the cluster to finish by waiting on the
// all-synced read-only mmio reg
while (mmio[barrier_offset] == 0);
// reset per-core flag back to zero for the next barrier
mmio[barrier_offset + 1 + core_id] = 0;
}
#endif
}
void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
const uint32_t tid_in_threadblock,
@@ -73,8 +101,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
B[global_b_offset];
}
vx_barrier(threadblock_id_in_core, threadblock_dim_y);
vx_fence();
threadblock_barrier(threadblock_id_in_core, threadblock_dim_y);
for (uint32_t local_k = 0; local_k < BK; local_k++) {
#pragma GCC unroll TM
@@ -103,8 +130,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
}
}
vx_barrier(threadblock_id_in_core, threadblock_dim_y);
vx_fence();
threadblock_barrier(threadblock_id_in_core, threadblock_dim_y);
}
#pragma GCC unroll TM
@@ -123,7 +149,7 @@ void kernel_body(int task_id, kernel_arg_t* __UNIFORM__ arg) {
// @perf: All threads are running these compute whose result is mostly same
// across the threadblock
const uint32_t threads_per_threadblock = ((BM * BN) / (TM * TN));
const uint32_t threads_per_threadblock = (BM * BN) / (TM * TN);
const uint32_t threadblocks_per_core =
vx_num_threads() * vx_num_warps() / threads_per_threadblock;
const uint32_t threadblock_dim_x = vx_num_threads();
@@ -138,6 +164,11 @@ void kernel_body(int task_id, kernel_arg_t* __UNIFORM__ arg) {
const int threadblock_id_x = threadblock_id % dim_n_in_blocks;
const int threadblock_id_y = threadblock_id / dim_n_in_blocks;
// initialize barrier MMIO
volatile uint32_t *barrier_mmio = (volatile uint32_t *)(DEV_BARRIER_MMIO_BASE_ADDR);
*barrier_mmio = 0;
vx_fence();
float *sharedmem_per_threadblock =
(float *)DEV_SMEM_START_ADDR +
(2 * BM * BK) * threadblock_id_in_core;

View File

@@ -147,9 +147,9 @@ int main(int argc, char *argv[]) {
RT_CHECK(vx_dev_open(&device));
// FIXME: hardcoded
uint32_t dim_m = 64;
uint32_t dim_n = 64;
uint32_t dim_k = 64;
uint32_t dim_m = 32;
uint32_t dim_n = 32;
uint32_t dim_k = 32;
generate_source_matrix(dim_m, dim_n, dim_k);
generate_reference_matmul(dim_m, dim_n, dim_k);