runtime instrinsics refactoring using RISC-V custom instruction assmebly directives

This commit is contained in:
Blaise Tine
2021-02-04 15:15:20 -05:00
parent a9f82bceae
commit b047f589d6
44 changed files with 90586 additions and 90486 deletions

View File

@@ -1,99 +0,0 @@
#include <VX_config.h>
.section .text
.type vx_wspawn, @function
.global vx_wspawn
vx_wspawn:
.word 0x00b5106b # wspawn a0(num_warps), a1(func_ptr)
ret
.type vx_tmc, @function
.global vx_tmc
vx_tmc:
.word 0x0005006b # tmc a0
ret
.type vx_barrier, @function
.global vx_barrier
vx_barrier:
.word 0x00b5406b # barrier a0(barrier_id), a1(num_warps)
ret
.type vx_split, @function
.global vx_split
vx_split:
.word 0x0005206b # split a0
ret
.type vx_join, @function
.global vx_join
vx_join:
.word 0x0000306b #join
ret
.type vx_warp_id, @function
.global vx_warp_id
vx_warp_id:
csrr a0, CSR_LWID
ret
.type vx_warp_gid, @function
.global vx_warp_gid
vx_warp_gid:
csrr a0, CSR_GWID
ret
.type vx_thread_id, @function
.global vx_thread_id
vx_thread_id:
csrr a0, CSR_WTID
ret
.type vx_thread_lid, @function
.global vx_thread_lid
vx_thread_lid:
csrr a0, CSR_LTID
ret
.type vx_thread_gid, @function
.global vx_thread_gid
vx_thread_gid:
csrr a0, CSR_GTID
ret
.type vx_core_id, @function
.global vx_core_id
vx_core_id:
csrr a0, CSR_GCID
ret
.type vx_num_threads, @function
.global vx_num_threads
vx_num_threads:
csrr a0, CSR_NT
ret
.type vx_num_warps, @function
.global vx_num_warps
vx_num_warps:
csrr a0, CSR_NW
ret
.type vx_num_cores, @function
.global vx_num_cores
vx_num_cores:
csrr a0, CSR_NC
ret
.type vx_num_cycles, @function
.global vx_num_cycles
vx_num_cycles:
csrr a0, CSR_CYCLE
ret
.type vx_num_instrs, @function
.global vx_num_instrs
vx_num_instrs:
csrr a0, CSR_INSTRET
ret

View File

@@ -12,13 +12,34 @@ extern "C" {
typedef struct {
pfn_callback callback;
void * args;
const void * args;
int offset;
int N;
int R;
} wspawn_args_t;
} wspawn_tasks_args_t;
wspawn_args_t* g_wspawn_args[NUM_CORES_MAX];
typedef struct {
struct context_t * ctx;
pfn_workgroup_func wg_func;
const void * args;
int offset;
int N;
int R;
char isXYpow2;
char isXpow2;
char log2XY;
char log2X;
} wspawn_kernel_args_t;
void* g_wspawn_args[NUM_CORES_MAX];
inline char is_log2(int x) {
return ((x & (x-1)) == 0);
}
inline int fast_log2(int x) {
return (*(int*)(&x)>>23) - 127;
}
void spawn_tasks_callback() {
vx_tmc(vx_num_threads());
@@ -28,7 +49,7 @@ void spawn_tasks_callback() {
int tid = vx_thread_id();
int NT = vx_num_threads();
wspawn_args_t* p_wspawn_args = g_wspawn_args[core_id];
wspawn_tasks_args_t* p_wspawn_args = (wspawn_tasks_args_t*)g_wspawn_args[core_id];
int wK = (p_wspawn_args->N * wid) + MIN(p_wspawn_args->R, wid);
int tK = p_wspawn_args->N + (wid < p_wspawn_args->R);
@@ -47,7 +68,7 @@ void spawn_remaining_tasks_callback(int nthreads) {
int core_id = vx_core_id();
int tid = vx_thread_gid();
wspawn_args_t* p_wspawn_args = g_wspawn_args[core_id];
wspawn_tasks_args_t* p_wspawn_args = (wspawn_tasks_args_t*)g_wspawn_args[core_id];
int task_id = p_wspawn_args->offset + tid;
(p_wspawn_args->callback)(task_id, p_wspawn_args->args);
@@ -55,7 +76,7 @@ void spawn_remaining_tasks_callback(int nthreads) {
vx_tmc(1);
}
void vx_spawn_tasks(int num_tasks, pfn_callback callback , void * args) {
void vx_spawn_tasks(int num_tasks, pfn_callback callback , const void * args) {
// device specs
int NC = vx_num_cores();
int NW = vx_num_warps();
@@ -90,7 +111,7 @@ void vx_spawn_tasks(int num_tasks, pfn_callback callback , void * args) {
fW = 1;
//--
wspawn_args_t wspawn_args = { callback, args, core_id * tasks_per_core, fW, rW };
wspawn_tasks_args_t wspawn_args = { callback, args, core_id * tasks_per_core, fW, rW };
g_wspawn_args[core_id] = &wspawn_args;
//--
@@ -107,6 +128,135 @@ void vx_spawn_tasks(int num_tasks, pfn_callback callback , void * args) {
}
}
///////////////////////////////////////////////////////////////////////////////
void spawn_kernel_callback() {
vx_tmc(vx_num_threads());
int core_id = vx_core_id();
int wid = vx_warp_id();
int tid = vx_thread_id();
int NT = vx_num_threads();
wspawn_kernel_args_t* p_wspawn_args = (wspawn_kernel_args_t*)g_wspawn_args[core_id];
int wK = (p_wspawn_args->N * wid) + MIN(p_wspawn_args->R, wid);
int tK = p_wspawn_args->N + (wid < p_wspawn_args->R);
int offset = p_wspawn_args->offset + (wK * NT) + (tid * tK);
int X = p_wspawn_args->ctx->num_groups[0];
int Y = p_wspawn_args->ctx->num_groups[1];
int XY = X * Y;
for (int wg_id = offset, N = wg_id + tK; wg_id < N; ++wg_id) {
int k = p_wspawn_args->isXYpow2 ? (wg_id / XY) : (wg_id >> p_wspawn_args->log2XY);
int wg_2d = wg_id - k * XY;
int j = p_wspawn_args->isXpow2 ? (wg_2d / X) : (wg_2d >> p_wspawn_args->log2X);
int i = wg_2d - j * X;
int gid0 = p_wspawn_args->ctx->global_offset[0] + i;
int gid1 = p_wspawn_args->ctx->global_offset[1] + j;
int gid2 = p_wspawn_args->ctx->global_offset[2] + k;
(p_wspawn_args->wg_func)(p_wspawn_args->args, p_wspawn_args->ctx, gid0, gid1, gid2);
}
vx_tmc(0 == wid);
}
void spawn_kernel_remaining_callback(int nthreads) {
vx_tmc(nthreads);
int core_id = vx_core_id();
int tid = vx_thread_gid();
wspawn_kernel_args_t* p_wspawn_args = (wspawn_kernel_args_t*)g_wspawn_args[core_id];
int wg_id = p_wspawn_args->offset + tid;
int X = p_wspawn_args->ctx->num_groups[0];
int Y = p_wspawn_args->ctx->num_groups[1];
int XY = X * Y;
int k = p_wspawn_args->isXYpow2 ? (wg_id / XY) : (wg_id >> p_wspawn_args->log2XY);
int wg_2d = wg_id - k * XY;
int j = p_wspawn_args->isXpow2 ? (wg_2d / X) : (wg_2d >> p_wspawn_args->log2X);
int i = wg_2d - j * X;
int gid0 = p_wspawn_args->ctx->global_offset[0] + i;
int gid1 = p_wspawn_args->ctx->global_offset[1] + j;
int gid2 = p_wspawn_args->ctx->global_offset[2] + k;
(p_wspawn_args->wg_func)(p_wspawn_args->args, p_wspawn_args->ctx, gid0, gid1, gid2);
vx_tmc(1);
}
void vx_spawn_kernel(struct context_t * ctx, pfn_workgroup_func wg_func, const void * args) {
// total number of WGs
int X = ctx->num_groups[0];
int Y = ctx->num_groups[1];
int Z = ctx->num_groups[2];
int XY = X * Y;
int Q = XY * Z;
// device specs
int NC = vx_num_cores();
int NW = vx_num_warps();
int NT = vx_num_threads();
// current core id
int core_id = vx_core_id();
if (core_id >= NUM_CORES_MAX)
return;
// calculate necessary active cores
int WT = NW * NT;
int nC = (Q > WT) ? (Q / WT) : 1;
int nc = MIN(nC, NC);
if (core_id >= nc)
return; // terminate extra cores
// number of workgroups per core
int wgs_per_core = Q / nc;
int wgs_per_core0 = wgs_per_core;
if (core_id == (NC-1)) {
int QC_r = Q - (nc * wgs_per_core0);
wgs_per_core0 += QC_r; // last core executes remaining WGs
}
// number of workgroups per warp
int nW = wgs_per_core0 / NT; // total warps per core
int rT = wgs_per_core0 - (nW * NT); // remaining threads
int fW = (nW >= NW) ? (nW / NW) : 0; // full warps iterations
int rW = (fW != 0) ? (nW - fW * NW) : 0; // reamining full warps
if (0 == fW)
fW = 1;
// fast path handling
char isXYpow2 = is_log2(XY);
char isXpow2 = is_log2(X);
char log2XY = fast_log2(XY);
char log2X = fast_log2(X);
//--
wspawn_kernel_args_t wspawn_args = { ctx, wg_func, args, core_id * wgs_per_core, fW, rW, isXYpow2, isXpow2, log2XY, log2X };
g_wspawn_args[core_id] = &wspawn_args;
//--
if (nW >= 1) {
int nw = MIN(nW, NW);
vx_wspawn(nw, (unsigned)&spawn_kernel_callback);
spawn_kernel_callback();
}
//--
if (rT != 0) {
wspawn_args.offset = wgs_per_core0 - rT;
spawn_kernel_remaining_callback(rT);
}
}
#ifdef __cplusplus
}
#endif

View File

@@ -8,12 +8,12 @@ _start:
# execute stack initialization on all warps
la a1, vx_set_sp
csrr a0, CSR_NW # get num warps
.word 0x00b5106b # wspawn a0, a1
.insn s 0x6b, 1, a1, 0(a0) # wspawn a0, a1
jal vx_set_sp
# return back to single thread execution
li a0, 1
.word 0x0005006b # tmc a0
.insn s 0x6b, 0, x0, 0(a0) # tmc a0
# Clear the bss segment
la a0, _edata
@@ -44,15 +44,15 @@ _start:
_exit:
# disable all threads in current warp
li a0, 0
.word 0x0005006b # tmc a0
.insn s 0x6b, 0, x0, 0(a0) # tmc a0
.section .text
.type vx_set_sp, @function
.global vx_set_sp
vx_set_sp:
# activate all threads
csrr a0, CSR_NT # get num threads
.word 0x0005006b # set thread mask
csrr a0, CSR_NT # get num threads
.insn s 0x6b, 0, x0, 0(a0) # tmc a0
# set global pointer register
.option push
@@ -76,7 +76,7 @@ vx_set_sp:
csrr a3, CSR_LWID # get local wid
beqz a3, RETURN
li a0, 0
.word 0x0005006b # tmc a0
.insn s 0x6b, 0, x0, 0(a0) # tmc a0
RETURN:
ret