122 lines
3.6 KiB
C++
122 lines
3.6 KiB
C++
#include "fp8_common.hpp"
|
|
#include "../wu_arch_cases/common_wu_min.h"
|
|
|
|
#define DEV_SMEM_START_ADDR 0xff000000u
|
|
#define FP8_VALIDATION_BASE 0x7800u
|
|
|
|
#define FP8_M 16u
|
|
#define FP8_N 16u
|
|
#define FP8_K 32u
|
|
#define FP8_TILE_BYTES 1024u
|
|
#define FP8_FRAGMENT_BYTES 16u
|
|
#define FP8_FRAGMENT_WORDS (FP8_FRAGMENT_BYTES / sizeof(uint32_t))
|
|
#define FP8_FRAGMENTS (FP8_TILE_BYTES / FP8_FRAGMENT_BYTES)
|
|
#define FP8_OUT_WORDS (FP8_M * FP8_N)
|
|
#define FP8_EXPECTED 0x42820000u
|
|
|
|
extern "C" {
|
|
volatile uint32_t g_fp8_a_frag[FP8_FRAGMENT_WORDS] __attribute__((aligned(16))) = {
|
|
WU_FP8_REP4(WU_FP8_PACK4(WU_FP8_E4M3_ONE, WU_FP8_E4M3_ONE,
|
|
WU_FP8_E4M3_ONE, WU_FP8_E4M3_ONE))};
|
|
volatile uint32_t g_fp8_b_frag[FP8_FRAGMENT_WORDS] __attribute__((aligned(16))) = {
|
|
WU_FP8_REP4(WU_FP8_PACK4(WU_FP8_E4M3_TWO, WU_FP8_E4M3_TWO,
|
|
WU_FP8_E4M3_TWO, WU_FP8_E4M3_TWO))};
|
|
volatile uint32_t g_fp8_c_frag[FP8_FRAGMENT_WORDS] __attribute__((aligned(16))) = {
|
|
WU_FP8_REP4(0x3f800000u)};
|
|
volatile uint32_t g_fp8_out[FP8_OUT_WORDS] __attribute__((aligned(16)));
|
|
}
|
|
|
|
#undef WU_FP8_REP2
|
|
#undef WU_FP8_REP4
|
|
#undef WU_FP8_REP8
|
|
|
|
extern "C" void __attribute__((naked, noinline, used)) fp8_validation_worker() {
|
|
asm volatile(
|
|
"li x1, 0\n\t"
|
|
"li x2, %[tile_bytes]\n\t"
|
|
"la x6, g_fp8_a_frag\n\t"
|
|
"la x3, g_fp8_c_frag\n\t"
|
|
"li x7, 0\n\t"
|
|
"1:\n\t"
|
|
"add x4, x1, x7\n\t"
|
|
".insn r %[custom3], 2, 0, x0, x4, x6\n\t"
|
|
"add x4, x2, x7\n\t"
|
|
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
|
|
"addi x7, x7, %[frag_bytes]\n\t"
|
|
"li x4, %[tile_bytes]\n\t"
|
|
"blt x7, x4, 1b\n\t"
|
|
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
|
"li x4, %[smem_base]\n\t"
|
|
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
|
|
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
|
|
"la x3, g_fp8_out\n\t"
|
|
"li x7, 0\n\t"
|
|
"2:\n\t"
|
|
"add x4, x2, x7\n\t"
|
|
"add x1, x3, x7\n\t"
|
|
".insn r %[custom3], 6, 0, x0, x4, x1\n\t"
|
|
"addi x7, x7, %[frag_bytes]\n\t"
|
|
"li x4, %[tile_bytes]\n\t"
|
|
"blt x7, x4, 2b\n\t"
|
|
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
|
"csrr x5, %[csr_wid]\n\t"
|
|
"slli x6, x5, 2\n\t"
|
|
"la x7, g_seen\n\t"
|
|
"add x7, x7, x6\n\t"
|
|
"li x6, %[done_base]\n\t"
|
|
"or x6, x6, x5\n\t"
|
|
"sw x6, 0(x7)\n\t"
|
|
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
|
|
"3: j 3b\n\t"
|
|
:
|
|
: [csr_wid] "i"(VX_CSR_WARP_ID),
|
|
[custom0] "i"(RISCV_CUSTOM0),
|
|
[custom3] "i"(RISCV_CUSTOM3),
|
|
[smem_base] "i"(DEV_SMEM_START_ADDR),
|
|
[done_base] "i"(FP8_VALIDATION_BASE),
|
|
[tile_bytes] "i"(FP8_TILE_BYTES),
|
|
[frag_bytes] "i"(FP8_FRAGMENT_BYTES)
|
|
: "memory");
|
|
}
|
|
|
|
extern "C" int wu_main() {
|
|
if (!wu_is_leader()) {
|
|
return 0;
|
|
}
|
|
|
|
wu_case_reset();
|
|
|
|
for (uint32_t i = 0; i < FP8_OUT_WORDS; ++i) {
|
|
g_fp8_out[i] = 0;
|
|
}
|
|
|
|
volatile uint32_t *smem_b =
|
|
reinterpret_cast<volatile uint32_t *>(DEV_SMEM_START_ADDR);
|
|
for (uint32_t frag = 0; frag < FP8_FRAGMENTS; ++frag) {
|
|
const uint32_t row = frag * FP8_FRAGMENT_WORDS;
|
|
for (uint32_t i = 0; i < FP8_FRAGMENT_WORDS; ++i) {
|
|
smem_b[row + i] = g_fp8_b_frag[i];
|
|
}
|
|
}
|
|
|
|
const uint32_t tensor_wid = NUM_SCALAR_WARPS;
|
|
vx_spawn_tensor(1u << tensor_wid, fp8_validation_worker);
|
|
|
|
if (wu_wait_seen_mask(1u << tensor_wid, FP8_VALIDATION_BASE) != 0) {
|
|
wu_case_fail(0x09u);
|
|
return 1;
|
|
}
|
|
|
|
for (uint32_t i = 0; i < FP8_OUT_WORDS; ++i) {
|
|
if (g_fp8_out[i] != FP8_EXPECTED) {
|
|
g_aux[0] = i;
|
|
g_aux[1] = g_fp8_out[i];
|
|
wu_case_fail(0x20u);
|
|
return 1;
|
|
}
|
|
}
|
|
|
|
wu_case_pass();
|
|
return 0;
|
|
}
|