feat: add flash pipeline kernel cases

This commit is contained in:
Zhongdi LUO
2026-07-02 07:24:59 +00:00
parent d6fbd447c3
commit f1aa1303d2
28 changed files with 1290 additions and 25 deletions

View File

@@ -19,7 +19,13 @@ CASES := \
case14_flash_pv_k64 \
case15_flash_softmax_pv_stage \
case16_flash_full_pipeline \
case17_flash_exp_softmax_probe
case17_flash_exp_softmax_probe \
case18_scalar_fexp \
case20_flash_bwd_fused \
case21_moe_gating \
case22_gemm_silu \
case23_softmax_only \
case24_flash_sw_pipeline
SMOKE_CASES := \
case00_boot_scalar \

View File

@@ -25,6 +25,12 @@ This directory contains small bare-metal kernels for incremental Wu architecture
- `case15_flash_softmax_pv_stage`: scalar reads TMEM C, writes softmax-like `P`, and tensor consumes it in PV.
- `case16_flash_full_pipeline`: compact `QK -> scalar softmax handoff -> PV` end-to-end FlashAttention-style pipeline.
- `case17_flash_exp_softmax_probe`: scalar non-uniform `e^x` softmax probe for generalized FlashAttention.
- `case18_scalar_fexp`: scalar `FEXP.S` numerical probe.
- `case20_flash_bwd_fused`: FlashAttention backward-style fused 5xMMA plus scalar softmax/dsoftmax handoff.
- `case21_moe_gating`: scalar `softmax -> Top-K -> scatter` MoE gating pipeline.
- `case22_gemm_silu`: tensor GEMM followed by scalar SiLU activation.
- `case23_softmax_only`: scalar-only stable softmax probe.
- `case24_flash_sw_pipeline`: four-iteration ping-pong FlashAttention-style software pipeline.
Each case has its own `README.md` describing the test objective, RTL surface, and expected pass marker.

View File

@@ -4,10 +4,12 @@ Validates a compact end-to-end FlashAttention-style pipeline on the current Wu
Blackwell path.
The tensor warp first computes `S = Q @ K` into TMEM C with `Q=1.0`, `K=1.0`,
and `O_init=0.0`, producing a constant fp32 score of `32.0`. Scalar warp 0 reads
the score through scalar TMEM load, records it, writes the uniform softmax result
`P=1/32` into TMEM A, refills SMEM with `V=2.0`, and releases the tensor warp.
The tensor warp reloads TMEM C with `O=0.0`, then computes `O = P @ V`.
and `O_init=0.0`, producing a constant fp32 score of `32.0`. The scalar warp
reads the score row through scalar TMEM loads, scans the row maximum and
normalization denominator with scalar-only `FEXP.S`, converts each probability
to packed fp16, writes `P` into TMEM A, refills SMEM with `V=2.0`, and releases
the tensor warp. The tensor warp reloads TMEM C with `O=0.0`, then computes
`O = P @ V`.
Every output word is expected to be fp32 `2.0`. This case covers the staged
`QK -> scalar softmax handoff -> PV` loop without using the legacy

View File

@@ -4,10 +4,10 @@
#define WU_CASE16_P_READY 0x9700u
#define WU_CASE16_DONE_BASE 0x9800u
#define WU_BW_FP16_ONE_OVER_32_PACKED 0x28002800u
#define WU_BW_FP32_ZERO 0x00000000u
#define WU_BW_FP32_TWO 0x40000000u
#define WU_BW_FP32_THIRTY_TWO 0x42000000u
#define WU_CASE16_ROW_N 32u
extern "C" {
volatile uint32_t g_case16_q_row[4] __attribute__((aligned(16))) = {
@@ -17,6 +17,85 @@ volatile uint32_t g_case16_zero_row[4] __attribute__((aligned(16))) = {
WU_BW_FP32_ZERO, WU_BW_FP32_ZERO, WU_BW_FP32_ZERO, WU_BW_FP32_ZERO};
volatile uint32_t g_case16_out[WU_BW_OUT_WORDS] __attribute__((aligned(16)));
volatile uint32_t g_case16_scalar_seen[4] __attribute__((aligned(16)));
volatile uint32_t g_case16_p_bits[4] __attribute__((aligned(16)));
}
static inline float wu_case16_bits_to_f32(uint32_t bits) {
union {
uint32_t u;
float f;
} v = {bits};
return v.f;
}
static inline uint32_t wu_case16_f32_to_bits(float value) {
union {
float f;
uint32_t u;
} v = {value};
return v.u;
}
static inline uint16_t wu_case16_f32_to_f16_positive(float value) {
const uint32_t bits = wu_case16_f32_to_bits(value);
const uint32_t exp = (bits >> 23) & 0xffu;
uint32_t mant = bits & 0x7fffffu;
if (exp == 0 || value <= 0.0f) {
return 0;
}
if (exp >= 143u) {
return 0x7c00u;
}
if (exp <= 112u) {
return 0;
}
uint32_t half_exp = exp - 112u;
mant += 0x1000u;
if (mant & 0x800000u) {
mant = 0;
++half_exp;
}
if (half_exp >= 31u) {
return 0x7c00u;
}
return static_cast<uint16_t>((half_exp << 10) | (mant >> 13));
}
static inline uint32_t wu_case16_pack_f16x2(float value) {
const uint32_t h = wu_case16_f32_to_f16_positive(value);
return h | (h << 16);
}
static inline void wu_case16_softmax_tmem_row_to_p(uint32_t score_frag_base,
uint32_t p_byte_base) {
float row_max = wu_case16_bits_to_f32(wu_bw_scalar_tmem_ld(score_frag_base));
for (uint32_t i = 1; i < WU_CASE16_ROW_N; ++i) {
const float score =
wu_case16_bits_to_f32(wu_bw_scalar_tmem_ld(score_frag_base + i));
row_max = score > row_max ? score : row_max;
}
float denom = 0.0f;
for (uint32_t i = 0; i < WU_CASE16_ROW_N; ++i) {
const float score =
wu_case16_bits_to_f32(wu_bw_scalar_tmem_ld(score_frag_base + i));
denom += wu_fexp_s(score - row_max);
}
const uint32_t p_frag_base = p_byte_base / WU_BW_TMEM_FRAGMENT_BYTES;
for (uint32_t frag = 0; frag < WU_BW_TMEM_FRAGMENTS; ++frag) {
const uint32_t row_idx = frag % WU_CASE16_ROW_N;
const float score =
wu_case16_bits_to_f32(wu_bw_scalar_tmem_ld(score_frag_base + row_idx));
const float p = wu_fexp_s(score - row_max) / denom;
if (frag == 0) {
g_case16_p_bits[wu_tid()] = wu_case16_f32_to_bits(p);
}
wu_bw_scalar_tmem_st(p_frag_base + frag, wu_case16_pack_f16x2(p));
}
asm volatile("fence rw, rw" ::: "memory");
}
extern "C" void __attribute__((naked, noinline, used)) tensor_case16_worker() {
@@ -114,6 +193,7 @@ extern "C" int wu_main() {
}
for (uint32_t i = 0; i < 4; ++i) {
g_case16_scalar_seen[i] = 0;
g_case16_p_bits[i] = 0;
}
wu_bw_fill_smem_tile(
reinterpret_cast<volatile uint32_t *>(WU_BW_DEV_SMEM_START_ADDR),
@@ -141,8 +221,7 @@ extern "C" int wu_main() {
if (g_case_mem[1] == 0) {
vx_tmc(wu_bw_all_lanes_mask());
wu_bw_fill_tmem_tile(wu_bw_tmem_a_byte_base(0),
WU_BW_FP16_ONE_OVER_32_PACKED);
wu_case16_softmax_tmem_row_to_p(c_frag, wu_bw_tmem_a_byte_base(0));
vx_tmc_one();
}
asm volatile("fence rw, rw" ::: "memory");

View File

@@ -1,9 +1,8 @@
# case17_flash_exp_softmax_probe
Validates that the current scalar Wu path can execute the `e^x` work needed by
non-uniform FlashAttention softmax. The current ISA/RTL does not expose a
dedicated exp or exp2 instruction, so this case uses scalar fp32 arithmetic to
approximate exp.
Validates that the scalar Wu path can execute the `e^x` work needed by
non-uniform FlashAttention softmax through the custom scalar-only `FEXP.S`
instruction.
The case evaluates a two-element row with scores `{0, ln(2)}`. A numerically
stable softmax computes `exp(score - row_max)`, so the exp inputs are
@@ -12,5 +11,5 @@ inputs are loaded from volatile memory so the compiler cannot fold the result
into constants.
This is intentionally separate from the tensor PV path. If this case fails, the
problem is in scalar fp32 arithmetic, exp approximation, or normalization rather
than TMEM handoff or BWGMMA.
problem is in scalar fp32 `FEXP.S` execution or normalization rather than TMEM
handoff or BWGMMA.

View File

@@ -26,15 +26,6 @@ static inline float wu_case17_absf(float value) {
return value < 0.0f ? -value : value;
}
static inline float wu_case17_exp_neg_ln2_to_0(float x) {
const float x2 = x * x;
const float x3 = x2 * x;
const float x4 = x2 * x2;
const float x5 = x4 * x;
return 1.0f + x + (0.5f * x2) + (0.1666666716f * x3) +
(0.0416666679f * x4) + (0.0083333338f * x5);
}
extern "C" int wu_main() {
if (vx_core_id() != 0 || vx_warp_id() != 0) {
return 0;
@@ -51,8 +42,8 @@ extern "C" int wu_main() {
const float score0 = wu_case17_bits_to_f32(g_case17_scores_bits[0]);
const float score1 = wu_case17_bits_to_f32(g_case17_scores_bits[1]);
const float row_max = score0 > score1 ? score0 : score1;
const float e0 = wu_case17_exp_neg_ln2_to_0(score0 - row_max);
const float e1 = wu_case17_exp_neg_ln2_to_0(score1 - row_max);
const float e0 = wu_fexp_s(score0 - row_max);
const float e1 = wu_fexp_s(score1 - row_max);
const float inv_sum = 1.0f / (e0 + e1);
const float p0 = e0 * inv_sum;
const float p1 = e1 * inv_sum;

View File

@@ -0,0 +1,3 @@
PROJECT = case18_scalar_fexp
include ../case.mk

View File

@@ -0,0 +1,5 @@
# case18_scalar_fexp
Verifies scalar-warp execution of the custom `FEXP.S` instruction. The test
checks representative fp32 inputs used by softmax-style code paths and confirms
the result is close to `expf`.

View File

@@ -0,0 +1,68 @@
#include "../common_wu_min.h"
extern "C" {
volatile uint32_t g_case18_out_bits[4] __attribute__((aligned(16)));
}
static inline uint32_t f32_to_bits(float value) {
union {
float f;
uint32_t u;
} v = {value};
return v.u;
}
static inline float absf_local(float value) {
return value < 0.0f ? -value : value;
}
extern "C" int wu_main() {
if (vx_core_id() != 0 || vx_warp_id() != 0) {
return 0;
}
const uint32_t tid = wu_tid();
if (tid == 0) {
wu_case_reset();
for (uint32_t i = 0; i < 4; ++i) {
g_case18_out_bits[i] = 0;
}
}
asm volatile("fence rw, rw" ::: "memory");
const float input0 = 0.0f;
const float input1 = 1.0f;
const float input2 = -0.6931471805599453f;
const float input3 = -10.0f;
const float out0 = wu_fexp_s(input0);
const float out1 = wu_fexp_s(input1);
const float out2 = wu_fexp_s(input2);
const float out3 = wu_fexp_s(input3);
if (tid == 0) {
g_case18_out_bits[0] = f32_to_bits(out0);
g_case18_out_bits[1] = f32_to_bits(out1);
g_case18_out_bits[2] = f32_to_bits(out2);
g_case18_out_bits[3] = f32_to_bits(out3);
}
asm volatile("fence rw, rw" ::: "memory");
if (tid == 0) {
const float err0 = absf_local(out0 - 1.0f);
const float err1 = absf_local(out1 - 2.7182817459f);
const float err2 = absf_local(out2 - 0.5f);
const float err3 = absf_local(out3 - 0.00004539993f);
if (err0 > 0.00001f || err1 > 0.0002f || err2 > 0.00001f ||
err3 > 0.000001f) {
g_aux[0] = g_case18_out_bits[0];
g_aux[1] = g_case18_out_bits[1];
g_aux[2] = g_case18_out_bits[2];
g_aux[3] = g_case18_out_bits[3];
wu_case_fail(0x18u);
return 1;
}
wu_case_pass();
}
return 0;
}

View File

@@ -0,0 +1,3 @@
PROJECT = case19_tensor_fexp_illegal
include ../case.mk

View File

@@ -0,0 +1,5 @@
# case19_tensor_fexp_illegal
Negative test for `FEXP.S`: tensor warps must not execute this scalar FPU
instruction. Running this case is expected to trip the existing tensor-FPU
illegal-instruction path in decode/dispatch rather than complete normally.

View File

@@ -0,0 +1,25 @@
#include "../common_wu_min.h"
extern "C" void __attribute__((naked, noinline, used)) tensor_worker() {
asm volatile(
"fmv.w.x f1, x0\n\t"
".insn r %[custom1], 2, 0x30, f2, f1, x0\n\t"
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
"1: j 1b\n\t"
:
: [custom0] "i"(RISCV_CUSTOM0),
[custom1] "i"(RISCV_CUSTOM1)
: "memory");
}
extern "C" int wu_main() {
if (!wu_is_leader()) {
return 0;
}
wu_case_reset();
vx_spawn_tensor(vx_tensor_warp_mask(), tensor_worker);
wu_case_fail(0x19u);
return 1;
}

View File

@@ -0,0 +1,3 @@
PROJECT = case20_flash_bwd_fused
include ../case.mk

View File

@@ -0,0 +1,19 @@
# case20_flash_bwd_fused
FlashAttention backward-style fused pipeline smoke test.
The tensor warp performs one score MMA, then waits for the scalar warp to run
softmax plus dsoftmax on the TMEM C row. The scalar warp writes the dS row back
to TMEM A using signed fp16 values. The tensor warp then performs four more
MMA steps, for five MMA operations total in this case.
This case verifies:
- tensor warp MMA sequencing around a scalar TMEM handoff;
- scalar-only `FEXP.S` use for stable softmax;
- dsoftmax shape `dS = P * (dP - sum(P * dP))`;
- signed scalar TMEM stores feeding later tensor MMA operations.
The input score row is uniform, so `P = 1/32`. The synthetic upstream gradient
uses two buckets, producing exact dS values `-1/32` for row entries 0..15 and
`+1/32` for row entries 16..31.

View File

@@ -0,0 +1,289 @@
#include "../common_wu_blackwell_fa.h"
#define WU_CASE20_SCORE_READY 0xa000u
#define WU_CASE20_DSOFTMAX_READY 0xa100u
#define WU_CASE20_DONE_BASE 0xa200u
#define WU_CASE20_ROW_N 32u
#define WU_CASE20_FP32_ZERO 0x00000000u
#define WU_CASE20_FP32_THIRTY_TWO 0x42000000u
extern "C" {
volatile uint32_t g_case20_q_row[4] __attribute__((aligned(16))) = {
WU_BW_FP16_ONE_PACKED, WU_BW_FP16_ONE_PACKED, WU_BW_FP16_ONE_PACKED,
WU_BW_FP16_ONE_PACKED};
volatile uint32_t g_case20_zero_row[4] __attribute__((aligned(16))) = {
WU_CASE20_FP32_ZERO, WU_CASE20_FP32_ZERO, WU_CASE20_FP32_ZERO,
WU_CASE20_FP32_ZERO};
volatile uint32_t g_case20_out[WU_BW_OUT_WORDS] __attribute__((aligned(16)));
volatile uint32_t g_case20_score_bits[4] __attribute__((aligned(16)));
volatile uint32_t g_case20_dsoftmax_bits[4] __attribute__((aligned(16)));
}
static inline float wu_case20_bits_to_f32(uint32_t bits) {
union {
uint32_t u;
float f;
} v = {bits};
return v.f;
}
static inline uint32_t wu_case20_f32_to_bits(float value) {
union {
float f;
uint32_t u;
} v = {value};
return v.u;
}
static inline float wu_case20_absf(float value) {
return value < 0.0f ? -value : value;
}
static inline uint16_t wu_case20_f32_to_f16(float value) {
const uint32_t bits = wu_case20_f32_to_bits(value);
const uint32_t sign = (bits >> 16) & 0x8000u;
const uint32_t exp = (bits >> 23) & 0xffu;
uint32_t mant = bits & 0x7fffffu;
if ((bits & 0x7fffffffu) == 0 || exp == 0) {
return static_cast<uint16_t>(sign);
}
if (exp >= 143u) {
return static_cast<uint16_t>(sign | 0x7c00u);
}
if (exp <= 112u) {
return static_cast<uint16_t>(sign);
}
uint32_t half_exp = exp - 112u;
mant += 0x1000u;
if (mant & 0x800000u) {
mant = 0;
++half_exp;
}
if (half_exp >= 31u) {
return static_cast<uint16_t>(sign | 0x7c00u);
}
return static_cast<uint16_t>(sign | (half_exp << 10) | (mant >> 13));
}
static inline uint32_t wu_case20_pack_f16x2(float value) {
const uint32_t h = wu_case20_f32_to_f16(value);
return h | (h << 16);
}
static inline float wu_case20_dp(uint32_t row_idx) {
return row_idx < 16u ? 0.0f : 2.0f;
}
static inline void wu_case20_dsoftmax_tmem_row(uint32_t score_frag_base,
uint32_t ds_byte_base) {
float row_max = wu_case20_bits_to_f32(wu_bw_scalar_tmem_ld(score_frag_base));
for (uint32_t i = 1; i < WU_CASE20_ROW_N; ++i) {
const float score =
wu_case20_bits_to_f32(wu_bw_scalar_tmem_ld(score_frag_base + i));
row_max = score > row_max ? score : row_max;
}
float denom = 0.0f;
for (uint32_t i = 0; i < WU_CASE20_ROW_N; ++i) {
const float score =
wu_case20_bits_to_f32(wu_bw_scalar_tmem_ld(score_frag_base + i));
denom += wu_fexp_s(score - row_max);
}
float dot = 0.0f;
for (uint32_t i = 0; i < WU_CASE20_ROW_N; ++i) {
const float score =
wu_case20_bits_to_f32(wu_bw_scalar_tmem_ld(score_frag_base + i));
const float p = wu_fexp_s(score - row_max) / denom;
dot += p * wu_case20_dp(i);
}
const uint32_t ds_frag_base = ds_byte_base / WU_BW_TMEM_FRAGMENT_BYTES;
for (uint32_t frag = 0; frag < WU_BW_TMEM_FRAGMENTS; ++frag) {
const uint32_t row_idx = frag % WU_CASE20_ROW_N;
const float score =
wu_case20_bits_to_f32(wu_bw_scalar_tmem_ld(score_frag_base + row_idx));
const float p = wu_fexp_s(score - row_max) / denom;
const float ds = p * (wu_case20_dp(row_idx) - dot);
if (wu_tid() == 0 && row_idx == 0) {
g_case20_dsoftmax_bits[0] = wu_case20_f32_to_bits(ds);
}
if (wu_tid() == 0 && row_idx == 16u) {
g_case20_dsoftmax_bits[1] = wu_case20_f32_to_bits(ds);
}
wu_bw_scalar_tmem_st(ds_frag_base + frag, wu_case20_pack_f16x2(ds));
}
asm volatile("fence rw, rw" ::: "memory");
}
extern "C" void __attribute__((naked, noinline, used)) tensor_case20_worker() {
asm volatile(
"csrr x5, %[csr_wid]\n\t"
"addi x1, x5, -%[num_scalar_warps]\n\t"
"slli x1, x1, 11\n\t"
"addi x2, x1, %[c_offset]\n\t"
"la x3, g_case20_q_row\n\t"
"li x7, 0\n\t"
"1:\n\t"
"add x4, x1, x7\n\t"
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
"addi x7, x7, 16\n\t"
"li x4, %[tile_bytes]\n\t"
"blt x7, x4, 1b\n\t"
"la x3, g_case20_zero_row\n\t"
"li x7, 0\n\t"
"2:\n\t"
"add x4, x2, x7\n\t"
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
"addi x7, x7, 16\n\t"
"li x4, %[tile_bytes]\n\t"
"blt x7, x4, 2b\n\t"
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
"li x4, %[smem_base]\n\t"
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
"csrr x5, %[csr_wid]\n\t"
"slli x6, x5, 2\n\t"
"la x7, g_seen\n\t"
"add x7, x7, x6\n\t"
"li x6, %[score_ready]\n\t"
"or x6, x6, x5\n\t"
"sw x6, 0(x7)\n\t"
"3:\n\t"
"la x6, g_case_mem\n\t"
"lw x7, 0(x6)\n\t"
"li x4, %[dsoftmax_ready]\n\t"
"bne x7, x4, 3b\n\t"
"la x3, g_case20_zero_row\n\t"
"li x7, 0\n\t"
"4:\n\t"
"add x4, x2, x7\n\t"
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
"addi x7, x7, 16\n\t"
"li x4, %[tile_bytes]\n\t"
"blt x7, x4, 4b\n\t"
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
"li x4, %[smem_base]\n\t"
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
"la x3, g_case20_out\n\t"
"li x7, 0\n\t"
"5:\n\t"
"add x4, x2, x7\n\t"
"add x6, x3, x7\n\t"
".insn r %[custom3], 6, 0, x0, x4, x6\n\t"
"addi x7, x7, 16\n\t"
"li x4, %[tile_bytes]\n\t"
"blt x7, x4, 5b\n\t"
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
"csrr x5, %[csr_wid]\n\t"
"slli x6, x5, 2\n\t"
"la x7, g_seen\n\t"
"add x7, x7, x6\n\t"
"li x6, %[done_base]\n\t"
"or x6, x6, x5\n\t"
"sw x6, 0(x7)\n\t"
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
"6: j 6b\n\t"
:
: [csr_wid] "i"(VX_CSR_WARP_ID), [custom0] "i"(RISCV_CUSTOM0),
[custom3] "i"(RISCV_CUSTOM3),
[num_scalar_warps] "i"(NUM_SCALAR_WARPS),
[c_offset] "i"(WU_BW_TMEM_C_BYTE_OFFSET),
[tile_bytes] "i"(WU_BW_TMEM_TILE_BYTES),
[smem_base] "i"(WU_BW_DEV_SMEM_START_ADDR),
[score_ready] "i"(WU_CASE20_SCORE_READY),
[dsoftmax_ready] "i"(WU_CASE20_DSOFTMAX_READY),
[done_base] "i"(WU_CASE20_DONE_BASE)
: "memory");
}
extern "C" int wu_main() {
if (vx_core_id() != 0 || vx_warp_id() != 0) {
return 0;
}
const uint32_t tid = wu_tid();
const uint32_t tensor_mask = 1u << NUM_SCALAR_WARPS;
if (tid == 0) {
wu_case_reset();
for (uint32_t i = 0; i < WU_BW_OUT_WORDS; ++i) {
g_case20_out[i] = 0xffffffffu;
}
for (uint32_t i = 0; i < 4; ++i) {
g_case20_score_bits[i] = 0;
g_case20_dsoftmax_bits[i] = 0;
}
wu_bw_fill_smem_tile(
reinterpret_cast<volatile uint32_t *>(WU_BW_DEV_SMEM_START_ADDR),
WU_BW_FP16_ONE_PACKED);
vx_spawn_tensor(tensor_mask, tensor_case20_worker);
if (wu_wait_seen_mask(tensor_mask, WU_CASE20_SCORE_READY) != 0) {
g_case_mem[1] = 0x41u;
}
}
asm volatile("fence rw, rw" ::: "memory");
const uint32_t c_frag =
wu_bw_tmem_c_byte_base(0) / WU_BW_TMEM_FRAGMENT_BYTES;
const uint32_t observed = wu_bw_scalar_tmem_ld(c_frag);
if (tid == 0) {
g_case20_score_bits[0] = observed;
if (g_case_mem[1] == 0 && observed != WU_CASE20_FP32_THIRTY_TWO) {
g_aux[0] = observed;
g_case_mem[1] = 0x42u;
}
}
asm volatile("fence rw, rw" ::: "memory");
if (g_case_mem[1] == 0) {
vx_tmc(wu_bw_all_lanes_mask());
wu_case20_dsoftmax_tmem_row(c_frag, wu_bw_tmem_a_byte_base(0));
vx_tmc_one();
}
asm volatile("fence rw, rw" ::: "memory");
if (tid == 0) {
if (g_case_mem[1] == 0) {
const float neg = wu_case20_bits_to_f32(g_case20_dsoftmax_bits[0]);
const float pos = wu_case20_bits_to_f32(g_case20_dsoftmax_bits[1]);
if (wu_case20_absf(neg + 0.03125f) > 0.0002f ||
wu_case20_absf(pos - 0.03125f) > 0.0002f) {
g_aux[0] = g_case20_dsoftmax_bits[0];
g_aux[1] = g_case20_dsoftmax_bits[1];
g_case_mem[1] = 0x43u;
}
}
asm volatile("fence rw, rw" ::: "memory");
g_case_mem[0] = WU_CASE20_DSOFTMAX_READY;
if (g_case_mem[1] == 0 &&
wu_wait_seen_mask(tensor_mask, WU_CASE20_DONE_BASE) != 0) {
g_case_mem[1] = 0x44u;
}
if (g_case_mem[1] == 0) {
volatile uint32_t bad_actual = 0;
const uint32_t bad =
wu_bw_verify_constant(g_case20_out, WU_BW_OUT_WORDS,
WU_CASE20_FP32_ZERO, &bad_actual);
if (bad != WU_BW_OUT_WORDS) {
g_aux[0] = bad;
g_aux[1] = bad_actual;
g_case_mem[1] = 0x45u;
}
}
if (g_case_mem[1] != 0) {
wu_case_fail(g_case_mem[1]);
return 1;
}
wu_case_pass();
}
return 0;
}

View File

@@ -0,0 +1,3 @@
PROJECT = case21_moe_gating
include ../case.mk

View File

@@ -0,0 +1,10 @@
# case21_moe_gating
MoE gating scalar pipeline test.
This case runs `softmax -> Top-K -> scatter` on scalar warp 0 using `FEXP.S`.
The logits are `log(1), log(2), log(4), log(8)`, so the expected probabilities
are `1/15, 2/15, 4/15, 8/15`. Top-2 should select experts 3 and 2, then scatter
the token id and weight into the selected expert slots.
No tensor warp is spawned in this case.

View File

@@ -0,0 +1,128 @@
#include "../common_wu_min.h"
#define WU_CASE21_TOKEN_ID 0x21u
#define WU_CASE21_EMPTY 0xffffffffu
extern "C" {
volatile uint32_t g_case21_logits_bits[4] __attribute__((aligned(16))) = {
0x00000000u, 0x3f317218u, 0x3fb17218u, 0x40051d8fu};
volatile uint32_t g_case21_prob_bits[4] __attribute__((aligned(16)));
volatile uint32_t g_case21_top_idx[2] __attribute__((aligned(16)));
volatile uint32_t g_case21_expert_token[4] __attribute__((aligned(16)));
volatile uint32_t g_case21_expert_weight_bits[4] __attribute__((aligned(16)));
}
static inline float wu_case21_bits_to_f32(uint32_t bits) {
union {
uint32_t u;
float f;
} v = {bits};
return v.f;
}
static inline uint32_t wu_case21_f32_to_bits(float value) {
union {
float f;
uint32_t u;
} v = {value};
return v.u;
}
static inline float wu_case21_absf(float value) {
return value < 0.0f ? -value : value;
}
extern "C" int wu_main() {
if (vx_core_id() != 0 || vx_warp_id() != 0) {
return 0;
}
const uint32_t tid = wu_tid();
if (tid == 0) {
wu_case_reset();
for (uint32_t i = 0; i < 4; ++i) {
g_case21_prob_bits[i] = 0;
g_case21_expert_token[i] = WU_CASE21_EMPTY;
g_case21_expert_weight_bits[i] = 0;
}
g_case21_top_idx[0] = WU_CASE21_EMPTY;
g_case21_top_idx[1] = WU_CASE21_EMPTY;
}
asm volatile("fence rw, rw" ::: "memory");
float logits[4];
float row_max = wu_case21_bits_to_f32(g_case21_logits_bits[0]);
for (uint32_t i = 0; i < 4; ++i) {
logits[i] = wu_case21_bits_to_f32(g_case21_logits_bits[i]);
row_max = logits[i] > row_max ? logits[i] : row_max;
}
float exp_values[4];
float denom = 0.0f;
for (uint32_t i = 0; i < 4; ++i) {
exp_values[i] = wu_fexp_s(logits[i] - row_max);
denom += exp_values[i];
}
float probs[4];
for (uint32_t i = 0; i < 4; ++i) {
probs[i] = exp_values[i] / denom;
}
uint32_t top0 = 0;
uint32_t top1 = 1;
if (probs[top1] > probs[top0]) {
const uint32_t tmp = top0;
top0 = top1;
top1 = tmp;
}
for (uint32_t i = 2; i < 4; ++i) {
if (probs[i] > probs[top0]) {
top1 = top0;
top0 = i;
} else if (probs[i] > probs[top1]) {
top1 = i;
}
}
if (tid == 0) {
for (uint32_t i = 0; i < 4; ++i) {
g_case21_prob_bits[i] = wu_case21_f32_to_bits(probs[i]);
}
g_case21_top_idx[0] = top0;
g_case21_top_idx[1] = top1;
g_case21_expert_token[top0] = WU_CASE21_TOKEN_ID;
g_case21_expert_token[top1] = WU_CASE21_TOKEN_ID;
g_case21_expert_weight_bits[top0] = wu_case21_f32_to_bits(probs[top0]);
g_case21_expert_weight_bits[top1] = wu_case21_f32_to_bits(probs[top1]);
}
asm volatile("fence rw, rw" ::: "memory");
if (tid == 0) {
const float expected[4] = {0.0666666701f, 0.1333333403f,
0.2666666806f, 0.5333333611f};
const float tolerance = 0.0015f;
for (uint32_t i = 0; i < 4; ++i) {
if (wu_case21_absf(probs[i] - expected[i]) > tolerance) {
g_aux[0] = i;
g_aux[1] = g_case21_prob_bits[i];
wu_case_fail(0x21u);
return 1;
}
}
if (top0 != 3u || top1 != 2u ||
g_case21_expert_token[3] != WU_CASE21_TOKEN_ID ||
g_case21_expert_token[2] != WU_CASE21_TOKEN_ID ||
g_case21_expert_token[0] != WU_CASE21_EMPTY ||
g_case21_expert_token[1] != WU_CASE21_EMPTY) {
g_aux[0] = top0;
g_aux[1] = top1;
g_aux[2] = g_case21_expert_token[3];
g_aux[3] = g_case21_expert_token[2];
wu_case_fail(0x22u);
return 1;
}
wu_case_pass();
}
return 0;
}

View File

@@ -0,0 +1,3 @@
PROJECT = case22_gemm_silu
include ../case.mk

View File

@@ -0,0 +1,10 @@
# case22_gemm_silu
GEMM plus SiLU fusion smoke test.
The tensor warp computes a compact GEMM with fp16 `A = 0.125` and fp16 `B = 1`,
producing fp32 `C = 4`. Scalar warp 0 reads TMEM C and applies
`SiLU(x) = x / (1 + exp(-x))` using scalar-only `FEXP.S`.
This case verifies the common `matmul -> nonlinear activation` fusion path
without allowing tensor warp FPU execution.

View File

@@ -0,0 +1,160 @@
#include "../common_wu_blackwell_fa.h"
#define WU_CASE22_INIT_BASE 0xa400u
#define WU_CASE22_DONE_BASE 0xa500u
#define WU_CASE22_FP16_ONE_EIGHTH_PACKED 0x30003000u
#define WU_CASE22_FP32_FOUR 0x40800000u
extern "C" {
volatile uint32_t g_case22_a_row[4] __attribute__((aligned(16))) = {
WU_CASE22_FP16_ONE_EIGHTH_PACKED, WU_CASE22_FP16_ONE_EIGHTH_PACKED,
WU_CASE22_FP16_ONE_EIGHTH_PACKED, WU_CASE22_FP16_ONE_EIGHTH_PACKED};
volatile uint32_t g_case22_zero_row[4] __attribute__((aligned(16))) = {
0x00000000u, 0x00000000u, 0x00000000u, 0x00000000u};
volatile uint32_t g_case22_out[WU_BW_OUT_WORDS] __attribute__((aligned(16)));
volatile uint32_t g_case22_silu_bits[4] __attribute__((aligned(16)));
}
static inline float wu_case22_bits_to_f32(uint32_t bits) {
union {
uint32_t u;
float f;
} v = {bits};
return v.f;
}
static inline uint32_t wu_case22_f32_to_bits(float value) {
union {
float f;
uint32_t u;
} v = {value};
return v.u;
}
static inline float wu_case22_absf(float value) {
return value < 0.0f ? -value : value;
}
extern "C" void __attribute__((naked, noinline, used)) tensor_case22_worker() {
asm volatile(
"csrr x5, %[csr_wid]\n\t"
"addi x1, x5, -%[num_scalar_warps]\n\t"
"slli x1, x1, 11\n\t"
"addi x2, x1, %[c_offset]\n\t"
"la x3, g_case22_a_row\n\t"
"li x7, 0\n\t"
"1:\n\t"
"add x4, x1, x7\n\t"
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
"addi x7, x7, 16\n\t"
"li x4, %[tile_bytes]\n\t"
"blt x7, x4, 1b\n\t"
"la x3, g_case22_zero_row\n\t"
"li x7, 0\n\t"
"2:\n\t"
"add x4, x2, x7\n\t"
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
"addi x7, x7, 16\n\t"
"li x4, %[tile_bytes]\n\t"
"blt x7, x4, 2b\n\t"
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
"li x4, %[smem_base]\n\t"
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
"la x3, g_case22_out\n\t"
"li x7, 0\n\t"
"3:\n\t"
"add x4, x2, x7\n\t"
"add x6, x3, x7\n\t"
".insn r %[custom3], 6, 0, x0, x4, x6\n\t"
"addi x7, x7, 16\n\t"
"li x4, %[tile_bytes]\n\t"
"blt x7, x4, 3b\n\t"
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
"csrr x5, %[csr_wid]\n\t"
"slli x6, x5, 2\n\t"
"la x7, g_seen\n\t"
"add x7, x7, x6\n\t"
"li x6, %[done_base]\n\t"
"or x6, x6, x5\n\t"
"sw x6, 0(x7)\n\t"
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
"4: j 4b\n\t"
:
: [csr_wid] "i"(VX_CSR_WARP_ID), [custom0] "i"(RISCV_CUSTOM0),
[custom3] "i"(RISCV_CUSTOM3),
[num_scalar_warps] "i"(NUM_SCALAR_WARPS),
[c_offset] "i"(WU_BW_TMEM_C_BYTE_OFFSET),
[tile_bytes] "i"(WU_BW_TMEM_TILE_BYTES),
[smem_base] "i"(WU_BW_DEV_SMEM_START_ADDR),
[done_base] "i"(WU_CASE22_DONE_BASE)
: "memory");
}
extern "C" int wu_main() {
if (vx_core_id() != 0 || vx_warp_id() != 0) {
return 0;
}
const uint32_t tid = wu_tid();
const uint32_t tensor_mask = 1u << NUM_SCALAR_WARPS;
if (tid == 0) {
wu_case_reset();
for (uint32_t i = 0; i < WU_BW_OUT_WORDS; ++i) {
g_case22_out[i] = 0;
}
for (uint32_t i = 0; i < 4; ++i) {
g_case22_silu_bits[i] = 0;
}
wu_bw_fill_smem_tile(
reinterpret_cast<volatile uint32_t *>(WU_BW_DEV_SMEM_START_ADDR),
WU_BW_FP16_ONE_PACKED);
vx_spawn_tensor(tensor_mask, tensor_case22_worker);
if (wu_wait_seen_mask(tensor_mask, WU_CASE22_DONE_BASE) != 0) {
g_case_mem[1] = 0x51u;
}
}
asm volatile("fence rw, rw" ::: "memory");
const uint32_t c_frag =
wu_bw_tmem_c_byte_base(0) / WU_BW_TMEM_FRAGMENT_BYTES;
const uint32_t observed_bits = wu_bw_scalar_tmem_ld(c_frag);
const float observed = wu_case22_bits_to_f32(observed_bits);
const float silu = observed / (1.0f + wu_fexp_s(-observed));
if (tid == 0) {
g_case22_silu_bits[0] = wu_case22_f32_to_bits(silu);
}
asm volatile("fence rw, rw" ::: "memory");
if (tid == 0) {
if (g_case_mem[1] == 0 && observed_bits != WU_CASE22_FP32_FOUR) {
g_aux[0] = observed_bits;
g_case_mem[1] = 0x52u;
}
if (g_case_mem[1] == 0) {
const float expected = 3.9280550480f;
if (wu_case22_absf(silu - expected) > 0.004f) {
g_aux[0] = g_case22_silu_bits[0];
g_case_mem[1] = 0x53u;
}
}
if (g_case_mem[1] == 0) {
volatile uint32_t bad_actual = 0;
const uint32_t bad =
wu_bw_verify_constant(g_case22_out, WU_BW_OUT_WORDS,
WU_CASE22_FP32_FOUR, &bad_actual);
if (bad != WU_BW_OUT_WORDS) {
g_aux[0] = bad;
g_aux[1] = bad_actual;
g_case_mem[1] = 0x54u;
}
}
if (g_case_mem[1] != 0) {
wu_case_fail(g_case_mem[1]);
return 1;
}
wu_case_pass();
}
return 0;
}

View File

@@ -0,0 +1,3 @@
PROJECT = case23_softmax_only
include ../case.mk

View File

@@ -0,0 +1,9 @@
# case23_softmax_only
Scalar softmax-only test.
This case runs a stable 4-way softmax on scalar warp 0 using `FEXP.S`. The logits
are `log(1), log(3), log(5), log(7)`, giving expected probabilities
`1/16, 3/16, 5/16, 7/16`.
No tensor warp is spawned in this case.

View File

@@ -0,0 +1,83 @@
#include "../common_wu_min.h"
extern "C" {
volatile uint32_t g_case23_scores_bits[4] __attribute__((aligned(16))) = {
0x00000000u, 0x3f8c9f54u, 0x3fcdf854u, 0x3ff91395u};
volatile uint32_t g_case23_out_bits[4] __attribute__((aligned(16)));
}
static inline float wu_case23_bits_to_f32(uint32_t bits) {
union {
uint32_t u;
float f;
} v = {bits};
return v.f;
}
static inline uint32_t wu_case23_f32_to_bits(float value) {
union {
float f;
uint32_t u;
} v = {value};
return v.u;
}
static inline float wu_case23_absf(float value) {
return value < 0.0f ? -value : value;
}
extern "C" int wu_main() {
if (vx_core_id() != 0 || vx_warp_id() != 0) {
return 0;
}
const uint32_t tid = wu_tid();
if (tid == 0) {
wu_case_reset();
for (uint32_t i = 0; i < 4; ++i) {
g_case23_out_bits[i] = 0;
}
}
asm volatile("fence rw, rw" ::: "memory");
float scores[4];
float row_max = wu_case23_bits_to_f32(g_case23_scores_bits[0]);
for (uint32_t i = 0; i < 4; ++i) {
scores[i] = wu_case23_bits_to_f32(g_case23_scores_bits[i]);
row_max = scores[i] > row_max ? scores[i] : row_max;
}
float exp_values[4];
float denom = 0.0f;
for (uint32_t i = 0; i < 4; ++i) {
exp_values[i] = wu_fexp_s(scores[i] - row_max);
denom += exp_values[i];
}
float probs[4];
for (uint32_t i = 0; i < 4; ++i) {
probs[i] = exp_values[i] / denom;
}
if (tid == 0) {
for (uint32_t i = 0; i < 4; ++i) {
g_case23_out_bits[i] = wu_case23_f32_to_bits(probs[i]);
}
}
asm volatile("fence rw, rw" ::: "memory");
if (tid == 0) {
const float expected[4] = {0.0625f, 0.1875f, 0.3125f, 0.4375f};
const float tolerance = 0.0015f;
for (uint32_t i = 0; i < 4; ++i) {
if (wu_case23_absf(probs[i] - expected[i]) > tolerance) {
g_aux[0] = i;
g_aux[1] = g_case23_out_bits[i];
wu_case_fail(0x23u);
return 1;
}
}
wu_case_pass();
}
return 0;
}

View File

@@ -0,0 +1,3 @@
PROJECT = case24_flash_sw_pipeline
include ../case.mk

View File

@@ -0,0 +1,24 @@
# case24_flash_sw_pipeline
Software-pipelined FlashAttention-style multi-iteration case.
This case keeps `case16_flash_full_pipeline` as the single-tile
producer/consumer baseline and adds a four-iteration ping-pong pipeline:
```text
tensor warp 2 / slot 0: QK(0) -> wait P(0) -> PV(0) -> QK(2) -> wait P(2) -> PV(2)
tensor warp 3 / slot 1: QK(1) -> wait P(1) -> PV(1) -> QK(3) -> wait P(3) -> PV(3)
scalar warp 0: softmax(0) -> softmax(1) -> softmax(2) -> softmax(3)
```
Each tensor warp owns one TMEM slot. The tensor warp writes `S = Q @ K` into
TMEM C for its slot, marks `score_ready[iter]`, waits for scalar-generated
`P`, then computes `O = P @ V`. Scalar warp 0 waits on each score in order,
uses scalar-only `FEXP.S` for stable softmax, writes packed fp16 probabilities
back to the same slot's TMEM A, and marks `p_ready[iter]`.
The first version intentionally uses constant `Q`, `K`, and `V` so the expected
numeric result is simple: every score is fp32 `32.0`, every softmax row is
uniform `1/32`, and every output word is fp32 `1.0`. The test objective is the
multi-iteration overlap structure and per-slot handoff, not non-uniform
FlashAttention numerics.

View File

@@ -0,0 +1,318 @@
#define WU_CASE_WAIT_SPIN 16384u
#include "../common_wu_blackwell_fa.h"
#define WU_CASE24_ITER_N 4u
#define WU_CASE24_ROW_N 32u
#define WU_CASE24_SCORE_READY_BASE 0xb000u
#define WU_CASE24_P_READY_BASE 0xb100u
#define WU_CASE24_DONE_BASE 0xb200u
#define WU_CASE24_FP32_ZERO 0x00000000u
#define WU_CASE24_FP32_THIRTY_TWO 0x42000000u
#define WU_CASE24_OUT_WORDS (WU_CASE24_ITER_N * WU_BW_OUT_WORDS)
extern "C" {
volatile uint32_t g_case24_q_row[4] __attribute__((aligned(16))) = {
WU_BW_FP16_ONE_PACKED, WU_BW_FP16_ONE_PACKED, WU_BW_FP16_ONE_PACKED,
WU_BW_FP16_ONE_PACKED};
volatile uint32_t g_case24_zero_row[4] __attribute__((aligned(16))) = {
WU_CASE24_FP32_ZERO, WU_CASE24_FP32_ZERO, WU_CASE24_FP32_ZERO,
WU_CASE24_FP32_ZERO};
volatile uint32_t g_case24_score_ready[WU_CASE24_ITER_N]
__attribute__((aligned(16)));
volatile uint32_t g_case24_p_ready[WU_CASE24_ITER_N]
__attribute__((aligned(16)));
volatile uint32_t g_case24_done[WU_CASE24_ITER_N] __attribute__((aligned(16)));
volatile uint32_t g_case24_score_bits[WU_CASE24_ITER_N * NUM_THREADS]
__attribute__((aligned(16)));
volatile uint32_t g_case24_p_bits[WU_CASE24_ITER_N * NUM_THREADS]
__attribute__((aligned(16)));
volatile uint32_t g_case24_overlap_hint __attribute__((aligned(16)));
volatile uint32_t g_case24_out[WU_CASE24_OUT_WORDS]
__attribute__((aligned(16)));
}
static inline float wu_case24_bits_to_f32(uint32_t bits) {
union {
uint32_t u;
float f;
} v = {bits};
return v.f;
}
static inline uint32_t wu_case24_f32_to_bits(float value) {
union {
float f;
uint32_t u;
} v = {value};
return v.u;
}
static inline uint16_t wu_case24_f32_to_f16_positive(float value) {
const uint32_t bits = wu_case24_f32_to_bits(value);
const uint32_t exp = (bits >> 23) & 0xffu;
uint32_t mant = bits & 0x7fffffu;
if (exp == 0 || value <= 0.0f) {
return 0;
}
if (exp >= 143u) {
return 0x7c00u;
}
if (exp <= 112u) {
return 0;
}
uint32_t half_exp = exp - 112u;
mant += 0x1000u;
if (mant & 0x800000u) {
mant = 0;
++half_exp;
}
if (half_exp >= 31u) {
return 0x7c00u;
}
return static_cast<uint16_t>((half_exp << 10) | (mant >> 13));
}
static inline uint32_t wu_case24_pack_f16x2(float value) {
const uint32_t h = wu_case24_f32_to_f16_positive(value);
return h | (h << 16);
}
static inline int wu_case24_wait_status(volatile uint32_t *status,
uint32_t iter, uint32_t base) {
const uint32_t expected = base | iter;
for (uint32_t spin = 0; spin < WU_CASE_WAIT_SPIN; ++spin) {
if (status[iter] == expected) {
return 0;
}
}
return 1;
}
static inline void wu_case24_softmax_tmem_row_to_p(uint32_t iter,
uint32_t score_frag_base,
uint32_t p_byte_base) {
float row_max = wu_case24_bits_to_f32(wu_bw_scalar_tmem_ld(score_frag_base));
for (uint32_t i = 1; i < WU_CASE24_ROW_N; ++i) {
const float score =
wu_case24_bits_to_f32(wu_bw_scalar_tmem_ld(score_frag_base + i));
row_max = score > row_max ? score : row_max;
}
float denom = 0.0f;
for (uint32_t i = 0; i < WU_CASE24_ROW_N; ++i) {
const float score =
wu_case24_bits_to_f32(wu_bw_scalar_tmem_ld(score_frag_base + i));
denom += wu_fexp_s(score - row_max);
}
const uint32_t p_frag_base = p_byte_base / WU_BW_TMEM_FRAGMENT_BYTES;
for (uint32_t frag = 0; frag < WU_BW_TMEM_FRAGMENTS; ++frag) {
const uint32_t row_idx = frag % WU_CASE24_ROW_N;
const float score =
wu_case24_bits_to_f32(wu_bw_scalar_tmem_ld(score_frag_base + row_idx));
const float p = wu_fexp_s(score - row_max) / denom;
if (frag == 0) {
g_case24_p_bits[iter * NUM_THREADS + wu_tid()] =
wu_case24_f32_to_bits(p);
}
wu_bw_scalar_tmem_st(p_frag_base + frag, wu_case24_pack_f16x2(p));
}
asm volatile("fence rw, rw" ::: "memory");
}
extern "C" void __attribute__((naked, noinline, used)) tensor_case24_worker() {
asm volatile(
"csrr x5, %[csr_wid]\n\t"
"addi x8, x5, -%[num_scalar_warps]\n\t"
"slli x1, x8, 11\n\t"
"addi x2, x1, %[c_offset]\n\t"
"mv x9, x8\n\t"
"1:\n\t"
"li x10, %[iter_n]\n\t"
"bge x9, x10, 9f\n\t"
"la x3, g_case24_q_row\n\t"
"li x7, 0\n\t"
"2:\n\t"
"add x4, x1, x7\n\t"
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
"addi x7, x7, 16\n\t"
"li x4, %[tile_bytes]\n\t"
"blt x7, x4, 2b\n\t"
"la x3, g_case24_zero_row\n\t"
"li x7, 0\n\t"
"3:\n\t"
"add x4, x2, x7\n\t"
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
"addi x7, x7, 16\n\t"
"li x4, %[tile_bytes]\n\t"
"blt x7, x4, 3b\n\t"
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
"li x4, %[smem_base]\n\t"
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
"slli x6, x9, 2\n\t"
"la x7, g_case24_score_ready\n\t"
"add x7, x7, x6\n\t"
"li x6, %[score_ready_base]\n\t"
"or x6, x6, x9\n\t"
"sw x6, 0(x7)\n\t"
"slli x6, x9, 2\n\t"
"la x7, g_case24_p_ready\n\t"
"add x7, x7, x6\n\t"
"li x4, %[p_ready_base]\n\t"
"or x4, x4, x9\n\t"
"4:\n\t"
"lw x6, 0(x7)\n\t"
"bne x6, x4, 4b\n\t"
"la x3, g_case24_zero_row\n\t"
"li x7, 0\n\t"
"5:\n\t"
"add x4, x2, x7\n\t"
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
"addi x7, x7, 16\n\t"
"li x4, %[tile_bytes]\n\t"
"blt x7, x4, 5b\n\t"
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
"li x4, %[smem_base]\n\t"
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
"slli x6, x9, 10\n\t"
"la x3, g_case24_out\n\t"
"add x3, x3, x6\n\t"
"li x7, 0\n\t"
"6:\n\t"
"add x4, x2, x7\n\t"
"add x6, x3, x7\n\t"
".insn r %[custom3], 6, 0, x0, x4, x6\n\t"
"addi x7, x7, 16\n\t"
"li x4, %[tile_bytes]\n\t"
"blt x7, x4, 6b\n\t"
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
"slli x6, x9, 2\n\t"
"la x7, g_case24_done\n\t"
"add x7, x7, x6\n\t"
"li x6, %[done_base]\n\t"
"or x6, x6, x9\n\t"
"sw x6, 0(x7)\n\t"
"addi x9, x9, 2\n\t"
"j 1b\n\t"
"9:\n\t"
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
"10: j 10b\n\t"
:
: [csr_wid] "i"(VX_CSR_WARP_ID), [custom0] "i"(RISCV_CUSTOM0),
[custom3] "i"(RISCV_CUSTOM3),
[num_scalar_warps] "i"(NUM_SCALAR_WARPS),
[c_offset] "i"(WU_BW_TMEM_C_BYTE_OFFSET),
[tile_bytes] "i"(WU_BW_TMEM_TILE_BYTES),
[smem_base] "i"(WU_BW_DEV_SMEM_START_ADDR),
[iter_n] "i"(WU_CASE24_ITER_N),
[score_ready_base] "i"(WU_CASE24_SCORE_READY_BASE),
[p_ready_base] "i"(WU_CASE24_P_READY_BASE),
[done_base] "i"(WU_CASE24_DONE_BASE)
: "memory");
}
extern "C" int wu_main() {
if (vx_core_id() != 0 || vx_warp_id() != 0) {
return 0;
}
const uint32_t tid = wu_tid();
const uint32_t tensor_mask = vx_tensor_warp_mask();
if (tid == 0) {
wu_case_reset();
for (uint32_t i = 0; i < WU_CASE24_ITER_N; ++i) {
g_case24_score_ready[i] = 0;
g_case24_p_ready[i] = 0;
g_case24_done[i] = 0;
}
for (uint32_t i = 0; i < WU_CASE24_ITER_N * NUM_THREADS; ++i) {
g_case24_score_bits[i] = 0;
g_case24_p_bits[i] = 0;
}
for (uint32_t i = 0; i < WU_CASE24_OUT_WORDS; ++i) {
g_case24_out[i] = 0;
}
g_case24_overlap_hint = 0;
wu_bw_fill_smem_tile(
reinterpret_cast<volatile uint32_t *>(WU_BW_DEV_SMEM_START_ADDR),
WU_BW_FP16_ONE_PACKED);
vx_spawn_tensor(tensor_mask, tensor_case24_worker);
}
asm volatile("fence rw, rw" ::: "memory");
for (uint32_t iter = 0; iter < WU_CASE24_ITER_N; ++iter) {
if (wu_case24_wait_status(g_case24_score_ready, iter,
WU_CASE24_SCORE_READY_BASE) != 0) {
if (tid == 0) {
g_case_mem[1] = 0x81u;
g_aux[0] = iter;
}
}
asm volatile("fence rw, rw" ::: "memory");
const uint32_t slot = iter & 1u;
const uint32_t c_frag =
wu_bw_tmem_c_byte_base(slot) / WU_BW_TMEM_FRAGMENT_BYTES;
const uint32_t observed = wu_bw_scalar_tmem_ld(c_frag);
g_case24_score_bits[iter * NUM_THREADS + tid] = observed;
asm volatile("fence rw, rw" ::: "memory");
if (tid == 0 && g_case_mem[1] == 0 &&
observed != WU_CASE24_FP32_THIRTY_TWO) {
g_aux[0] = iter;
g_aux[1] = observed;
g_case_mem[1] = 0x82u;
}
asm volatile("fence rw, rw" ::: "memory");
if (g_case_mem[1] == 0) {
vx_tmc(wu_bw_all_lanes_mask());
wu_case24_softmax_tmem_row_to_p(iter, c_frag,
wu_bw_tmem_a_byte_base(slot));
vx_tmc_one();
}
asm volatile("fence rw, rw" ::: "memory");
if (tid == 0 && g_case_mem[1] == 0) {
if (iter == 0 &&
g_case24_score_ready[1] == (WU_CASE24_SCORE_READY_BASE | 1u)) {
g_case24_overlap_hint = 1;
}
g_case24_p_ready[iter] = WU_CASE24_P_READY_BASE | iter;
}
asm volatile("fence rw, rw" ::: "memory");
}
if (tid == 0) {
for (uint32_t iter = 0; iter < WU_CASE24_ITER_N; ++iter) {
if (g_case_mem[1] == 0 &&
wu_case24_wait_status(g_case24_done, iter, WU_CASE24_DONE_BASE) !=
0) {
g_aux[0] = iter;
g_case_mem[1] = 0x83u;
}
}
if (g_case_mem[1] == 0) {
volatile uint32_t bad_actual = 0;
const uint32_t bad =
wu_bw_verify_constant(g_case24_out, WU_CASE24_OUT_WORDS,
WU_BW_FP32_ONE, &bad_actual);
if (bad != WU_CASE24_OUT_WORDS) {
g_aux[0] = bad;
g_aux[1] = bad_actual;
g_case_mem[1] = 0x84u;
}
}
if (g_case_mem[1] != 0) {
wu_case_fail(g_case_mem[1]);
return 1;
}
wu_case_pass();
}
return 0;
}

View File

@@ -134,6 +134,14 @@ static inline int wu_is_leader() {
return vx_core_id() == 0 && vx_warp_id() == 0 && vx_thread_id() == 0;
}
static inline float wu_fexp_s(float value) {
float result;
asm volatile(".insn r %[custom1], 2, 0x30, %[rd], %[rs1], x0"
: [rd] "=f"(result)
: [rs1] "f"(value), [custom1] "i"(RISCV_CUSTOM1));
return result;
}
static inline void wu_report_tohost(uint32_t exit_code) {
asm volatile("fence rw, rw" ::: "memory");
tohost = (static_cast<uint64_t>(exit_code) << 1) | 1u;