Compare commits
22 Commits
dde3602046
...
wu-blackwe
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8a15e5251e | ||
|
|
3f7ce1f1c9 | ||
|
|
f1aa1303d2 | ||
|
|
d6fbd447c3 | ||
|
|
ed16541c8e | ||
| 122a048ea6 | |||
|
|
9f4be1b8f7 | ||
| e7229dae27 | |||
| 8f7dba5920 | |||
| bcc566b621 | |||
|
|
71f713b9fc | ||
|
|
9847072eff | ||
|
|
f8c51669c1 | ||
|
|
17a9d31be5 | ||
|
|
238b942133 | ||
|
|
2c1ac4e938 | ||
|
|
9cdee597b6 | ||
|
|
6bdc6af607 | ||
|
|
b73147cd06 | ||
|
|
471f89e371 | ||
|
|
7e1fc54c97 | ||
|
|
50c8f1c410 |
5
.gitignore
vendored
Normal file
5
.gitignore
vendored
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
**/*.o
|
||||||
|
.codex
|
||||||
|
**/*.elf
|
||||||
|
**/*.dump
|
||||||
|
**/*.a
|
||||||
10
kernels/blackwell_fp8_e4m3/Makefile
Normal file
10
kernels/blackwell_fp8_e4m3/Makefile
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
PROJECT = blackwell_fp8_e4m3
|
||||||
|
|
||||||
|
VX_SRCS = kernel.cpp
|
||||||
|
VX_INCLUDES = fp8_common.hpp
|
||||||
|
OPTS ?= -n1
|
||||||
|
|
||||||
|
include ../common.mk
|
||||||
|
|
||||||
|
args.bin input.a.bin input.b.bin input.c.bin: ../wu_arch_cases/zero.bin
|
||||||
|
cp $< $@
|
||||||
21
kernels/blackwell_fp8_e4m3/README.md
Normal file
21
kernels/blackwell_fp8_e4m3/README.md
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
# blackwell_fp8_e4m3
|
||||||
|
|
||||||
|
Standalone FP8 E4M3 validation kernel for the Wu Blackwell BWGMMA branch.
|
||||||
|
|
||||||
|
This directory is the only kernel area used by the FP8 branch work. Existing
|
||||||
|
FP16 HGEMM, `wu_arch_cases`, and flash kernels are intentionally left unchanged.
|
||||||
|
|
||||||
|
The validation runs one tensor warp on a 16x16x32 tile:
|
||||||
|
|
||||||
|
- A is FP8 E4M3 1.0 (`0x38`)
|
||||||
|
- B is FP8 E4M3 2.0 (`0x40`)
|
||||||
|
- C is FP32 1.0 (`0x3f800000`)
|
||||||
|
- Expected output is FP32 65.0 (`0x42820000`)
|
||||||
|
- `VirgoBlackwellConfig` currently uses 4 core/memory lanes, so one
|
||||||
|
`tcgen05_cp/cb` fragment is 16 bytes.
|
||||||
|
|
||||||
|
Build:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
make -C /home/lzd/wu/wuarch/virgo-kernels/kernels/blackwell_fp8_e4m3
|
||||||
|
```
|
||||||
53
kernels/blackwell_fp8_e4m3/fp8_common.hpp
Normal file
53
kernels/blackwell_fp8_e4m3/fp8_common.hpp
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
#ifndef BLACKWELL_FP8_E4M3_COMMON_HPP
|
||||||
|
#define BLACKWELL_FP8_E4M3_COMMON_HPP
|
||||||
|
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <vx_intrinsics.h>
|
||||||
|
|
||||||
|
#define WU_FP8_E4M3_ZERO 0x00u
|
||||||
|
#define WU_FP8_E4M3_HALF 0x30u
|
||||||
|
#define WU_FP8_E4M3_ONE 0x38u
|
||||||
|
#define WU_FP8_E4M3_TWO 0x40u
|
||||||
|
|
||||||
|
#define WU_FP8_PACK4(a, b, c, d) \
|
||||||
|
((((uint32_t)(a) & 0xffu) << 0) | (((uint32_t)(b) & 0xffu) << 8) | \
|
||||||
|
(((uint32_t)(c) & 0xffu) << 16) | (((uint32_t)(d) & 0xffu) << 24))
|
||||||
|
|
||||||
|
#define WU_FP8_REP2(x) x, x
|
||||||
|
#define WU_FP8_REP4(x) WU_FP8_REP2(x), WU_FP8_REP2(x)
|
||||||
|
#define WU_FP8_REP8(x) WU_FP8_REP4(x), WU_FP8_REP4(x)
|
||||||
|
|
||||||
|
static inline void wu_tcgen05_cp(uint32_t addr_tmem, uint32_t addr_gmem) {
|
||||||
|
asm volatile(".insn r %0, 2, 0, x0, %1, %2"
|
||||||
|
:
|
||||||
|
: "i"(RISCV_CUSTOM3), "r"(addr_tmem), "r"(addr_gmem)
|
||||||
|
: "memory");
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline void wu_tcgen05_cp_wait() {
|
||||||
|
asm volatile(".insn r %0, 3, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3)
|
||||||
|
: "memory");
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline void wu_tcgen05_cb(uint32_t addr_tmem, uint32_t addr_gmem) {
|
||||||
|
asm volatile(".insn r %0, 6, 0, x0, %1, %2"
|
||||||
|
:
|
||||||
|
: "i"(RISCV_CUSTOM3), "r"(addr_tmem), "r"(addr_gmem)
|
||||||
|
: "memory");
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline void wu_bwgmma_fp8(uint32_t addr_tmem_c, uint32_t addr_tmem_a,
|
||||||
|
uint32_t addr_smem_b) {
|
||||||
|
asm volatile(".insn r %0, 0, 0, %1, %2, %3"
|
||||||
|
:
|
||||||
|
: "i"(RISCV_CUSTOM3), "r"(addr_tmem_c), "r"(addr_tmem_a),
|
||||||
|
"r"(addr_smem_b)
|
||||||
|
: "memory");
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline void wu_bwgmma_wait() {
|
||||||
|
asm volatile(".insn r %0, 1, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3)
|
||||||
|
: "memory");
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
121
kernels/blackwell_fp8_e4m3/kernel.cpp
Normal file
121
kernels/blackwell_fp8_e4m3/kernel.cpp
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
#include "fp8_common.hpp"
|
||||||
|
#include "../wu_arch_cases/common_wu_min.h"
|
||||||
|
|
||||||
|
#define DEV_SMEM_START_ADDR 0xff000000u
|
||||||
|
#define FP8_VALIDATION_BASE 0x7800u
|
||||||
|
|
||||||
|
#define FP8_M 16u
|
||||||
|
#define FP8_N 16u
|
||||||
|
#define FP8_K 32u
|
||||||
|
#define FP8_TILE_BYTES 1024u
|
||||||
|
#define FP8_FRAGMENT_BYTES 16u
|
||||||
|
#define FP8_FRAGMENT_WORDS (FP8_FRAGMENT_BYTES / sizeof(uint32_t))
|
||||||
|
#define FP8_FRAGMENTS (FP8_TILE_BYTES / FP8_FRAGMENT_BYTES)
|
||||||
|
#define FP8_OUT_WORDS (FP8_M * FP8_N)
|
||||||
|
#define FP8_EXPECTED 0x42820000u
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
volatile uint32_t g_fp8_a_frag[FP8_FRAGMENT_WORDS] __attribute__((aligned(16))) = {
|
||||||
|
WU_FP8_REP4(WU_FP8_PACK4(WU_FP8_E4M3_ONE, WU_FP8_E4M3_ONE,
|
||||||
|
WU_FP8_E4M3_ONE, WU_FP8_E4M3_ONE))};
|
||||||
|
volatile uint32_t g_fp8_b_frag[FP8_FRAGMENT_WORDS] __attribute__((aligned(16))) = {
|
||||||
|
WU_FP8_REP4(WU_FP8_PACK4(WU_FP8_E4M3_TWO, WU_FP8_E4M3_TWO,
|
||||||
|
WU_FP8_E4M3_TWO, WU_FP8_E4M3_TWO))};
|
||||||
|
volatile uint32_t g_fp8_c_frag[FP8_FRAGMENT_WORDS] __attribute__((aligned(16))) = {
|
||||||
|
WU_FP8_REP4(0x3f800000u)};
|
||||||
|
volatile uint32_t g_fp8_out[FP8_OUT_WORDS] __attribute__((aligned(16)));
|
||||||
|
}
|
||||||
|
|
||||||
|
#undef WU_FP8_REP2
|
||||||
|
#undef WU_FP8_REP4
|
||||||
|
#undef WU_FP8_REP8
|
||||||
|
|
||||||
|
extern "C" void __attribute__((naked, noinline, used)) fp8_validation_worker() {
|
||||||
|
asm volatile(
|
||||||
|
"li x1, 0\n\t"
|
||||||
|
"li x2, %[tile_bytes]\n\t"
|
||||||
|
"la x6, g_fp8_a_frag\n\t"
|
||||||
|
"la x3, g_fp8_c_frag\n\t"
|
||||||
|
"li x7, 0\n\t"
|
||||||
|
"1:\n\t"
|
||||||
|
"add x4, x1, x7\n\t"
|
||||||
|
".insn r %[custom3], 2, 0, x0, x4, x6\n\t"
|
||||||
|
"add x4, x2, x7\n\t"
|
||||||
|
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
|
||||||
|
"addi x7, x7, %[frag_bytes]\n\t"
|
||||||
|
"li x4, %[tile_bytes]\n\t"
|
||||||
|
"blt x7, x4, 1b\n\t"
|
||||||
|
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||||
|
"li x4, %[smem_base]\n\t"
|
||||||
|
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
|
||||||
|
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
|
||||||
|
"la x3, g_fp8_out\n\t"
|
||||||
|
"li x7, 0\n\t"
|
||||||
|
"2:\n\t"
|
||||||
|
"add x4, x2, x7\n\t"
|
||||||
|
"add x1, x3, x7\n\t"
|
||||||
|
".insn r %[custom3], 6, 0, x0, x4, x1\n\t"
|
||||||
|
"addi x7, x7, %[frag_bytes]\n\t"
|
||||||
|
"li x4, %[tile_bytes]\n\t"
|
||||||
|
"blt x7, x4, 2b\n\t"
|
||||||
|
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||||
|
"csrr x5, %[csr_wid]\n\t"
|
||||||
|
"slli x6, x5, 2\n\t"
|
||||||
|
"la x7, g_seen\n\t"
|
||||||
|
"add x7, x7, x6\n\t"
|
||||||
|
"li x6, %[done_base]\n\t"
|
||||||
|
"or x6, x6, x5\n\t"
|
||||||
|
"sw x6, 0(x7)\n\t"
|
||||||
|
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
|
||||||
|
"3: j 3b\n\t"
|
||||||
|
:
|
||||||
|
: [csr_wid] "i"(VX_CSR_WARP_ID),
|
||||||
|
[custom0] "i"(RISCV_CUSTOM0),
|
||||||
|
[custom3] "i"(RISCV_CUSTOM3),
|
||||||
|
[smem_base] "i"(DEV_SMEM_START_ADDR),
|
||||||
|
[done_base] "i"(FP8_VALIDATION_BASE),
|
||||||
|
[tile_bytes] "i"(FP8_TILE_BYTES),
|
||||||
|
[frag_bytes] "i"(FP8_FRAGMENT_BYTES)
|
||||||
|
: "memory");
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" int wu_main() {
|
||||||
|
if (!wu_is_leader()) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
wu_case_reset();
|
||||||
|
|
||||||
|
for (uint32_t i = 0; i < FP8_OUT_WORDS; ++i) {
|
||||||
|
g_fp8_out[i] = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
volatile uint32_t *smem_b =
|
||||||
|
reinterpret_cast<volatile uint32_t *>(DEV_SMEM_START_ADDR);
|
||||||
|
for (uint32_t frag = 0; frag < FP8_FRAGMENTS; ++frag) {
|
||||||
|
const uint32_t row = frag * FP8_FRAGMENT_WORDS;
|
||||||
|
for (uint32_t i = 0; i < FP8_FRAGMENT_WORDS; ++i) {
|
||||||
|
smem_b[row + i] = g_fp8_b_frag[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint32_t tensor_wid = NUM_SCALAR_WARPS;
|
||||||
|
vx_spawn_tensor(1u << tensor_wid, fp8_validation_worker);
|
||||||
|
|
||||||
|
if (wu_wait_seen_mask(1u << tensor_wid, FP8_VALIDATION_BASE) != 0) {
|
||||||
|
wu_case_fail(0x09u);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (uint32_t i = 0; i < FP8_OUT_WORDS; ++i) {
|
||||||
|
if (g_fp8_out[i] != FP8_EXPECTED) {
|
||||||
|
g_aux[0] = i;
|
||||||
|
g_aux[1] = g_fp8_out[i];
|
||||||
|
wu_case_fail(0x20u);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
wu_case_pass();
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
7
kernels/blackwell_insts/Makefile
Normal file
7
kernels/blackwell_insts/Makefile
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
PROJECT = blackwell_insts
|
||||||
|
|
||||||
|
VX_SRCS = kernel.cpp
|
||||||
|
|
||||||
|
OPTS ?= -n1
|
||||||
|
|
||||||
|
include ../common.mk
|
||||||
1
kernels/blackwell_insts/args.bin
Normal file
1
kernels/blackwell_insts/args.bin
Normal file
@@ -0,0 +1 @@
|
|||||||
|
0
|
||||||
1
kernels/blackwell_insts/input.a.bin
Normal file
1
kernels/blackwell_insts/input.a.bin
Normal file
@@ -0,0 +1 @@
|
|||||||
|
0
|
||||||
1
kernels/blackwell_insts/input.b.bin
Normal file
1
kernels/blackwell_insts/input.b.bin
Normal file
@@ -0,0 +1 @@
|
|||||||
|
0
|
||||||
1
kernels/blackwell_insts/input.c.bin
Normal file
1
kernels/blackwell_insts/input.c.bin
Normal file
@@ -0,0 +1 @@
|
|||||||
|
0
|
||||||
192
kernels/blackwell_insts/kernel.cpp
Normal file
192
kernels/blackwell_insts/kernel.cpp
Normal file
@@ -0,0 +1,192 @@
|
|||||||
|
#include <stdint.h>
|
||||||
|
#include <vx_intrinsics.h>
|
||||||
|
#include <vx_spawn.h>
|
||||||
|
|
||||||
|
#define DEV_SMEM_START_ADDR 0xff000000u
|
||||||
|
|
||||||
|
#define BW_REP2(x) x, x
|
||||||
|
#define BW_REP4(x) BW_REP2(x), BW_REP2(x)
|
||||||
|
#define BW_REP8(x) BW_REP4(x), BW_REP4(x)
|
||||||
|
#define BW_REP16(x) BW_REP8(x), BW_REP8(x)
|
||||||
|
#define BW_REP32(x) BW_REP16(x), BW_REP16(x)
|
||||||
|
#define BW_REP64(x) BW_REP32(x), BW_REP32(x)
|
||||||
|
#define BW_REP128(x) BW_REP64(x), BW_REP64(x)
|
||||||
|
#define BW_REP256(x) BW_REP128(x), BW_REP128(x)
|
||||||
|
|
||||||
|
static volatile uint32_t g_a[256] __attribute__((aligned(32))) = {
|
||||||
|
BW_REP256(0x3c003c00u)}; // two fp16 1.0 values
|
||||||
|
static volatile uint32_t g_b[256] __attribute__((aligned(32))) = {
|
||||||
|
BW_REP256(0x40004000u)}; // two fp16 2.0 values
|
||||||
|
static volatile uint32_t g_c[256] __attribute__((aligned(32))) = {
|
||||||
|
BW_REP256(0x3f800000u)}; // one fp32 1.0 value
|
||||||
|
static volatile uint32_t g_dst[256] __attribute__((aligned(32)));
|
||||||
|
static volatile uint32_t g_debug[16] __attribute__((aligned(32)));
|
||||||
|
static volatile uint32_t g_status __attribute__((aligned(4)));
|
||||||
|
|
||||||
|
#undef BW_REP2
|
||||||
|
#undef BW_REP4
|
||||||
|
#undef BW_REP8
|
||||||
|
#undef BW_REP16
|
||||||
|
#undef BW_REP32
|
||||||
|
#undef BW_REP64
|
||||||
|
#undef BW_REP128
|
||||||
|
#undef BW_REP256
|
||||||
|
|
||||||
|
struct kernel_arg_t {
|
||||||
|
volatile uint32_t *a;
|
||||||
|
volatile uint32_t *b;
|
||||||
|
volatile uint32_t *c;
|
||||||
|
volatile uint32_t *dst;
|
||||||
|
volatile uint32_t *debug;
|
||||||
|
volatile uint32_t *status;
|
||||||
|
};
|
||||||
|
|
||||||
|
static inline void tcgen05_cp(uint32_t addr_tmem, uint32_t addr_gmem) {
|
||||||
|
asm volatile(".insn r %0, 2, 0, x0, %1, %2"
|
||||||
|
:
|
||||||
|
: "i"(RISCV_CUSTOM3), "r"(addr_tmem), "r"(addr_gmem)
|
||||||
|
: "memory");
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline void tcgen05_cb(uint32_t addr_tmem, uint32_t addr_gmem) {
|
||||||
|
asm volatile(".insn r %0, 6, 0, x0, %1, %2"
|
||||||
|
:
|
||||||
|
: "i"(RISCV_CUSTOM3), "r"(addr_tmem), "r"(addr_gmem)
|
||||||
|
: "memory");
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline void tcgen05_cp_wait() {
|
||||||
|
asm volatile(".insn r %0, 3, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3)
|
||||||
|
: "memory");
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline void bwgmma(uint32_t addr_tmem_c,
|
||||||
|
uint32_t addr_tmem_a,
|
||||||
|
uint32_t addr_smem_b) {
|
||||||
|
asm volatile(".insn r %0, 0, 0, %1, %2, %3"
|
||||||
|
:
|
||||||
|
: "i"(RISCV_CUSTOM3), "r"(addr_tmem_c), "r"(addr_tmem_a),
|
||||||
|
"r"(addr_smem_b)
|
||||||
|
: "memory");
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline void bwgmma_wait() {
|
||||||
|
asm volatile(".insn r %0, 1, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3)
|
||||||
|
: "memory");
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline float tcgen05_ld_f32(uint32_t addr_tmem) {
|
||||||
|
float value;
|
||||||
|
asm volatile(".insn r %1, 4, 0, %0, %2, x0"
|
||||||
|
: "=f"(value)
|
||||||
|
: "i"(RISCV_CUSTOM3), "r"(addr_tmem)
|
||||||
|
: "memory");
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline void tcgen05_st_f32(uint32_t addr_tmem, float value) {
|
||||||
|
asm volatile(".insn r %0, 5, 0, %1, %2, x0"
|
||||||
|
:
|
||||||
|
: "i"(RISCV_CUSTOM3), "f"(value), "r"(addr_tmem)
|
||||||
|
: "memory");
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline uint32_t f32_bits(float value) {
|
||||||
|
union {
|
||||||
|
float f;
|
||||||
|
uint32_t u;
|
||||||
|
} bits = {value};
|
||||||
|
return bits.u;
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" void vx_perf_dump() {}
|
||||||
|
|
||||||
|
void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg)
|
||||||
|
__attribute__((convergent));
|
||||||
|
|
||||||
|
void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||||
|
if (task_id != 0)
|
||||||
|
return;
|
||||||
|
|
||||||
|
volatile uint32_t *a = arg->a;
|
||||||
|
volatile uint32_t *b = arg->b;
|
||||||
|
volatile uint32_t *c = arg->c;
|
||||||
|
volatile uint32_t *dst = arg->dst;
|
||||||
|
volatile uint32_t *debug = arg->debug;
|
||||||
|
volatile uint32_t *status = arg->status;
|
||||||
|
|
||||||
|
const uint32_t tmem_a = 0x000;
|
||||||
|
const uint32_t tmem_c = 0x400;
|
||||||
|
const uint32_t tmem_st_scratch = 0x800;
|
||||||
|
volatile uint32_t *smem_b_ptr =
|
||||||
|
reinterpret_cast<volatile uint32_t *>(DEV_SMEM_START_ADDR);
|
||||||
|
const uint32_t smem_b = reinterpret_cast<uint32_t>(smem_b_ptr);
|
||||||
|
const uint32_t expected = 0x42820000u; // 65.0f
|
||||||
|
const uint32_t expected_st = 0x3f800000u; // 1.0f
|
||||||
|
|
||||||
|
if (status != nullptr)
|
||||||
|
status[0] = 0x100u;
|
||||||
|
|
||||||
|
for (int i = 0; i < 256; ++i)
|
||||||
|
smem_b_ptr[i] = b[i];
|
||||||
|
|
||||||
|
for (int frag = 0; frag < 32; ++frag) {
|
||||||
|
const uint32_t offset = static_cast<uint32_t>(frag * 32);
|
||||||
|
tcgen05_cp(tmem_a + offset,
|
||||||
|
reinterpret_cast<uint32_t>(&a[frag * 8]));
|
||||||
|
tcgen05_cp(tmem_c + offset,
|
||||||
|
reinterpret_cast<uint32_t>(&c[frag * 8]));
|
||||||
|
}
|
||||||
|
tcgen05_cp_wait();
|
||||||
|
|
||||||
|
const float st_value = 1.0f;
|
||||||
|
tcgen05_st_f32(tmem_st_scratch, st_value);
|
||||||
|
const uint32_t st_bits = f32_bits(tcgen05_ld_f32(tmem_st_scratch));
|
||||||
|
debug[0] = st_bits;
|
||||||
|
if (st_bits != expected_st) {
|
||||||
|
if (status != nullptr)
|
||||||
|
status[0] = 0xe002u;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
bwgmma(tmem_c, tmem_a, smem_b);
|
||||||
|
bwgmma_wait();
|
||||||
|
|
||||||
|
const float ld_value = tcgen05_ld_f32(tmem_c);
|
||||||
|
const uint32_t ld_bits = f32_bits(ld_value);
|
||||||
|
debug[1] = ld_bits;
|
||||||
|
tcgen05_st_f32(tmem_st_scratch + 32, ld_value);
|
||||||
|
|
||||||
|
for (int frag = 0; frag < 32; ++frag) {
|
||||||
|
tcgen05_cb(tmem_c + static_cast<uint32_t>(frag * 32),
|
||||||
|
reinterpret_cast<uint32_t>(&dst[frag * 8]));
|
||||||
|
}
|
||||||
|
tcgen05_cp_wait();
|
||||||
|
|
||||||
|
if (ld_bits != expected) {
|
||||||
|
if (status != nullptr)
|
||||||
|
status[0] = 0xe001u;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < 256; ++i) {
|
||||||
|
if (dst[i] != expected) {
|
||||||
|
if (status != nullptr)
|
||||||
|
status[0] = 0xe100u | static_cast<uint32_t>(i & 0xff);
|
||||||
|
debug[2] = static_cast<uint32_t>(i);
|
||||||
|
debug[3] = dst[i];
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (status != nullptr)
|
||||||
|
status[0] = 0x600du;
|
||||||
|
}
|
||||||
|
|
||||||
|
int main() {
|
||||||
|
kernel_arg_t arg = {g_a, g_b, g_c, g_dst, g_debug, &g_status};
|
||||||
|
|
||||||
|
vx_spawn_tasks_contiguous(1, reinterpret_cast<vx_spawn_tasks_cb>(kernel_body),
|
||||||
|
&arg);
|
||||||
|
return (g_status == 0x600du) ? 0 : 1;
|
||||||
|
}
|
||||||
7
kernels/blackwell_multi_tc/Makefile
Normal file
7
kernels/blackwell_multi_tc/Makefile
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
PROJECT = blackwell_multi_tc
|
||||||
|
|
||||||
|
VX_SRCS = kernel.cpp
|
||||||
|
|
||||||
|
OPTS ?= -n1
|
||||||
|
|
||||||
|
include ../common.mk
|
||||||
188
kernels/blackwell_multi_tc/kernel.cpp
Normal file
188
kernels/blackwell_multi_tc/kernel.cpp
Normal file
@@ -0,0 +1,188 @@
|
|||||||
|
#include <stdint.h>
|
||||||
|
#include <vx_intrinsics.h>
|
||||||
|
#include <vx_spawn.h>
|
||||||
|
|
||||||
|
#define RISCV_CUSTOM3 0x7B
|
||||||
|
#define DEV_SMEM_START_ADDR 0xff000000u
|
||||||
|
#define MAX_CORES 4
|
||||||
|
#define MAX_WARPS 8
|
||||||
|
|
||||||
|
#define BW_REP2(x) x, x
|
||||||
|
#define BW_REP4(x) BW_REP2(x), BW_REP2(x)
|
||||||
|
#define BW_REP8(x) BW_REP4(x), BW_REP4(x)
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
volatile uint32_t g_a_row[8] __attribute__((aligned(32))) = {
|
||||||
|
BW_REP8(0x3c003c00u)}; // two fp16 1.0 values per word
|
||||||
|
volatile uint32_t g_b_row[8] __attribute__((aligned(32))) = {
|
||||||
|
BW_REP8(0x40004000u)}; // two fp16 2.0 values per word
|
||||||
|
volatile uint32_t g_c_row[8] __attribute__((aligned(32))) = {
|
||||||
|
BW_REP8(0x3f800000u)}; // one fp32 1.0 value per word
|
||||||
|
volatile uint32_t g_result[MAX_CORES * MAX_WARPS] __attribute__((aligned(32)));
|
||||||
|
volatile uint32_t g_status[MAX_CORES * MAX_WARPS] __attribute__((aligned(32)));
|
||||||
|
}
|
||||||
|
|
||||||
|
#undef BW_REP2
|
||||||
|
#undef BW_REP4
|
||||||
|
#undef BW_REP8
|
||||||
|
|
||||||
|
static inline void tcgen05_cp(uint32_t addr_tmem, uint32_t addr_gmem) {
|
||||||
|
asm volatile(".insn r %0, 2, 0, x0, %1, %2"
|
||||||
|
:
|
||||||
|
: "i"(RISCV_CUSTOM3), "r"(addr_tmem), "r"(addr_gmem)
|
||||||
|
: "memory");
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline void tcgen05_cp_wait() {
|
||||||
|
asm volatile(".insn r %0, 3, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3)
|
||||||
|
: "memory");
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline void bwgmma(uint32_t addr_tmem_c,
|
||||||
|
uint32_t addr_tmem_a,
|
||||||
|
uint32_t addr_smem_b) {
|
||||||
|
asm volatile(".insn r %0, 0, 0, %1, %2, %3"
|
||||||
|
:
|
||||||
|
: "i"(RISCV_CUSTOM3), "r"(addr_tmem_c), "r"(addr_tmem_a),
|
||||||
|
"r"(addr_smem_b)
|
||||||
|
: "memory");
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline void bwgmma_wait() {
|
||||||
|
asm volatile(".insn r %0, 1, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3)
|
||||||
|
: "memory");
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline float tcgen05_ld_f32(uint32_t addr_tmem) {
|
||||||
|
float value;
|
||||||
|
asm volatile(".insn r %1, 4, 0, %0, %2, x0"
|
||||||
|
: "=f"(value)
|
||||||
|
: "i"(RISCV_CUSTOM3), "r"(addr_tmem)
|
||||||
|
: "memory");
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline uint32_t f32_bits(float value) {
|
||||||
|
union {
|
||||||
|
float f;
|
||||||
|
uint32_t u;
|
||||||
|
} bits = {value};
|
||||||
|
return bits.u;
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" void vx_perf_dump() {}
|
||||||
|
|
||||||
|
extern "C" void __attribute__((naked, noinline, used)) tensor_kernel_entry() {
|
||||||
|
asm volatile(
|
||||||
|
"csrr x5, %[csr_wid]\n\t"
|
||||||
|
"slli x1, x5, 11\n\t" // tmem_a = wid * 0x800
|
||||||
|
"addi x2, x1, 1024\n\t" // tmem_c = tmem_a + 0x400
|
||||||
|
"la x6, g_a_row\n\t"
|
||||||
|
"la x3, g_c_row\n\t"
|
||||||
|
"li x7, 0\n\t"
|
||||||
|
"1:\n\t"
|
||||||
|
"add x4, x1, x7\n\t"
|
||||||
|
".insn r %[custom3], 2, 0, x0, x4, x6\n\t"
|
||||||
|
"add x4, x2, x7\n\t"
|
||||||
|
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
|
||||||
|
"addi x7, x7, 32\n\t"
|
||||||
|
"li x4, 1024\n\t"
|
||||||
|
"blt x7, x4, 1b\n\t"
|
||||||
|
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||||
|
"li x4, %[smem_base]\n\t"
|
||||||
|
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
|
||||||
|
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
|
||||||
|
"csrr x5, %[csr_wid]\n\t"
|
||||||
|
"slli x5, x5, 2\n\t"
|
||||||
|
"la x6, g_status\n\t"
|
||||||
|
"add x6, x6, x5\n\t"
|
||||||
|
"li x7, 0x600d\n\t"
|
||||||
|
"sw x7, 0(x6)\n\t"
|
||||||
|
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
|
||||||
|
"2: j 2b\n\t"
|
||||||
|
:
|
||||||
|
: [csr_wid] "i"(VX_CSR_WARP_ID), [custom0] "i"(RISCV_CUSTOM0),
|
||||||
|
[custom3] "i"(RISCV_CUSTOM3), [smem_base] "i"(DEV_SMEM_START_ADDR)
|
||||||
|
: "memory");
|
||||||
|
}
|
||||||
|
|
||||||
|
struct kernel_arg_t {
|
||||||
|
volatile uint32_t *a_row;
|
||||||
|
volatile uint32_t *b_row;
|
||||||
|
volatile uint32_t *c_row;
|
||||||
|
volatile uint32_t *result;
|
||||||
|
volatile uint32_t *status;
|
||||||
|
};
|
||||||
|
|
||||||
|
static void __attribute__((convergent)) kernel_body(int task_id,
|
||||||
|
kernel_arg_t *__UNIFORM__ arg) {
|
||||||
|
const int warp_id = task_id / vx_num_threads();
|
||||||
|
const int lane_id = task_id % vx_num_threads();
|
||||||
|
if (lane_id != 0)
|
||||||
|
return;
|
||||||
|
|
||||||
|
const int num_warps = vx_num_warps();
|
||||||
|
if (warp_id >= num_warps)
|
||||||
|
return;
|
||||||
|
|
||||||
|
volatile uint32_t *a_row = arg->a_row;
|
||||||
|
volatile uint32_t *c_row = arg->c_row;
|
||||||
|
volatile uint32_t *result = arg->result;
|
||||||
|
volatile uint32_t *status = arg->status;
|
||||||
|
const int core_id = vx_core_id();
|
||||||
|
const int status_idx = core_id * MAX_WARPS + warp_id;
|
||||||
|
|
||||||
|
const uint32_t smem_b = DEV_SMEM_START_ADDR;
|
||||||
|
const uint32_t base = static_cast<uint32_t>(warp_id) * 0x800u;
|
||||||
|
const uint32_t tmem_a = base + 0x000u;
|
||||||
|
const uint32_t tmem_c = base + 0x400u;
|
||||||
|
|
||||||
|
status[status_idx] = 0x100u;
|
||||||
|
|
||||||
|
for (int frag = 0; frag < 32; ++frag) {
|
||||||
|
const uint32_t offset = static_cast<uint32_t>(frag * 32);
|
||||||
|
tcgen05_cp(tmem_a + offset, reinterpret_cast<uint32_t>(a_row));
|
||||||
|
tcgen05_cp(tmem_c + offset, reinterpret_cast<uint32_t>(c_row));
|
||||||
|
}
|
||||||
|
tcgen05_cp_wait();
|
||||||
|
|
||||||
|
status[status_idx] = 0x200u;
|
||||||
|
bwgmma(tmem_c, tmem_a, smem_b);
|
||||||
|
bwgmma_wait();
|
||||||
|
|
||||||
|
result[status_idx] = f32_bits(tcgen05_ld_f32(tmem_c));
|
||||||
|
status[status_idx] = (result[status_idx] == 0x42820000u) ? 0x600du : 0xe000u;
|
||||||
|
}
|
||||||
|
|
||||||
|
int main() {
|
||||||
|
const int core_id = vx_core_id();
|
||||||
|
if (core_id != 0 || vx_warp_id() != 0 || vx_thread_id() != 0)
|
||||||
|
return 0;
|
||||||
|
|
||||||
|
volatile uint32_t *smem_b_ptr =
|
||||||
|
reinterpret_cast<volatile uint32_t *>(DEV_SMEM_START_ADDR);
|
||||||
|
for (int frag = 0; frag < 32; ++frag) {
|
||||||
|
const int row = frag * 8;
|
||||||
|
for (int i = 0; i < 8; ++i)
|
||||||
|
smem_b_ptr[row + i] = g_b_row[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < MAX_WARPS; ++i)
|
||||||
|
g_status[core_id * MAX_WARPS + i] = 0;
|
||||||
|
|
||||||
|
vx_spawn_tensor(vx_tensor_warp_mask(), tensor_kernel_entry);
|
||||||
|
|
||||||
|
for (int spin = 0; spin < 100000; ++spin) {
|
||||||
|
int done = 1;
|
||||||
|
for (int i = NUM_SCALAR_WARPS; i < vx_num_warps(); ++i)
|
||||||
|
done &= (g_status[core_id * MAX_WARPS + i] == 0x600du);
|
||||||
|
if (done)
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = NUM_SCALAR_WARPS; i < vx_num_warps(); ++i) {
|
||||||
|
if (g_status[core_id * MAX_WARPS + i] != 0x600du)
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
@@ -1,3 +1,6 @@
|
|||||||
|
# Get the directory where this common.mk file is located
|
||||||
|
COMMON_MK_DIR := $(dir $(lastword $(MAKEFILE_LIST)))
|
||||||
|
|
||||||
XLEN ?= 32
|
XLEN ?= 32
|
||||||
|
|
||||||
TOOLDIR ?= /opt
|
TOOLDIR ?= /opt
|
||||||
@@ -7,7 +10,7 @@ RISCV_TOOLCHAIN_PATH ?= $(TOOLDIR)/riscv64-gnu-toolchain
|
|||||||
VX_CFLAGS += -march=rv64imafd -mabi=lp64d
|
VX_CFLAGS += -march=rv64imafd -mabi=lp64d
|
||||||
STARTUP_ADDR ?= 0x180000000
|
STARTUP_ADDR ?= 0x180000000
|
||||||
else
|
else
|
||||||
RISCV_TOOLCHAIN_PATH ?= $(TOOLDIR)/riscv-gnu-toolchain
|
RISCV_TOOLCHAIN_PATH ?= $(realpath $(COMMON_MK_DIR)../../toolchain/vortex-toolchain-prebuilt/riscv-gnu-toolchain)
|
||||||
VX_CFLAGS += -march=rv32imaf -mabi=ilp32f
|
VX_CFLAGS += -march=rv32imaf -mabi=ilp32f
|
||||||
STARTUP_ADDR ?= 0x80000000
|
STARTUP_ADDR ?= 0x80000000
|
||||||
endif
|
endif
|
||||||
@@ -18,7 +21,7 @@ RISCV_SYSROOT ?= $(RISCV_TOOLCHAIN_PATH)/$(RISCV_PREFIX)
|
|||||||
VORTEX_KN_PATH ?= $(realpath ../../lib)
|
VORTEX_KN_PATH ?= $(realpath ../../lib)
|
||||||
GEMMINI_SW_PATH ?= $(realpath ../../lib/gemmini)
|
GEMMINI_SW_PATH ?= $(realpath ../../lib/gemmini)
|
||||||
|
|
||||||
LLVM_VORTEX ?= $(TOOLDIR)/llvm-vortex
|
LLVM_VORTEX ?= $(realpath $(COMMON_MK_DIR)../../toolchain/llvm-r8)
|
||||||
|
|
||||||
LLVM_CFLAGS += --sysroot=$(RISCV_SYSROOT)
|
LLVM_CFLAGS += --sysroot=$(RISCV_SYSROOT)
|
||||||
LLVM_CFLAGS += --gcc-toolchain=$(RISCV_TOOLCHAIN_PATH)
|
LLVM_CFLAGS += --gcc-toolchain=$(RISCV_TOOLCHAIN_PATH)
|
||||||
|
|||||||
1
kernels/flash_attention/args.bin
Symbolic link
1
kernels/flash_attention/args.bin
Symbolic link
@@ -0,0 +1 @@
|
|||||||
|
args.seq1024.headdim64.bin
|
||||||
BIN
kernels/flash_attention/args.seq1024.headdim64.bin
Normal file
BIN
kernels/flash_attention/args.seq1024.headdim64.bin
Normal file
Binary file not shown.
BIN
kernels/flash_attention/args.seq128.headdim64.bin
Normal file
BIN
kernels/flash_attention/args.seq128.headdim64.bin
Normal file
Binary file not shown.
BIN
kernels/flash_attention/args.seq192.headdim64.bin
Normal file
BIN
kernels/flash_attention/args.seq192.headdim64.bin
Normal file
Binary file not shown.
BIN
kernels/flash_attention/args.seq64.headdim64.bin
Normal file
BIN
kernels/flash_attention/args.seq64.headdim64.bin
Normal file
Binary file not shown.
45
kernels/flash_attention/compile_flash.sh
Executable file
45
kernels/flash_attention/compile_flash.sh
Executable file
@@ -0,0 +1,45 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
archs=("ampere" "virgo")
|
||||||
|
|
||||||
|
if [ -z "$TOOLDIR" ]; then
|
||||||
|
echo "error: \$TOOLDIR not set. Did you run source ci/toolchain_env.sh?"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
check_exists() {
|
||||||
|
if ! [ -f "$1" ]; then
|
||||||
|
echo "error: looked for file $1 that does not exist."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
# generate operands
|
||||||
|
echo "generating flash_attn operands for seqlen 1024, headdim 64"
|
||||||
|
python3 flash_attn.py 1024 64 64
|
||||||
|
mv -v input.a.col.bin input.a.rand.fp32.seqlen1024headdim64.col.bin
|
||||||
|
mv -v input.a.row.bin input.a.rand.fp32.seqlen1024headdim64.row.bin
|
||||||
|
mv -v input.b.bin input.b.rand.fp32.seqlen1024headdim64.row.bin
|
||||||
|
mv -v input.c.bin input.c.rand.fp32.seqlen1024headdim64.row.bin
|
||||||
|
ln -sf input.a.rand.fp32.seqlen1024headdim64.row.bin input.a.bin
|
||||||
|
ln -sf input.b.rand.fp32.seqlen1024headdim64.row.bin input.b.bin
|
||||||
|
ln -sf input.c.rand.fp32.seqlen1024headdim64.row.bin input.c.bin
|
||||||
|
|
||||||
|
for arch in "${archs[@]}"; do
|
||||||
|
git checkout ae-flash-$arch
|
||||||
|
# git pull
|
||||||
|
|
||||||
|
# re-compile libvortexrt.a
|
||||||
|
pushd ../../lib
|
||||||
|
make
|
||||||
|
popd
|
||||||
|
|
||||||
|
echo "compiling flash_attn kernel for $arch with seqlen 1024, headdim 64"
|
||||||
|
|
||||||
|
# touch source file to force re-building, as the Makefile does not track
|
||||||
|
# binary changes
|
||||||
|
touch kernel.cpp
|
||||||
|
touch kernel.gemmini.cpp
|
||||||
|
|
||||||
|
make CONFIG=flash.$arch.seqlen1024.headdim64
|
||||||
|
done
|
||||||
159
kernels/flash_attention/flash_attn.py
Normal file
159
kernels/flash_attention/flash_attn.py
Normal file
@@ -0,0 +1,159 @@
|
|||||||
|
import sys
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
def parse_mnk():
|
||||||
|
if len(sys.argv) != 4:
|
||||||
|
print(f"usage: {sys.argv[0]} dimM dimN dimK", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
m = int(sys.argv[1])
|
||||||
|
n = int(sys.argv[2])
|
||||||
|
k = int(sys.argv[3])
|
||||||
|
return (m, n, k)
|
||||||
|
|
||||||
|
|
||||||
|
# Reorder array in a way that groups two adjacent elements along the column to
|
||||||
|
# be now adjacent along the row. This way, when the resulting fp16 array is
|
||||||
|
# read in column-major order with 32-bit granularity, the fp16 elements will be
|
||||||
|
# read in the same order as regular fp32 elements in column-major.
|
||||||
|
#
|
||||||
|
# For example:
|
||||||
|
# [[1 2]
|
||||||
|
# [3 4]
|
||||||
|
# [5 6]
|
||||||
|
# [7 8]]
|
||||||
|
# becomes
|
||||||
|
# [[1 3 2 4]
|
||||||
|
# [5 7 6 8]]
|
||||||
|
def pack_fp16_by_column(array):
|
||||||
|
rows = array.shape[0]
|
||||||
|
cols = array.shape[1]
|
||||||
|
|
||||||
|
T = array.transpose([1, 0])
|
||||||
|
T_packed = T.reshape([cols, -1, 2])
|
||||||
|
result = T_packed.transpose([1, 0, 2])
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
# Do the same as pack_fp16_by_column, but for every two elements along the row.
|
||||||
|
def pack_fp16_by_row(array):
|
||||||
|
rows = array.shape[0]
|
||||||
|
cols = array.shape[1]
|
||||||
|
|
||||||
|
result = array.reshape([rows, -1, 2])
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
seqlen, _, headdim = parse_mnk()
|
||||||
|
|
||||||
|
rand = True
|
||||||
|
if not rand:
|
||||||
|
A_array = np.arange(seqlen * headdim).reshape([seqlen, headdim])
|
||||||
|
B_array = np.arange(headdim * seqlen).reshape([headdim, seqlen])
|
||||||
|
C_array = np.arange(seqlen * seqlen).reshape([seqlen, headdim])
|
||||||
|
else:
|
||||||
|
np.random.seed(0)
|
||||||
|
A_array = np.random.rand(seqlen, headdim) - 0.5
|
||||||
|
B_array = np.random.rand(headdim, seqlen) - 0.5
|
||||||
|
C_array = np.random.rand(seqlen, headdim) - 0.5
|
||||||
|
# C_array = np.zeros([M, N])
|
||||||
|
|
||||||
|
fp16 = False
|
||||||
|
if fp16:
|
||||||
|
A_packed = pack_fp16_by_row(A_array)
|
||||||
|
AT_packed = A_packed.transpose([1, 0, 2])
|
||||||
|
AT_array = AT_packed.reshape([-1, seqlen * 2])
|
||||||
|
AT_array.astype('float16').tofile("input.a.col.bin")
|
||||||
|
# print('AT:')
|
||||||
|
# print(AT_array)
|
||||||
|
B_packed = pack_fp16_by_column(B_array)
|
||||||
|
B_array = B_packed.reshape([-1, headdim * 2])
|
||||||
|
B_array.astype('float16').tofile("input.b.row.bin")
|
||||||
|
# print('B:')
|
||||||
|
# print(B_array)
|
||||||
|
else:
|
||||||
|
A_array.astype('float32').tofile("input.a.row.bin")
|
||||||
|
AT_array = A_array.transpose([1, 0])
|
||||||
|
AT_array.astype('float32').tofile("input.a.col.bin")
|
||||||
|
B_array.astype('float32').tofile("input.b.bin")
|
||||||
|
C_array.astype('float32').tofile("input.c.bin")
|
||||||
|
# print('AT:')
|
||||||
|
# print(AT_array)
|
||||||
|
# print('B:')
|
||||||
|
# print(B_array)
|
||||||
|
|
||||||
|
assert((seqlen % 64) == 0)
|
||||||
|
|
||||||
|
Br = 64
|
||||||
|
Bc = Br
|
||||||
|
|
||||||
|
rowmax = np.zeros([Br])
|
||||||
|
rowsum = np.zeros([Br])
|
||||||
|
O = np.zeros([Br, headdim])
|
||||||
|
|
||||||
|
def exp2(x):
|
||||||
|
return (x**2) / 2.0 + x + 1.0
|
||||||
|
|
||||||
|
full_S = A_array @ B_array
|
||||||
|
full_S_T = full_S.transpose([1, 0])
|
||||||
|
full_S.astype('float32').tofile("full_S.bin")
|
||||||
|
|
||||||
|
col_to_save = 0
|
||||||
|
|
||||||
|
for col in range(0, seqlen, Bc):
|
||||||
|
print(f"tile iteration {col}~{col + Bc} ======================================")
|
||||||
|
|
||||||
|
# FIXME: only work with the first 64 rows of Q for now
|
||||||
|
Q_tile = A_array[0:64, :]
|
||||||
|
K_tile = B_array[:, col:col+Bc]
|
||||||
|
|
||||||
|
S = Q_tile @ K_tile
|
||||||
|
if col == col_to_save:
|
||||||
|
print('S_expected:')
|
||||||
|
print(S)
|
||||||
|
S.astype('float32').tofile("S_expected.bin")
|
||||||
|
|
||||||
|
# generate rowmax result in online softmax
|
||||||
|
rowmax_this = np.max(S, axis=1)
|
||||||
|
rowmax_prev = rowmax.copy()
|
||||||
|
rowmax = np.maximum(rowmax, rowmax_this)
|
||||||
|
if col == col_to_save:
|
||||||
|
rowmax.astype('float32').tofile("rowmax.bin")
|
||||||
|
|
||||||
|
# subtrace rowmax from each row by broadcasting
|
||||||
|
# (placeholder for exp)
|
||||||
|
x = S - rowmax[:, np.newaxis]
|
||||||
|
P = exp2(x)
|
||||||
|
# for i in range(3, 4):
|
||||||
|
# P += (x**i) / np.math.factorial(i)
|
||||||
|
# P = np.exp(exp)
|
||||||
|
# print('P error:')
|
||||||
|
# print(P / np.exp(x))
|
||||||
|
if col == col_to_save:
|
||||||
|
print('P_expected:')
|
||||||
|
print(P)
|
||||||
|
P.astype('float32').tofile("P_expected.bin")
|
||||||
|
P.transpose([1, 0]).astype('float32').tofile("P_expected.col.bin")
|
||||||
|
|
||||||
|
rowsum_this = np.sum(P, axis=1)
|
||||||
|
x = rowmax_prev - rowmax_this
|
||||||
|
rowsum = exp2(x) * rowsum + rowsum_this
|
||||||
|
if col == col_to_save:
|
||||||
|
rowsum.astype('float32').tofile("rowsum.bin")
|
||||||
|
|
||||||
|
x = rowmax_prev - rowmax
|
||||||
|
O = O / (exp2(x)[:, np.newaxis])
|
||||||
|
if col == col_to_save:
|
||||||
|
print('O_before_PV:')
|
||||||
|
print(O)
|
||||||
|
O.astype('float32').tofile("O_before_PV.bin")
|
||||||
|
|
||||||
|
V = C_array[col:col+Bc, :]
|
||||||
|
if col == col_to_save:
|
||||||
|
V.astype('float32').tofile("V_expected.bin")
|
||||||
|
# O = P.transpose([1, 0]) @ V
|
||||||
|
O = O + P @ V
|
||||||
|
if col == col_to_save:
|
||||||
|
print('O_after_PV:')
|
||||||
|
print(O)
|
||||||
|
O.astype('float32').tofile("O_after_PV.bin")
|
||||||
9
kernels/hgemm_validation/Makefile
Normal file
9
kernels/hgemm_validation/Makefile
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
PROJECT = hgemm_validation
|
||||||
|
|
||||||
|
VX_SRCS = kernel.cpp
|
||||||
|
OPTS ?= -n1
|
||||||
|
|
||||||
|
include ../common.mk
|
||||||
|
|
||||||
|
args.bin input.a.bin input.b.bin input.c.bin: ../wu_arch_cases/zero.bin
|
||||||
|
cp $< $@
|
||||||
14
kernels/hgemm_validation/README.md
Normal file
14
kernels/hgemm_validation/README.md
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
# hgemm_validation
|
||||||
|
|
||||||
|
Small HGEMM correctness validation for the 4-lane Blackwell tensor-core path.
|
||||||
|
|
||||||
|
The test runs one tensor warp on a single 16x16x32 HGEMM tile:
|
||||||
|
|
||||||
|
- A is fp16 1.0
|
||||||
|
- B is fp16 2.0
|
||||||
|
- C starts as fp32 1.0
|
||||||
|
- expected output is fp32 65.0 for all 16x16 C elements
|
||||||
|
|
||||||
|
Scalar warp 0 initializes the shared-memory B tile, spawns only the first tensor
|
||||||
|
warp, waits for completion, and checks all 256 output words copied back from
|
||||||
|
TMEM. Success reports `WU_CASE_PASS` through `tohost`.
|
||||||
119
kernels/hgemm_validation/kernel.cpp
Normal file
119
kernels/hgemm_validation/kernel.cpp
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
#include "../wu_arch_cases/common_wu_min.h"
|
||||||
|
|
||||||
|
#define DEV_SMEM_START_ADDR 0xff000000u
|
||||||
|
#define HGEMM_VALIDATION_BASE 0x7700u
|
||||||
|
|
||||||
|
#define HGEMM_M 16u
|
||||||
|
#define HGEMM_N 16u
|
||||||
|
#define HGEMM_K 32u
|
||||||
|
#define HGEMM_TILE_BYTES 1024u
|
||||||
|
#define HGEMM_FRAGMENT_BYTES 16u
|
||||||
|
#define HGEMM_FRAGMENTS (HGEMM_TILE_BYTES / HGEMM_FRAGMENT_BYTES)
|
||||||
|
#define HGEMM_OUT_WORDS (HGEMM_M * HGEMM_N)
|
||||||
|
#define HGEMM_EXPECTED 0x42820000u
|
||||||
|
|
||||||
|
#define BW_REP2(x) x, x
|
||||||
|
#define BW_REP4(x) BW_REP2(x), BW_REP2(x)
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
volatile uint32_t g_hgemm_a_frag[4] __attribute__((aligned(16))) = {
|
||||||
|
BW_REP4(0x3c003c00u)};
|
||||||
|
volatile uint32_t g_hgemm_b_frag[4] __attribute__((aligned(16))) = {
|
||||||
|
BW_REP4(0x40004000u)};
|
||||||
|
volatile uint32_t g_hgemm_c_frag[4] __attribute__((aligned(16))) = {
|
||||||
|
BW_REP4(0x3f800000u)};
|
||||||
|
volatile uint32_t g_hgemm_out[HGEMM_OUT_WORDS] __attribute__((aligned(16)));
|
||||||
|
}
|
||||||
|
|
||||||
|
#undef BW_REP2
|
||||||
|
#undef BW_REP4
|
||||||
|
|
||||||
|
extern "C" void __attribute__((naked, noinline, used)) hgemm_validation_worker() {
|
||||||
|
asm volatile(
|
||||||
|
"li x1, 0\n\t"
|
||||||
|
"li x2, %[tile_bytes]\n\t"
|
||||||
|
"la x6, g_hgemm_a_frag\n\t"
|
||||||
|
"la x3, g_hgemm_c_frag\n\t"
|
||||||
|
"li x7, 0\n\t"
|
||||||
|
"1:\n\t"
|
||||||
|
"add x4, x1, x7\n\t"
|
||||||
|
".insn r %[custom3], 2, 0, x0, x4, x6\n\t"
|
||||||
|
"add x4, x2, x7\n\t"
|
||||||
|
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
|
||||||
|
"addi x7, x7, %[frag_bytes]\n\t"
|
||||||
|
"li x4, %[tile_bytes]\n\t"
|
||||||
|
"blt x7, x4, 1b\n\t"
|
||||||
|
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||||
|
"li x4, %[smem_base]\n\t"
|
||||||
|
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
|
||||||
|
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
|
||||||
|
"la x3, g_hgemm_out\n\t"
|
||||||
|
"li x7, 0\n\t"
|
||||||
|
"2:\n\t"
|
||||||
|
"add x4, x2, x7\n\t"
|
||||||
|
"add x1, x3, x7\n\t"
|
||||||
|
".insn r %[custom3], 6, 0, x0, x4, x1\n\t"
|
||||||
|
"addi x7, x7, %[frag_bytes]\n\t"
|
||||||
|
"li x4, %[tile_bytes]\n\t"
|
||||||
|
"blt x7, x4, 2b\n\t"
|
||||||
|
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||||
|
"csrr x5, %[csr_wid]\n\t"
|
||||||
|
"slli x6, x5, 2\n\t"
|
||||||
|
"la x7, g_seen\n\t"
|
||||||
|
"add x7, x7, x6\n\t"
|
||||||
|
"li x6, %[done_base]\n\t"
|
||||||
|
"or x6, x6, x5\n\t"
|
||||||
|
"sw x6, 0(x7)\n\t"
|
||||||
|
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
|
||||||
|
"3: j 3b\n\t"
|
||||||
|
:
|
||||||
|
: [csr_wid] "i"(VX_CSR_WARP_ID),
|
||||||
|
[custom0] "i"(RISCV_CUSTOM0),
|
||||||
|
[custom3] "i"(RISCV_CUSTOM3),
|
||||||
|
[smem_base] "i"(DEV_SMEM_START_ADDR),
|
||||||
|
[done_base] "i"(HGEMM_VALIDATION_BASE),
|
||||||
|
[tile_bytes] "i"(HGEMM_TILE_BYTES),
|
||||||
|
[frag_bytes] "i"(HGEMM_FRAGMENT_BYTES)
|
||||||
|
: "memory");
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" int wu_main() {
|
||||||
|
if (!wu_is_leader()) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
wu_case_reset();
|
||||||
|
|
||||||
|
for (uint32_t i = 0; i < HGEMM_OUT_WORDS; ++i) {
|
||||||
|
g_hgemm_out[i] = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
volatile uint32_t *smem_b =
|
||||||
|
reinterpret_cast<volatile uint32_t *>(DEV_SMEM_START_ADDR);
|
||||||
|
for (uint32_t frag = 0; frag < HGEMM_FRAGMENTS; ++frag) {
|
||||||
|
const uint32_t row = frag * 4u;
|
||||||
|
for (uint32_t i = 0; i < 4u; ++i) {
|
||||||
|
smem_b[row + i] = g_hgemm_b_frag[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint32_t tensor_wid = NUM_SCALAR_WARPS;
|
||||||
|
vx_spawn_tensor(1u << tensor_wid, hgemm_validation_worker);
|
||||||
|
|
||||||
|
if (wu_wait_seen_mask(1u << tensor_wid, HGEMM_VALIDATION_BASE) != 0) {
|
||||||
|
wu_case_fail(0x09u);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (uint32_t i = 0; i < HGEMM_OUT_WORDS; ++i) {
|
||||||
|
if (g_hgemm_out[i] != HGEMM_EXPECTED) {
|
||||||
|
g_aux[0] = i;
|
||||||
|
g_aux[1] = g_hgemm_out[i];
|
||||||
|
wu_case_fail(0x20u);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
wu_case_pass();
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
@@ -1,5 +1,14 @@
|
|||||||
#!/bin/sh
|
#!/bin/sh
|
||||||
|
|
||||||
|
# hopper and virgo has the same SIMT configurations
|
||||||
|
git checkout ae-hopper
|
||||||
|
# git pull
|
||||||
|
|
||||||
|
# re-compile libvortexrt.a
|
||||||
|
pushd ../../lib
|
||||||
|
make
|
||||||
|
popd
|
||||||
|
|
||||||
if [ ! -f input.a.rand01.fp16.m256n256k256.row.bin ]; then
|
if [ ! -f input.a.rand01.fp16.m256n256k256.row.bin ]; then
|
||||||
echo "input binaries not found, generating operands"
|
echo "input binaries not found, generating operands"
|
||||||
python3 generate_operands.py
|
python3 generate_operands.py
|
||||||
|
|||||||
@@ -6,7 +6,7 @@
|
|||||||
#define KERNEL_ARG_DEV_MEM_ADDR 0x9fff0000
|
#define KERNEL_ARG_DEV_MEM_ADDR 0x9fff0000
|
||||||
#define DEV_SMEM_START_ADDR 0xff000000
|
#define DEV_SMEM_START_ADDR 0xff000000
|
||||||
|
|
||||||
typedef struct {
|
typedef struct __attribute__((packed)) {
|
||||||
uint32_t dim_m;
|
uint32_t dim_m;
|
||||||
uint32_t dim_n;
|
uint32_t dim_n;
|
||||||
uint32_t dim_k;
|
uint32_t dim_k;
|
||||||
|
|||||||
@@ -41,12 +41,22 @@ check_exists() {
|
|||||||
fi
|
fi
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# generate operands
|
||||||
|
for dim in "${dims[@]}"; do
|
||||||
|
echo "generating operands for dim $dim"
|
||||||
|
python3 generate_operands.py $dim $dim $dim
|
||||||
|
mv -v input.a.col.bin input.a.rand01.fp16.m${dim}n${dim}k${dim}.col.swizzle_fp16.bin
|
||||||
|
mv -v input.a.row.bin input.a.rand01.fp16.m${dim}n${dim}k${dim}.row.swizzle_fp16.bin
|
||||||
|
mv -v input.b.row.bin input.b.rand01.fp16.m${dim}n${dim}k${dim}.row.bin
|
||||||
|
mv -v input.b.row.swizzled.bin input.b.rand01.fp16.m${dim}n${dim}k${dim}.row.swizzle_fp16.bin
|
||||||
|
done
|
||||||
|
|
||||||
for arch in "${archs[@]}"; do
|
for arch in "${archs[@]}"; do
|
||||||
git checkout ae-$arch
|
git checkout ae-$arch
|
||||||
|
# git pull
|
||||||
|
|
||||||
# re-compile libvortexrt.a
|
# re-compile libvortexrt.a
|
||||||
# FIXME after restructure
|
pushd ../../lib
|
||||||
pushd ../../libs
|
|
||||||
make
|
make
|
||||||
popd
|
popd
|
||||||
|
|
||||||
|
|||||||
116
kernels/sgemm_tcore/generate_operands.py
Normal file
116
kernels/sgemm_tcore/generate_operands.py
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
import sys
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
def parse_mnk():
|
||||||
|
if len(sys.argv) != 4:
|
||||||
|
print(f"usage: {sys.argv[0]} dimM dimN dimK", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
m = int(sys.argv[1])
|
||||||
|
n = int(sys.argv[2])
|
||||||
|
k = int(sys.argv[3])
|
||||||
|
return (m, n, k)
|
||||||
|
|
||||||
|
|
||||||
|
# Reorder array in a way that groups two adjacent elements along the column to
|
||||||
|
# be now adjacent along the row. This way, when the resulting fp16 array is
|
||||||
|
# read in column-major order with 32-bit granularity, the fp16 elements will be
|
||||||
|
# read in the same order as regular fp32 elements in column-major.
|
||||||
|
#
|
||||||
|
# For example:
|
||||||
|
# [[1 2]
|
||||||
|
# [3 4]
|
||||||
|
# [5 6]
|
||||||
|
# [7 8]]
|
||||||
|
# becomes
|
||||||
|
# [[1 3 2 4]
|
||||||
|
# [5 7 6 8]]
|
||||||
|
def pack_fp16_by_column(array):
|
||||||
|
rows = array.shape[0]
|
||||||
|
cols = array.shape[1]
|
||||||
|
|
||||||
|
T = array.transpose([1, 0])
|
||||||
|
T_packed = T.reshape([cols, -1, 2])
|
||||||
|
result = T_packed.transpose([1, 0, 2])
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
# Do the same as pack_fp16_by_column, but for every two elements along the row.
|
||||||
|
def pack_fp16_by_row(array):
|
||||||
|
rows = array.shape[0]
|
||||||
|
cols = array.shape[1]
|
||||||
|
|
||||||
|
result = array.reshape([rows, -1, 2])
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
M, N, K = parse_mnk()
|
||||||
|
|
||||||
|
rand = True
|
||||||
|
if not rand:
|
||||||
|
A_array = np.arange(M * K).reshape([M, K])
|
||||||
|
B_array = np.arange(K * N).reshape([K, N])
|
||||||
|
# C_array = np.arange(M * N).reshape([M, N])
|
||||||
|
C_array = np.zeros([M, N])
|
||||||
|
else:
|
||||||
|
np.random.seed(0)
|
||||||
|
A_array = np.random.rand(M, K)
|
||||||
|
B_array = np.random.rand(K, N)
|
||||||
|
C_array = np.random.rand(N, K)
|
||||||
|
# C_array = np.zeros([M, N])
|
||||||
|
|
||||||
|
with open('a_matrix.h', 'w') as f:
|
||||||
|
for i in range(A_array.shape[0]):
|
||||||
|
for j in range(A_array.shape[1]):
|
||||||
|
f.write(f'{A_array[i,j]:f}f, ')
|
||||||
|
f.write('\n')
|
||||||
|
with open('b_matrix.h', 'w') as f:
|
||||||
|
for i in range(B_array.shape[0]):
|
||||||
|
for j in range(B_array.shape[1]):
|
||||||
|
f.write(f'{B_array[i,j]:f}f, ')
|
||||||
|
f.write('\n')
|
||||||
|
with open('c_matrix.h', 'w') as f:
|
||||||
|
for i in range(C_array.shape[0]):
|
||||||
|
for j in range(C_array.shape[1]):
|
||||||
|
f.write(f'{C_array[i,j]:f}f, ')
|
||||||
|
f.write('\n')
|
||||||
|
|
||||||
|
np.savez("abc", A_array=A_array, B_array=B_array, C_array=C_array)
|
||||||
|
|
||||||
|
fp16 = True
|
||||||
|
if fp16:
|
||||||
|
A_packed = pack_fp16_by_row(A_array)
|
||||||
|
A_swizzled = A_packed.reshape([-1, M * 2])
|
||||||
|
A_swizzled.astype('float16').tofile("input.a.row.bin")
|
||||||
|
AT_packed = A_packed.transpose([1, 0, 2])
|
||||||
|
AT_swizzled = AT_packed.reshape([-1, M * 2])
|
||||||
|
AT_swizzled.astype('float16').tofile("input.a.col.bin")
|
||||||
|
print('A:')
|
||||||
|
print(A_swizzled)
|
||||||
|
print('AT:')
|
||||||
|
print(AT_swizzled)
|
||||||
|
B_array.astype('float16').tofile("input.b.row.bin")
|
||||||
|
# B_packed_row = pack_fp16_by_row(B_array)
|
||||||
|
# B_packed_row = B_packed_row.reshape([-1, N * 2])
|
||||||
|
# B_packed_row.astype('float16').tofile("input.b.row.bin")
|
||||||
|
B_packed = pack_fp16_by_column(B_array)
|
||||||
|
B_swizzled = B_packed.reshape([-1, N * 2])
|
||||||
|
B_swizzled.astype('float16').tofile("input.b.row.swizzled.bin")
|
||||||
|
print('B:')
|
||||||
|
print(B_swizzled)
|
||||||
|
else:
|
||||||
|
A_array.astype('float32').tofile("input.a.row.bin")
|
||||||
|
AT_array = A_array.transpose([1, 0])
|
||||||
|
AT_array.astype('float32').tofile("input.a.col.bin")
|
||||||
|
B_array.astype('float32').tofile("input.b.bin")
|
||||||
|
C_array.astype('float32').tofile("input.c.bin")
|
||||||
|
print('AT:')
|
||||||
|
print(AT_array)
|
||||||
|
print('B:')
|
||||||
|
print(B_array)
|
||||||
|
|
||||||
|
D_expected = A_array @ B_array
|
||||||
|
D_expected.astype('float32').tofile("d_expected.bin")
|
||||||
|
print('D_expected:')
|
||||||
|
print(D_expected)
|
||||||
|
|
||||||
@@ -267,6 +267,34 @@ inline void vx_wgmma_wait() {
|
|||||||
asm volatile (".insn r %0, 1, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3));
|
asm volatile (".insn r %0, 1, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline void vx_tcgen05_cp(const uint32_t addr_tmem, const uint32_t addr_smem) {
|
||||||
|
asm volatile(".insn r %0, 2, 0, x0, %1, %2" ::"i"(RISCV_CUSTOM3), "r"(addr_tmem),
|
||||||
|
"r"(addr_smem));
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void vx_tcgen05_cp_wait() {
|
||||||
|
asm volatile (".insn r %0, 3, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3));
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void vx_bwgmma(const uint32_t addr_tmem_a, const uint32_t addr_smem_b) {
|
||||||
|
asm volatile(".insn r %0, 0, 0, x0, %1, %2" ::"i"(RISCV_CUSTOM3), "r"(addr_tmem_a),
|
||||||
|
"r"(addr_smem_b));
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void vx_bwgmma_wait() {
|
||||||
|
asm volatile (".insn r %0, 1, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3));
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void vx_tcgen05_ld(const uint32_t addr_tmem, const uint32_t rd_hint) {
|
||||||
|
asm volatile(".insn r %0, 4, 0, %1, %2, x0" ::"i"(RISCV_CUSTOM3), "r"(rd_hint),
|
||||||
|
"r"(addr_tmem));
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void vx_tcgen05_st(const uint32_t addr_tmem, const uint32_t rd_hint) {
|
||||||
|
asm volatile(".insn r %0, 5, 0, %1, %2, x0" ::"i"(RISCV_CUSTOM3), "r"(rd_hint),
|
||||||
|
"r"(addr_tmem));
|
||||||
|
}
|
||||||
|
|
||||||
// Remap logical row/col coordinate of a matrix element to a memory index that
|
// Remap logical row/col coordinate of a matrix element to a memory index that
|
||||||
// follows the 2-level block-row-major layout that Gemmini DMA uses
|
// follows the 2-level block-row-major layout that Gemmini DMA uses
|
||||||
template <bool use_dma, uint32_t dim_col>
|
template <bool use_dma, uint32_t dim_col>
|
||||||
@@ -1190,10 +1218,10 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
|
|||||||
(uint64_t)(B + /*block_k:*/ 0 * BK * dim_n + block_n * BN),
|
(uint64_t)(B + /*block_k:*/ 0 * BK * dim_n + block_n * BN),
|
||||||
k_LOOP_WS_CONFIG_ADDRS_AB)
|
k_LOOP_WS_CONFIG_ADDRS_AB)
|
||||||
// GEMMINI_CISC(8) does k_LOOP_WS_CONFIG_STRIDES_AB
|
// GEMMINI_CISC(8) does k_LOOP_WS_CONFIG_STRIDES_AB
|
||||||
GEMMINI_CISC_CMD_R((dim_n << 20) | (dim_k << 8) | 8);
|
GEMMINI_CISC_CMD_R((dim_n << 20) | (dim_k << 8) | GEMMINI_CISC_SET_AB_STRIDE);
|
||||||
gemmini_fence();
|
gemmini_fence();
|
||||||
|
|
||||||
GEMMINI_CISC_CMD_I(10);
|
GEMMINI_CISC_CMD_R((11 << 16) | (0 << 8) | GEMMINI_CISC_LOAD_TO_HEXADECILES);
|
||||||
gemmini_fence();
|
gemmini_fence();
|
||||||
|
|
||||||
#if 0
|
#if 0
|
||||||
@@ -1257,7 +1285,7 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
|
|||||||
k_LOOP_WS_CONFIG_ADDRS_AB)
|
k_LOOP_WS_CONFIG_ADDRS_AB)
|
||||||
// GEMMINI_CISC(8) does k_LOOP_WS_CONFIG_STRIDES_AB
|
// GEMMINI_CISC(8) does k_LOOP_WS_CONFIG_STRIDES_AB
|
||||||
GEMMINI_CISC_CMD_R((dim_n << 20) | (dim_k << 8) | 8);
|
GEMMINI_CISC_CMD_R((dim_n << 20) | (dim_k << 8) | 8);
|
||||||
gemmini_fence();
|
// gemmini_fence();
|
||||||
|
|
||||||
// block_k is even: opcode 11 (write to local_a_buf)
|
// block_k is even: opcode 11 (write to local_a_buf)
|
||||||
// block_k is odd: opcode 10 (write to local_a)
|
// block_k is odd: opcode 10 (write to local_a)
|
||||||
@@ -1266,8 +1294,9 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
|
|||||||
// the last iteration of the k-loop is prefetching for the first
|
// the last iteration of the k-loop is prefetching for the first
|
||||||
// iteration of the n-loop. The ping-poing indexing has to match for
|
// iteration of the n-loop. The ping-poing indexing has to match for
|
||||||
// the two loop end to connect.
|
// the two loop end to connect.
|
||||||
const uint32_t opcode = 11 - (block_k & 1);
|
const uint32_t a_hexadecile = 4 - ((block_k & 1) * 4);
|
||||||
GEMMINI_CISC_CMD_I(opcode);
|
const uint32_t b_hexadecile = a_hexadecile + 11;
|
||||||
|
GEMMINI_CISC_CMD_R((b_hexadecile << 16) | (a_hexadecile << 8) | GEMMINI_CISC_LOAD_TO_HEXADECILES);
|
||||||
// // TODO: branch is probably slow
|
// // TODO: branch is probably slow
|
||||||
// if (block_k & 1) {
|
// if (block_k & 1) {
|
||||||
// GEMMINI_CISC_CMD_I(12);
|
// GEMMINI_CISC_CMD_I(12);
|
||||||
|
|||||||
26
kernels/wu_arch/Makefile
Normal file
26
kernels/wu_arch/Makefile
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
PROJECT = wu_arch
|
||||||
|
|
||||||
|
VX_SRCS = kernel.cpp
|
||||||
|
|
||||||
|
OPTS ?= -n1
|
||||||
|
|
||||||
|
WU_VARIANT_DUMPS = \
|
||||||
|
kernel.radiance.barriers.dump
|
||||||
|
|
||||||
|
all: kernel.radiance.dump $(WU_VARIANT_DUMPS)
|
||||||
|
|
||||||
|
include ../common.mk
|
||||||
|
|
||||||
|
kernel.radiance.barriers.dump: kernel.radiance.barriers.elf
|
||||||
|
$(VX_DP) -D $< > $@
|
||||||
|
|
||||||
|
kernel.radiance.barriers.elf: $(VX_SRCS) $(VX_INCLUDES) $(BINFILES)
|
||||||
|
$(VX_CXX) $(VX_CFLAGS) $(VX_SRCS) $(VX_LDFLAGS) -DRADIANCE -DWU_RUN_DOMAIN_BARRIERS -o $@
|
||||||
|
$(OBJCOPY) --set-section-flags .operand.a=$(OBJCOPY_FLAGS) $@
|
||||||
|
$(OBJCOPY) --set-section-flags .operand.b=$(OBJCOPY_FLAGS) $@
|
||||||
|
$(OBJCOPY) --set-section-flags .operand.c=$(OBJCOPY_FLAGS) $@
|
||||||
|
$(OBJCOPY) --set-section-flags .args=$(OBJCOPY_FLAGS) $@
|
||||||
|
$(OBJCOPY) --update-section .operand.a=input.a.bin $@ || true
|
||||||
|
$(OBJCOPY) --update-section .operand.b=input.b.bin $@ || true
|
||||||
|
$(OBJCOPY) --update-section .operand.c=input.c.bin $@ || true
|
||||||
|
$(OBJCOPY) --update-section .args=args.bin $@ || true
|
||||||
1
kernels/wu_arch/args.bin
Normal file
1
kernels/wu_arch/args.bin
Normal file
@@ -0,0 +1 @@
|
|||||||
|
0
|
||||||
1
kernels/wu_arch/input.a.bin
Normal file
1
kernels/wu_arch/input.a.bin
Normal file
@@ -0,0 +1 @@
|
|||||||
|
0
|
||||||
1
kernels/wu_arch/input.b.bin
Normal file
1
kernels/wu_arch/input.b.bin
Normal file
@@ -0,0 +1 @@
|
|||||||
|
0
|
||||||
1
kernels/wu_arch/input.c.bin
Normal file
1
kernels/wu_arch/input.c.bin
Normal file
@@ -0,0 +1 @@
|
|||||||
|
0
|
||||||
173
kernels/wu_arch/kernel.cpp
Normal file
173
kernels/wu_arch/kernel.cpp
Normal file
@@ -0,0 +1,173 @@
|
|||||||
|
#include <stdint.h>
|
||||||
|
#include <vx_intrinsics.h>
|
||||||
|
|
||||||
|
#define DEV_SMEM_START_ADDR 0xff000000u
|
||||||
|
#define MAX_WARPS 8
|
||||||
|
#define MINIMAL_INIT_WORDS 4
|
||||||
|
#define WU_SCALAR_SPIN 32
|
||||||
|
#define WU_TENSOR_SPIN 32
|
||||||
|
#define WU_WAIT_SPIN 8192
|
||||||
|
#define WU_STATUS_DONE 0x600du
|
||||||
|
#define WU_STATUS_SCALAR_BASE 0x5100u
|
||||||
|
#define WU_STATUS_TENSOR_BASE 0x7100u
|
||||||
|
#define WU_BARRIER_SCALAR 0u
|
||||||
|
#define WU_BARRIER_MASKED 1u
|
||||||
|
#define WU_BARRIER_TENSOR 2u
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
volatile uint32_t g_status[MAX_WARPS] __attribute__((aligned(32)));
|
||||||
|
volatile uint32_t g_scalar_seen[MAX_WARPS] __attribute__((aligned(32)));
|
||||||
|
volatile uint32_t g_tensor_seen[MAX_WARPS] __attribute__((aligned(32)));
|
||||||
|
volatile uint32_t g_spin_sink[MAX_WARPS] __attribute__((aligned(32)));
|
||||||
|
extern volatile uint64_t tohost;
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" void vx_perf_dump() {}
|
||||||
|
|
||||||
|
static inline void wu_report_tohost(uint32_t exit_code) {
|
||||||
|
asm volatile("fence rw, rw" ::: "memory");
|
||||||
|
tohost = (static_cast<uint64_t>(exit_code) << 1) | 1u;
|
||||||
|
asm volatile("fence rw, rw" ::: "memory");
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" int wu_main();
|
||||||
|
|
||||||
|
extern "C" void __attribute__((naked, section(".init"), used)) _start() {
|
||||||
|
asm volatile(
|
||||||
|
".option push\n\t"
|
||||||
|
".option norelax\n\t"
|
||||||
|
"la gp, __global_pointer\n\t"
|
||||||
|
".option pop\n\t"
|
||||||
|
"csrr t0, %[csr_core]\n\t"
|
||||||
|
"bnez t0, 2f\n\t"
|
||||||
|
"li sp, %[stack_base]\n\t"
|
||||||
|
"call wu_main\n\t"
|
||||||
|
"mv gp, a0\n\t"
|
||||||
|
"2:\n\t"
|
||||||
|
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
|
||||||
|
"1: j 1b\n\t"
|
||||||
|
:
|
||||||
|
: [csr_core] "i"(VX_CSR_CORE_ID),
|
||||||
|
[stack_base] "i"(STACK_BASE_ADDR),
|
||||||
|
[custom0] "i"(RISCV_CUSTOM0)
|
||||||
|
: "memory");
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" void scalar_worker() {
|
||||||
|
const uint32_t wid = static_cast<uint32_t>(vx_warp_id());
|
||||||
|
const uint32_t tid = static_cast<uint32_t>(vx_thread_id());
|
||||||
|
volatile uint32_t mix = wid + 1u;
|
||||||
|
|
||||||
|
#ifdef WU_RUN_DOMAIN_BARRIERS
|
||||||
|
vx_barrier_scalar(WU_BARRIER_SCALAR, NUM_SCALAR_WARPS);
|
||||||
|
vx_barrier_mask(WU_BARRIER_MASKED, vx_scalar_warp_mask());
|
||||||
|
#endif
|
||||||
|
|
||||||
|
for (uint32_t i = 0; i < WU_SCALAR_SPIN; ++i)
|
||||||
|
mix = (mix << 1) ^ (i + wid);
|
||||||
|
|
||||||
|
if (tid == 0 && wid < MAX_WARPS) {
|
||||||
|
g_spin_sink[wid] = mix;
|
||||||
|
g_scalar_seen[wid] = WU_STATUS_SCALAR_BASE | wid;
|
||||||
|
}
|
||||||
|
|
||||||
|
vx_tmc_zero();
|
||||||
|
while (1) {}
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" void __attribute__((naked, noinline, used)) tensor_worker() {
|
||||||
|
asm volatile(
|
||||||
|
"csrr x5, %[csr_wid]\n\t"
|
||||||
|
#ifdef WU_RUN_DOMAIN_BARRIERS
|
||||||
|
"li x1, (%[bar_tensor] | (%[domain_tensor] << %[domain_shift]))\n\t"
|
||||||
|
"li x2, %[num_tensor]\n\t"
|
||||||
|
".insn r %[custom0], 4, 0, x0, x1, x2\n\t"
|
||||||
|
#endif
|
||||||
|
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||||
|
"li x7, %[tensor_spin]\n\t"
|
||||||
|
"1:\n\t"
|
||||||
|
"addi x7, x7, -1\n\t"
|
||||||
|
"bnez x7, 1b\n\t"
|
||||||
|
"slli x5, x5, 2\n\t"
|
||||||
|
"la x6, g_tensor_seen\n\t"
|
||||||
|
"add x6, x6, x5\n\t"
|
||||||
|
"li x7, %[tensor_base]\n\t"
|
||||||
|
"srli x5, x5, 2\n\t"
|
||||||
|
"or x7, x7, x5\n\t"
|
||||||
|
"sw x7, 0(x6)\n\t"
|
||||||
|
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
|
||||||
|
"2: j 2b\n\t"
|
||||||
|
:
|
||||||
|
: [csr_wid] "i"(VX_CSR_WARP_ID),
|
||||||
|
[custom0] "i"(RISCV_CUSTOM0),
|
||||||
|
[custom3] "i"(RISCV_CUSTOM3),
|
||||||
|
[bar_tensor] "i"(WU_BARRIER_TENSOR),
|
||||||
|
[domain_tensor] "i"(VX_BARRIER_DOMAIN_TENSOR),
|
||||||
|
[domain_shift] "i"(VX_BARRIER_DOMAIN_SHIFT),
|
||||||
|
[num_tensor] "i"(NUM_TENSOR_WARPS),
|
||||||
|
[tensor_spin] "i"(WU_TENSOR_SPIN),
|
||||||
|
[tensor_base] "i"(WU_STATUS_TENSOR_BASE)
|
||||||
|
: "memory");
|
||||||
|
}
|
||||||
|
|
||||||
|
static void init_state() {
|
||||||
|
g_status[0] = 0;
|
||||||
|
for (uint32_t i = 0; i < MAX_WARPS; ++i) {
|
||||||
|
g_scalar_seen[i] = 0;
|
||||||
|
g_tensor_seen[i] = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
volatile uint32_t *smem =
|
||||||
|
reinterpret_cast<volatile uint32_t *>(DEV_SMEM_START_ADDR);
|
||||||
|
for (uint32_t i = 0; i < MINIMAL_INIT_WORDS; ++i)
|
||||||
|
smem[i] = 0x100u + i;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int wait_for_wu_completion() {
|
||||||
|
for (uint32_t spin = 0; spin < WU_WAIT_SPIN; ++spin) {
|
||||||
|
uint32_t done = 1;
|
||||||
|
for (uint32_t wid = 0; wid < NUM_SCALAR_WARPS; ++wid)
|
||||||
|
done &= (g_scalar_seen[wid] == (WU_STATUS_SCALAR_BASE | wid));
|
||||||
|
for (uint32_t wid = NUM_SCALAR_WARPS; wid < NUM_WARPS; ++wid)
|
||||||
|
done &= (g_tensor_seen[wid] == (WU_STATUS_TENSOR_BASE | wid));
|
||||||
|
if (done)
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" int wu_main() {
|
||||||
|
if (vx_core_id() != 0 || vx_warp_id() != 0)
|
||||||
|
return 0;
|
||||||
|
if (vx_thread_id() != 0)
|
||||||
|
return 0;
|
||||||
|
|
||||||
|
init_state();
|
||||||
|
|
||||||
|
const uint32_t other_scalar_warps = vx_scalar_warp_mask() & ~1u;
|
||||||
|
if (other_scalar_warps != 0)
|
||||||
|
vx_spawn_scalar(other_scalar_warps, scalar_worker);
|
||||||
|
|
||||||
|
vx_spawn_tensor(vx_tensor_warp_mask(), tensor_worker);
|
||||||
|
|
||||||
|
#ifdef WU_RUN_DOMAIN_BARRIERS
|
||||||
|
vx_barrier_scalar(WU_BARRIER_SCALAR, NUM_SCALAR_WARPS);
|
||||||
|
vx_barrier_mask(WU_BARRIER_MASKED, vx_scalar_warp_mask());
|
||||||
|
#endif
|
||||||
|
|
||||||
|
volatile uint32_t mix = 1;
|
||||||
|
for (uint32_t i = 0; i < WU_SCALAR_SPIN; ++i)
|
||||||
|
mix = (mix << 1) ^ (i + 3u);
|
||||||
|
g_spin_sink[0] = mix;
|
||||||
|
g_scalar_seen[0] = WU_STATUS_SCALAR_BASE;
|
||||||
|
|
||||||
|
if (wait_for_wu_completion() != 0) {
|
||||||
|
g_status[0] = 0xe001u;
|
||||||
|
wu_report_tohost(1);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
g_status[0] = WU_STATUS_DONE;
|
||||||
|
wu_report_tohost(0);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
58
kernels/wu_arch_cases/Makefile
Normal file
58
kernels/wu_arch_cases/Makefile
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
CASES := \
|
||||||
|
case00_boot_scalar \
|
||||||
|
case01_scalar_spawn \
|
||||||
|
case02_tensor_spawn_stop \
|
||||||
|
case03_dual_fetch_issue \
|
||||||
|
case04_scalar_barrier \
|
||||||
|
case05_tensor_barrier \
|
||||||
|
case06_masked_barrier \
|
||||||
|
case07_tensor_csr_tmc \
|
||||||
|
case08_tensor_lsu_optional \
|
||||||
|
case09_scalar_tmem_ldst \
|
||||||
|
case10_tensor_scalar_tmem_handoff \
|
||||||
|
case11_scalar_tmem_softmax_stage \
|
||||||
|
case12_flash_pv_accum \
|
||||||
|
case12_1_scalar_tmem_cb_probe \
|
||||||
|
case12_2_flash_pv_p_probe \
|
||||||
|
case12_3_scalar_tmem_lane_store \
|
||||||
|
case13_flash_pv_two_warps \
|
||||||
|
case14_flash_pv_k64 \
|
||||||
|
case15_flash_softmax_pv_stage \
|
||||||
|
case16_flash_full_pipeline \
|
||||||
|
case17_flash_exp_softmax_probe \
|
||||||
|
case18_scalar_fexp \
|
||||||
|
case20_flash_bwd_fused \
|
||||||
|
case21_moe_gating \
|
||||||
|
case22_gemm_silu \
|
||||||
|
case23_softmax_only \
|
||||||
|
case24_flash_sw_pipeline
|
||||||
|
|
||||||
|
SMOKE_CASES := \
|
||||||
|
case00_boot_scalar \
|
||||||
|
case01_scalar_spawn \
|
||||||
|
case02_tensor_spawn_stop \
|
||||||
|
case03_dual_fetch_issue
|
||||||
|
|
||||||
|
BARRIER_CASES := \
|
||||||
|
case04_scalar_barrier \
|
||||||
|
case05_tensor_barrier \
|
||||||
|
case06_masked_barrier
|
||||||
|
|
||||||
|
.PHONY: all smoke barriers full clean clean-all $(CASES)
|
||||||
|
|
||||||
|
all: full
|
||||||
|
|
||||||
|
smoke: $(SMOKE_CASES)
|
||||||
|
|
||||||
|
barriers: $(BARRIER_CASES)
|
||||||
|
|
||||||
|
full: $(CASES)
|
||||||
|
|
||||||
|
$(CASES):
|
||||||
|
$(MAKE) -C $@
|
||||||
|
|
||||||
|
clean:
|
||||||
|
set -e; for dir in $(CASES); do $(MAKE) -C $$dir clean; done
|
||||||
|
|
||||||
|
clean-all:
|
||||||
|
set -e; for dir in $(CASES); do $(MAKE) -C $$dir clean-all; done
|
||||||
72
kernels/wu_arch_cases/README.md
Normal file
72
kernels/wu_arch_cases/README.md
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
# Wu Architecture Staged Cases
|
||||||
|
|
||||||
|
This directory contains small bare-metal kernels for incremental Wu architecture testing. The original `kernels/wu_arch` kernel is useful as an integrated test, but it combines scalar spawning, tensor spawning, barriers, tensor control, and memory behavior in one large workload. These cases isolate those surfaces so failures can be reproduced faster under Verilator.
|
||||||
|
|
||||||
|
## Case List
|
||||||
|
|
||||||
|
- `case00_boot_scalar`: minimal scalar boot, status writes, and pass marker.
|
||||||
|
- `case01_scalar_spawn`: scalar warp spawning without tensor warps or barriers.
|
||||||
|
- `case02_tensor_spawn_stop`: tensor warp spawn, marker store, and stop.
|
||||||
|
- `case03_dual_fetch_issue`: scalar and tensor warps active together to exercise split scheduling and issue.
|
||||||
|
- `case04_scalar_barrier`: scalar-domain barrier release.
|
||||||
|
- `case05_tensor_barrier`: tensor-domain barrier through tensor control.
|
||||||
|
- `case06_masked_barrier`: explicit mixed `BAR_MASK` with scalar warp 0 and tensor warps.
|
||||||
|
- `case07_tensor_csr_tmc`: tensor CSR/TMC path without barrier behavior.
|
||||||
|
- `case08_tensor_lsu_optional`: tensor LSU store/load marker path; keep last because memory interaction is broader and slower.
|
||||||
|
- `case09_scalar_tmem_ldst`: scalar warp direct TMEM store/load path for the banked TMEM softmax mechanism.
|
||||||
|
- `case10_tensor_scalar_tmem_handoff`: tensor BWGMMA result in TMEM C observed by scalar TMEM loads.
|
||||||
|
- `case11_scalar_tmem_softmax_stage`: scalar TMEM transform written back for tensor-side copy-out.
|
||||||
|
- `case12_flash_pv_accum`: one tensor warp consumes scalar-written `P` in TMEM A for `O = O + P @ V`.
|
||||||
|
- `case12_1_scalar_tmem_cb_probe`: scalar-written TMEM A rows copied back by tensor `tcgen05_cb`.
|
||||||
|
- `case12_2_flash_pv_p_probe`: case12 P-write diagnostic using the same scalar fill path.
|
||||||
|
- `case12_3_scalar_tmem_lane_store`: scalar TMEM store lane-coalesced fragment write diagnostic.
|
||||||
|
- `case13_flash_pv_two_warps`: both tensor warps consume scalar-written `P` tiles for two row blocks.
|
||||||
|
- `case14_flash_pv_k64`: two consecutive `K=32` BWGMMA steps accumulate into one PV output.
|
||||||
|
- `case15_flash_softmax_pv_stage`: scalar reads TMEM C, writes softmax-like `P`, and tensor consumes it in PV.
|
||||||
|
- `case16_flash_full_pipeline`: compact `QK -> scalar softmax handoff -> PV` end-to-end FlashAttention-style pipeline.
|
||||||
|
- `case17_flash_exp_softmax_probe`: scalar non-uniform `e^x` softmax probe for generalized FlashAttention.
|
||||||
|
- `case18_scalar_fexp`: scalar `FEXP.S` numerical probe.
|
||||||
|
- `case20_flash_bwd_fused`: FlashAttention backward-style fused 5xMMA plus scalar softmax/dsoftmax handoff.
|
||||||
|
- `case21_moe_gating`: scalar `softmax -> Top-K -> scatter` MoE gating pipeline.
|
||||||
|
- `case22_gemm_silu`: tensor GEMM followed by scalar SiLU activation.
|
||||||
|
- `case23_softmax_only`: scalar-only stable softmax probe.
|
||||||
|
- `case24_flash_sw_pipeline`: four-iteration ping-pong FlashAttention-style software pipeline.
|
||||||
|
|
||||||
|
Each case has its own `README.md` describing the test objective, RTL surface, and expected pass marker.
|
||||||
|
|
||||||
|
## Debug Notes
|
||||||
|
|
||||||
|
- `TMC_DEBUG_NOTES.md`: reusable notes for diagnosing lane-mask/TMC operand
|
||||||
|
bugs, including the `case12_3_scalar_tmem_lane_store` failure where a
|
||||||
|
lane0-only register value was consumed under an all-lane mask.
|
||||||
|
|
||||||
|
## Build
|
||||||
|
|
||||||
|
Use the suite Makefile from this directory:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
make smoke -j4 LLVM_VORTEX=/home/hexu/dse/wu/virgo-artifact-full/toolchain/llvm-vortex2 RISCV_TOOLCHAIN_PATH=/home/hexu/dse/wu/virgo-artifact-full/chipyard/.conda-env/riscv-tools RISCV_PREFIX=riscv64-unknown-elf
|
||||||
|
make barriers -j4 LLVM_VORTEX=/home/hexu/dse/wu/virgo-artifact-full/toolchain/llvm-vortex2 RISCV_TOOLCHAIN_PATH=/home/hexu/dse/wu/virgo-artifact-full/chipyard/.conda-env/riscv-tools RISCV_PREFIX=riscv64-unknown-elf
|
||||||
|
make full -j4 LLVM_VORTEX=/home/hexu/dse/wu/virgo-artifact-full/toolchain/llvm-vortex2 RISCV_TOOLCHAIN_PATH=/home/hexu/dse/wu/virgo-artifact-full/chipyard/.conda-env/riscv-tools RISCV_PREFIX=riscv64-unknown-elf
|
||||||
|
```
|
||||||
|
|
||||||
|
`smoke` builds the boot/spawn/dual-issue cases. `barriers` builds the barrier-focused cases. `full` builds all cases.
|
||||||
|
|
||||||
|
## Verilator Run Notes
|
||||||
|
|
||||||
|
For RTL simulation, use the same simulator setup as the main Virgo artifact, but run these ELFs one at a time:
|
||||||
|
|
||||||
|
- `VM_PARALLEL_BUILDS=1`
|
||||||
|
- `LOADMEM=1`, so `SimDRAM::memory_init()` preloads the ELF instead of relying on slow runtime SimTSI writes.
|
||||||
|
- `CCACHE_DIR=/tmp/ccache` when ccache is enabled in the sandbox.
|
||||||
|
- Use `/home/hexu/dse/firtool-1.62.0` for firtool and `/usr/local/bin/verilator` for Verilator.
|
||||||
|
- Keep system `gcc/g++` on `PATH`; do not use the `gcc/g++` injected by `chipyard/env.sh`.
|
||||||
|
- For generated Verilator C++ compilation, prefer `-O0 -fno-inline` to reduce compile time.
|
||||||
|
|
||||||
|
## Cleanup
|
||||||
|
|
||||||
|
```sh
|
||||||
|
make clean-all
|
||||||
|
```
|
||||||
|
|
||||||
|
This removes kernel ELF/dump outputs and the generated placeholder input blobs in each case directory.
|
||||||
174
kernels/wu_arch_cases/TMC_DEBUG_NOTES.md
Normal file
174
kernels/wu_arch_cases/TMC_DEBUG_NOTES.md
Normal file
@@ -0,0 +1,174 @@
|
|||||||
|
# TMC Operand Debug Notes
|
||||||
|
|
||||||
|
This note records a failure pattern found while debugging
|
||||||
|
`case12_3_scalar_tmem_lane_store`.
|
||||||
|
|
||||||
|
## Symptom
|
||||||
|
|
||||||
|
The simulator keeps producing trace output, but the kernel makes no forward
|
||||||
|
progress:
|
||||||
|
|
||||||
|
- a tensor worker repeatedly polls a ready flag such as `g_case_mem[0]`;
|
||||||
|
- the poll load always returns the old value;
|
||||||
|
- the scalar warp has already committed the work before the ready flag store;
|
||||||
|
- the expected ready flag store never appears in the LSU trace.
|
||||||
|
|
||||||
|
For `case12_3_scalar_tmem_lane_store`, the tensor worker was looping at
|
||||||
|
`PC=0x80000034..0x80000048`, repeatedly reading `g_case_mem[0]` at
|
||||||
|
`0x20000430` as `0`. The scalar warp committed the scalar TMEM store and the
|
||||||
|
following `TMC`, but never issued the next ready flag store at `PC=0x8000015c`.
|
||||||
|
|
||||||
|
## Root Cause Pattern
|
||||||
|
|
||||||
|
Vortex scalar registers are lane-local. A value computed while only lane 0 is
|
||||||
|
active is only known to be valid in lane 0. If that register is later consumed
|
||||||
|
by an instruction while multiple lanes are active, the inactive lanes may hold
|
||||||
|
stale or zero values.
|
||||||
|
|
||||||
|
This is especially easy to miss around `TMC`:
|
||||||
|
|
||||||
|
```asm
|
||||||
|
# Bad when xN was defined while only lane 0 was active.
|
||||||
|
vx_tmc all_lanes
|
||||||
|
...
|
||||||
|
vx_tmc xN
|
||||||
|
```
|
||||||
|
|
||||||
|
The failed `case12_3` sequence used a C operand for `1u` that the compiler
|
||||||
|
materialized before switching to all lanes. Runtime trace showed the source
|
||||||
|
register for the final `TMC(1)` as:
|
||||||
|
|
||||||
|
```text
|
||||||
|
rs1_data={0x0, 0x0, 0x0, 0x1}
|
||||||
|
```
|
||||||
|
|
||||||
|
After that `TMC`, the scalar warp did not fetch the ready flag store.
|
||||||
|
|
||||||
|
## Correct Pattern
|
||||||
|
|
||||||
|
When a value is consumed under an all-lane mask, define that value under the
|
||||||
|
same all-lane mask unless the instruction semantics explicitly use only lane 0.
|
||||||
|
|
||||||
|
For switching back to lane 0 from an all-lane region, keep the immediate
|
||||||
|
materialization and the `TMC` adjacent inside the all-lane region:
|
||||||
|
|
||||||
|
```asm
|
||||||
|
vx_tmc all_lanes
|
||||||
|
...
|
||||||
|
fence rw, rw
|
||||||
|
li t2, 1
|
||||||
|
vx_tmc t2
|
||||||
|
```
|
||||||
|
|
||||||
|
The expected trace at the final `TMC` is:
|
||||||
|
|
||||||
|
```text
|
||||||
|
rs1_data={0x1, 0x1, 0x1, 0x1}
|
||||||
|
```
|
||||||
|
|
||||||
|
The library helper `vx_tmc_one()` follows this pattern because it emits
|
||||||
|
`li a0, 1` and `vx_tmc a0` in the same volatile asm block. It is safe when
|
||||||
|
called while all lanes are active.
|
||||||
|
|
||||||
|
## Fast Log Checks
|
||||||
|
|
||||||
|
Use the simulation log to distinguish a simulator stall from a kernel-level
|
||||||
|
wait loop:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
stat -c '%s %y' chipyard/sims/verilator/output/chipyard.harness.TestHarness.VirgoBlackwellConfig/kernel.radiance.log
|
||||||
|
tail -n 80 chipyard/sims/verilator/output/chipyard.harness.TestHarness.VirgoBlackwellConfig/kernel.radiance.log
|
||||||
|
```
|
||||||
|
|
||||||
|
If the file is still growing but the tail repeatedly shows the same PC range,
|
||||||
|
the simulator is alive and the kernel is probably spinning.
|
||||||
|
|
||||||
|
For a ready flag handoff, check both the polling load and the producer store:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
rg -n "0x20000430|0x9900|wid=0, PC=0x8000014|wid=0, PC=0x8000015" \
|
||||||
|
chipyard/sims/verilator/output/chipyard.harness.TestHarness.VirgoBlackwellConfig/kernel.radiance.log
|
||||||
|
```
|
||||||
|
|
||||||
|
Interpretation:
|
||||||
|
|
||||||
|
- load repeatedly returns `0` and no later `0x9900` store exists: producer did
|
||||||
|
not reach the ready flag store;
|
||||||
|
- ready flag store exists but consumer still reads old data: investigate memory
|
||||||
|
ordering or address aliasing;
|
||||||
|
- producer stops immediately after a `TMC`: inspect that `TMC` source register
|
||||||
|
across all lanes.
|
||||||
|
|
||||||
|
## Dump Checks
|
||||||
|
|
||||||
|
Check that the final lane-mask narrowing defines `1` immediately before `TMC`
|
||||||
|
inside the all-lane region:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
rg -n "li\\s+.*1|vx_tmc" \
|
||||||
|
kernels/wu_arch_cases/case12_3_scalar_tmem_lane_store/kernel.radiance.dump
|
||||||
|
```
|
||||||
|
|
||||||
|
The fixed `case12_3` dump has:
|
||||||
|
|
||||||
|
```text
|
||||||
|
li t2, 1
|
||||||
|
vx_tmc t2
|
||||||
|
auipc a3, 1
|
||||||
|
sw a5, ...
|
||||||
|
```
|
||||||
|
|
||||||
|
## Case12-Case15 Audit
|
||||||
|
|
||||||
|
The following cases were checked after fixing `case12_3`:
|
||||||
|
|
||||||
|
- `case12_flash_pv_accum`
|
||||||
|
- `case12_2_flash_pv_p_probe`
|
||||||
|
- `case13_flash_pv_two_warps`
|
||||||
|
- `case14_flash_pv_k64`
|
||||||
|
- `case15_flash_softmax_pv_stage`
|
||||||
|
|
||||||
|
They switch to all lanes with `vx_tmc(wu_bw_all_lanes_mask())` and switch back
|
||||||
|
with `vx_tmc_one()`. Their dumps show the safe adjacent sequence:
|
||||||
|
|
||||||
|
```text
|
||||||
|
li a0, 1
|
||||||
|
vx_tmc a0
|
||||||
|
```
|
||||||
|
|
||||||
|
No analogous source change is needed for those cases.
|
||||||
|
|
||||||
|
## Scalar TMEM Fill Lane Coverage
|
||||||
|
|
||||||
|
`wu_bw_fill_tmem_tile()` must run with all fragment lanes active. Scalar TMEM
|
||||||
|
store writes the active lane's source word into the matching TMEM fragment word:
|
||||||
|
|
||||||
|
```text
|
||||||
|
TMEM[addr].word[lane] = rs2_data[lane]
|
||||||
|
```
|
||||||
|
|
||||||
|
Therefore a fill loop executed with `tmask=0001` only initializes word 0 of
|
||||||
|
each 16-byte fragment. Word 1 and later can retain old TMEM data. In
|
||||||
|
`case12_1_scalar_tmem_cb_probe`, this showed up as a normal completion path
|
||||||
|
followed by verification failure `0x14`; `g_aux[0]` was `1`, and `g_aux[1]`
|
||||||
|
held the stale copied-back word.
|
||||||
|
|
||||||
|
The correct pattern is:
|
||||||
|
|
||||||
|
```c++
|
||||||
|
vx_tmc(wu_bw_all_lanes_mask());
|
||||||
|
wu_bw_fill_tmem_tile(wu_bw_tmem_a_byte_base(0), WU_BW_FP16_ONE_PACKED);
|
||||||
|
vx_tmc_one();
|
||||||
|
```
|
||||||
|
|
||||||
|
The dump should show the fill loop bracketed by all-lane and lane-0 masks:
|
||||||
|
|
||||||
|
```text
|
||||||
|
li a5, 15
|
||||||
|
vx_tmc a5
|
||||||
|
...
|
||||||
|
vx_cmov zero, a0, a5, a2
|
||||||
|
...
|
||||||
|
li a0, 1
|
||||||
|
vx_tmc a0
|
||||||
|
```
|
||||||
19
kernels/wu_arch_cases/case.mk
Normal file
19
kernels/wu_arch_cases/case.mk
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
PROJECT ?= wu_arch_case
|
||||||
|
VX_SRCS = kernel.cpp
|
||||||
|
VX_CFLAGS += -I..
|
||||||
|
VORTEX_KN_PATH ?= $(realpath ../../../lib)
|
||||||
|
GEMMINI_SW_PATH ?= $(realpath ../../../lib/gemmini)
|
||||||
|
LLVM_VORTEX ?= $(realpath ../../../../toolchain/llvm-r8)
|
||||||
|
RISCV_TOOLCHAIN_PATH ?= $(realpath ../../../../toolchain/vortex-toolchain-prebuilt/riscv-gnu-toolchain)
|
||||||
|
OPTS ?= -n1
|
||||||
|
|
||||||
|
include ../../common.mk
|
||||||
|
|
||||||
|
args.bin input.a.bin input.b.bin input.c.bin: ../zero.bin
|
||||||
|
cp $< $@
|
||||||
|
|
||||||
|
clean-all: clean-wu-case-inputs
|
||||||
|
|
||||||
|
.PHONY: clean-wu-case-inputs
|
||||||
|
clean-wu-case-inputs:
|
||||||
|
rm -f args.bin input.a.bin input.b.bin input.c.bin
|
||||||
3
kernels/wu_arch_cases/case00_boot_scalar/Makefile
Normal file
3
kernels/wu_arch_cases/case00_boot_scalar/Makefile
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
PROJECT = wu_arch_case00_boot_scalar
|
||||||
|
|
||||||
|
include ../case.mk
|
||||||
15
kernels/wu_arch_cases/case00_boot_scalar/README.md
Normal file
15
kernels/wu_arch_cases/case00_boot_scalar/README.md
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
# Case 00: Boot Scalar
|
||||||
|
|
||||||
|
## Test Objective
|
||||||
|
|
||||||
|
Verify the minimal Wu bare-metal entry path: core 0, warp 0, thread 0 reaches `wu_main`, can write the shared status arrays, and can terminate with `WU_CASE_PASS`.
|
||||||
|
|
||||||
|
## RTL Surface Covered
|
||||||
|
|
||||||
|
- Minimal `_start` path from `common_wu_min.h`
|
||||||
|
- Scalar warp 0 fetch/decode/issue
|
||||||
|
- Scalar ALU/store path for status writes
|
||||||
|
|
||||||
|
## Expected Result
|
||||||
|
|
||||||
|
`g_status[0] == WU_CASE_PASS` and `g_seen[0] == WU_CASE_SCALAR_BASE`.
|
||||||
13
kernels/wu_arch_cases/case00_boot_scalar/kernel.cpp
Normal file
13
kernels/wu_arch_cases/case00_boot_scalar/kernel.cpp
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
#include "common_wu_min.h"
|
||||||
|
|
||||||
|
extern "C" int wu_main() {
|
||||||
|
if (!wu_is_leader()) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
wu_case_reset();
|
||||||
|
g_seen[0] = WU_CASE_SCALAR_BASE;
|
||||||
|
g_aux[0] = static_cast<uint32_t>(vx_num_warps());
|
||||||
|
wu_case_pass();
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
3
kernels/wu_arch_cases/case01_scalar_spawn/Makefile
Normal file
3
kernels/wu_arch_cases/case01_scalar_spawn/Makefile
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
PROJECT = wu_arch_case01_scalar_spawn
|
||||||
|
|
||||||
|
include ../case.mk
|
||||||
16
kernels/wu_arch_cases/case01_scalar_spawn/README.md
Normal file
16
kernels/wu_arch_cases/case01_scalar_spawn/README.md
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
# Case 01: Scalar Spawn
|
||||||
|
|
||||||
|
## Test Objective
|
||||||
|
|
||||||
|
Verify scalar-domain warp spawning without tensor warps, barriers, or shared tensor resources.
|
||||||
|
|
||||||
|
## RTL Surface Covered
|
||||||
|
|
||||||
|
- Scalar scheduler output
|
||||||
|
- Scalar fetch/decode/issue path
|
||||||
|
- Scalar `WSPAWN` mask path
|
||||||
|
- Scalar store path for per-warp completion markers
|
||||||
|
|
||||||
|
## Expected Result
|
||||||
|
|
||||||
|
Every scalar warp writes `WU_CASE_SCALAR_BASE | wid` to `g_seen[wid]`, and warp 0 writes `WU_CASE_PASS` to `g_status[0]`.
|
||||||
54
kernels/wu_arch_cases/case01_scalar_spawn/kernel.cpp
Normal file
54
kernels/wu_arch_cases/case01_scalar_spawn/kernel.cpp
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
#define WU_START_BRANCH_TO_MAIN 1
|
||||||
|
#include "common_wu_min.h"
|
||||||
|
|
||||||
|
extern "C" void scalar_worker();
|
||||||
|
|
||||||
|
extern "C" int wu_main() {
|
||||||
|
if (!wu_is_leader()) {
|
||||||
|
wu_stop_warp();
|
||||||
|
}
|
||||||
|
|
||||||
|
wu_case_reset();
|
||||||
|
|
||||||
|
const uint32_t spawn_mask = wu_scalar_mask_without_warp0();
|
||||||
|
if (spawn_mask != 0) {
|
||||||
|
vx_spawn_scalar(spawn_mask, scalar_worker);
|
||||||
|
}
|
||||||
|
|
||||||
|
wu_mark_seen(WU_CASE_SCALAR_BASE);
|
||||||
|
if (wu_wait_seen_range(0, NUM_SCALAR_WARPS, WU_CASE_SCALAR_BASE) != 0) {
|
||||||
|
wu_case_fail(0x01u);
|
||||||
|
wu_stop_warp();
|
||||||
|
}
|
||||||
|
|
||||||
|
wu_case_pass();
|
||||||
|
wu_stop_warp();
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" void scalar_worker_body();
|
||||||
|
|
||||||
|
extern "C" void __attribute__((naked, used)) scalar_worker() {
|
||||||
|
asm volatile(
|
||||||
|
".option push\n\t"
|
||||||
|
".option norelax\n\t"
|
||||||
|
"la gp, __global_pointer\n\t"
|
||||||
|
".option pop\n\t"
|
||||||
|
"li sp, %[stack_base]\n\t"
|
||||||
|
"csrr t0, %[csr_hart]\n\t"
|
||||||
|
"slli t1, t0, %[stack_log2]\n\t"
|
||||||
|
"slli t2, t0, 4\n\t"
|
||||||
|
"add t1, t1, t2\n\t"
|
||||||
|
"sub sp, sp, t1\n\t"
|
||||||
|
"j scalar_worker_body\n\t"
|
||||||
|
:
|
||||||
|
: [csr_hart] "i"(VX_CSR_MHARTID),
|
||||||
|
[stack_base] "i"(STACK_BASE_ADDR),
|
||||||
|
[stack_log2] "i"(STACK_LOG2_SIZE)
|
||||||
|
: "memory");
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" void scalar_worker_body() {
|
||||||
|
wu_short_delay(wu_wid());
|
||||||
|
wu_mark_seen(WU_CASE_SCALAR_BASE);
|
||||||
|
wu_stop_warp();
|
||||||
|
}
|
||||||
3
kernels/wu_arch_cases/case02_tensor_spawn_stop/Makefile
Normal file
3
kernels/wu_arch_cases/case02_tensor_spawn_stop/Makefile
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
PROJECT = wu_arch_case02_tensor_spawn_stop
|
||||||
|
|
||||||
|
include ../case.mk
|
||||||
16
kernels/wu_arch_cases/case02_tensor_spawn_stop/README.md
Normal file
16
kernels/wu_arch_cases/case02_tensor_spawn_stop/README.md
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
# Case 02: Tensor Spawn Stop
|
||||||
|
|
||||||
|
## Test Objective
|
||||||
|
|
||||||
|
Verify tensor warps can be spawned, scheduled, issued, write a completion marker, and stop without barriers or tensor LSU stress.
|
||||||
|
|
||||||
|
## RTL Surface Covered
|
||||||
|
|
||||||
|
- Tensor scheduler output
|
||||||
|
- Tensor fetch/decode/issue path
|
||||||
|
- Tensor ALU/store for a minimal marker
|
||||||
|
- Tensor-domain stop via `TMC zero`
|
||||||
|
|
||||||
|
## Expected Result
|
||||||
|
|
||||||
|
Every tensor warp writes `WU_CASE_TENSOR_BASE | wid` to `g_seen[wid]`, and warp 0 writes `WU_CASE_PASS` to `g_status[0]`.
|
||||||
37
kernels/wu_arch_cases/case02_tensor_spawn_stop/kernel.cpp
Normal file
37
kernels/wu_arch_cases/case02_tensor_spawn_stop/kernel.cpp
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
#define WU_CASE_WAIT_SPIN 1024u
|
||||||
|
#include "common_wu_min.h"
|
||||||
|
|
||||||
|
extern "C" void __attribute__((naked, noinline, used)) tensor_worker() {
|
||||||
|
asm volatile(
|
||||||
|
"csrr x5, %[csr_wid]\n\t"
|
||||||
|
"slli x6, x5, 2\n\t"
|
||||||
|
"la x7, g_seen\n\t"
|
||||||
|
"add x7, x7, x6\n\t"
|
||||||
|
"li x6, %[tensor_base]\n\t"
|
||||||
|
"or x6, x6, x5\n\t"
|
||||||
|
"sw x6, 0(x7)\n\t"
|
||||||
|
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
|
||||||
|
"1: j 1b\n\t"
|
||||||
|
:
|
||||||
|
: [csr_wid] "i"(VX_CSR_WARP_ID),
|
||||||
|
[custom0] "i"(RISCV_CUSTOM0),
|
||||||
|
[tensor_base] "i"(WU_CASE_TENSOR_BASE)
|
||||||
|
: "memory");
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" int wu_main() {
|
||||||
|
if (!wu_is_leader()) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
wu_case_reset();
|
||||||
|
vx_spawn_tensor(vx_tensor_warp_mask(), tensor_worker);
|
||||||
|
|
||||||
|
if (wu_wait_seen_range(NUM_SCALAR_WARPS, NUM_WARPS, WU_CASE_TENSOR_BASE) != 0) {
|
||||||
|
wu_case_fail(0x02u);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
wu_case_pass();
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
3
kernels/wu_arch_cases/case03_dual_fetch_issue/Makefile
Normal file
3
kernels/wu_arch_cases/case03_dual_fetch_issue/Makefile
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
PROJECT = wu_arch_case03_dual_fetch_issue
|
||||||
|
|
||||||
|
include ../case.mk
|
||||||
16
kernels/wu_arch_cases/case03_dual_fetch_issue/README.md
Normal file
16
kernels/wu_arch_cases/case03_dual_fetch_issue/README.md
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
# Case 03: Dual Fetch Issue
|
||||||
|
|
||||||
|
## Test Objective
|
||||||
|
|
||||||
|
Verify scalar and tensor domains can be active together and both make forward progress through the shared fetch path and split issue paths.
|
||||||
|
|
||||||
|
## RTL Surface Covered
|
||||||
|
|
||||||
|
- Shared fetch arbitration between scalar and tensor schedule streams
|
||||||
|
- Scalar decode/issue domain
|
||||||
|
- Tensor decode/issue domain
|
||||||
|
- Completion markers from both warp classes
|
||||||
|
|
||||||
|
## Expected Result
|
||||||
|
|
||||||
|
All scalar warps write `WU_CASE_SCALAR_BASE | wid`, all tensor warps write `WU_CASE_TENSOR_BASE | wid`, and warp 0 writes `WU_CASE_PASS`.
|
||||||
59
kernels/wu_arch_cases/case03_dual_fetch_issue/kernel.cpp
Normal file
59
kernels/wu_arch_cases/case03_dual_fetch_issue/kernel.cpp
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
#include "common_wu_min.h"
|
||||||
|
|
||||||
|
extern "C" void scalar_worker() {
|
||||||
|
wu_short_delay(wu_wid());
|
||||||
|
wu_mark_seen(WU_CASE_SCALAR_BASE);
|
||||||
|
wu_stop_warp();
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" void __attribute__((naked, noinline, used)) tensor_worker() {
|
||||||
|
asm volatile(
|
||||||
|
"csrr x5, %[csr_wid]\n\t"
|
||||||
|
"li x6, %[spin]\n\t"
|
||||||
|
"1:\n\t"
|
||||||
|
"addi x6, x6, -1\n\t"
|
||||||
|
"bnez x6, 1b\n\t"
|
||||||
|
"slli x6, x5, 2\n\t"
|
||||||
|
"la x7, g_seen\n\t"
|
||||||
|
"add x7, x7, x6\n\t"
|
||||||
|
"li x6, %[tensor_base]\n\t"
|
||||||
|
"or x6, x6, x5\n\t"
|
||||||
|
"sw x6, 0(x7)\n\t"
|
||||||
|
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
|
||||||
|
"2: j 2b\n\t"
|
||||||
|
:
|
||||||
|
: [csr_wid] "i"(VX_CSR_WARP_ID),
|
||||||
|
[custom0] "i"(RISCV_CUSTOM0),
|
||||||
|
[spin] "i"(WU_CASE_SHORT_SPIN),
|
||||||
|
[tensor_base] "i"(WU_CASE_TENSOR_BASE)
|
||||||
|
: "memory");
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" int wu_main() {
|
||||||
|
if (!wu_is_leader()) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
wu_case_reset();
|
||||||
|
|
||||||
|
const uint32_t scalar_mask = wu_scalar_mask_without_warp0();
|
||||||
|
if (scalar_mask != 0) {
|
||||||
|
vx_spawn_scalar(scalar_mask, scalar_worker);
|
||||||
|
}
|
||||||
|
vx_spawn_tensor(vx_tensor_warp_mask(), tensor_worker);
|
||||||
|
|
||||||
|
wu_short_delay(0);
|
||||||
|
wu_mark_seen(WU_CASE_SCALAR_BASE);
|
||||||
|
|
||||||
|
if (wu_wait_seen_range(0, NUM_SCALAR_WARPS, WU_CASE_SCALAR_BASE) != 0) {
|
||||||
|
wu_case_fail(0x31u);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
if (wu_wait_seen_range(NUM_SCALAR_WARPS, NUM_WARPS, WU_CASE_TENSOR_BASE) != 0) {
|
||||||
|
wu_case_fail(0x32u);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
wu_case_pass();
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
3
kernels/wu_arch_cases/case04_scalar_barrier/Makefile
Normal file
3
kernels/wu_arch_cases/case04_scalar_barrier/Makefile
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
PROJECT = wu_arch_case04_scalar_barrier
|
||||||
|
|
||||||
|
include ../case.mk
|
||||||
15
kernels/wu_arch_cases/case04_scalar_barrier/README.md
Normal file
15
kernels/wu_arch_cases/case04_scalar_barrier/README.md
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
# Case 04: Scalar Barrier
|
||||||
|
|
||||||
|
## Test Objective
|
||||||
|
|
||||||
|
Verify scalar-domain `BAR` synchronizes only scalar warps and releases them correctly.
|
||||||
|
|
||||||
|
## RTL Surface Covered
|
||||||
|
|
||||||
|
- Scalar WCTL barrier path
|
||||||
|
- Scheduler scalar barrier mask handling
|
||||||
|
- Scalar wakeup after barrier release
|
||||||
|
|
||||||
|
## Expected Result
|
||||||
|
|
||||||
|
All scalar warps pass `vx_barrier_scalar`, write `WU_CASE_SCALAR_BASE | wid`, and warp 0 writes `WU_CASE_PASS`.
|
||||||
33
kernels/wu_arch_cases/case04_scalar_barrier/kernel.cpp
Normal file
33
kernels/wu_arch_cases/case04_scalar_barrier/kernel.cpp
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
#include "common_wu_min.h"
|
||||||
|
|
||||||
|
#define CASE04_BARRIER_ID 0u
|
||||||
|
|
||||||
|
extern "C" void scalar_worker() {
|
||||||
|
vx_barrier_scalar(CASE04_BARRIER_ID, NUM_SCALAR_WARPS);
|
||||||
|
wu_mark_seen(WU_CASE_SCALAR_BASE);
|
||||||
|
wu_stop_warp();
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" int wu_main() {
|
||||||
|
if (!wu_is_leader()) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
wu_case_reset();
|
||||||
|
|
||||||
|
const uint32_t scalar_mask = wu_scalar_mask_without_warp0();
|
||||||
|
if (scalar_mask != 0) {
|
||||||
|
vx_spawn_scalar(scalar_mask, scalar_worker);
|
||||||
|
}
|
||||||
|
|
||||||
|
vx_barrier_scalar(CASE04_BARRIER_ID, NUM_SCALAR_WARPS);
|
||||||
|
wu_mark_seen(WU_CASE_SCALAR_BASE);
|
||||||
|
|
||||||
|
if (wu_wait_seen_range(0, NUM_SCALAR_WARPS, WU_CASE_SCALAR_BASE) != 0) {
|
||||||
|
wu_case_fail(0x04u);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
wu_case_pass();
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
3
kernels/wu_arch_cases/case05_tensor_barrier/Makefile
Normal file
3
kernels/wu_arch_cases/case05_tensor_barrier/Makefile
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
PROJECT = wu_arch_case05_tensor_barrier
|
||||||
|
|
||||||
|
include ../case.mk
|
||||||
15
kernels/wu_arch_cases/case05_tensor_barrier/README.md
Normal file
15
kernels/wu_arch_cases/case05_tensor_barrier/README.md
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
# Case 05: Tensor Barrier
|
||||||
|
|
||||||
|
## Test Objective
|
||||||
|
|
||||||
|
Verify tensor-domain `BAR` is handled by tensor control and releases tensor warps without relying on scalar SFU dispatch.
|
||||||
|
|
||||||
|
## RTL Surface Covered
|
||||||
|
|
||||||
|
- Tensor control barrier decode
|
||||||
|
- Tensor warp-control merge into scheduler
|
||||||
|
- Scheduler tensor-domain barrier mask handling
|
||||||
|
|
||||||
|
## Expected Result
|
||||||
|
|
||||||
|
All tensor warps pass the tensor-domain barrier, write `WU_CASE_TENSOR_BASE | wid`, and warp 0 writes `WU_CASE_PASS`.
|
||||||
45
kernels/wu_arch_cases/case05_tensor_barrier/kernel.cpp
Normal file
45
kernels/wu_arch_cases/case05_tensor_barrier/kernel.cpp
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
#include "common_wu_min.h"
|
||||||
|
|
||||||
|
#define CASE05_BARRIER_ID 1u
|
||||||
|
|
||||||
|
extern "C" void __attribute__((naked, noinline, used)) tensor_worker() {
|
||||||
|
asm volatile(
|
||||||
|
"csrr x5, %[csr_wid]\n\t"
|
||||||
|
"li x1, (%[bar_id] | (%[domain_tensor] << %[domain_shift]))\n\t"
|
||||||
|
"li x2, %[num_tensor]\n\t"
|
||||||
|
".insn r %[custom0], 4, 0, x0, x1, x2\n\t"
|
||||||
|
"slli x6, x5, 2\n\t"
|
||||||
|
"la x7, g_seen\n\t"
|
||||||
|
"add x7, x7, x6\n\t"
|
||||||
|
"li x6, %[tensor_base]\n\t"
|
||||||
|
"or x6, x6, x5\n\t"
|
||||||
|
"sw x6, 0(x7)\n\t"
|
||||||
|
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
|
||||||
|
"1: j 1b\n\t"
|
||||||
|
:
|
||||||
|
: [csr_wid] "i"(VX_CSR_WARP_ID),
|
||||||
|
[custom0] "i"(RISCV_CUSTOM0),
|
||||||
|
[bar_id] "i"(CASE05_BARRIER_ID),
|
||||||
|
[domain_tensor] "i"(VX_BARRIER_DOMAIN_TENSOR),
|
||||||
|
[domain_shift] "i"(VX_BARRIER_DOMAIN_SHIFT),
|
||||||
|
[num_tensor] "i"(NUM_TENSOR_WARPS),
|
||||||
|
[tensor_base] "i"(WU_CASE_TENSOR_BASE)
|
||||||
|
: "memory");
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" int wu_main() {
|
||||||
|
if (!wu_is_leader()) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
wu_case_reset();
|
||||||
|
vx_spawn_tensor(vx_tensor_warp_mask(), tensor_worker);
|
||||||
|
|
||||||
|
if (wu_wait_seen_range(NUM_SCALAR_WARPS, NUM_WARPS, WU_CASE_TENSOR_BASE) != 0) {
|
||||||
|
wu_case_fail(0x05u);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
wu_case_pass();
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
3
kernels/wu_arch_cases/case06_masked_barrier/Makefile
Normal file
3
kernels/wu_arch_cases/case06_masked_barrier/Makefile
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
PROJECT = wu_arch_case06_masked_barrier
|
||||||
|
|
||||||
|
include ../case.mk
|
||||||
16
kernels/wu_arch_cases/case06_masked_barrier/README.md
Normal file
16
kernels/wu_arch_cases/case06_masked_barrier/README.md
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
# Case 06: Masked Barrier
|
||||||
|
|
||||||
|
## Test Objective
|
||||||
|
|
||||||
|
Verify `BAR_MASK` can synchronize an explicit mixed mask containing scalar warp 0 and all tensor warps.
|
||||||
|
|
||||||
|
## RTL Surface Covered
|
||||||
|
|
||||||
|
- Scalar-side masked barrier issue
|
||||||
|
- Tensor-side masked barrier issue through tensor control
|
||||||
|
- Scheduler explicit barrier mask release
|
||||||
|
- Scalar/tensor warp-control merge when both domains participate in one barrier
|
||||||
|
|
||||||
|
## Expected Result
|
||||||
|
|
||||||
|
Scalar warp 0 and all tensor warps pass the same `BAR_MASK`; tensor warps write `WU_CASE_TENSOR_BASE | wid`, and warp 0 writes `WU_CASE_PASS`.
|
||||||
47
kernels/wu_arch_cases/case06_masked_barrier/kernel.cpp
Normal file
47
kernels/wu_arch_cases/case06_masked_barrier/kernel.cpp
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
#include "common_wu_min.h"
|
||||||
|
|
||||||
|
#define CASE06_BARRIER_ID 2u
|
||||||
|
#define CASE06_BARRIER_MASK (vx_tensor_warp_mask() | 1u)
|
||||||
|
|
||||||
|
extern "C" void __attribute__((naked, noinline, used)) tensor_worker() {
|
||||||
|
asm volatile(
|
||||||
|
"csrr x5, %[csr_wid]\n\t"
|
||||||
|
"li x1, %[bar_id]\n\t"
|
||||||
|
"li x2, %[mask]\n\t"
|
||||||
|
".insn r %[custom0], 7, 0, x0, x1, x2\n\t"
|
||||||
|
"slli x6, x5, 2\n\t"
|
||||||
|
"la x7, g_seen\n\t"
|
||||||
|
"add x7, x7, x6\n\t"
|
||||||
|
"li x6, %[tensor_base]\n\t"
|
||||||
|
"or x6, x6, x5\n\t"
|
||||||
|
"sw x6, 0(x7)\n\t"
|
||||||
|
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
|
||||||
|
"1: j 1b\n\t"
|
||||||
|
:
|
||||||
|
: [csr_wid] "i"(VX_CSR_WARP_ID),
|
||||||
|
[custom0] "i"(RISCV_CUSTOM0),
|
||||||
|
[bar_id] "i"(CASE06_BARRIER_ID),
|
||||||
|
[mask] "i"(((1u << NUM_TENSOR_WARPS) - 1u) << NUM_SCALAR_WARPS | 1u),
|
||||||
|
[tensor_base] "i"(WU_CASE_TENSOR_BASE)
|
||||||
|
: "memory");
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" int wu_main() {
|
||||||
|
if (!wu_is_leader()) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
wu_case_reset();
|
||||||
|
vx_spawn_tensor(vx_tensor_warp_mask(), tensor_worker);
|
||||||
|
|
||||||
|
vx_barrier_mask(CASE06_BARRIER_ID, CASE06_BARRIER_MASK);
|
||||||
|
wu_mark_seen(WU_CASE_SCALAR_BASE);
|
||||||
|
|
||||||
|
if (wu_wait_seen_mask(vx_tensor_warp_mask(), WU_CASE_TENSOR_BASE) != 0) {
|
||||||
|
wu_case_fail(0x06u);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
wu_case_pass();
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
3
kernels/wu_arch_cases/case07_tensor_csr_tmc/Makefile
Normal file
3
kernels/wu_arch_cases/case07_tensor_csr_tmc/Makefile
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
PROJECT = wu_arch_case07_tensor_csr_tmc
|
||||||
|
|
||||||
|
include ../case.mk
|
||||||
15
kernels/wu_arch_cases/case07_tensor_csr_tmc/README.md
Normal file
15
kernels/wu_arch_cases/case07_tensor_csr_tmc/README.md
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
# Case 07: Tensor CSR TMC
|
||||||
|
|
||||||
|
## Test Objective
|
||||||
|
|
||||||
|
Verify tensor control handles legal tensor-domain CSR reads and `TMC` operations without involving barrier behavior.
|
||||||
|
|
||||||
|
## RTL Surface Covered
|
||||||
|
|
||||||
|
- Tensor CSRRS path for `VX_CSR_WARP_ID`
|
||||||
|
- Tensor TMC path setting a single active lane
|
||||||
|
- Tensor control completion and tensor-domain stop
|
||||||
|
|
||||||
|
## Expected Result
|
||||||
|
|
||||||
|
Every tensor warp writes `WU_CASE_TENSOR_CSR_BASE | wid` to `g_seen[wid]`, and warp 0 writes `WU_CASE_PASS`.
|
||||||
38
kernels/wu_arch_cases/case07_tensor_csr_tmc/kernel.cpp
Normal file
38
kernels/wu_arch_cases/case07_tensor_csr_tmc/kernel.cpp
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
#include "common_wu_min.h"
|
||||||
|
|
||||||
|
extern "C" void __attribute__((naked, noinline, used)) tensor_worker() {
|
||||||
|
asm volatile(
|
||||||
|
"csrr x5, %[csr_wid]\n\t"
|
||||||
|
"li x6, 1\n\t"
|
||||||
|
".insn r %[custom0], 0, 0, x0, x6, x0\n\t"
|
||||||
|
"slli x6, x5, 2\n\t"
|
||||||
|
"la x7, g_seen\n\t"
|
||||||
|
"add x7, x7, x6\n\t"
|
||||||
|
"li x6, %[tensor_csr_base]\n\t"
|
||||||
|
"or x6, x6, x5\n\t"
|
||||||
|
"sw x6, 0(x7)\n\t"
|
||||||
|
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
|
||||||
|
"1: j 1b\n\t"
|
||||||
|
:
|
||||||
|
: [csr_wid] "i"(VX_CSR_WARP_ID),
|
||||||
|
[custom0] "i"(RISCV_CUSTOM0),
|
||||||
|
[tensor_csr_base] "i"(WU_CASE_TENSOR_CSR_BASE)
|
||||||
|
: "memory");
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" int wu_main() {
|
||||||
|
if (!wu_is_leader()) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
wu_case_reset();
|
||||||
|
vx_spawn_tensor(vx_tensor_warp_mask(), tensor_worker);
|
||||||
|
|
||||||
|
if (wu_wait_seen_range(NUM_SCALAR_WARPS, NUM_WARPS, WU_CASE_TENSOR_CSR_BASE) != 0) {
|
||||||
|
wu_case_fail(0x07u);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
wu_case_pass();
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
PROJECT = wu_arch_case08_tensor_lsu_optional
|
||||||
|
|
||||||
|
include ../case.mk
|
||||||
15
kernels/wu_arch_cases/case08_tensor_lsu_optional/README.md
Normal file
15
kernels/wu_arch_cases/case08_tensor_lsu_optional/README.md
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
# Case 08: Tensor LSU Optional
|
||||||
|
|
||||||
|
## Test Objective
|
||||||
|
|
||||||
|
Verify tensor-domain LSU can store and reload a small per-warp value. This is intentionally last because memory hierarchy interaction is slower and has a larger debug surface.
|
||||||
|
|
||||||
|
## RTL Surface Covered
|
||||||
|
|
||||||
|
- Tensor LSU dispatch
|
||||||
|
- Tensor LSU response/writeback path
|
||||||
|
- Shared memory hierarchy merge after tensor-domain issue
|
||||||
|
|
||||||
|
## Expected Result
|
||||||
|
|
||||||
|
Every tensor warp stores and reloads `WU_CASE_TENSOR_LSU_BASE | wid`, writes that value to `g_seen[wid]`, and warp 0 writes `WU_CASE_PASS`.
|
||||||
42
kernels/wu_arch_cases/case08_tensor_lsu_optional/kernel.cpp
Normal file
42
kernels/wu_arch_cases/case08_tensor_lsu_optional/kernel.cpp
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
#include "common_wu_min.h"
|
||||||
|
|
||||||
|
extern "C" void __attribute__((naked, noinline, used)) tensor_worker() {
|
||||||
|
asm volatile(
|
||||||
|
"csrr x5, %[csr_wid]\n\t"
|
||||||
|
"slli x6, x5, 2\n\t"
|
||||||
|
"la x7, g_case_mem\n\t"
|
||||||
|
"add x7, x7, x6\n\t"
|
||||||
|
"li x6, %[tensor_lsu_base]\n\t"
|
||||||
|
"or x5, x6, x5\n\t"
|
||||||
|
"sw x5, 0(x7)\n\t"
|
||||||
|
"lw x5, 0(x7)\n\t"
|
||||||
|
"sub x6, x5, x6\n\t"
|
||||||
|
"slli x6, x6, 2\n\t"
|
||||||
|
"la x7, g_seen\n\t"
|
||||||
|
"add x7, x7, x6\n\t"
|
||||||
|
"sw x5, 0(x7)\n\t"
|
||||||
|
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
|
||||||
|
"1: j 1b\n\t"
|
||||||
|
:
|
||||||
|
: [csr_wid] "i"(VX_CSR_WARP_ID),
|
||||||
|
[custom0] "i"(RISCV_CUSTOM0),
|
||||||
|
[tensor_lsu_base] "i"(WU_CASE_TENSOR_LSU_BASE)
|
||||||
|
: "memory");
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" int wu_main() {
|
||||||
|
if (!wu_is_leader()) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
wu_case_reset();
|
||||||
|
vx_spawn_tensor(vx_tensor_warp_mask(), tensor_worker);
|
||||||
|
|
||||||
|
if (wu_wait_seen_range(NUM_SCALAR_WARPS, NUM_WARPS, WU_CASE_TENSOR_LSU_BASE) != 0) {
|
||||||
|
wu_case_fail(0x08u);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
wu_case_pass();
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
3
kernels/wu_arch_cases/case09_scalar_tmem_ldst/Makefile
Normal file
3
kernels/wu_arch_cases/case09_scalar_tmem_ldst/Makefile
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
PROJECT = wu_arch_case09_scalar_tmem_ldst
|
||||||
|
|
||||||
|
include ../case.mk
|
||||||
7
kernels/wu_arch_cases/case09_scalar_tmem_ldst/README.md
Normal file
7
kernels/wu_arch_cases/case09_scalar_tmem_ldst/README.md
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
# case09_scalar_tmem_ldst
|
||||||
|
|
||||||
|
Smoke test for scalar warp direct TMEM access.
|
||||||
|
|
||||||
|
The leader scalar lane writes a lane-distinct word vector to a TMEM row with
|
||||||
|
`CUSTOM1 func7=0x30 func3=1`, reads it back with `CUSTOM1 func7=0x30 func3=0`,
|
||||||
|
and reports pass only if the lane 0 value matches.
|
||||||
39
kernels/wu_arch_cases/case09_scalar_tmem_ldst/kernel.cpp
Normal file
39
kernels/wu_arch_cases/case09_scalar_tmem_ldst/kernel.cpp
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
#include "common_wu_min.h"
|
||||||
|
|
||||||
|
static inline uint32_t wu_scalar_tmem_ld(uint32_t addr) {
|
||||||
|
uint32_t value;
|
||||||
|
asm volatile(".insn r %[custom1], 0, 0x30, %[value], %[addr], x0"
|
||||||
|
: [value] "=r"(value)
|
||||||
|
: [custom1] "i"(RISCV_CUSTOM1), [addr] "r"(addr)
|
||||||
|
: "memory");
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline void wu_scalar_tmem_st(uint32_t addr, uint32_t value) {
|
||||||
|
asm volatile(".insn r %[custom1], 1, 0x30, x0, %[addr], %[value]"
|
||||||
|
:
|
||||||
|
: [custom1] "i"(RISCV_CUSTOM1), [addr] "r"(addr), [value] "r"(value)
|
||||||
|
: "memory");
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" int wu_main() {
|
||||||
|
if (!wu_is_leader()) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
wu_case_reset();
|
||||||
|
|
||||||
|
const uint32_t addr = 0x24u;
|
||||||
|
const uint32_t expected = 0x5a090000u | wu_tid();
|
||||||
|
wu_scalar_tmem_st(addr, expected);
|
||||||
|
const uint32_t observed = wu_scalar_tmem_ld(addr);
|
||||||
|
|
||||||
|
if (observed != expected) {
|
||||||
|
g_aux[0] = observed;
|
||||||
|
wu_case_fail(0x09u);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
wu_case_pass();
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
PROJECT = wu_arch_case10_tensor_scalar_tmem_handoff
|
||||||
|
|
||||||
|
include ../case.mk
|
||||||
@@ -0,0 +1,8 @@
|
|||||||
|
# case10_tensor_scalar_tmem_handoff
|
||||||
|
|
||||||
|
Validates the tensor-to-scalar TMEM handoff needed by FlashAttention.
|
||||||
|
|
||||||
|
The tensor warp initializes TMEM A/C, runs one BWGMMA, and leaves the result in
|
||||||
|
TMEM C. The scalar leader waits for the tensor warp and then reads several C
|
||||||
|
fragments through the scalar TMEM load instruction. Each lane is expected to
|
||||||
|
observe the HGEMM value `0x42820000`.
|
||||||
@@ -0,0 +1,111 @@
|
|||||||
|
#include "../common_wu_min.h"
|
||||||
|
|
||||||
|
#define DEV_SMEM_START_ADDR 0xff000000u
|
||||||
|
#define WU_CASE_TMEM_HANDOFF_BASE 0x7700u
|
||||||
|
#define WU_TMEM_TILE_BYTES 1024u
|
||||||
|
#define WU_TMEM_FRAGMENT_BYTES 16u
|
||||||
|
#define WU_TMEM_C_BYTE_BASE 1024u
|
||||||
|
#define WU_TMEM_EXPECTED_HGEMM 0x42820000u
|
||||||
|
|
||||||
|
static_assert(NUM_TENSOR_WARPS >= 1, "case10 requires at least one tensor warp");
|
||||||
|
static_assert(NUM_THREADS == 4, "case10 expects the 4-lane Blackwell tensor core");
|
||||||
|
|
||||||
|
#define BW_REP2(x) x, x
|
||||||
|
#define BW_REP4(x) BW_REP2(x), BW_REP2(x)
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
volatile uint32_t g_case10_a_row[4] __attribute__((aligned(16))) = {
|
||||||
|
BW_REP4(0x3c003c00u)};
|
||||||
|
volatile uint32_t g_case10_b_row[4] __attribute__((aligned(16))) = {
|
||||||
|
BW_REP4(0x40004000u)};
|
||||||
|
volatile uint32_t g_case10_c_row[4] __attribute__((aligned(16))) = {
|
||||||
|
BW_REP4(0x3f800000u)};
|
||||||
|
}
|
||||||
|
|
||||||
|
#undef BW_REP2
|
||||||
|
#undef BW_REP4
|
||||||
|
|
||||||
|
static inline uint32_t wu_scalar_tmem_ld(uint32_t frag_addr) {
|
||||||
|
uint32_t value;
|
||||||
|
asm volatile(".insn r %[custom1], 0, 0x30, %[value], %[addr], x0"
|
||||||
|
: [value] "=r"(value)
|
||||||
|
: [custom1] "i"(RISCV_CUSTOM1), [addr] "r"(frag_addr)
|
||||||
|
: "memory");
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" void __attribute__((naked, noinline, used)) tensor_case10_worker() {
|
||||||
|
asm volatile(
|
||||||
|
"csrr x5, %[csr_wid]\n\t"
|
||||||
|
"addi x1, x5, -%[num_scalar_warps]\n\t"
|
||||||
|
"slli x1, x1, 11\n\t"
|
||||||
|
"addi x2, x1, 1024\n\t"
|
||||||
|
"la x6, g_case10_a_row\n\t"
|
||||||
|
"la x3, g_case10_c_row\n\t"
|
||||||
|
"li x7, 0\n\t"
|
||||||
|
"1:\n\t"
|
||||||
|
"add x4, x1, x7\n\t"
|
||||||
|
".insn r %[custom3], 2, 0, x0, x4, x6\n\t"
|
||||||
|
"add x4, x2, x7\n\t"
|
||||||
|
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
|
||||||
|
"addi x7, x7, 16\n\t"
|
||||||
|
"li x4, %[tile_bytes]\n\t"
|
||||||
|
"blt x7, x4, 1b\n\t"
|
||||||
|
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||||
|
"li x4, %[smem_base]\n\t"
|
||||||
|
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
|
||||||
|
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
|
||||||
|
"slli x6, x5, 2\n\t"
|
||||||
|
"la x7, g_seen\n\t"
|
||||||
|
"add x7, x7, x6\n\t"
|
||||||
|
"li x6, %[handoff_base]\n\t"
|
||||||
|
"or x6, x6, x5\n\t"
|
||||||
|
"sw x6, 0(x7)\n\t"
|
||||||
|
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
|
||||||
|
"2: j 2b\n\t"
|
||||||
|
:
|
||||||
|
: [csr_wid] "i"(VX_CSR_WARP_ID),
|
||||||
|
[custom0] "i"(RISCV_CUSTOM0),
|
||||||
|
[custom3] "i"(RISCV_CUSTOM3),
|
||||||
|
[smem_base] "i"(DEV_SMEM_START_ADDR),
|
||||||
|
[handoff_base] "i"(WU_CASE_TMEM_HANDOFF_BASE),
|
||||||
|
[num_scalar_warps] "i"(NUM_SCALAR_WARPS),
|
||||||
|
[tile_bytes] "i"(WU_TMEM_TILE_BYTES)
|
||||||
|
: "memory");
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" int wu_main() {
|
||||||
|
if (!wu_is_leader()) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
wu_case_reset();
|
||||||
|
|
||||||
|
volatile uint32_t *smem_b =
|
||||||
|
reinterpret_cast<volatile uint32_t *>(DEV_SMEM_START_ADDR);
|
||||||
|
for (uint32_t i = 0; i < WU_TMEM_TILE_BYTES / sizeof(uint32_t); ++i) {
|
||||||
|
smem_b[i] = g_case10_b_row[i & 3u];
|
||||||
|
}
|
||||||
|
asm volatile("fence rw, rw" ::: "memory");
|
||||||
|
|
||||||
|
vx_spawn_tensor(1u << NUM_SCALAR_WARPS, tensor_case10_worker);
|
||||||
|
|
||||||
|
if (wu_wait_seen_mask(1u << NUM_SCALAR_WARPS, WU_CASE_TMEM_HANDOFF_BASE) != 0) {
|
||||||
|
wu_case_fail(0x10u);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint32_t c_frag_base = WU_TMEM_C_BYTE_BASE / WU_TMEM_FRAGMENT_BYTES;
|
||||||
|
for (uint32_t i = 0; i < 8; ++i) {
|
||||||
|
const uint32_t observed = wu_scalar_tmem_ld(c_frag_base + i);
|
||||||
|
if (observed != WU_TMEM_EXPECTED_HGEMM) {
|
||||||
|
g_aux[0] = i;
|
||||||
|
g_aux[1] = observed;
|
||||||
|
wu_case_fail(0x11u);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
wu_case_pass();
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
PROJECT = wu_arch_case11_scalar_tmem_softmax_stage
|
||||||
|
|
||||||
|
include ../case.mk
|
||||||
@@ -0,0 +1,10 @@
|
|||||||
|
# case11_scalar_tmem_softmax_stage
|
||||||
|
|
||||||
|
Validates the FlashAttention softmax-stage TMEM path.
|
||||||
|
|
||||||
|
The tensor warp writes HGEMM results into TMEM C. The scalar warp reads TMEM C,
|
||||||
|
performs a deterministic lane-wise transform, and writes all four lanes back
|
||||||
|
into one TMEM A-region fragment. The tensor warp then uses TCGEN05_CB to copy
|
||||||
|
that TMEM fragment to global memory so scalar code can verify that tensor-side
|
||||||
|
TMEM reads observe the scalar writeback with the correct lane ordering and
|
||||||
|
write mask.
|
||||||
@@ -0,0 +1,200 @@
|
|||||||
|
#include "../common_wu_min.h"
|
||||||
|
|
||||||
|
#define DEV_SMEM_START_ADDR 0xff000000u
|
||||||
|
#define WU_CASE_TMEM_STAGE_HGEMM_BASE 0x7800u
|
||||||
|
#define WU_CASE_TMEM_STAGE_VERIFY_BASE 0x7900u
|
||||||
|
#define WU_TMEM_TILE_BYTES 1024u
|
||||||
|
#define WU_TMEM_FRAGMENT_BYTES 16u
|
||||||
|
#define WU_TMEM_C_BYTE_BASE 1024u
|
||||||
|
#define WU_TMEM_SCALAR_WRITE_BYTE_ADDR 512u
|
||||||
|
#define WU_TMEM_EXPECTED_HGEMM 0x42820000u
|
||||||
|
#define WU_TMEM_STAGE_VALUE_BASE 0x42820001u
|
||||||
|
|
||||||
|
static_assert(NUM_TENSOR_WARPS >= 1, "case11 requires at least one tensor warp");
|
||||||
|
static_assert(NUM_THREADS == 4, "case11 expects the 4-lane Blackwell tensor core");
|
||||||
|
|
||||||
|
#define BW_REP2(x) x, x
|
||||||
|
#define BW_REP4(x) BW_REP2(x), BW_REP2(x)
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
volatile uint32_t g_case11_a_row[4] __attribute__((aligned(16))) = {
|
||||||
|
BW_REP4(0x3c003c00u)};
|
||||||
|
volatile uint32_t g_case11_b_row[4] __attribute__((aligned(16))) = {
|
||||||
|
BW_REP4(0x40004000u)};
|
||||||
|
volatile uint32_t g_case11_c_row[4] __attribute__((aligned(16))) = {
|
||||||
|
BW_REP4(0x3f800000u)};
|
||||||
|
volatile uint32_t g_case11_out[4] __attribute__((aligned(16)));
|
||||||
|
volatile uint32_t g_case11_scalar_seen[4] __attribute__((aligned(16)));
|
||||||
|
}
|
||||||
|
|
||||||
|
#undef BW_REP2
|
||||||
|
#undef BW_REP4
|
||||||
|
|
||||||
|
static inline uint32_t wu_scalar_tmem_ld(uint32_t frag_addr) {
|
||||||
|
uint32_t value;
|
||||||
|
asm volatile(".insn r %[custom1], 0, 0x30, %[value], %[addr], x0"
|
||||||
|
: [value] "=r"(value)
|
||||||
|
: [custom1] "i"(RISCV_CUSTOM1), [addr] "r"(frag_addr)
|
||||||
|
: "memory");
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline void wu_scalar_tmem_st(uint32_t frag_addr, uint32_t value) {
|
||||||
|
asm volatile(".insn r %[custom1], 1, 0x30, x0, %[addr], %[value]"
|
||||||
|
:
|
||||||
|
: [custom1] "i"(RISCV_CUSTOM1), [addr] "r"(frag_addr),
|
||||||
|
[value] "r"(value)
|
||||||
|
: "memory");
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" void __attribute__((naked, noinline, used)) tensor_case11_worker() {
|
||||||
|
asm volatile(
|
||||||
|
"csrr x5, %[csr_wid]\n\t"
|
||||||
|
"addi x1, x5, -%[num_scalar_warps]\n\t"
|
||||||
|
"slli x1, x1, 11\n\t"
|
||||||
|
"addi x2, x1, 1024\n\t"
|
||||||
|
"la x6, g_case11_a_row\n\t"
|
||||||
|
"la x3, g_case11_c_row\n\t"
|
||||||
|
"li x7, 0\n\t"
|
||||||
|
"1:\n\t"
|
||||||
|
"add x4, x1, x7\n\t"
|
||||||
|
".insn r %[custom3], 2, 0, x0, x4, x6\n\t"
|
||||||
|
"add x4, x2, x7\n\t"
|
||||||
|
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
|
||||||
|
"addi x7, x7, 16\n\t"
|
||||||
|
"li x4, %[tile_bytes]\n\t"
|
||||||
|
"blt x7, x4, 1b\n\t"
|
||||||
|
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||||
|
"li x4, %[smem_base]\n\t"
|
||||||
|
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
|
||||||
|
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
|
||||||
|
"slli x6, x5, 2\n\t"
|
||||||
|
"la x7, g_seen\n\t"
|
||||||
|
"add x7, x7, x6\n\t"
|
||||||
|
"li x6, %[hgemm_base]\n\t"
|
||||||
|
"or x6, x6, x5\n\t"
|
||||||
|
"sw x6, 0(x7)\n\t"
|
||||||
|
"3:\n\t"
|
||||||
|
"la x6, g_case_mem\n\t"
|
||||||
|
"lw x7, 0(x6)\n\t"
|
||||||
|
"li x4, %[verify_base]\n\t"
|
||||||
|
"bne x7, x4, 3b\n\t"
|
||||||
|
"addi x4, x1, %[write_byte_addr]\n\t"
|
||||||
|
"la x6, g_case11_out\n\t"
|
||||||
|
".insn r %[custom3], 6, 0, x0, x4, x6\n\t"
|
||||||
|
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||||
|
"csrr x5, %[csr_wid]\n\t"
|
||||||
|
"slli x6, x5, 2\n\t"
|
||||||
|
"la x7, g_seen\n\t"
|
||||||
|
"add x7, x7, x6\n\t"
|
||||||
|
"li x6, %[verify_base]\n\t"
|
||||||
|
"or x6, x6, x5\n\t"
|
||||||
|
"sw x6, 0(x7)\n\t"
|
||||||
|
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
|
||||||
|
"2: j 2b\n\t"
|
||||||
|
:
|
||||||
|
: [csr_wid] "i"(VX_CSR_WARP_ID),
|
||||||
|
[custom0] "i"(RISCV_CUSTOM0),
|
||||||
|
[custom3] "i"(RISCV_CUSTOM3),
|
||||||
|
[smem_base] "i"(DEV_SMEM_START_ADDR),
|
||||||
|
[hgemm_base] "i"(WU_CASE_TMEM_STAGE_HGEMM_BASE),
|
||||||
|
[verify_base] "i"(WU_CASE_TMEM_STAGE_VERIFY_BASE),
|
||||||
|
[num_scalar_warps] "i"(NUM_SCALAR_WARPS),
|
||||||
|
[tile_bytes] "i"(WU_TMEM_TILE_BYTES),
|
||||||
|
[write_byte_addr] "i"(WU_TMEM_SCALAR_WRITE_BYTE_ADDR)
|
||||||
|
: "memory");
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" int wu_main() {
|
||||||
|
if (vx_core_id() != 0 || vx_warp_id() != 0) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint32_t tid = wu_tid();
|
||||||
|
|
||||||
|
if (tid == 0) {
|
||||||
|
wu_case_reset();
|
||||||
|
for (uint32_t i = 0; i < 4; ++i) {
|
||||||
|
g_case11_out[i] = 0;
|
||||||
|
g_case11_scalar_seen[i] = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
volatile uint32_t *smem_b =
|
||||||
|
reinterpret_cast<volatile uint32_t *>(DEV_SMEM_START_ADDR);
|
||||||
|
for (uint32_t i = 0; i < WU_TMEM_TILE_BYTES / sizeof(uint32_t); ++i) {
|
||||||
|
smem_b[i] = g_case11_b_row[i & 3u];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
asm volatile("fence rw, rw" ::: "memory");
|
||||||
|
|
||||||
|
if (tid == 0) {
|
||||||
|
vx_spawn_tensor(1u << NUM_SCALAR_WARPS, tensor_case11_worker);
|
||||||
|
|
||||||
|
if (wu_wait_seen_mask(1u << NUM_SCALAR_WARPS,
|
||||||
|
WU_CASE_TMEM_STAGE_HGEMM_BASE) != 0) {
|
||||||
|
g_case_mem[1] = 0x20u;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
asm volatile("fence rw, rw" ::: "memory");
|
||||||
|
if (g_case_mem[1] != 0) {
|
||||||
|
if (tid == 0) {
|
||||||
|
wu_case_fail(g_case_mem[1]);
|
||||||
|
}
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint32_t c_frag_base = WU_TMEM_C_BYTE_BASE / WU_TMEM_FRAGMENT_BYTES;
|
||||||
|
const uint32_t observed = wu_scalar_tmem_ld(c_frag_base);
|
||||||
|
g_case11_scalar_seen[tid] = observed;
|
||||||
|
|
||||||
|
const uint32_t write_frag =
|
||||||
|
WU_TMEM_SCALAR_WRITE_BYTE_ADDR / WU_TMEM_FRAGMENT_BYTES;
|
||||||
|
wu_scalar_tmem_st(write_frag, observed + 1u + tid);
|
||||||
|
asm volatile("fence rw, rw" ::: "memory");
|
||||||
|
|
||||||
|
if (tid == 0) {
|
||||||
|
if (g_case11_scalar_seen[0] != WU_TMEM_EXPECTED_HGEMM) {
|
||||||
|
g_aux[0] = 0;
|
||||||
|
g_aux[1] = g_case11_scalar_seen[0];
|
||||||
|
g_case_mem[1] = 0x21u;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
asm volatile("fence rw, rw" ::: "memory");
|
||||||
|
if (g_case_mem[1] != 0) {
|
||||||
|
if (tid == 0) {
|
||||||
|
wu_case_fail(g_case_mem[1]);
|
||||||
|
}
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (tid == 0) {
|
||||||
|
g_case_mem[0] = WU_CASE_TMEM_STAGE_VERIFY_BASE;
|
||||||
|
}
|
||||||
|
asm volatile("fence rw, rw" ::: "memory");
|
||||||
|
|
||||||
|
if (tid == 0) {
|
||||||
|
if (wu_wait_seen_mask(1u << NUM_SCALAR_WARPS,
|
||||||
|
WU_CASE_TMEM_STAGE_VERIFY_BASE) != 0) {
|
||||||
|
g_case_mem[1] = 0x22u;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
asm volatile("fence rw, rw" ::: "memory");
|
||||||
|
if (g_case_mem[1] != 0) {
|
||||||
|
if (tid == 0) {
|
||||||
|
wu_case_fail(g_case_mem[1]);
|
||||||
|
}
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (tid == 0) {
|
||||||
|
if (g_case11_out[0] != WU_TMEM_STAGE_VALUE_BASE) {
|
||||||
|
g_aux[0] = 0;
|
||||||
|
g_aux[1] = g_case11_out[0];
|
||||||
|
wu_case_fail(0x23u);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
wu_case_pass();
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
PROJECT = case12_flash_pv_accum
|
||||||
|
|
||||||
|
include ../case.mk
|
||||||
@@ -0,0 +1,10 @@
|
|||||||
|
# case12.1 scalar TMEM CB probe
|
||||||
|
|
||||||
|
Validates the scalar-to-tensor TMEM handoff without BWGMMA. Scalar warp 0 fills
|
||||||
|
tensor block 0 TMEM A rows `0..63` with packed fp16 `1.0`. Tensor warp
|
||||||
|
`NUM_SCALAR_WARPS` waits for that fill, then uses `tcgen05_cb` to copy the same
|
||||||
|
TMEM A rows back to global memory.
|
||||||
|
|
||||||
|
The scalar leader verifies all 256 copied words are `0x3c003c00`. A mismatch
|
||||||
|
means scalar TMEM stores are not visible to the tensor TMEM read/copy path at
|
||||||
|
the expected row addresses.
|
||||||
108
kernels/wu_arch_cases/case12_1_scalar_tmem_cb_probe/kernel.cpp
Normal file
108
kernels/wu_arch_cases/case12_1_scalar_tmem_cb_probe/kernel.cpp
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
#include "../common_wu_blackwell_fa.h"
|
||||||
|
|
||||||
|
#define WU_CASE12_1_INIT_BASE 0x8d00u
|
||||||
|
#define WU_CASE12_1_DONE_BASE 0x8e00u
|
||||||
|
#define WU_CASE12_1_COPY_READY 0x8f00u
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
volatile uint32_t g_case12_1_out[WU_BW_OUT_WORDS] __attribute__((aligned(16)));
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" void __attribute__((naked, noinline, used)) tensor_case12_1_worker() {
|
||||||
|
asm volatile(
|
||||||
|
"csrr x5, %[csr_wid]\n\t"
|
||||||
|
"addi x1, x5, -%[num_scalar_warps]\n\t"
|
||||||
|
"slli x1, x1, 11\n\t"
|
||||||
|
"slli x6, x5, 2\n\t"
|
||||||
|
"la x7, g_seen\n\t"
|
||||||
|
"add x7, x7, x6\n\t"
|
||||||
|
"li x6, %[init_base]\n\t"
|
||||||
|
"or x6, x6, x5\n\t"
|
||||||
|
"sw x6, 0(x7)\n\t"
|
||||||
|
"1:\n\t"
|
||||||
|
"la x6, g_case_mem\n\t"
|
||||||
|
"lw x7, 0(x6)\n\t"
|
||||||
|
"li x4, %[copy_ready]\n\t"
|
||||||
|
"bne x7, x4, 1b\n\t"
|
||||||
|
"la x3, g_case12_1_out\n\t"
|
||||||
|
"li x7, 0\n\t"
|
||||||
|
"2:\n\t"
|
||||||
|
"add x4, x1, x7\n\t"
|
||||||
|
"add x6, x3, x7\n\t"
|
||||||
|
".insn r %[custom3], 6, 0, x0, x4, x6\n\t"
|
||||||
|
"addi x7, x7, 16\n\t"
|
||||||
|
"li x4, %[tile_bytes]\n\t"
|
||||||
|
"blt x7, x4, 2b\n\t"
|
||||||
|
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||||
|
"csrr x5, %[csr_wid]\n\t"
|
||||||
|
"slli x6, x5, 2\n\t"
|
||||||
|
"la x7, g_seen\n\t"
|
||||||
|
"add x7, x7, x6\n\t"
|
||||||
|
"li x6, %[done_base]\n\t"
|
||||||
|
"or x6, x6, x5\n\t"
|
||||||
|
"sw x6, 0(x7)\n\t"
|
||||||
|
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
|
||||||
|
"3: j 3b\n\t"
|
||||||
|
:
|
||||||
|
: [csr_wid] "i"(VX_CSR_WARP_ID), [custom0] "i"(RISCV_CUSTOM0),
|
||||||
|
[custom3] "i"(RISCV_CUSTOM3),
|
||||||
|
[num_scalar_warps] "i"(NUM_SCALAR_WARPS),
|
||||||
|
[tile_bytes] "i"(WU_BW_TMEM_TILE_BYTES),
|
||||||
|
[init_base] "i"(WU_CASE12_1_INIT_BASE),
|
||||||
|
[done_base] "i"(WU_CASE12_1_DONE_BASE),
|
||||||
|
[copy_ready] "i"(WU_CASE12_1_COPY_READY)
|
||||||
|
: "memory");
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" int wu_main() {
|
||||||
|
if (vx_core_id() != 0 || vx_warp_id() != 0) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint32_t tid = wu_tid();
|
||||||
|
if (tid == 0) {
|
||||||
|
wu_case_reset();
|
||||||
|
for (uint32_t i = 0; i < WU_BW_OUT_WORDS; ++i) {
|
||||||
|
g_case12_1_out[i] = 0;
|
||||||
|
}
|
||||||
|
vx_spawn_tensor(1u << NUM_SCALAR_WARPS, tensor_case12_1_worker);
|
||||||
|
if (wu_wait_seen_mask(1u << NUM_SCALAR_WARPS, WU_CASE12_1_INIT_BASE) !=
|
||||||
|
0) {
|
||||||
|
g_case_mem[1] = 0x12u;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
asm volatile("fence rw, rw" ::: "memory");
|
||||||
|
|
||||||
|
if (g_case_mem[1] == 0) {
|
||||||
|
vx_tmc(wu_bw_all_lanes_mask());
|
||||||
|
wu_bw_fill_tmem_tile(wu_bw_tmem_a_byte_base(0), WU_BW_FP16_ONE_PACKED);
|
||||||
|
vx_tmc_one();
|
||||||
|
}
|
||||||
|
asm volatile("fence rw, rw" ::: "memory");
|
||||||
|
|
||||||
|
if (tid == 0) {
|
||||||
|
g_case_mem[0] = WU_CASE12_1_COPY_READY;
|
||||||
|
if (g_case_mem[1] == 0 &&
|
||||||
|
wu_wait_seen_mask(1u << NUM_SCALAR_WARPS, WU_CASE12_1_DONE_BASE) !=
|
||||||
|
0) {
|
||||||
|
g_case_mem[1] = 0x13u;
|
||||||
|
}
|
||||||
|
if (g_case_mem[1] == 0) {
|
||||||
|
volatile uint32_t bad_actual = 0;
|
||||||
|
const uint32_t bad =
|
||||||
|
wu_bw_verify_constant(g_case12_1_out, WU_BW_OUT_WORDS,
|
||||||
|
WU_BW_FP16_ONE_PACKED, &bad_actual);
|
||||||
|
if (bad != WU_BW_OUT_WORDS) {
|
||||||
|
g_aux[0] = bad;
|
||||||
|
g_aux[1] = bad_actual;
|
||||||
|
g_case_mem[1] = 0x14u;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (g_case_mem[1] != 0) {
|
||||||
|
wu_case_fail(g_case_mem[1]);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
wu_case_pass();
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
3
kernels/wu_arch_cases/case12_2_flash_pv_p_probe/Makefile
Normal file
3
kernels/wu_arch_cases/case12_2_flash_pv_p_probe/Makefile
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
PROJECT = case12_2_flash_pv_p_probe
|
||||||
|
|
||||||
|
include ../case.mk
|
||||||
11
kernels/wu_arch_cases/case12_2_flash_pv_p_probe/README.md
Normal file
11
kernels/wu_arch_cases/case12_2_flash_pv_p_probe/README.md
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
# case12_2_flash_pv_p_probe
|
||||||
|
|
||||||
|
Diagnostic for `case12_flash_pv_accum`.
|
||||||
|
|
||||||
|
Scalar warp 0 fills tensor block 0 TMEM A with packed fp16 `1.0`, using the
|
||||||
|
same `wu_bw_fill_tmem_tile()` path as case12. Tensor warp `NUM_SCALAR_WARPS`
|
||||||
|
then copies the full TMEM A tile back to global memory with `tcgen05_cb`.
|
||||||
|
|
||||||
|
The scalar leader verifies all copied words are `0x3c003c00`. A mismatch means
|
||||||
|
case12's scalar-written P tile is not reaching the tensor TMEM read/copy path as
|
||||||
|
expected.
|
||||||
109
kernels/wu_arch_cases/case12_2_flash_pv_p_probe/kernel.cpp
Normal file
109
kernels/wu_arch_cases/case12_2_flash_pv_p_probe/kernel.cpp
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
#include "../common_wu_blackwell_fa.h"
|
||||||
|
|
||||||
|
#define WU_CASE12_2_INIT_BASE 0x9600u
|
||||||
|
#define WU_CASE12_2_DONE_BASE 0x9700u
|
||||||
|
#define WU_CASE12_2_COPY_READY 0x9800u
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
volatile uint32_t g_case12_2_out[WU_BW_OUT_WORDS]
|
||||||
|
__attribute__((aligned(16)));
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" void __attribute__((naked, noinline, used))
|
||||||
|
tensor_case12_2_worker() {
|
||||||
|
asm volatile(
|
||||||
|
"csrr x5, %[csr_wid]\n\t"
|
||||||
|
"addi x1, x5, -%[num_scalar_warps]\n\t"
|
||||||
|
"slli x1, x1, 11\n\t"
|
||||||
|
"slli x6, x5, 2\n\t"
|
||||||
|
"la x7, g_seen\n\t"
|
||||||
|
"add x7, x7, x6\n\t"
|
||||||
|
"li x6, %[init_base]\n\t"
|
||||||
|
"or x6, x6, x5\n\t"
|
||||||
|
"sw x6, 0(x7)\n\t"
|
||||||
|
"1:\n\t"
|
||||||
|
"la x6, g_case_mem\n\t"
|
||||||
|
"lw x7, 0(x6)\n\t"
|
||||||
|
"li x4, %[copy_ready]\n\t"
|
||||||
|
"bne x7, x4, 1b\n\t"
|
||||||
|
"la x3, g_case12_2_out\n\t"
|
||||||
|
"li x7, 0\n\t"
|
||||||
|
"2:\n\t"
|
||||||
|
"add x4, x1, x7\n\t"
|
||||||
|
"add x6, x3, x7\n\t"
|
||||||
|
".insn r %[custom3], 6, 0, x0, x4, x6\n\t"
|
||||||
|
"addi x7, x7, 16\n\t"
|
||||||
|
"li x4, %[tile_bytes]\n\t"
|
||||||
|
"blt x7, x4, 2b\n\t"
|
||||||
|
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||||
|
"csrr x5, %[csr_wid]\n\t"
|
||||||
|
"slli x6, x5, 2\n\t"
|
||||||
|
"la x7, g_seen\n\t"
|
||||||
|
"add x7, x7, x6\n\t"
|
||||||
|
"li x6, %[done_base]\n\t"
|
||||||
|
"or x6, x6, x5\n\t"
|
||||||
|
"sw x6, 0(x7)\n\t"
|
||||||
|
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
|
||||||
|
"3: j 3b\n\t"
|
||||||
|
:
|
||||||
|
: [csr_wid] "i"(VX_CSR_WARP_ID), [custom0] "i"(RISCV_CUSTOM0),
|
||||||
|
[custom3] "i"(RISCV_CUSTOM3),
|
||||||
|
[num_scalar_warps] "i"(NUM_SCALAR_WARPS),
|
||||||
|
[tile_bytes] "i"(WU_BW_TMEM_TILE_BYTES),
|
||||||
|
[init_base] "i"(WU_CASE12_2_INIT_BASE),
|
||||||
|
[done_base] "i"(WU_CASE12_2_DONE_BASE),
|
||||||
|
[copy_ready] "i"(WU_CASE12_2_COPY_READY)
|
||||||
|
: "memory");
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" int wu_main() {
|
||||||
|
if (vx_core_id() != 0 || vx_warp_id() != 0) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint32_t tid = wu_tid();
|
||||||
|
const uint32_t tensor_mask = 1u << NUM_SCALAR_WARPS;
|
||||||
|
if (tid == 0) {
|
||||||
|
wu_case_reset();
|
||||||
|
for (uint32_t i = 0; i < WU_BW_OUT_WORDS; ++i) {
|
||||||
|
g_case12_2_out[i] = 0;
|
||||||
|
}
|
||||||
|
vx_spawn_tensor(tensor_mask, tensor_case12_2_worker);
|
||||||
|
if (wu_wait_seen_mask(tensor_mask, WU_CASE12_2_INIT_BASE) != 0) {
|
||||||
|
g_case_mem[1] = 0x51u;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
asm volatile("fence rw, rw" ::: "memory");
|
||||||
|
|
||||||
|
if (g_case_mem[1] == 0) {
|
||||||
|
vx_tmc(wu_bw_all_lanes_mask());
|
||||||
|
wu_bw_fill_tmem_tile(wu_bw_tmem_a_byte_base(0), WU_BW_FP16_ONE_PACKED);
|
||||||
|
vx_tmc_one();
|
||||||
|
}
|
||||||
|
asm volatile("fence rw, rw" ::: "memory");
|
||||||
|
|
||||||
|
if (tid == 0) {
|
||||||
|
g_case_mem[0] = WU_CASE12_2_COPY_READY;
|
||||||
|
if (g_case_mem[1] == 0 &&
|
||||||
|
wu_wait_seen_mask(tensor_mask, WU_CASE12_2_DONE_BASE) != 0) {
|
||||||
|
g_case_mem[1] = 0x52u;
|
||||||
|
}
|
||||||
|
if (g_case_mem[1] == 0) {
|
||||||
|
volatile uint32_t bad_actual = 0;
|
||||||
|
const uint32_t bad =
|
||||||
|
wu_bw_verify_constant(g_case12_2_out, WU_BW_OUT_WORDS,
|
||||||
|
WU_BW_FP16_ONE_PACKED, &bad_actual);
|
||||||
|
if (bad != WU_BW_OUT_WORDS) {
|
||||||
|
g_aux[0] = bad;
|
||||||
|
g_aux[1] = bad_actual;
|
||||||
|
g_case_mem[1] = 0x53u;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (g_case_mem[1] != 0) {
|
||||||
|
wu_case_fail(g_case_mem[1]);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
wu_case_pass();
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
PROJECT = case12_3_scalar_tmem_lane_store
|
||||||
|
|
||||||
|
include ../case.mk
|
||||||
@@ -0,0 +1,14 @@
|
|||||||
|
# case12_3_scalar_tmem_lane_store
|
||||||
|
|
||||||
|
Validates scalar TMEM store lane-coalesced semantics.
|
||||||
|
|
||||||
|
One scalar warp writes a single 16-byte TMEM fragment. Each active scalar lane supplies one 32-bit word:
|
||||||
|
|
||||||
|
```text
|
||||||
|
word0 = lane0.rs2
|
||||||
|
word1 = lane1.rs2
|
||||||
|
word2 = lane2.rs2
|
||||||
|
word3 = lane3.rs2
|
||||||
|
```
|
||||||
|
|
||||||
|
The tensor copy-back path then copies that fragment to memory. This catches implementations that only write lane0 data or broadcast one scalar value.
|
||||||
@@ -0,0 +1,91 @@
|
|||||||
|
#include "../common_wu_blackwell_fa.h"
|
||||||
|
|
||||||
|
#define WU_CASE12_3_COPY_READY 0x9900u
|
||||||
|
#define WU_CASE12_3_DONE_BASE 0x9a00u
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
volatile uint32_t g_case12_3_out[4] __attribute__((aligned(16)));
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" void __attribute__((naked, noinline, used))
|
||||||
|
tensor_case12_3_copy_worker() {
|
||||||
|
asm volatile(
|
||||||
|
"csrr x5, %[csr_wid]\n\t"
|
||||||
|
"addi x1, x5, -%[num_scalar_warps]\n\t"
|
||||||
|
"slli x1, x1, 11\n\t"
|
||||||
|
"1:\n\t"
|
||||||
|
"la x6, g_case_mem\n\t"
|
||||||
|
"lw x7, 0(x6)\n\t"
|
||||||
|
"li x4, %[copy_ready]\n\t"
|
||||||
|
"bne x7, x4, 1b\n\t"
|
||||||
|
"la x6, g_case12_3_out\n\t"
|
||||||
|
".insn r %[custom3], 6, 0, x0, x1, x6\n\t"
|
||||||
|
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||||
|
"csrr x5, %[csr_wid]\n\t"
|
||||||
|
"slli x6, x5, 2\n\t"
|
||||||
|
"la x7, g_seen\n\t"
|
||||||
|
"add x7, x7, x6\n\t"
|
||||||
|
"li x6, %[done_base]\n\t"
|
||||||
|
"or x6, x6, x5\n\t"
|
||||||
|
"sw x6, 0(x7)\n\t"
|
||||||
|
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
|
||||||
|
"2: j 2b\n\t"
|
||||||
|
:
|
||||||
|
: [csr_wid] "i"(VX_CSR_WARP_ID), [custom0] "i"(RISCV_CUSTOM0),
|
||||||
|
[custom3] "i"(RISCV_CUSTOM3),
|
||||||
|
[num_scalar_warps] "i"(NUM_SCALAR_WARPS),
|
||||||
|
[copy_ready] "i"(WU_CASE12_3_COPY_READY),
|
||||||
|
[done_base] "i"(WU_CASE12_3_DONE_BASE)
|
||||||
|
: "memory");
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" int wu_main() {
|
||||||
|
if (vx_core_id() != 0 || vx_warp_id() != 0) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint32_t tensor_mask = 1u << NUM_SCALAR_WARPS;
|
||||||
|
|
||||||
|
wu_case_reset();
|
||||||
|
for (uint32_t i = 0; i < 4; ++i) {
|
||||||
|
g_case12_3_out[i] = 0;
|
||||||
|
}
|
||||||
|
vx_spawn_tensor(tensor_mask, tensor_case12_3_copy_worker);
|
||||||
|
asm volatile("fence rw, rw" ::: "memory");
|
||||||
|
|
||||||
|
const uint32_t frag_addr =
|
||||||
|
wu_bw_tmem_a_byte_base(0) / WU_BW_TMEM_FRAGMENT_BYTES;
|
||||||
|
asm volatile(
|
||||||
|
".insn r %[custom0], 0, 0, x0, %[all_lanes], x0\n\t"
|
||||||
|
"csrr t0, %[csr_tid]\n\t"
|
||||||
|
"lui t1, 0x5a120\n\t"
|
||||||
|
"or t1, t1, t0\n\t"
|
||||||
|
".insn r %[custom1], 1, 0x30, x0, %[addr], t1\n\t"
|
||||||
|
"fence rw, rw\n\t"
|
||||||
|
"li t2, 1\n\t"
|
||||||
|
".insn r %[custom0], 0, 0, x0, t2, x0"
|
||||||
|
:
|
||||||
|
: [custom0] "i"(RISCV_CUSTOM0), [custom1] "i"(RISCV_CUSTOM1),
|
||||||
|
[csr_tid] "i"(VX_CSR_THREAD_ID), [addr] "r"(frag_addr),
|
||||||
|
[all_lanes] "r"(wu_bw_all_lanes_mask())
|
||||||
|
: "t0", "t1", "t2", "memory");
|
||||||
|
|
||||||
|
g_case_mem[0] = WU_CASE12_3_COPY_READY;
|
||||||
|
if (wu_wait_seen_mask(tensor_mask, WU_CASE12_3_DONE_BASE) != 0) {
|
||||||
|
wu_case_fail(0x51u);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (uint32_t lane = 0; lane < NUM_THREADS; ++lane) {
|
||||||
|
const uint32_t expected = 0x5a120000u | lane;
|
||||||
|
if (g_case12_3_out[lane] != expected) {
|
||||||
|
g_aux[0] = lane;
|
||||||
|
g_aux[1] = g_case12_3_out[lane];
|
||||||
|
wu_case_fail(0x52u);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
wu_case_pass();
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
3
kernels/wu_arch_cases/case12_flash_pv_accum/Makefile
Normal file
3
kernels/wu_arch_cases/case12_flash_pv_accum/Makefile
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
PROJECT = case12_flash_pv_accum
|
||||||
|
|
||||||
|
include ../case.mk
|
||||||
7
kernels/wu_arch_cases/case12_flash_pv_accum/README.md
Normal file
7
kernels/wu_arch_cases/case12_flash_pv_accum/README.md
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
# case12_flash_pv_accum
|
||||||
|
|
||||||
|
Validates the first Wu Blackwell FlashAttention PV substage.
|
||||||
|
|
||||||
|
Scalar warp 0 writes a full `16x32` fp16 `P` tile into TMEM A. One tensor warp
|
||||||
|
uses BWGMMA with a `32x16` fp16 `V` tile in SMEM and a preloaded fp32 TMEM C
|
||||||
|
tile. The expected result is `O = 3.0 + P @ V = 67.0` for every output element.
|
||||||
127
kernels/wu_arch_cases/case12_flash_pv_accum/kernel.cpp
Normal file
127
kernels/wu_arch_cases/case12_flash_pv_accum/kernel.cpp
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
#include "../common_wu_blackwell_fa.h"
|
||||||
|
|
||||||
|
#define WU_CASE12_INIT_BASE 0x8a00u
|
||||||
|
#define WU_CASE12_DONE_BASE 0x8b00u
|
||||||
|
#define WU_CASE12_P_READY 0x8c00u
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
volatile uint32_t g_case12_o_row[4] __attribute__((aligned(16))) = {
|
||||||
|
WU_BW_FP32_THREE, WU_BW_FP32_THREE, WU_BW_FP32_THREE,
|
||||||
|
WU_BW_FP32_THREE};
|
||||||
|
volatile uint32_t g_case12_out[WU_BW_OUT_WORDS] __attribute__((aligned(16)));
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" void __attribute__((naked, noinline, used)) tensor_case12_worker() {
|
||||||
|
asm volatile(
|
||||||
|
"csrr x5, %[csr_wid]\n\t"
|
||||||
|
"addi x1, x5, -%[num_scalar_warps]\n\t"
|
||||||
|
"slli x1, x1, 11\n\t"
|
||||||
|
"addi x2, x1, %[c_offset]\n\t"
|
||||||
|
"la x3, g_case12_o_row\n\t"
|
||||||
|
"li x7, 0\n\t"
|
||||||
|
"1:\n\t"
|
||||||
|
"add x4, x2, x7\n\t"
|
||||||
|
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
|
||||||
|
"addi x7, x7, 16\n\t"
|
||||||
|
"li x4, %[tile_bytes]\n\t"
|
||||||
|
"blt x7, x4, 1b\n\t"
|
||||||
|
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||||
|
"slli x6, x5, 2\n\t"
|
||||||
|
"la x7, g_seen\n\t"
|
||||||
|
"add x7, x7, x6\n\t"
|
||||||
|
"li x6, %[init_base]\n\t"
|
||||||
|
"or x6, x6, x5\n\t"
|
||||||
|
"sw x6, 0(x7)\n\t"
|
||||||
|
"2:\n\t"
|
||||||
|
"la x6, g_case_mem\n\t"
|
||||||
|
"lw x7, 0(x6)\n\t"
|
||||||
|
"li x4, %[p_ready]\n\t"
|
||||||
|
"bne x7, x4, 2b\n\t"
|
||||||
|
"li x4, %[smem_base]\n\t"
|
||||||
|
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
|
||||||
|
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
|
||||||
|
"la x3, g_case12_out\n\t"
|
||||||
|
"li x7, 0\n\t"
|
||||||
|
"3:\n\t"
|
||||||
|
"add x4, x2, x7\n\t"
|
||||||
|
"add x6, x3, x7\n\t"
|
||||||
|
".insn r %[custom3], 6, 0, x0, x4, x6\n\t"
|
||||||
|
"addi x7, x7, 16\n\t"
|
||||||
|
"li x4, %[tile_bytes]\n\t"
|
||||||
|
"blt x7, x4, 3b\n\t"
|
||||||
|
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||||
|
"csrr x5, %[csr_wid]\n\t"
|
||||||
|
"slli x6, x5, 2\n\t"
|
||||||
|
"la x7, g_seen\n\t"
|
||||||
|
"add x7, x7, x6\n\t"
|
||||||
|
"li x6, %[done_base]\n\t"
|
||||||
|
"or x6, x6, x5\n\t"
|
||||||
|
"sw x6, 0(x7)\n\t"
|
||||||
|
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
|
||||||
|
"4: j 4b\n\t"
|
||||||
|
:
|
||||||
|
: [csr_wid] "i"(VX_CSR_WARP_ID), [custom0] "i"(RISCV_CUSTOM0),
|
||||||
|
[custom3] "i"(RISCV_CUSTOM3),
|
||||||
|
[num_scalar_warps] "i"(NUM_SCALAR_WARPS),
|
||||||
|
[c_offset] "i"(WU_BW_TMEM_C_BYTE_OFFSET),
|
||||||
|
[tile_bytes] "i"(WU_BW_TMEM_TILE_BYTES),
|
||||||
|
[smem_base] "i"(WU_BW_DEV_SMEM_START_ADDR),
|
||||||
|
[init_base] "i"(WU_CASE12_INIT_BASE),
|
||||||
|
[done_base] "i"(WU_CASE12_DONE_BASE),
|
||||||
|
[p_ready] "i"(WU_CASE12_P_READY)
|
||||||
|
: "memory");
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" int wu_main() {
|
||||||
|
if (vx_core_id() != 0 || vx_warp_id() != 0) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint32_t tid = wu_tid();
|
||||||
|
if (tid == 0) {
|
||||||
|
wu_case_reset();
|
||||||
|
for (uint32_t i = 0; i < WU_BW_OUT_WORDS; ++i) {
|
||||||
|
g_case12_out[i] = 0;
|
||||||
|
}
|
||||||
|
wu_bw_fill_smem_tile(
|
||||||
|
reinterpret_cast<volatile uint32_t *>(WU_BW_DEV_SMEM_START_ADDR),
|
||||||
|
WU_BW_FP16_TWO_PACKED);
|
||||||
|
vx_spawn_tensor(1u << NUM_SCALAR_WARPS, tensor_case12_worker);
|
||||||
|
if (wu_wait_seen_mask(1u << NUM_SCALAR_WARPS, WU_CASE12_INIT_BASE) != 0) {
|
||||||
|
g_case_mem[1] = 0x12u;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
asm volatile("fence rw, rw" ::: "memory");
|
||||||
|
|
||||||
|
if (g_case_mem[1] == 0) {
|
||||||
|
vx_tmc(wu_bw_all_lanes_mask());
|
||||||
|
wu_bw_fill_tmem_tile(wu_bw_tmem_a_byte_base(0), WU_BW_FP16_ONE_PACKED);
|
||||||
|
vx_tmc_one();
|
||||||
|
}
|
||||||
|
asm volatile("fence rw, rw" ::: "memory");
|
||||||
|
|
||||||
|
if (tid == 0) {
|
||||||
|
g_case_mem[0] = WU_CASE12_P_READY;
|
||||||
|
if (g_case_mem[1] == 0 &&
|
||||||
|
wu_wait_seen_mask(1u << NUM_SCALAR_WARPS, WU_CASE12_DONE_BASE) != 0) {
|
||||||
|
g_case_mem[1] = 0x13u;
|
||||||
|
}
|
||||||
|
if (g_case_mem[1] == 0) {
|
||||||
|
volatile uint32_t bad_actual = 0;
|
||||||
|
const uint32_t bad =
|
||||||
|
wu_bw_verify_constant(g_case12_out, WU_BW_OUT_WORDS,
|
||||||
|
WU_BW_FP32_SIXTY_SEVEN, &bad_actual);
|
||||||
|
if (bad != WU_BW_OUT_WORDS) {
|
||||||
|
g_aux[0] = bad;
|
||||||
|
g_aux[1] = bad_actual;
|
||||||
|
g_case_mem[1] = 0x14u;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (g_case_mem[1] != 0) {
|
||||||
|
wu_case_fail(g_case_mem[1]);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
wu_case_pass();
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
3
kernels/wu_arch_cases/case13_flash_pv_two_warps/Makefile
Normal file
3
kernels/wu_arch_cases/case13_flash_pv_two_warps/Makefile
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
PROJECT = case13_flash_pv_two_warps
|
||||||
|
|
||||||
|
include ../case.mk
|
||||||
@@ -0,0 +1,8 @@
|
|||||||
|
# case13_flash_pv_two_warps
|
||||||
|
|
||||||
|
Validates the Wu Blackwell FlashAttention PV substage across both tensor warps.
|
||||||
|
|
||||||
|
Scalar warp 0 writes one `P` tile into each tensor warp's TMEM A partition.
|
||||||
|
Both tensor warps consume the same SMEM `V` tile, accumulate into their own TMEM
|
||||||
|
C tiles, copy out contiguous row blocks, and stop. Every fp32 output is expected
|
||||||
|
to be `67.0`.
|
||||||
135
kernels/wu_arch_cases/case13_flash_pv_two_warps/kernel.cpp
Normal file
135
kernels/wu_arch_cases/case13_flash_pv_two_warps/kernel.cpp
Normal file
@@ -0,0 +1,135 @@
|
|||||||
|
#include "../common_wu_blackwell_fa.h"
|
||||||
|
|
||||||
|
#define WU_CASE13_INIT_BASE 0x8d00u
|
||||||
|
#define WU_CASE13_DONE_BASE 0x8e00u
|
||||||
|
#define WU_CASE13_P_READY 0x8f00u
|
||||||
|
#define WU_CASE13_OUT_WORDS (NUM_TENSOR_WARPS * WU_BW_OUT_WORDS)
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
volatile uint32_t g_case13_o_row[4] __attribute__((aligned(16))) = {
|
||||||
|
WU_BW_FP32_THREE, WU_BW_FP32_THREE, WU_BW_FP32_THREE,
|
||||||
|
WU_BW_FP32_THREE};
|
||||||
|
volatile uint32_t g_case13_out[WU_CASE13_OUT_WORDS]
|
||||||
|
__attribute__((aligned(16)));
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" void __attribute__((naked, noinline, used)) tensor_case13_worker() {
|
||||||
|
asm volatile(
|
||||||
|
"csrr x5, %[csr_wid]\n\t"
|
||||||
|
"addi x1, x5, -%[num_scalar_warps]\n\t"
|
||||||
|
"slli x1, x1, 11\n\t"
|
||||||
|
"addi x2, x1, %[c_offset]\n\t"
|
||||||
|
"la x3, g_case13_o_row\n\t"
|
||||||
|
"li x7, 0\n\t"
|
||||||
|
"1:\n\t"
|
||||||
|
"add x4, x2, x7\n\t"
|
||||||
|
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
|
||||||
|
"addi x7, x7, 16\n\t"
|
||||||
|
"li x4, %[tile_bytes]\n\t"
|
||||||
|
"blt x7, x4, 1b\n\t"
|
||||||
|
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||||
|
"slli x6, x5, 2\n\t"
|
||||||
|
"la x7, g_seen\n\t"
|
||||||
|
"add x7, x7, x6\n\t"
|
||||||
|
"li x6, %[init_base]\n\t"
|
||||||
|
"or x6, x6, x5\n\t"
|
||||||
|
"sw x6, 0(x7)\n\t"
|
||||||
|
"2:\n\t"
|
||||||
|
"la x6, g_case_mem\n\t"
|
||||||
|
"lw x7, 0(x6)\n\t"
|
||||||
|
"li x4, %[p_ready]\n\t"
|
||||||
|
"bne x7, x4, 2b\n\t"
|
||||||
|
"li x4, %[smem_base]\n\t"
|
||||||
|
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
|
||||||
|
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
|
||||||
|
"csrr x5, %[csr_wid]\n\t"
|
||||||
|
"addi x6, x5, -%[num_scalar_warps]\n\t"
|
||||||
|
"slli x7, x6, 10\n\t"
|
||||||
|
"la x3, g_case13_out\n\t"
|
||||||
|
"add x3, x3, x7\n\t"
|
||||||
|
"li x7, 0\n\t"
|
||||||
|
"3:\n\t"
|
||||||
|
"add x4, x2, x7\n\t"
|
||||||
|
"add x6, x3, x7\n\t"
|
||||||
|
".insn r %[custom3], 6, 0, x0, x4, x6\n\t"
|
||||||
|
"addi x7, x7, 16\n\t"
|
||||||
|
"li x4, %[tile_bytes]\n\t"
|
||||||
|
"blt x7, x4, 3b\n\t"
|
||||||
|
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||||
|
"csrr x5, %[csr_wid]\n\t"
|
||||||
|
"slli x6, x5, 2\n\t"
|
||||||
|
"la x7, g_seen\n\t"
|
||||||
|
"add x7, x7, x6\n\t"
|
||||||
|
"li x6, %[done_base]\n\t"
|
||||||
|
"or x6, x6, x5\n\t"
|
||||||
|
"sw x6, 0(x7)\n\t"
|
||||||
|
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
|
||||||
|
"4: j 4b\n\t"
|
||||||
|
:
|
||||||
|
: [csr_wid] "i"(VX_CSR_WARP_ID), [custom0] "i"(RISCV_CUSTOM0),
|
||||||
|
[custom3] "i"(RISCV_CUSTOM3),
|
||||||
|
[num_scalar_warps] "i"(NUM_SCALAR_WARPS),
|
||||||
|
[c_offset] "i"(WU_BW_TMEM_C_BYTE_OFFSET),
|
||||||
|
[tile_bytes] "i"(WU_BW_TMEM_TILE_BYTES),
|
||||||
|
[smem_base] "i"(WU_BW_DEV_SMEM_START_ADDR),
|
||||||
|
[init_base] "i"(WU_CASE13_INIT_BASE),
|
||||||
|
[done_base] "i"(WU_CASE13_DONE_BASE),
|
||||||
|
[p_ready] "i"(WU_CASE13_P_READY)
|
||||||
|
: "memory");
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" int wu_main() {
|
||||||
|
if (vx_core_id() != 0 || vx_warp_id() != 0) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint32_t tid = wu_tid();
|
||||||
|
const uint32_t tensor_mask = vx_tensor_warp_mask();
|
||||||
|
if (tid == 0) {
|
||||||
|
wu_case_reset();
|
||||||
|
for (uint32_t i = 0; i < WU_CASE13_OUT_WORDS; ++i) {
|
||||||
|
g_case13_out[i] = 0;
|
||||||
|
}
|
||||||
|
wu_bw_fill_smem_tile(
|
||||||
|
reinterpret_cast<volatile uint32_t *>(WU_BW_DEV_SMEM_START_ADDR),
|
||||||
|
WU_BW_FP16_TWO_PACKED);
|
||||||
|
vx_spawn_tensor(tensor_mask, tensor_case13_worker);
|
||||||
|
if (wu_wait_seen_mask(tensor_mask, WU_CASE13_INIT_BASE) != 0) {
|
||||||
|
g_case_mem[1] = 0x21u;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
asm volatile("fence rw, rw" ::: "memory");
|
||||||
|
|
||||||
|
if (g_case_mem[1] == 0) {
|
||||||
|
vx_tmc(wu_bw_all_lanes_mask());
|
||||||
|
wu_bw_fill_tmem_tile(wu_bw_tmem_a_byte_base(0), WU_BW_FP16_ONE_PACKED);
|
||||||
|
wu_bw_fill_tmem_tile(wu_bw_tmem_a_byte_base(1), WU_BW_FP16_ONE_PACKED);
|
||||||
|
vx_tmc_one();
|
||||||
|
}
|
||||||
|
asm volatile("fence rw, rw" ::: "memory");
|
||||||
|
|
||||||
|
if (tid == 0) {
|
||||||
|
g_case_mem[0] = WU_CASE13_P_READY;
|
||||||
|
if (g_case_mem[1] == 0 &&
|
||||||
|
wu_wait_seen_mask(tensor_mask, WU_CASE13_DONE_BASE) != 0) {
|
||||||
|
g_case_mem[1] = 0x22u;
|
||||||
|
}
|
||||||
|
if (g_case_mem[1] == 0) {
|
||||||
|
volatile uint32_t bad_actual = 0;
|
||||||
|
const uint32_t bad =
|
||||||
|
wu_bw_verify_constant(g_case13_out, WU_CASE13_OUT_WORDS,
|
||||||
|
WU_BW_FP32_SIXTY_SEVEN, &bad_actual);
|
||||||
|
if (bad != WU_CASE13_OUT_WORDS) {
|
||||||
|
g_aux[0] = bad;
|
||||||
|
g_aux[1] = bad_actual;
|
||||||
|
g_case_mem[1] = 0x23u;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (g_case_mem[1] != 0) {
|
||||||
|
wu_case_fail(g_case_mem[1]);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
wu_case_pass();
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
3
kernels/wu_arch_cases/case14_flash_pv_k64/Makefile
Normal file
3
kernels/wu_arch_cases/case14_flash_pv_k64/Makefile
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
PROJECT = case14_flash_pv_k64
|
||||||
|
|
||||||
|
include ../case.mk
|
||||||
7
kernels/wu_arch_cases/case14_flash_pv_k64/README.md
Normal file
7
kernels/wu_arch_cases/case14_flash_pv_k64/README.md
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
# case14_flash_pv_k64
|
||||||
|
|
||||||
|
Validates PV accumulation over two Blackwell `K=32` BWGMMA steps.
|
||||||
|
|
||||||
|
Both tensor warps run two consecutive BWGMMA operations against the same TMEM C
|
||||||
|
tile, modeling a `K=64` FlashAttention PV accumulation. With `P=1.0`,
|
||||||
|
`V=2.0`, and `O_init=5.0`, every fp32 output is expected to be `133.0`.
|
||||||
136
kernels/wu_arch_cases/case14_flash_pv_k64/kernel.cpp
Normal file
136
kernels/wu_arch_cases/case14_flash_pv_k64/kernel.cpp
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
#include "../common_wu_blackwell_fa.h"
|
||||||
|
|
||||||
|
#define WU_CASE14_INIT_BASE 0x9000u
|
||||||
|
#define WU_CASE14_DONE_BASE 0x9100u
|
||||||
|
#define WU_CASE14_P_READY 0x9200u
|
||||||
|
#define WU_CASE14_OUT_WORDS (NUM_TENSOR_WARPS * WU_BW_OUT_WORDS)
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
volatile uint32_t g_case14_o_row[4] __attribute__((aligned(16))) = {
|
||||||
|
WU_BW_FP32_FIVE, WU_BW_FP32_FIVE, WU_BW_FP32_FIVE, WU_BW_FP32_FIVE};
|
||||||
|
volatile uint32_t g_case14_out[WU_CASE14_OUT_WORDS]
|
||||||
|
__attribute__((aligned(16)));
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" void __attribute__((naked, noinline, used)) tensor_case14_worker() {
|
||||||
|
asm volatile(
|
||||||
|
"csrr x5, %[csr_wid]\n\t"
|
||||||
|
"addi x1, x5, -%[num_scalar_warps]\n\t"
|
||||||
|
"slli x1, x1, 11\n\t"
|
||||||
|
"addi x2, x1, %[c_offset]\n\t"
|
||||||
|
"la x3, g_case14_o_row\n\t"
|
||||||
|
"li x7, 0\n\t"
|
||||||
|
"1:\n\t"
|
||||||
|
"add x4, x2, x7\n\t"
|
||||||
|
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
|
||||||
|
"addi x7, x7, 16\n\t"
|
||||||
|
"li x4, %[tile_bytes]\n\t"
|
||||||
|
"blt x7, x4, 1b\n\t"
|
||||||
|
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||||
|
"slli x6, x5, 2\n\t"
|
||||||
|
"la x7, g_seen\n\t"
|
||||||
|
"add x7, x7, x6\n\t"
|
||||||
|
"li x6, %[init_base]\n\t"
|
||||||
|
"or x6, x6, x5\n\t"
|
||||||
|
"sw x6, 0(x7)\n\t"
|
||||||
|
"2:\n\t"
|
||||||
|
"la x6, g_case_mem\n\t"
|
||||||
|
"lw x7, 0(x6)\n\t"
|
||||||
|
"li x4, %[p_ready]\n\t"
|
||||||
|
"bne x7, x4, 2b\n\t"
|
||||||
|
"li x4, %[smem_base]\n\t"
|
||||||
|
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
|
||||||
|
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
|
||||||
|
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
|
||||||
|
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
|
||||||
|
"csrr x5, %[csr_wid]\n\t"
|
||||||
|
"addi x6, x5, -%[num_scalar_warps]\n\t"
|
||||||
|
"slli x7, x6, 10\n\t"
|
||||||
|
"la x3, g_case14_out\n\t"
|
||||||
|
"add x3, x3, x7\n\t"
|
||||||
|
"li x7, 0\n\t"
|
||||||
|
"3:\n\t"
|
||||||
|
"add x4, x2, x7\n\t"
|
||||||
|
"add x6, x3, x7\n\t"
|
||||||
|
".insn r %[custom3], 6, 0, x0, x4, x6\n\t"
|
||||||
|
"addi x7, x7, 16\n\t"
|
||||||
|
"li x4, %[tile_bytes]\n\t"
|
||||||
|
"blt x7, x4, 3b\n\t"
|
||||||
|
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||||
|
"csrr x5, %[csr_wid]\n\t"
|
||||||
|
"slli x6, x5, 2\n\t"
|
||||||
|
"la x7, g_seen\n\t"
|
||||||
|
"add x7, x7, x6\n\t"
|
||||||
|
"li x6, %[done_base]\n\t"
|
||||||
|
"or x6, x6, x5\n\t"
|
||||||
|
"sw x6, 0(x7)\n\t"
|
||||||
|
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
|
||||||
|
"4: j 4b\n\t"
|
||||||
|
:
|
||||||
|
: [csr_wid] "i"(VX_CSR_WARP_ID), [custom0] "i"(RISCV_CUSTOM0),
|
||||||
|
[custom3] "i"(RISCV_CUSTOM3),
|
||||||
|
[num_scalar_warps] "i"(NUM_SCALAR_WARPS),
|
||||||
|
[c_offset] "i"(WU_BW_TMEM_C_BYTE_OFFSET),
|
||||||
|
[tile_bytes] "i"(WU_BW_TMEM_TILE_BYTES),
|
||||||
|
[smem_base] "i"(WU_BW_DEV_SMEM_START_ADDR),
|
||||||
|
[init_base] "i"(WU_CASE14_INIT_BASE),
|
||||||
|
[done_base] "i"(WU_CASE14_DONE_BASE),
|
||||||
|
[p_ready] "i"(WU_CASE14_P_READY)
|
||||||
|
: "memory");
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" int wu_main() {
|
||||||
|
if (vx_core_id() != 0 || vx_warp_id() != 0) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint32_t tid = wu_tid();
|
||||||
|
const uint32_t tensor_mask = vx_tensor_warp_mask();
|
||||||
|
if (tid == 0) {
|
||||||
|
wu_case_reset();
|
||||||
|
for (uint32_t i = 0; i < WU_CASE14_OUT_WORDS; ++i) {
|
||||||
|
g_case14_out[i] = 0;
|
||||||
|
}
|
||||||
|
wu_bw_fill_smem_tile(
|
||||||
|
reinterpret_cast<volatile uint32_t *>(WU_BW_DEV_SMEM_START_ADDR),
|
||||||
|
WU_BW_FP16_TWO_PACKED);
|
||||||
|
vx_spawn_tensor(tensor_mask, tensor_case14_worker);
|
||||||
|
if (wu_wait_seen_mask(tensor_mask, WU_CASE14_INIT_BASE) != 0) {
|
||||||
|
g_case_mem[1] = 0x31u;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
asm volatile("fence rw, rw" ::: "memory");
|
||||||
|
|
||||||
|
if (g_case_mem[1] == 0) {
|
||||||
|
vx_tmc(wu_bw_all_lanes_mask());
|
||||||
|
wu_bw_fill_tmem_tile(wu_bw_tmem_a_byte_base(0), WU_BW_FP16_ONE_PACKED);
|
||||||
|
wu_bw_fill_tmem_tile(wu_bw_tmem_a_byte_base(1), WU_BW_FP16_ONE_PACKED);
|
||||||
|
vx_tmc_one();
|
||||||
|
}
|
||||||
|
asm volatile("fence rw, rw" ::: "memory");
|
||||||
|
|
||||||
|
if (tid == 0) {
|
||||||
|
g_case_mem[0] = WU_CASE14_P_READY;
|
||||||
|
if (g_case_mem[1] == 0 &&
|
||||||
|
wu_wait_seen_mask(tensor_mask, WU_CASE14_DONE_BASE) != 0) {
|
||||||
|
g_case_mem[1] = 0x32u;
|
||||||
|
}
|
||||||
|
if (g_case_mem[1] == 0) {
|
||||||
|
volatile uint32_t bad_actual = 0;
|
||||||
|
const uint32_t bad =
|
||||||
|
wu_bw_verify_constant(g_case14_out, WU_CASE14_OUT_WORDS,
|
||||||
|
WU_BW_FP32_ONE_THIRTY_THREE, &bad_actual);
|
||||||
|
if (bad != WU_CASE14_OUT_WORDS) {
|
||||||
|
g_aux[0] = bad;
|
||||||
|
g_aux[1] = bad_actual;
|
||||||
|
g_case_mem[1] = 0x33u;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (g_case_mem[1] != 0) {
|
||||||
|
wu_case_fail(g_case_mem[1]);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
wu_case_pass();
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
PROJECT = case15_flash_softmax_pv_stage
|
||||||
|
|
||||||
|
include ../case.mk
|
||||||
@@ -0,0 +1,8 @@
|
|||||||
|
# case15_flash_softmax_pv_stage
|
||||||
|
|
||||||
|
Validates the scalar softmax-stage handoff into the Blackwell PV GEMM.
|
||||||
|
|
||||||
|
The tensor warp initializes TMEM C with a score/O value of `1.0`. Scalar warp 0
|
||||||
|
reads that value through the scalar TMEM load path, writes a full fp16 `P=1.0`
|
||||||
|
tile into TMEM A, and releases the tensor warp. The tensor warp then runs PV
|
||||||
|
against `V=2.0`; every fp32 output is expected to be `65.0`.
|
||||||
145
kernels/wu_arch_cases/case15_flash_softmax_pv_stage/kernel.cpp
Normal file
145
kernels/wu_arch_cases/case15_flash_softmax_pv_stage/kernel.cpp
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
#include "../common_wu_blackwell_fa.h"
|
||||||
|
|
||||||
|
#define WU_CASE15_INIT_BASE 0x9300u
|
||||||
|
#define WU_CASE15_DONE_BASE 0x9400u
|
||||||
|
#define WU_CASE15_P_READY 0x9500u
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
volatile uint32_t g_case15_score_row[4] __attribute__((aligned(16))) = {
|
||||||
|
WU_BW_FP32_ONE, WU_BW_FP32_ONE, WU_BW_FP32_ONE, WU_BW_FP32_ONE};
|
||||||
|
volatile uint32_t g_case15_out[WU_BW_OUT_WORDS] __attribute__((aligned(16)));
|
||||||
|
volatile uint32_t g_case15_scalar_seen[4] __attribute__((aligned(16)));
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" void __attribute__((naked, noinline, used)) tensor_case15_worker() {
|
||||||
|
asm volatile(
|
||||||
|
"csrr x5, %[csr_wid]\n\t"
|
||||||
|
"addi x1, x5, -%[num_scalar_warps]\n\t"
|
||||||
|
"slli x1, x1, 11\n\t"
|
||||||
|
"addi x2, x1, %[c_offset]\n\t"
|
||||||
|
"la x3, g_case15_score_row\n\t"
|
||||||
|
"li x7, 0\n\t"
|
||||||
|
"1:\n\t"
|
||||||
|
"add x4, x2, x7\n\t"
|
||||||
|
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
|
||||||
|
"addi x7, x7, 16\n\t"
|
||||||
|
"li x4, %[tile_bytes]\n\t"
|
||||||
|
"blt x7, x4, 1b\n\t"
|
||||||
|
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||||
|
"slli x6, x5, 2\n\t"
|
||||||
|
"la x7, g_seen\n\t"
|
||||||
|
"add x7, x7, x6\n\t"
|
||||||
|
"li x6, %[init_base]\n\t"
|
||||||
|
"or x6, x6, x5\n\t"
|
||||||
|
"sw x6, 0(x7)\n\t"
|
||||||
|
"2:\n\t"
|
||||||
|
"la x6, g_case_mem\n\t"
|
||||||
|
"lw x7, 0(x6)\n\t"
|
||||||
|
"li x4, %[p_ready]\n\t"
|
||||||
|
"bne x7, x4, 2b\n\t"
|
||||||
|
"li x4, %[smem_base]\n\t"
|
||||||
|
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
|
||||||
|
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
|
||||||
|
"la x3, g_case15_out\n\t"
|
||||||
|
"li x7, 0\n\t"
|
||||||
|
"3:\n\t"
|
||||||
|
"add x4, x2, x7\n\t"
|
||||||
|
"add x6, x3, x7\n\t"
|
||||||
|
".insn r %[custom3], 6, 0, x0, x4, x6\n\t"
|
||||||
|
"addi x7, x7, 16\n\t"
|
||||||
|
"li x4, %[tile_bytes]\n\t"
|
||||||
|
"blt x7, x4, 3b\n\t"
|
||||||
|
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||||
|
"csrr x5, %[csr_wid]\n\t"
|
||||||
|
"slli x6, x5, 2\n\t"
|
||||||
|
"la x7, g_seen\n\t"
|
||||||
|
"add x7, x7, x6\n\t"
|
||||||
|
"li x6, %[done_base]\n\t"
|
||||||
|
"or x6, x6, x5\n\t"
|
||||||
|
"sw x6, 0(x7)\n\t"
|
||||||
|
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
|
||||||
|
"4: j 4b\n\t"
|
||||||
|
:
|
||||||
|
: [csr_wid] "i"(VX_CSR_WARP_ID), [custom0] "i"(RISCV_CUSTOM0),
|
||||||
|
[custom3] "i"(RISCV_CUSTOM3),
|
||||||
|
[num_scalar_warps] "i"(NUM_SCALAR_WARPS),
|
||||||
|
[c_offset] "i"(WU_BW_TMEM_C_BYTE_OFFSET),
|
||||||
|
[tile_bytes] "i"(WU_BW_TMEM_TILE_BYTES),
|
||||||
|
[smem_base] "i"(WU_BW_DEV_SMEM_START_ADDR),
|
||||||
|
[init_base] "i"(WU_CASE15_INIT_BASE),
|
||||||
|
[done_base] "i"(WU_CASE15_DONE_BASE),
|
||||||
|
[p_ready] "i"(WU_CASE15_P_READY)
|
||||||
|
: "memory");
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" int wu_main() {
|
||||||
|
if (vx_core_id() != 0 || vx_warp_id() != 0) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint32_t tid = wu_tid();
|
||||||
|
const uint32_t tensor_mask = 1u << NUM_SCALAR_WARPS;
|
||||||
|
if (tid == 0) {
|
||||||
|
wu_case_reset();
|
||||||
|
for (uint32_t i = 0; i < WU_BW_OUT_WORDS; ++i) {
|
||||||
|
g_case15_out[i] = 0;
|
||||||
|
}
|
||||||
|
for (uint32_t i = 0; i < 4; ++i) {
|
||||||
|
g_case15_scalar_seen[i] = 0;
|
||||||
|
}
|
||||||
|
wu_bw_fill_smem_tile(
|
||||||
|
reinterpret_cast<volatile uint32_t *>(WU_BW_DEV_SMEM_START_ADDR),
|
||||||
|
WU_BW_FP16_TWO_PACKED);
|
||||||
|
vx_spawn_tensor(tensor_mask, tensor_case15_worker);
|
||||||
|
if (wu_wait_seen_mask(tensor_mask, WU_CASE15_INIT_BASE) != 0) {
|
||||||
|
g_case_mem[1] = 0x41u;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
asm volatile("fence rw, rw" ::: "memory");
|
||||||
|
|
||||||
|
const uint32_t c_frag =
|
||||||
|
wu_bw_tmem_c_byte_base(0) / WU_BW_TMEM_FRAGMENT_BYTES;
|
||||||
|
const uint32_t observed = wu_bw_scalar_tmem_ld(c_frag);
|
||||||
|
g_case15_scalar_seen[tid] = observed;
|
||||||
|
asm volatile("fence rw, rw" ::: "memory");
|
||||||
|
|
||||||
|
if (tid == 0 && g_case_mem[1] == 0 &&
|
||||||
|
g_case15_scalar_seen[0] != WU_BW_FP32_ONE) {
|
||||||
|
g_aux[0] = 0;
|
||||||
|
g_aux[1] = g_case15_scalar_seen[0];
|
||||||
|
g_case_mem[1] = 0x42u;
|
||||||
|
}
|
||||||
|
asm volatile("fence rw, rw" ::: "memory");
|
||||||
|
|
||||||
|
if (g_case_mem[1] == 0) {
|
||||||
|
vx_tmc(wu_bw_all_lanes_mask());
|
||||||
|
wu_bw_fill_tmem_tile(wu_bw_tmem_a_byte_base(0), WU_BW_FP16_ONE_PACKED);
|
||||||
|
vx_tmc_one();
|
||||||
|
}
|
||||||
|
asm volatile("fence rw, rw" ::: "memory");
|
||||||
|
|
||||||
|
if (tid == 0) {
|
||||||
|
g_case_mem[0] = WU_CASE15_P_READY;
|
||||||
|
if (g_case_mem[1] == 0 &&
|
||||||
|
wu_wait_seen_mask(tensor_mask, WU_CASE15_DONE_BASE) != 0) {
|
||||||
|
g_case_mem[1] = 0x43u;
|
||||||
|
}
|
||||||
|
if (g_case_mem[1] == 0) {
|
||||||
|
volatile uint32_t bad_actual = 0;
|
||||||
|
const uint32_t bad =
|
||||||
|
wu_bw_verify_constant(g_case15_out, WU_BW_OUT_WORDS,
|
||||||
|
WU_BW_FP32_SIXTY_FIVE, &bad_actual);
|
||||||
|
if (bad != WU_BW_OUT_WORDS) {
|
||||||
|
g_aux[0] = bad;
|
||||||
|
g_aux[1] = bad_actual;
|
||||||
|
g_case_mem[1] = 0x44u;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (g_case_mem[1] != 0) {
|
||||||
|
wu_case_fail(g_case_mem[1]);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
wu_case_pass();
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
PROJECT = case16_flash_full_pipeline
|
||||||
|
|
||||||
|
include ../case.mk
|
||||||
16
kernels/wu_arch_cases/case16_flash_full_pipeline/README.md
Normal file
16
kernels/wu_arch_cases/case16_flash_full_pipeline/README.md
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
# case16_flash_full_pipeline
|
||||||
|
|
||||||
|
Validates a compact end-to-end FlashAttention-style pipeline on the current Wu
|
||||||
|
Blackwell path.
|
||||||
|
|
||||||
|
The tensor warp first computes `S = Q @ K` into TMEM C with `Q=1.0`, `K=1.0`,
|
||||||
|
and `O_init=0.0`, producing a constant fp32 score of `32.0`. The scalar warp
|
||||||
|
reads the score row through scalar TMEM loads, scans the row maximum and
|
||||||
|
normalization denominator with scalar-only `FEXP.S`, converts each probability
|
||||||
|
to packed fp16, writes `P` into TMEM A, refills SMEM with `V=2.0`, and releases
|
||||||
|
the tensor warp. The tensor warp reloads TMEM C with `O=0.0`, then computes
|
||||||
|
`O = P @ V`.
|
||||||
|
|
||||||
|
Every output word is expected to be fp32 `2.0`. This case covers the staged
|
||||||
|
`QK -> scalar softmax handoff -> PV` loop without using the legacy
|
||||||
|
`kernels/flash_attention` implementation.
|
||||||
259
kernels/wu_arch_cases/case16_flash_full_pipeline/kernel.cpp
Normal file
259
kernels/wu_arch_cases/case16_flash_full_pipeline/kernel.cpp
Normal file
@@ -0,0 +1,259 @@
|
|||||||
|
#include "../common_wu_blackwell_fa.h"
|
||||||
|
|
||||||
|
#define WU_CASE16_INIT_BASE 0x9600u
|
||||||
|
#define WU_CASE16_P_READY 0x9700u
|
||||||
|
#define WU_CASE16_DONE_BASE 0x9800u
|
||||||
|
|
||||||
|
#define WU_BW_FP32_ZERO 0x00000000u
|
||||||
|
#define WU_BW_FP32_TWO 0x40000000u
|
||||||
|
#define WU_BW_FP32_THIRTY_TWO 0x42000000u
|
||||||
|
#define WU_CASE16_ROW_N 32u
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
volatile uint32_t g_case16_q_row[4] __attribute__((aligned(16))) = {
|
||||||
|
WU_BW_FP16_ONE_PACKED, WU_BW_FP16_ONE_PACKED, WU_BW_FP16_ONE_PACKED,
|
||||||
|
WU_BW_FP16_ONE_PACKED};
|
||||||
|
volatile uint32_t g_case16_zero_row[4] __attribute__((aligned(16))) = {
|
||||||
|
WU_BW_FP32_ZERO, WU_BW_FP32_ZERO, WU_BW_FP32_ZERO, WU_BW_FP32_ZERO};
|
||||||
|
volatile uint32_t g_case16_out[WU_BW_OUT_WORDS] __attribute__((aligned(16)));
|
||||||
|
volatile uint32_t g_case16_scalar_seen[4] __attribute__((aligned(16)));
|
||||||
|
volatile uint32_t g_case16_p_bits[4] __attribute__((aligned(16)));
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline float wu_case16_bits_to_f32(uint32_t bits) {
|
||||||
|
union {
|
||||||
|
uint32_t u;
|
||||||
|
float f;
|
||||||
|
} v = {bits};
|
||||||
|
return v.f;
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline uint32_t wu_case16_f32_to_bits(float value) {
|
||||||
|
union {
|
||||||
|
float f;
|
||||||
|
uint32_t u;
|
||||||
|
} v = {value};
|
||||||
|
return v.u;
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline uint16_t wu_case16_f32_to_f16_positive(float value) {
|
||||||
|
const uint32_t bits = wu_case16_f32_to_bits(value);
|
||||||
|
const uint32_t exp = (bits >> 23) & 0xffu;
|
||||||
|
uint32_t mant = bits & 0x7fffffu;
|
||||||
|
|
||||||
|
if (exp == 0 || value <= 0.0f) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
if (exp >= 143u) {
|
||||||
|
return 0x7c00u;
|
||||||
|
}
|
||||||
|
if (exp <= 112u) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t half_exp = exp - 112u;
|
||||||
|
mant += 0x1000u;
|
||||||
|
if (mant & 0x800000u) {
|
||||||
|
mant = 0;
|
||||||
|
++half_exp;
|
||||||
|
}
|
||||||
|
if (half_exp >= 31u) {
|
||||||
|
return 0x7c00u;
|
||||||
|
}
|
||||||
|
return static_cast<uint16_t>((half_exp << 10) | (mant >> 13));
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline uint32_t wu_case16_pack_f16x2(float value) {
|
||||||
|
const uint32_t h = wu_case16_f32_to_f16_positive(value);
|
||||||
|
return h | (h << 16);
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline void wu_case16_softmax_tmem_row_to_p(uint32_t score_frag_base,
|
||||||
|
uint32_t p_byte_base) {
|
||||||
|
float row_max = wu_case16_bits_to_f32(wu_bw_scalar_tmem_ld(score_frag_base));
|
||||||
|
for (uint32_t i = 1; i < WU_CASE16_ROW_N; ++i) {
|
||||||
|
const float score =
|
||||||
|
wu_case16_bits_to_f32(wu_bw_scalar_tmem_ld(score_frag_base + i));
|
||||||
|
row_max = score > row_max ? score : row_max;
|
||||||
|
}
|
||||||
|
|
||||||
|
float denom = 0.0f;
|
||||||
|
for (uint32_t i = 0; i < WU_CASE16_ROW_N; ++i) {
|
||||||
|
const float score =
|
||||||
|
wu_case16_bits_to_f32(wu_bw_scalar_tmem_ld(score_frag_base + i));
|
||||||
|
denom += wu_fexp_s(score - row_max);
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint32_t p_frag_base = p_byte_base / WU_BW_TMEM_FRAGMENT_BYTES;
|
||||||
|
for (uint32_t frag = 0; frag < WU_BW_TMEM_FRAGMENTS; ++frag) {
|
||||||
|
const uint32_t row_idx = frag % WU_CASE16_ROW_N;
|
||||||
|
const float score =
|
||||||
|
wu_case16_bits_to_f32(wu_bw_scalar_tmem_ld(score_frag_base + row_idx));
|
||||||
|
const float p = wu_fexp_s(score - row_max) / denom;
|
||||||
|
if (frag == 0) {
|
||||||
|
g_case16_p_bits[wu_tid()] = wu_case16_f32_to_bits(p);
|
||||||
|
}
|
||||||
|
wu_bw_scalar_tmem_st(p_frag_base + frag, wu_case16_pack_f16x2(p));
|
||||||
|
}
|
||||||
|
asm volatile("fence rw, rw" ::: "memory");
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" void __attribute__((naked, noinline, used)) tensor_case16_worker() {
|
||||||
|
asm volatile(
|
||||||
|
"csrr x5, %[csr_wid]\n\t"
|
||||||
|
"addi x1, x5, -%[num_scalar_warps]\n\t"
|
||||||
|
"slli x1, x1, 11\n\t"
|
||||||
|
"addi x2, x1, %[c_offset]\n\t"
|
||||||
|
"la x3, g_case16_q_row\n\t"
|
||||||
|
"li x7, 0\n\t"
|
||||||
|
"1:\n\t"
|
||||||
|
"add x4, x1, x7\n\t"
|
||||||
|
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
|
||||||
|
"addi x7, x7, 16\n\t"
|
||||||
|
"li x4, %[tile_bytes]\n\t"
|
||||||
|
"blt x7, x4, 1b\n\t"
|
||||||
|
"la x3, g_case16_zero_row\n\t"
|
||||||
|
"li x7, 0\n\t"
|
||||||
|
"2:\n\t"
|
||||||
|
"add x4, x2, x7\n\t"
|
||||||
|
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
|
||||||
|
"addi x7, x7, 16\n\t"
|
||||||
|
"li x4, %[tile_bytes]\n\t"
|
||||||
|
"blt x7, x4, 2b\n\t"
|
||||||
|
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||||
|
"li x4, %[smem_base]\n\t"
|
||||||
|
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
|
||||||
|
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
|
||||||
|
"slli x6, x5, 2\n\t"
|
||||||
|
"la x7, g_seen\n\t"
|
||||||
|
"add x7, x7, x6\n\t"
|
||||||
|
"li x6, %[init_base]\n\t"
|
||||||
|
"or x6, x6, x5\n\t"
|
||||||
|
"sw x6, 0(x7)\n\t"
|
||||||
|
"3:\n\t"
|
||||||
|
"la x6, g_case_mem\n\t"
|
||||||
|
"lw x7, 0(x6)\n\t"
|
||||||
|
"li x4, %[p_ready]\n\t"
|
||||||
|
"bne x7, x4, 3b\n\t"
|
||||||
|
"la x3, g_case16_zero_row\n\t"
|
||||||
|
"li x7, 0\n\t"
|
||||||
|
"4:\n\t"
|
||||||
|
"add x4, x2, x7\n\t"
|
||||||
|
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
|
||||||
|
"addi x7, x7, 16\n\t"
|
||||||
|
"li x4, %[tile_bytes]\n\t"
|
||||||
|
"blt x7, x4, 4b\n\t"
|
||||||
|
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||||
|
"li x4, %[smem_base]\n\t"
|
||||||
|
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
|
||||||
|
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
|
||||||
|
"la x3, g_case16_out\n\t"
|
||||||
|
"li x7, 0\n\t"
|
||||||
|
"5:\n\t"
|
||||||
|
"add x4, x2, x7\n\t"
|
||||||
|
"add x6, x3, x7\n\t"
|
||||||
|
".insn r %[custom3], 6, 0, x0, x4, x6\n\t"
|
||||||
|
"addi x7, x7, 16\n\t"
|
||||||
|
"li x4, %[tile_bytes]\n\t"
|
||||||
|
"blt x7, x4, 5b\n\t"
|
||||||
|
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||||
|
"csrr x5, %[csr_wid]\n\t"
|
||||||
|
"slli x6, x5, 2\n\t"
|
||||||
|
"la x7, g_seen\n\t"
|
||||||
|
"add x7, x7, x6\n\t"
|
||||||
|
"li x6, %[done_base]\n\t"
|
||||||
|
"or x6, x6, x5\n\t"
|
||||||
|
"sw x6, 0(x7)\n\t"
|
||||||
|
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
|
||||||
|
"6: j 6b\n\t"
|
||||||
|
:
|
||||||
|
: [csr_wid] "i"(VX_CSR_WARP_ID), [custom0] "i"(RISCV_CUSTOM0),
|
||||||
|
[custom3] "i"(RISCV_CUSTOM3),
|
||||||
|
[num_scalar_warps] "i"(NUM_SCALAR_WARPS),
|
||||||
|
[c_offset] "i"(WU_BW_TMEM_C_BYTE_OFFSET),
|
||||||
|
[tile_bytes] "i"(WU_BW_TMEM_TILE_BYTES),
|
||||||
|
[smem_base] "i"(WU_BW_DEV_SMEM_START_ADDR),
|
||||||
|
[init_base] "i"(WU_CASE16_INIT_BASE),
|
||||||
|
[p_ready] "i"(WU_CASE16_P_READY),
|
||||||
|
[done_base] "i"(WU_CASE16_DONE_BASE)
|
||||||
|
: "memory");
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" int wu_main() {
|
||||||
|
if (vx_core_id() != 0 || vx_warp_id() != 0) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint32_t tid = wu_tid();
|
||||||
|
const uint32_t tensor_mask = 1u << NUM_SCALAR_WARPS;
|
||||||
|
if (tid == 0) {
|
||||||
|
wu_case_reset();
|
||||||
|
for (uint32_t i = 0; i < WU_BW_OUT_WORDS; ++i) {
|
||||||
|
g_case16_out[i] = 0;
|
||||||
|
}
|
||||||
|
for (uint32_t i = 0; i < 4; ++i) {
|
||||||
|
g_case16_scalar_seen[i] = 0;
|
||||||
|
g_case16_p_bits[i] = 0;
|
||||||
|
}
|
||||||
|
wu_bw_fill_smem_tile(
|
||||||
|
reinterpret_cast<volatile uint32_t *>(WU_BW_DEV_SMEM_START_ADDR),
|
||||||
|
WU_BW_FP16_ONE_PACKED);
|
||||||
|
vx_spawn_tensor(tensor_mask, tensor_case16_worker);
|
||||||
|
if (wu_wait_seen_mask(tensor_mask, WU_CASE16_INIT_BASE) != 0) {
|
||||||
|
g_case_mem[1] = 0x61u;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
asm volatile("fence rw, rw" ::: "memory");
|
||||||
|
|
||||||
|
const uint32_t c_frag =
|
||||||
|
wu_bw_tmem_c_byte_base(0) / WU_BW_TMEM_FRAGMENT_BYTES;
|
||||||
|
const uint32_t observed = wu_bw_scalar_tmem_ld(c_frag);
|
||||||
|
g_case16_scalar_seen[tid] = observed;
|
||||||
|
asm volatile("fence rw, rw" ::: "memory");
|
||||||
|
|
||||||
|
if (tid == 0 && g_case_mem[1] == 0 &&
|
||||||
|
g_case16_scalar_seen[0] != WU_BW_FP32_THIRTY_TWO) {
|
||||||
|
g_aux[0] = 0;
|
||||||
|
g_aux[1] = g_case16_scalar_seen[0];
|
||||||
|
g_case_mem[1] = 0x62u;
|
||||||
|
}
|
||||||
|
asm volatile("fence rw, rw" ::: "memory");
|
||||||
|
|
||||||
|
if (g_case_mem[1] == 0) {
|
||||||
|
vx_tmc(wu_bw_all_lanes_mask());
|
||||||
|
wu_case16_softmax_tmem_row_to_p(c_frag, wu_bw_tmem_a_byte_base(0));
|
||||||
|
vx_tmc_one();
|
||||||
|
}
|
||||||
|
asm volatile("fence rw, rw" ::: "memory");
|
||||||
|
|
||||||
|
if (tid == 0) {
|
||||||
|
if (g_case_mem[1] == 0) {
|
||||||
|
wu_bw_fill_smem_tile(
|
||||||
|
reinterpret_cast<volatile uint32_t *>(WU_BW_DEV_SMEM_START_ADDR),
|
||||||
|
WU_BW_FP16_TWO_PACKED);
|
||||||
|
}
|
||||||
|
asm volatile("fence rw, rw" ::: "memory");
|
||||||
|
g_case_mem[0] = WU_CASE16_P_READY;
|
||||||
|
if (g_case_mem[1] == 0 &&
|
||||||
|
wu_wait_seen_mask(tensor_mask, WU_CASE16_DONE_BASE) != 0) {
|
||||||
|
g_case_mem[1] = 0x63u;
|
||||||
|
}
|
||||||
|
if (g_case_mem[1] == 0) {
|
||||||
|
volatile uint32_t bad_actual = 0;
|
||||||
|
const uint32_t bad =
|
||||||
|
wu_bw_verify_constant(g_case16_out, WU_BW_OUT_WORDS,
|
||||||
|
WU_BW_FP32_TWO, &bad_actual);
|
||||||
|
if (bad != WU_BW_OUT_WORDS) {
|
||||||
|
g_aux[0] = bad;
|
||||||
|
g_aux[1] = bad_actual;
|
||||||
|
g_case_mem[1] = 0x64u;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (g_case_mem[1] != 0) {
|
||||||
|
wu_case_fail(g_case_mem[1]);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
wu_case_pass();
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
PROJECT = case17_flash_exp_softmax_probe
|
||||||
|
|
||||||
|
include ../case.mk
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user