runtime instrinsics refactoring using RISC-V custom instruction assmebly directives
This commit is contained in:
@@ -1,99 +0,0 @@
|
||||
#include <VX_config.h>
|
||||
|
||||
.section .text
|
||||
|
||||
.type vx_wspawn, @function
|
||||
.global vx_wspawn
|
||||
vx_wspawn:
|
||||
.word 0x00b5106b # wspawn a0(num_warps), a1(func_ptr)
|
||||
ret
|
||||
|
||||
.type vx_tmc, @function
|
||||
.global vx_tmc
|
||||
vx_tmc:
|
||||
.word 0x0005006b # tmc a0
|
||||
ret
|
||||
|
||||
.type vx_barrier, @function
|
||||
.global vx_barrier
|
||||
vx_barrier:
|
||||
.word 0x00b5406b # barrier a0(barrier_id), a1(num_warps)
|
||||
ret
|
||||
|
||||
.type vx_split, @function
|
||||
.global vx_split
|
||||
vx_split:
|
||||
.word 0x0005206b # split a0
|
||||
ret
|
||||
|
||||
.type vx_join, @function
|
||||
.global vx_join
|
||||
vx_join:
|
||||
.word 0x0000306b #join
|
||||
ret
|
||||
|
||||
.type vx_warp_id, @function
|
||||
.global vx_warp_id
|
||||
vx_warp_id:
|
||||
csrr a0, CSR_LWID
|
||||
ret
|
||||
|
||||
.type vx_warp_gid, @function
|
||||
.global vx_warp_gid
|
||||
vx_warp_gid:
|
||||
csrr a0, CSR_GWID
|
||||
ret
|
||||
|
||||
.type vx_thread_id, @function
|
||||
.global vx_thread_id
|
||||
vx_thread_id:
|
||||
csrr a0, CSR_WTID
|
||||
ret
|
||||
|
||||
.type vx_thread_lid, @function
|
||||
.global vx_thread_lid
|
||||
vx_thread_lid:
|
||||
csrr a0, CSR_LTID
|
||||
ret
|
||||
|
||||
.type vx_thread_gid, @function
|
||||
.global vx_thread_gid
|
||||
vx_thread_gid:
|
||||
csrr a0, CSR_GTID
|
||||
ret
|
||||
|
||||
.type vx_core_id, @function
|
||||
.global vx_core_id
|
||||
vx_core_id:
|
||||
csrr a0, CSR_GCID
|
||||
ret
|
||||
|
||||
.type vx_num_threads, @function
|
||||
.global vx_num_threads
|
||||
vx_num_threads:
|
||||
csrr a0, CSR_NT
|
||||
ret
|
||||
|
||||
.type vx_num_warps, @function
|
||||
.global vx_num_warps
|
||||
vx_num_warps:
|
||||
csrr a0, CSR_NW
|
||||
ret
|
||||
|
||||
.type vx_num_cores, @function
|
||||
.global vx_num_cores
|
||||
vx_num_cores:
|
||||
csrr a0, CSR_NC
|
||||
ret
|
||||
|
||||
.type vx_num_cycles, @function
|
||||
.global vx_num_cycles
|
||||
vx_num_cycles:
|
||||
csrr a0, CSR_CYCLE
|
||||
ret
|
||||
|
||||
.type vx_num_instrs, @function
|
||||
.global vx_num_instrs
|
||||
vx_num_instrs:
|
||||
csrr a0, CSR_INSTRET
|
||||
ret
|
||||
@@ -12,13 +12,34 @@ extern "C" {
|
||||
|
||||
typedef struct {
|
||||
pfn_callback callback;
|
||||
void * args;
|
||||
const void * args;
|
||||
int offset;
|
||||
int N;
|
||||
int R;
|
||||
} wspawn_args_t;
|
||||
} wspawn_tasks_args_t;
|
||||
|
||||
wspawn_args_t* g_wspawn_args[NUM_CORES_MAX];
|
||||
typedef struct {
|
||||
struct context_t * ctx;
|
||||
pfn_workgroup_func wg_func;
|
||||
const void * args;
|
||||
int offset;
|
||||
int N;
|
||||
int R;
|
||||
char isXYpow2;
|
||||
char isXpow2;
|
||||
char log2XY;
|
||||
char log2X;
|
||||
} wspawn_kernel_args_t;
|
||||
|
||||
void* g_wspawn_args[NUM_CORES_MAX];
|
||||
|
||||
inline char is_log2(int x) {
|
||||
return ((x & (x-1)) == 0);
|
||||
}
|
||||
|
||||
inline int fast_log2(int x) {
|
||||
return (*(int*)(&x)>>23) - 127;
|
||||
}
|
||||
|
||||
void spawn_tasks_callback() {
|
||||
vx_tmc(vx_num_threads());
|
||||
@@ -28,7 +49,7 @@ void spawn_tasks_callback() {
|
||||
int tid = vx_thread_id();
|
||||
int NT = vx_num_threads();
|
||||
|
||||
wspawn_args_t* p_wspawn_args = g_wspawn_args[core_id];
|
||||
wspawn_tasks_args_t* p_wspawn_args = (wspawn_tasks_args_t*)g_wspawn_args[core_id];
|
||||
|
||||
int wK = (p_wspawn_args->N * wid) + MIN(p_wspawn_args->R, wid);
|
||||
int tK = p_wspawn_args->N + (wid < p_wspawn_args->R);
|
||||
@@ -47,7 +68,7 @@ void spawn_remaining_tasks_callback(int nthreads) {
|
||||
int core_id = vx_core_id();
|
||||
int tid = vx_thread_gid();
|
||||
|
||||
wspawn_args_t* p_wspawn_args = g_wspawn_args[core_id];
|
||||
wspawn_tasks_args_t* p_wspawn_args = (wspawn_tasks_args_t*)g_wspawn_args[core_id];
|
||||
|
||||
int task_id = p_wspawn_args->offset + tid;
|
||||
(p_wspawn_args->callback)(task_id, p_wspawn_args->args);
|
||||
@@ -55,7 +76,7 @@ void spawn_remaining_tasks_callback(int nthreads) {
|
||||
vx_tmc(1);
|
||||
}
|
||||
|
||||
void vx_spawn_tasks(int num_tasks, pfn_callback callback , void * args) {
|
||||
void vx_spawn_tasks(int num_tasks, pfn_callback callback , const void * args) {
|
||||
// device specs
|
||||
int NC = vx_num_cores();
|
||||
int NW = vx_num_warps();
|
||||
@@ -90,7 +111,7 @@ void vx_spawn_tasks(int num_tasks, pfn_callback callback , void * args) {
|
||||
fW = 1;
|
||||
|
||||
//--
|
||||
wspawn_args_t wspawn_args = { callback, args, core_id * tasks_per_core, fW, rW };
|
||||
wspawn_tasks_args_t wspawn_args = { callback, args, core_id * tasks_per_core, fW, rW };
|
||||
g_wspawn_args[core_id] = &wspawn_args;
|
||||
|
||||
//--
|
||||
@@ -107,6 +128,135 @@ void vx_spawn_tasks(int num_tasks, pfn_callback callback , void * args) {
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void spawn_kernel_callback() {
|
||||
vx_tmc(vx_num_threads());
|
||||
|
||||
int core_id = vx_core_id();
|
||||
int wid = vx_warp_id();
|
||||
int tid = vx_thread_id();
|
||||
int NT = vx_num_threads();
|
||||
|
||||
wspawn_kernel_args_t* p_wspawn_args = (wspawn_kernel_args_t*)g_wspawn_args[core_id];
|
||||
|
||||
int wK = (p_wspawn_args->N * wid) + MIN(p_wspawn_args->R, wid);
|
||||
int tK = p_wspawn_args->N + (wid < p_wspawn_args->R);
|
||||
int offset = p_wspawn_args->offset + (wK * NT) + (tid * tK);
|
||||
|
||||
int X = p_wspawn_args->ctx->num_groups[0];
|
||||
int Y = p_wspawn_args->ctx->num_groups[1];
|
||||
int XY = X * Y;
|
||||
|
||||
for (int wg_id = offset, N = wg_id + tK; wg_id < N; ++wg_id) {
|
||||
int k = p_wspawn_args->isXYpow2 ? (wg_id / XY) : (wg_id >> p_wspawn_args->log2XY);
|
||||
int wg_2d = wg_id - k * XY;
|
||||
int j = p_wspawn_args->isXpow2 ? (wg_2d / X) : (wg_2d >> p_wspawn_args->log2X);
|
||||
int i = wg_2d - j * X;
|
||||
|
||||
int gid0 = p_wspawn_args->ctx->global_offset[0] + i;
|
||||
int gid1 = p_wspawn_args->ctx->global_offset[1] + j;
|
||||
int gid2 = p_wspawn_args->ctx->global_offset[2] + k;
|
||||
|
||||
(p_wspawn_args->wg_func)(p_wspawn_args->args, p_wspawn_args->ctx, gid0, gid1, gid2);
|
||||
}
|
||||
|
||||
vx_tmc(0 == wid);
|
||||
}
|
||||
|
||||
void spawn_kernel_remaining_callback(int nthreads) {
|
||||
vx_tmc(nthreads);
|
||||
|
||||
int core_id = vx_core_id();
|
||||
int tid = vx_thread_gid();
|
||||
|
||||
wspawn_kernel_args_t* p_wspawn_args = (wspawn_kernel_args_t*)g_wspawn_args[core_id];
|
||||
|
||||
int wg_id = p_wspawn_args->offset + tid;
|
||||
|
||||
int X = p_wspawn_args->ctx->num_groups[0];
|
||||
int Y = p_wspawn_args->ctx->num_groups[1];
|
||||
int XY = X * Y;
|
||||
|
||||
int k = p_wspawn_args->isXYpow2 ? (wg_id / XY) : (wg_id >> p_wspawn_args->log2XY);
|
||||
int wg_2d = wg_id - k * XY;
|
||||
int j = p_wspawn_args->isXpow2 ? (wg_2d / X) : (wg_2d >> p_wspawn_args->log2X);
|
||||
int i = wg_2d - j * X;
|
||||
|
||||
int gid0 = p_wspawn_args->ctx->global_offset[0] + i;
|
||||
int gid1 = p_wspawn_args->ctx->global_offset[1] + j;
|
||||
int gid2 = p_wspawn_args->ctx->global_offset[2] + k;
|
||||
|
||||
(p_wspawn_args->wg_func)(p_wspawn_args->args, p_wspawn_args->ctx, gid0, gid1, gid2);
|
||||
|
||||
vx_tmc(1);
|
||||
}
|
||||
|
||||
void vx_spawn_kernel(struct context_t * ctx, pfn_workgroup_func wg_func, const void * args) {
|
||||
// total number of WGs
|
||||
int X = ctx->num_groups[0];
|
||||
int Y = ctx->num_groups[1];
|
||||
int Z = ctx->num_groups[2];
|
||||
int XY = X * Y;
|
||||
int Q = XY * Z;
|
||||
|
||||
// device specs
|
||||
int NC = vx_num_cores();
|
||||
int NW = vx_num_warps();
|
||||
int NT = vx_num_threads();
|
||||
|
||||
// current core id
|
||||
int core_id = vx_core_id();
|
||||
if (core_id >= NUM_CORES_MAX)
|
||||
return;
|
||||
|
||||
// calculate necessary active cores
|
||||
int WT = NW * NT;
|
||||
int nC = (Q > WT) ? (Q / WT) : 1;
|
||||
int nc = MIN(nC, NC);
|
||||
if (core_id >= nc)
|
||||
return; // terminate extra cores
|
||||
|
||||
// number of workgroups per core
|
||||
int wgs_per_core = Q / nc;
|
||||
int wgs_per_core0 = wgs_per_core;
|
||||
if (core_id == (NC-1)) {
|
||||
int QC_r = Q - (nc * wgs_per_core0);
|
||||
wgs_per_core0 += QC_r; // last core executes remaining WGs
|
||||
}
|
||||
|
||||
// number of workgroups per warp
|
||||
int nW = wgs_per_core0 / NT; // total warps per core
|
||||
int rT = wgs_per_core0 - (nW * NT); // remaining threads
|
||||
int fW = (nW >= NW) ? (nW / NW) : 0; // full warps iterations
|
||||
int rW = (fW != 0) ? (nW - fW * NW) : 0; // reamining full warps
|
||||
if (0 == fW)
|
||||
fW = 1;
|
||||
|
||||
// fast path handling
|
||||
char isXYpow2 = is_log2(XY);
|
||||
char isXpow2 = is_log2(X);
|
||||
char log2XY = fast_log2(XY);
|
||||
char log2X = fast_log2(X);
|
||||
|
||||
//--
|
||||
wspawn_kernel_args_t wspawn_args = { ctx, wg_func, args, core_id * wgs_per_core, fW, rW, isXYpow2, isXpow2, log2XY, log2X };
|
||||
g_wspawn_args[core_id] = &wspawn_args;
|
||||
|
||||
//--
|
||||
if (nW >= 1) {
|
||||
int nw = MIN(nW, NW);
|
||||
vx_wspawn(nw, (unsigned)&spawn_kernel_callback);
|
||||
spawn_kernel_callback();
|
||||
}
|
||||
|
||||
//--
|
||||
if (rT != 0) {
|
||||
wspawn_args.offset = wgs_per_core0 - rT;
|
||||
spawn_kernel_remaining_callback(rT);
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
@@ -8,12 +8,12 @@ _start:
|
||||
# execute stack initialization on all warps
|
||||
la a1, vx_set_sp
|
||||
csrr a0, CSR_NW # get num warps
|
||||
.word 0x00b5106b # wspawn a0, a1
|
||||
.insn s 0x6b, 1, a1, 0(a0) # wspawn a0, a1
|
||||
jal vx_set_sp
|
||||
|
||||
# return back to single thread execution
|
||||
li a0, 1
|
||||
.word 0x0005006b # tmc a0
|
||||
.insn s 0x6b, 0, x0, 0(a0) # tmc a0
|
||||
|
||||
# Clear the bss segment
|
||||
la a0, _edata
|
||||
@@ -44,15 +44,15 @@ _start:
|
||||
_exit:
|
||||
# disable all threads in current warp
|
||||
li a0, 0
|
||||
.word 0x0005006b # tmc a0
|
||||
.insn s 0x6b, 0, x0, 0(a0) # tmc a0
|
||||
|
||||
.section .text
|
||||
.type vx_set_sp, @function
|
||||
.global vx_set_sp
|
||||
vx_set_sp:
|
||||
# activate all threads
|
||||
csrr a0, CSR_NT # get num threads
|
||||
.word 0x0005006b # set thread mask
|
||||
csrr a0, CSR_NT # get num threads
|
||||
.insn s 0x6b, 0, x0, 0(a0) # tmc a0
|
||||
|
||||
# set global pointer register
|
||||
.option push
|
||||
@@ -76,7 +76,7 @@ vx_set_sp:
|
||||
csrr a3, CSR_LWID # get local wid
|
||||
beqz a3, RETURN
|
||||
li a0, 0
|
||||
.word 0x0005006b # tmc a0
|
||||
.insn s 0x6b, 0, x0, 0(a0) # tmc a0
|
||||
RETURN:
|
||||
ret
|
||||
|
||||
|
||||
Reference in New Issue
Block a user