diff --git a/kernels/wu_arch_cases/Makefile b/kernels/wu_arch_cases/Makefile index 946c33b4..82eadab6 100644 --- a/kernels/wu_arch_cases/Makefile +++ b/kernels/wu_arch_cases/Makefile @@ -19,7 +19,13 @@ CASES := \ case14_flash_pv_k64 \ case15_flash_softmax_pv_stage \ case16_flash_full_pipeline \ - case17_flash_exp_softmax_probe + case17_flash_exp_softmax_probe \ + case18_scalar_fexp \ + case20_flash_bwd_fused \ + case21_moe_gating \ + case22_gemm_silu \ + case23_softmax_only \ + case24_flash_sw_pipeline SMOKE_CASES := \ case00_boot_scalar \ diff --git a/kernels/wu_arch_cases/README.md b/kernels/wu_arch_cases/README.md index 3b182bd7..839870fe 100644 --- a/kernels/wu_arch_cases/README.md +++ b/kernels/wu_arch_cases/README.md @@ -25,6 +25,12 @@ This directory contains small bare-metal kernels for incremental Wu architecture - `case15_flash_softmax_pv_stage`: scalar reads TMEM C, writes softmax-like `P`, and tensor consumes it in PV. - `case16_flash_full_pipeline`: compact `QK -> scalar softmax handoff -> PV` end-to-end FlashAttention-style pipeline. - `case17_flash_exp_softmax_probe`: scalar non-uniform `e^x` softmax probe for generalized FlashAttention. +- `case18_scalar_fexp`: scalar `FEXP.S` numerical probe. +- `case20_flash_bwd_fused`: FlashAttention backward-style fused 5xMMA plus scalar softmax/dsoftmax handoff. +- `case21_moe_gating`: scalar `softmax -> Top-K -> scatter` MoE gating pipeline. +- `case22_gemm_silu`: tensor GEMM followed by scalar SiLU activation. +- `case23_softmax_only`: scalar-only stable softmax probe. +- `case24_flash_sw_pipeline`: four-iteration ping-pong FlashAttention-style software pipeline. Each case has its own `README.md` describing the test objective, RTL surface, and expected pass marker. diff --git a/kernels/wu_arch_cases/case16_flash_full_pipeline/README.md b/kernels/wu_arch_cases/case16_flash_full_pipeline/README.md index 58577a7e..cd7c17af 100644 --- a/kernels/wu_arch_cases/case16_flash_full_pipeline/README.md +++ b/kernels/wu_arch_cases/case16_flash_full_pipeline/README.md @@ -4,10 +4,12 @@ Validates a compact end-to-end FlashAttention-style pipeline on the current Wu Blackwell path. The tensor warp first computes `S = Q @ K` into TMEM C with `Q=1.0`, `K=1.0`, -and `O_init=0.0`, producing a constant fp32 score of `32.0`. Scalar warp 0 reads -the score through scalar TMEM load, records it, writes the uniform softmax result -`P=1/32` into TMEM A, refills SMEM with `V=2.0`, and releases the tensor warp. -The tensor warp reloads TMEM C with `O=0.0`, then computes `O = P @ V`. +and `O_init=0.0`, producing a constant fp32 score of `32.0`. The scalar warp +reads the score row through scalar TMEM loads, scans the row maximum and +normalization denominator with scalar-only `FEXP.S`, converts each probability +to packed fp16, writes `P` into TMEM A, refills SMEM with `V=2.0`, and releases +the tensor warp. The tensor warp reloads TMEM C with `O=0.0`, then computes +`O = P @ V`. Every output word is expected to be fp32 `2.0`. This case covers the staged `QK -> scalar softmax handoff -> PV` loop without using the legacy diff --git a/kernels/wu_arch_cases/case16_flash_full_pipeline/kernel.cpp b/kernels/wu_arch_cases/case16_flash_full_pipeline/kernel.cpp index d81deb5a..8c50ac6a 100644 --- a/kernels/wu_arch_cases/case16_flash_full_pipeline/kernel.cpp +++ b/kernels/wu_arch_cases/case16_flash_full_pipeline/kernel.cpp @@ -4,10 +4,10 @@ #define WU_CASE16_P_READY 0x9700u #define WU_CASE16_DONE_BASE 0x9800u -#define WU_BW_FP16_ONE_OVER_32_PACKED 0x28002800u #define WU_BW_FP32_ZERO 0x00000000u #define WU_BW_FP32_TWO 0x40000000u #define WU_BW_FP32_THIRTY_TWO 0x42000000u +#define WU_CASE16_ROW_N 32u extern "C" { volatile uint32_t g_case16_q_row[4] __attribute__((aligned(16))) = { @@ -17,6 +17,85 @@ volatile uint32_t g_case16_zero_row[4] __attribute__((aligned(16))) = { WU_BW_FP32_ZERO, WU_BW_FP32_ZERO, WU_BW_FP32_ZERO, WU_BW_FP32_ZERO}; volatile uint32_t g_case16_out[WU_BW_OUT_WORDS] __attribute__((aligned(16))); volatile uint32_t g_case16_scalar_seen[4] __attribute__((aligned(16))); +volatile uint32_t g_case16_p_bits[4] __attribute__((aligned(16))); +} + +static inline float wu_case16_bits_to_f32(uint32_t bits) { + union { + uint32_t u; + float f; + } v = {bits}; + return v.f; +} + +static inline uint32_t wu_case16_f32_to_bits(float value) { + union { + float f; + uint32_t u; + } v = {value}; + return v.u; +} + +static inline uint16_t wu_case16_f32_to_f16_positive(float value) { + const uint32_t bits = wu_case16_f32_to_bits(value); + const uint32_t exp = (bits >> 23) & 0xffu; + uint32_t mant = bits & 0x7fffffu; + + if (exp == 0 || value <= 0.0f) { + return 0; + } + if (exp >= 143u) { + return 0x7c00u; + } + if (exp <= 112u) { + return 0; + } + + uint32_t half_exp = exp - 112u; + mant += 0x1000u; + if (mant & 0x800000u) { + mant = 0; + ++half_exp; + } + if (half_exp >= 31u) { + return 0x7c00u; + } + return static_cast((half_exp << 10) | (mant >> 13)); +} + +static inline uint32_t wu_case16_pack_f16x2(float value) { + const uint32_t h = wu_case16_f32_to_f16_positive(value); + return h | (h << 16); +} + +static inline void wu_case16_softmax_tmem_row_to_p(uint32_t score_frag_base, + uint32_t p_byte_base) { + float row_max = wu_case16_bits_to_f32(wu_bw_scalar_tmem_ld(score_frag_base)); + for (uint32_t i = 1; i < WU_CASE16_ROW_N; ++i) { + const float score = + wu_case16_bits_to_f32(wu_bw_scalar_tmem_ld(score_frag_base + i)); + row_max = score > row_max ? score : row_max; + } + + float denom = 0.0f; + for (uint32_t i = 0; i < WU_CASE16_ROW_N; ++i) { + const float score = + wu_case16_bits_to_f32(wu_bw_scalar_tmem_ld(score_frag_base + i)); + denom += wu_fexp_s(score - row_max); + } + + const uint32_t p_frag_base = p_byte_base / WU_BW_TMEM_FRAGMENT_BYTES; + for (uint32_t frag = 0; frag < WU_BW_TMEM_FRAGMENTS; ++frag) { + const uint32_t row_idx = frag % WU_CASE16_ROW_N; + const float score = + wu_case16_bits_to_f32(wu_bw_scalar_tmem_ld(score_frag_base + row_idx)); + const float p = wu_fexp_s(score - row_max) / denom; + if (frag == 0) { + g_case16_p_bits[wu_tid()] = wu_case16_f32_to_bits(p); + } + wu_bw_scalar_tmem_st(p_frag_base + frag, wu_case16_pack_f16x2(p)); + } + asm volatile("fence rw, rw" ::: "memory"); } extern "C" void __attribute__((naked, noinline, used)) tensor_case16_worker() { @@ -114,6 +193,7 @@ extern "C" int wu_main() { } for (uint32_t i = 0; i < 4; ++i) { g_case16_scalar_seen[i] = 0; + g_case16_p_bits[i] = 0; } wu_bw_fill_smem_tile( reinterpret_cast(WU_BW_DEV_SMEM_START_ADDR), @@ -141,8 +221,7 @@ extern "C" int wu_main() { if (g_case_mem[1] == 0) { vx_tmc(wu_bw_all_lanes_mask()); - wu_bw_fill_tmem_tile(wu_bw_tmem_a_byte_base(0), - WU_BW_FP16_ONE_OVER_32_PACKED); + wu_case16_softmax_tmem_row_to_p(c_frag, wu_bw_tmem_a_byte_base(0)); vx_tmc_one(); } asm volatile("fence rw, rw" ::: "memory"); diff --git a/kernels/wu_arch_cases/case17_flash_exp_softmax_probe/README.md b/kernels/wu_arch_cases/case17_flash_exp_softmax_probe/README.md index 719c91d4..1df8af22 100644 --- a/kernels/wu_arch_cases/case17_flash_exp_softmax_probe/README.md +++ b/kernels/wu_arch_cases/case17_flash_exp_softmax_probe/README.md @@ -1,9 +1,8 @@ # case17_flash_exp_softmax_probe -Validates that the current scalar Wu path can execute the `e^x` work needed by -non-uniform FlashAttention softmax. The current ISA/RTL does not expose a -dedicated exp or exp2 instruction, so this case uses scalar fp32 arithmetic to -approximate exp. +Validates that the scalar Wu path can execute the `e^x` work needed by +non-uniform FlashAttention softmax through the custom scalar-only `FEXP.S` +instruction. The case evaluates a two-element row with scores `{0, ln(2)}`. A numerically stable softmax computes `exp(score - row_max)`, so the exp inputs are @@ -12,5 +11,5 @@ inputs are loaded from volatile memory so the compiler cannot fold the result into constants. This is intentionally separate from the tensor PV path. If this case fails, the -problem is in scalar fp32 arithmetic, exp approximation, or normalization rather -than TMEM handoff or BWGMMA. +problem is in scalar fp32 `FEXP.S` execution or normalization rather than TMEM +handoff or BWGMMA. diff --git a/kernels/wu_arch_cases/case17_flash_exp_softmax_probe/kernel.cpp b/kernels/wu_arch_cases/case17_flash_exp_softmax_probe/kernel.cpp index c5cb8111..e4d8ee31 100644 --- a/kernels/wu_arch_cases/case17_flash_exp_softmax_probe/kernel.cpp +++ b/kernels/wu_arch_cases/case17_flash_exp_softmax_probe/kernel.cpp @@ -26,15 +26,6 @@ static inline float wu_case17_absf(float value) { return value < 0.0f ? -value : value; } -static inline float wu_case17_exp_neg_ln2_to_0(float x) { - const float x2 = x * x; - const float x3 = x2 * x; - const float x4 = x2 * x2; - const float x5 = x4 * x; - return 1.0f + x + (0.5f * x2) + (0.1666666716f * x3) + - (0.0416666679f * x4) + (0.0083333338f * x5); -} - extern "C" int wu_main() { if (vx_core_id() != 0 || vx_warp_id() != 0) { return 0; @@ -51,8 +42,8 @@ extern "C" int wu_main() { const float score0 = wu_case17_bits_to_f32(g_case17_scores_bits[0]); const float score1 = wu_case17_bits_to_f32(g_case17_scores_bits[1]); const float row_max = score0 > score1 ? score0 : score1; - const float e0 = wu_case17_exp_neg_ln2_to_0(score0 - row_max); - const float e1 = wu_case17_exp_neg_ln2_to_0(score1 - row_max); + const float e0 = wu_fexp_s(score0 - row_max); + const float e1 = wu_fexp_s(score1 - row_max); const float inv_sum = 1.0f / (e0 + e1); const float p0 = e0 * inv_sum; const float p1 = e1 * inv_sum; diff --git a/kernels/wu_arch_cases/case18_scalar_fexp/Makefile b/kernels/wu_arch_cases/case18_scalar_fexp/Makefile new file mode 100644 index 00000000..2db25a61 --- /dev/null +++ b/kernels/wu_arch_cases/case18_scalar_fexp/Makefile @@ -0,0 +1,3 @@ +PROJECT = case18_scalar_fexp + +include ../case.mk diff --git a/kernels/wu_arch_cases/case18_scalar_fexp/README.md b/kernels/wu_arch_cases/case18_scalar_fexp/README.md new file mode 100644 index 00000000..906e3b05 --- /dev/null +++ b/kernels/wu_arch_cases/case18_scalar_fexp/README.md @@ -0,0 +1,5 @@ +# case18_scalar_fexp + +Verifies scalar-warp execution of the custom `FEXP.S` instruction. The test +checks representative fp32 inputs used by softmax-style code paths and confirms +the result is close to `expf`. diff --git a/kernels/wu_arch_cases/case18_scalar_fexp/kernel.cpp b/kernels/wu_arch_cases/case18_scalar_fexp/kernel.cpp new file mode 100644 index 00000000..31ecc222 --- /dev/null +++ b/kernels/wu_arch_cases/case18_scalar_fexp/kernel.cpp @@ -0,0 +1,68 @@ +#include "../common_wu_min.h" + +extern "C" { +volatile uint32_t g_case18_out_bits[4] __attribute__((aligned(16))); +} + +static inline uint32_t f32_to_bits(float value) { + union { + float f; + uint32_t u; + } v = {value}; + return v.u; +} + +static inline float absf_local(float value) { + return value < 0.0f ? -value : value; +} + +extern "C" int wu_main() { + if (vx_core_id() != 0 || vx_warp_id() != 0) { + return 0; + } + + const uint32_t tid = wu_tid(); + if (tid == 0) { + wu_case_reset(); + for (uint32_t i = 0; i < 4; ++i) { + g_case18_out_bits[i] = 0; + } + } + asm volatile("fence rw, rw" ::: "memory"); + + const float input0 = 0.0f; + const float input1 = 1.0f; + const float input2 = -0.6931471805599453f; + const float input3 = -10.0f; + + const float out0 = wu_fexp_s(input0); + const float out1 = wu_fexp_s(input1); + const float out2 = wu_fexp_s(input2); + const float out3 = wu_fexp_s(input3); + + if (tid == 0) { + g_case18_out_bits[0] = f32_to_bits(out0); + g_case18_out_bits[1] = f32_to_bits(out1); + g_case18_out_bits[2] = f32_to_bits(out2); + g_case18_out_bits[3] = f32_to_bits(out3); + } + asm volatile("fence rw, rw" ::: "memory"); + + if (tid == 0) { + const float err0 = absf_local(out0 - 1.0f); + const float err1 = absf_local(out1 - 2.7182817459f); + const float err2 = absf_local(out2 - 0.5f); + const float err3 = absf_local(out3 - 0.00004539993f); + if (err0 > 0.00001f || err1 > 0.0002f || err2 > 0.00001f || + err3 > 0.000001f) { + g_aux[0] = g_case18_out_bits[0]; + g_aux[1] = g_case18_out_bits[1]; + g_aux[2] = g_case18_out_bits[2]; + g_aux[3] = g_case18_out_bits[3]; + wu_case_fail(0x18u); + return 1; + } + wu_case_pass(); + } + return 0; +} diff --git a/kernels/wu_arch_cases/case19_tensor_fexp_illegal/Makefile b/kernels/wu_arch_cases/case19_tensor_fexp_illegal/Makefile new file mode 100644 index 00000000..e1976a75 --- /dev/null +++ b/kernels/wu_arch_cases/case19_tensor_fexp_illegal/Makefile @@ -0,0 +1,3 @@ +PROJECT = case19_tensor_fexp_illegal + +include ../case.mk diff --git a/kernels/wu_arch_cases/case19_tensor_fexp_illegal/README.md b/kernels/wu_arch_cases/case19_tensor_fexp_illegal/README.md new file mode 100644 index 00000000..a88ad1dd --- /dev/null +++ b/kernels/wu_arch_cases/case19_tensor_fexp_illegal/README.md @@ -0,0 +1,5 @@ +# case19_tensor_fexp_illegal + +Negative test for `FEXP.S`: tensor warps must not execute this scalar FPU +instruction. Running this case is expected to trip the existing tensor-FPU +illegal-instruction path in decode/dispatch rather than complete normally. diff --git a/kernels/wu_arch_cases/case19_tensor_fexp_illegal/kernel.cpp b/kernels/wu_arch_cases/case19_tensor_fexp_illegal/kernel.cpp new file mode 100644 index 00000000..0a399c53 --- /dev/null +++ b/kernels/wu_arch_cases/case19_tensor_fexp_illegal/kernel.cpp @@ -0,0 +1,25 @@ +#include "../common_wu_min.h" + +extern "C" void __attribute__((naked, noinline, used)) tensor_worker() { + asm volatile( + "fmv.w.x f1, x0\n\t" + ".insn r %[custom1], 2, 0x30, f2, f1, x0\n\t" + ".insn r %[custom0], 0, 0, x0, x0, x0\n\t" + "1: j 1b\n\t" + : + : [custom0] "i"(RISCV_CUSTOM0), + [custom1] "i"(RISCV_CUSTOM1) + : "memory"); +} + +extern "C" int wu_main() { + if (!wu_is_leader()) { + return 0; + } + + wu_case_reset(); + vx_spawn_tensor(vx_tensor_warp_mask(), tensor_worker); + + wu_case_fail(0x19u); + return 1; +} diff --git a/kernels/wu_arch_cases/case20_flash_bwd_fused/Makefile b/kernels/wu_arch_cases/case20_flash_bwd_fused/Makefile new file mode 100644 index 00000000..8d108124 --- /dev/null +++ b/kernels/wu_arch_cases/case20_flash_bwd_fused/Makefile @@ -0,0 +1,3 @@ +PROJECT = case20_flash_bwd_fused + +include ../case.mk diff --git a/kernels/wu_arch_cases/case20_flash_bwd_fused/README.md b/kernels/wu_arch_cases/case20_flash_bwd_fused/README.md new file mode 100644 index 00000000..1c9c6ba9 --- /dev/null +++ b/kernels/wu_arch_cases/case20_flash_bwd_fused/README.md @@ -0,0 +1,19 @@ +# case20_flash_bwd_fused + +FlashAttention backward-style fused pipeline smoke test. + +The tensor warp performs one score MMA, then waits for the scalar warp to run +softmax plus dsoftmax on the TMEM C row. The scalar warp writes the dS row back +to TMEM A using signed fp16 values. The tensor warp then performs four more +MMA steps, for five MMA operations total in this case. + +This case verifies: + +- tensor warp MMA sequencing around a scalar TMEM handoff; +- scalar-only `FEXP.S` use for stable softmax; +- dsoftmax shape `dS = P * (dP - sum(P * dP))`; +- signed scalar TMEM stores feeding later tensor MMA operations. + +The input score row is uniform, so `P = 1/32`. The synthetic upstream gradient +uses two buckets, producing exact dS values `-1/32` for row entries 0..15 and +`+1/32` for row entries 16..31. diff --git a/kernels/wu_arch_cases/case20_flash_bwd_fused/kernel.cpp b/kernels/wu_arch_cases/case20_flash_bwd_fused/kernel.cpp new file mode 100644 index 00000000..98c3223c --- /dev/null +++ b/kernels/wu_arch_cases/case20_flash_bwd_fused/kernel.cpp @@ -0,0 +1,289 @@ +#include "../common_wu_blackwell_fa.h" + +#define WU_CASE20_SCORE_READY 0xa000u +#define WU_CASE20_DSOFTMAX_READY 0xa100u +#define WU_CASE20_DONE_BASE 0xa200u +#define WU_CASE20_ROW_N 32u +#define WU_CASE20_FP32_ZERO 0x00000000u +#define WU_CASE20_FP32_THIRTY_TWO 0x42000000u + +extern "C" { +volatile uint32_t g_case20_q_row[4] __attribute__((aligned(16))) = { + WU_BW_FP16_ONE_PACKED, WU_BW_FP16_ONE_PACKED, WU_BW_FP16_ONE_PACKED, + WU_BW_FP16_ONE_PACKED}; +volatile uint32_t g_case20_zero_row[4] __attribute__((aligned(16))) = { + WU_CASE20_FP32_ZERO, WU_CASE20_FP32_ZERO, WU_CASE20_FP32_ZERO, + WU_CASE20_FP32_ZERO}; +volatile uint32_t g_case20_out[WU_BW_OUT_WORDS] __attribute__((aligned(16))); +volatile uint32_t g_case20_score_bits[4] __attribute__((aligned(16))); +volatile uint32_t g_case20_dsoftmax_bits[4] __attribute__((aligned(16))); +} + +static inline float wu_case20_bits_to_f32(uint32_t bits) { + union { + uint32_t u; + float f; + } v = {bits}; + return v.f; +} + +static inline uint32_t wu_case20_f32_to_bits(float value) { + union { + float f; + uint32_t u; + } v = {value}; + return v.u; +} + +static inline float wu_case20_absf(float value) { + return value < 0.0f ? -value : value; +} + +static inline uint16_t wu_case20_f32_to_f16(float value) { + const uint32_t bits = wu_case20_f32_to_bits(value); + const uint32_t sign = (bits >> 16) & 0x8000u; + const uint32_t exp = (bits >> 23) & 0xffu; + uint32_t mant = bits & 0x7fffffu; + + if ((bits & 0x7fffffffu) == 0 || exp == 0) { + return static_cast(sign); + } + if (exp >= 143u) { + return static_cast(sign | 0x7c00u); + } + if (exp <= 112u) { + return static_cast(sign); + } + + uint32_t half_exp = exp - 112u; + mant += 0x1000u; + if (mant & 0x800000u) { + mant = 0; + ++half_exp; + } + if (half_exp >= 31u) { + return static_cast(sign | 0x7c00u); + } + return static_cast(sign | (half_exp << 10) | (mant >> 13)); +} + +static inline uint32_t wu_case20_pack_f16x2(float value) { + const uint32_t h = wu_case20_f32_to_f16(value); + return h | (h << 16); +} + +static inline float wu_case20_dp(uint32_t row_idx) { + return row_idx < 16u ? 0.0f : 2.0f; +} + +static inline void wu_case20_dsoftmax_tmem_row(uint32_t score_frag_base, + uint32_t ds_byte_base) { + float row_max = wu_case20_bits_to_f32(wu_bw_scalar_tmem_ld(score_frag_base)); + for (uint32_t i = 1; i < WU_CASE20_ROW_N; ++i) { + const float score = + wu_case20_bits_to_f32(wu_bw_scalar_tmem_ld(score_frag_base + i)); + row_max = score > row_max ? score : row_max; + } + + float denom = 0.0f; + for (uint32_t i = 0; i < WU_CASE20_ROW_N; ++i) { + const float score = + wu_case20_bits_to_f32(wu_bw_scalar_tmem_ld(score_frag_base + i)); + denom += wu_fexp_s(score - row_max); + } + + float dot = 0.0f; + for (uint32_t i = 0; i < WU_CASE20_ROW_N; ++i) { + const float score = + wu_case20_bits_to_f32(wu_bw_scalar_tmem_ld(score_frag_base + i)); + const float p = wu_fexp_s(score - row_max) / denom; + dot += p * wu_case20_dp(i); + } + + const uint32_t ds_frag_base = ds_byte_base / WU_BW_TMEM_FRAGMENT_BYTES; + for (uint32_t frag = 0; frag < WU_BW_TMEM_FRAGMENTS; ++frag) { + const uint32_t row_idx = frag % WU_CASE20_ROW_N; + const float score = + wu_case20_bits_to_f32(wu_bw_scalar_tmem_ld(score_frag_base + row_idx)); + const float p = wu_fexp_s(score - row_max) / denom; + const float ds = p * (wu_case20_dp(row_idx) - dot); + if (wu_tid() == 0 && row_idx == 0) { + g_case20_dsoftmax_bits[0] = wu_case20_f32_to_bits(ds); + } + if (wu_tid() == 0 && row_idx == 16u) { + g_case20_dsoftmax_bits[1] = wu_case20_f32_to_bits(ds); + } + wu_bw_scalar_tmem_st(ds_frag_base + frag, wu_case20_pack_f16x2(ds)); + } + asm volatile("fence rw, rw" ::: "memory"); +} + +extern "C" void __attribute__((naked, noinline, used)) tensor_case20_worker() { + asm volatile( + "csrr x5, %[csr_wid]\n\t" + "addi x1, x5, -%[num_scalar_warps]\n\t" + "slli x1, x1, 11\n\t" + "addi x2, x1, %[c_offset]\n\t" + "la x3, g_case20_q_row\n\t" + "li x7, 0\n\t" + "1:\n\t" + "add x4, x1, x7\n\t" + ".insn r %[custom3], 2, 0, x0, x4, x3\n\t" + "addi x7, x7, 16\n\t" + "li x4, %[tile_bytes]\n\t" + "blt x7, x4, 1b\n\t" + "la x3, g_case20_zero_row\n\t" + "li x7, 0\n\t" + "2:\n\t" + "add x4, x2, x7\n\t" + ".insn r %[custom3], 2, 0, x0, x4, x3\n\t" + "addi x7, x7, 16\n\t" + "li x4, %[tile_bytes]\n\t" + "blt x7, x4, 2b\n\t" + ".insn r %[custom3], 3, 0, x0, x0, x0\n\t" + "li x4, %[smem_base]\n\t" + ".insn r %[custom3], 0, 0, x2, x1, x4\n\t" + ".insn r %[custom3], 1, 0, x0, x0, x0\n\t" + "csrr x5, %[csr_wid]\n\t" + "slli x6, x5, 2\n\t" + "la x7, g_seen\n\t" + "add x7, x7, x6\n\t" + "li x6, %[score_ready]\n\t" + "or x6, x6, x5\n\t" + "sw x6, 0(x7)\n\t" + "3:\n\t" + "la x6, g_case_mem\n\t" + "lw x7, 0(x6)\n\t" + "li x4, %[dsoftmax_ready]\n\t" + "bne x7, x4, 3b\n\t" + "la x3, g_case20_zero_row\n\t" + "li x7, 0\n\t" + "4:\n\t" + "add x4, x2, x7\n\t" + ".insn r %[custom3], 2, 0, x0, x4, x3\n\t" + "addi x7, x7, 16\n\t" + "li x4, %[tile_bytes]\n\t" + "blt x7, x4, 4b\n\t" + ".insn r %[custom3], 3, 0, x0, x0, x0\n\t" + "li x4, %[smem_base]\n\t" + ".insn r %[custom3], 0, 0, x2, x1, x4\n\t" + ".insn r %[custom3], 1, 0, x0, x0, x0\n\t" + ".insn r %[custom3], 0, 0, x2, x1, x4\n\t" + ".insn r %[custom3], 1, 0, x0, x0, x0\n\t" + ".insn r %[custom3], 0, 0, x2, x1, x4\n\t" + ".insn r %[custom3], 1, 0, x0, x0, x0\n\t" + ".insn r %[custom3], 0, 0, x2, x1, x4\n\t" + ".insn r %[custom3], 1, 0, x0, x0, x0\n\t" + "la x3, g_case20_out\n\t" + "li x7, 0\n\t" + "5:\n\t" + "add x4, x2, x7\n\t" + "add x6, x3, x7\n\t" + ".insn r %[custom3], 6, 0, x0, x4, x6\n\t" + "addi x7, x7, 16\n\t" + "li x4, %[tile_bytes]\n\t" + "blt x7, x4, 5b\n\t" + ".insn r %[custom3], 3, 0, x0, x0, x0\n\t" + "csrr x5, %[csr_wid]\n\t" + "slli x6, x5, 2\n\t" + "la x7, g_seen\n\t" + "add x7, x7, x6\n\t" + "li x6, %[done_base]\n\t" + "or x6, x6, x5\n\t" + "sw x6, 0(x7)\n\t" + ".insn r %[custom0], 0, 0, x0, x0, x0\n\t" + "6: j 6b\n\t" + : + : [csr_wid] "i"(VX_CSR_WARP_ID), [custom0] "i"(RISCV_CUSTOM0), + [custom3] "i"(RISCV_CUSTOM3), + [num_scalar_warps] "i"(NUM_SCALAR_WARPS), + [c_offset] "i"(WU_BW_TMEM_C_BYTE_OFFSET), + [tile_bytes] "i"(WU_BW_TMEM_TILE_BYTES), + [smem_base] "i"(WU_BW_DEV_SMEM_START_ADDR), + [score_ready] "i"(WU_CASE20_SCORE_READY), + [dsoftmax_ready] "i"(WU_CASE20_DSOFTMAX_READY), + [done_base] "i"(WU_CASE20_DONE_BASE) + : "memory"); +} + +extern "C" int wu_main() { + if (vx_core_id() != 0 || vx_warp_id() != 0) { + return 0; + } + + const uint32_t tid = wu_tid(); + const uint32_t tensor_mask = 1u << NUM_SCALAR_WARPS; + if (tid == 0) { + wu_case_reset(); + for (uint32_t i = 0; i < WU_BW_OUT_WORDS; ++i) { + g_case20_out[i] = 0xffffffffu; + } + for (uint32_t i = 0; i < 4; ++i) { + g_case20_score_bits[i] = 0; + g_case20_dsoftmax_bits[i] = 0; + } + wu_bw_fill_smem_tile( + reinterpret_cast(WU_BW_DEV_SMEM_START_ADDR), + WU_BW_FP16_ONE_PACKED); + vx_spawn_tensor(tensor_mask, tensor_case20_worker); + if (wu_wait_seen_mask(tensor_mask, WU_CASE20_SCORE_READY) != 0) { + g_case_mem[1] = 0x41u; + } + } + asm volatile("fence rw, rw" ::: "memory"); + + const uint32_t c_frag = + wu_bw_tmem_c_byte_base(0) / WU_BW_TMEM_FRAGMENT_BYTES; + const uint32_t observed = wu_bw_scalar_tmem_ld(c_frag); + if (tid == 0) { + g_case20_score_bits[0] = observed; + if (g_case_mem[1] == 0 && observed != WU_CASE20_FP32_THIRTY_TWO) { + g_aux[0] = observed; + g_case_mem[1] = 0x42u; + } + } + asm volatile("fence rw, rw" ::: "memory"); + + if (g_case_mem[1] == 0) { + vx_tmc(wu_bw_all_lanes_mask()); + wu_case20_dsoftmax_tmem_row(c_frag, wu_bw_tmem_a_byte_base(0)); + vx_tmc_one(); + } + asm volatile("fence rw, rw" ::: "memory"); + + if (tid == 0) { + if (g_case_mem[1] == 0) { + const float neg = wu_case20_bits_to_f32(g_case20_dsoftmax_bits[0]); + const float pos = wu_case20_bits_to_f32(g_case20_dsoftmax_bits[1]); + if (wu_case20_absf(neg + 0.03125f) > 0.0002f || + wu_case20_absf(pos - 0.03125f) > 0.0002f) { + g_aux[0] = g_case20_dsoftmax_bits[0]; + g_aux[1] = g_case20_dsoftmax_bits[1]; + g_case_mem[1] = 0x43u; + } + } + asm volatile("fence rw, rw" ::: "memory"); + g_case_mem[0] = WU_CASE20_DSOFTMAX_READY; + if (g_case_mem[1] == 0 && + wu_wait_seen_mask(tensor_mask, WU_CASE20_DONE_BASE) != 0) { + g_case_mem[1] = 0x44u; + } + if (g_case_mem[1] == 0) { + volatile uint32_t bad_actual = 0; + const uint32_t bad = + wu_bw_verify_constant(g_case20_out, WU_BW_OUT_WORDS, + WU_CASE20_FP32_ZERO, &bad_actual); + if (bad != WU_BW_OUT_WORDS) { + g_aux[0] = bad; + g_aux[1] = bad_actual; + g_case_mem[1] = 0x45u; + } + } + if (g_case_mem[1] != 0) { + wu_case_fail(g_case_mem[1]); + return 1; + } + wu_case_pass(); + } + return 0; +} diff --git a/kernels/wu_arch_cases/case21_moe_gating/Makefile b/kernels/wu_arch_cases/case21_moe_gating/Makefile new file mode 100644 index 00000000..d64f4f39 --- /dev/null +++ b/kernels/wu_arch_cases/case21_moe_gating/Makefile @@ -0,0 +1,3 @@ +PROJECT = case21_moe_gating + +include ../case.mk diff --git a/kernels/wu_arch_cases/case21_moe_gating/README.md b/kernels/wu_arch_cases/case21_moe_gating/README.md new file mode 100644 index 00000000..7aaf79e0 --- /dev/null +++ b/kernels/wu_arch_cases/case21_moe_gating/README.md @@ -0,0 +1,10 @@ +# case21_moe_gating + +MoE gating scalar pipeline test. + +This case runs `softmax -> Top-K -> scatter` on scalar warp 0 using `FEXP.S`. +The logits are `log(1), log(2), log(4), log(8)`, so the expected probabilities +are `1/15, 2/15, 4/15, 8/15`. Top-2 should select experts 3 and 2, then scatter +the token id and weight into the selected expert slots. + +No tensor warp is spawned in this case. diff --git a/kernels/wu_arch_cases/case21_moe_gating/kernel.cpp b/kernels/wu_arch_cases/case21_moe_gating/kernel.cpp new file mode 100644 index 00000000..45363e4c --- /dev/null +++ b/kernels/wu_arch_cases/case21_moe_gating/kernel.cpp @@ -0,0 +1,128 @@ +#include "../common_wu_min.h" + +#define WU_CASE21_TOKEN_ID 0x21u +#define WU_CASE21_EMPTY 0xffffffffu + +extern "C" { +volatile uint32_t g_case21_logits_bits[4] __attribute__((aligned(16))) = { + 0x00000000u, 0x3f317218u, 0x3fb17218u, 0x40051d8fu}; +volatile uint32_t g_case21_prob_bits[4] __attribute__((aligned(16))); +volatile uint32_t g_case21_top_idx[2] __attribute__((aligned(16))); +volatile uint32_t g_case21_expert_token[4] __attribute__((aligned(16))); +volatile uint32_t g_case21_expert_weight_bits[4] __attribute__((aligned(16))); +} + +static inline float wu_case21_bits_to_f32(uint32_t bits) { + union { + uint32_t u; + float f; + } v = {bits}; + return v.f; +} + +static inline uint32_t wu_case21_f32_to_bits(float value) { + union { + float f; + uint32_t u; + } v = {value}; + return v.u; +} + +static inline float wu_case21_absf(float value) { + return value < 0.0f ? -value : value; +} + +extern "C" int wu_main() { + if (vx_core_id() != 0 || vx_warp_id() != 0) { + return 0; + } + + const uint32_t tid = wu_tid(); + if (tid == 0) { + wu_case_reset(); + for (uint32_t i = 0; i < 4; ++i) { + g_case21_prob_bits[i] = 0; + g_case21_expert_token[i] = WU_CASE21_EMPTY; + g_case21_expert_weight_bits[i] = 0; + } + g_case21_top_idx[0] = WU_CASE21_EMPTY; + g_case21_top_idx[1] = WU_CASE21_EMPTY; + } + asm volatile("fence rw, rw" ::: "memory"); + + float logits[4]; + float row_max = wu_case21_bits_to_f32(g_case21_logits_bits[0]); + for (uint32_t i = 0; i < 4; ++i) { + logits[i] = wu_case21_bits_to_f32(g_case21_logits_bits[i]); + row_max = logits[i] > row_max ? logits[i] : row_max; + } + + float exp_values[4]; + float denom = 0.0f; + for (uint32_t i = 0; i < 4; ++i) { + exp_values[i] = wu_fexp_s(logits[i] - row_max); + denom += exp_values[i]; + } + + float probs[4]; + for (uint32_t i = 0; i < 4; ++i) { + probs[i] = exp_values[i] / denom; + } + + uint32_t top0 = 0; + uint32_t top1 = 1; + if (probs[top1] > probs[top0]) { + const uint32_t tmp = top0; + top0 = top1; + top1 = tmp; + } + for (uint32_t i = 2; i < 4; ++i) { + if (probs[i] > probs[top0]) { + top1 = top0; + top0 = i; + } else if (probs[i] > probs[top1]) { + top1 = i; + } + } + + if (tid == 0) { + for (uint32_t i = 0; i < 4; ++i) { + g_case21_prob_bits[i] = wu_case21_f32_to_bits(probs[i]); + } + g_case21_top_idx[0] = top0; + g_case21_top_idx[1] = top1; + g_case21_expert_token[top0] = WU_CASE21_TOKEN_ID; + g_case21_expert_token[top1] = WU_CASE21_TOKEN_ID; + g_case21_expert_weight_bits[top0] = wu_case21_f32_to_bits(probs[top0]); + g_case21_expert_weight_bits[top1] = wu_case21_f32_to_bits(probs[top1]); + } + asm volatile("fence rw, rw" ::: "memory"); + + if (tid == 0) { + const float expected[4] = {0.0666666701f, 0.1333333403f, + 0.2666666806f, 0.5333333611f}; + const float tolerance = 0.0015f; + for (uint32_t i = 0; i < 4; ++i) { + if (wu_case21_absf(probs[i] - expected[i]) > tolerance) { + g_aux[0] = i; + g_aux[1] = g_case21_prob_bits[i]; + wu_case_fail(0x21u); + return 1; + } + } + if (top0 != 3u || top1 != 2u || + g_case21_expert_token[3] != WU_CASE21_TOKEN_ID || + g_case21_expert_token[2] != WU_CASE21_TOKEN_ID || + g_case21_expert_token[0] != WU_CASE21_EMPTY || + g_case21_expert_token[1] != WU_CASE21_EMPTY) { + g_aux[0] = top0; + g_aux[1] = top1; + g_aux[2] = g_case21_expert_token[3]; + g_aux[3] = g_case21_expert_token[2]; + wu_case_fail(0x22u); + return 1; + } + wu_case_pass(); + } + return 0; +} diff --git a/kernels/wu_arch_cases/case22_gemm_silu/Makefile b/kernels/wu_arch_cases/case22_gemm_silu/Makefile new file mode 100644 index 00000000..75ca4ff5 --- /dev/null +++ b/kernels/wu_arch_cases/case22_gemm_silu/Makefile @@ -0,0 +1,3 @@ +PROJECT = case22_gemm_silu + +include ../case.mk diff --git a/kernels/wu_arch_cases/case22_gemm_silu/README.md b/kernels/wu_arch_cases/case22_gemm_silu/README.md new file mode 100644 index 00000000..b1c6c669 --- /dev/null +++ b/kernels/wu_arch_cases/case22_gemm_silu/README.md @@ -0,0 +1,10 @@ +# case22_gemm_silu + +GEMM plus SiLU fusion smoke test. + +The tensor warp computes a compact GEMM with fp16 `A = 0.125` and fp16 `B = 1`, +producing fp32 `C = 4`. Scalar warp 0 reads TMEM C and applies +`SiLU(x) = x / (1 + exp(-x))` using scalar-only `FEXP.S`. + +This case verifies the common `matmul -> nonlinear activation` fusion path +without allowing tensor warp FPU execution. diff --git a/kernels/wu_arch_cases/case22_gemm_silu/kernel.cpp b/kernels/wu_arch_cases/case22_gemm_silu/kernel.cpp new file mode 100644 index 00000000..697d0717 --- /dev/null +++ b/kernels/wu_arch_cases/case22_gemm_silu/kernel.cpp @@ -0,0 +1,160 @@ +#include "../common_wu_blackwell_fa.h" + +#define WU_CASE22_INIT_BASE 0xa400u +#define WU_CASE22_DONE_BASE 0xa500u +#define WU_CASE22_FP16_ONE_EIGHTH_PACKED 0x30003000u +#define WU_CASE22_FP32_FOUR 0x40800000u + +extern "C" { +volatile uint32_t g_case22_a_row[4] __attribute__((aligned(16))) = { + WU_CASE22_FP16_ONE_EIGHTH_PACKED, WU_CASE22_FP16_ONE_EIGHTH_PACKED, + WU_CASE22_FP16_ONE_EIGHTH_PACKED, WU_CASE22_FP16_ONE_EIGHTH_PACKED}; +volatile uint32_t g_case22_zero_row[4] __attribute__((aligned(16))) = { + 0x00000000u, 0x00000000u, 0x00000000u, 0x00000000u}; +volatile uint32_t g_case22_out[WU_BW_OUT_WORDS] __attribute__((aligned(16))); +volatile uint32_t g_case22_silu_bits[4] __attribute__((aligned(16))); +} + +static inline float wu_case22_bits_to_f32(uint32_t bits) { + union { + uint32_t u; + float f; + } v = {bits}; + return v.f; +} + +static inline uint32_t wu_case22_f32_to_bits(float value) { + union { + float f; + uint32_t u; + } v = {value}; + return v.u; +} + +static inline float wu_case22_absf(float value) { + return value < 0.0f ? -value : value; +} + +extern "C" void __attribute__((naked, noinline, used)) tensor_case22_worker() { + asm volatile( + "csrr x5, %[csr_wid]\n\t" + "addi x1, x5, -%[num_scalar_warps]\n\t" + "slli x1, x1, 11\n\t" + "addi x2, x1, %[c_offset]\n\t" + "la x3, g_case22_a_row\n\t" + "li x7, 0\n\t" + "1:\n\t" + "add x4, x1, x7\n\t" + ".insn r %[custom3], 2, 0, x0, x4, x3\n\t" + "addi x7, x7, 16\n\t" + "li x4, %[tile_bytes]\n\t" + "blt x7, x4, 1b\n\t" + "la x3, g_case22_zero_row\n\t" + "li x7, 0\n\t" + "2:\n\t" + "add x4, x2, x7\n\t" + ".insn r %[custom3], 2, 0, x0, x4, x3\n\t" + "addi x7, x7, 16\n\t" + "li x4, %[tile_bytes]\n\t" + "blt x7, x4, 2b\n\t" + ".insn r %[custom3], 3, 0, x0, x0, x0\n\t" + "li x4, %[smem_base]\n\t" + ".insn r %[custom3], 0, 0, x2, x1, x4\n\t" + ".insn r %[custom3], 1, 0, x0, x0, x0\n\t" + "la x3, g_case22_out\n\t" + "li x7, 0\n\t" + "3:\n\t" + "add x4, x2, x7\n\t" + "add x6, x3, x7\n\t" + ".insn r %[custom3], 6, 0, x0, x4, x6\n\t" + "addi x7, x7, 16\n\t" + "li x4, %[tile_bytes]\n\t" + "blt x7, x4, 3b\n\t" + ".insn r %[custom3], 3, 0, x0, x0, x0\n\t" + "csrr x5, %[csr_wid]\n\t" + "slli x6, x5, 2\n\t" + "la x7, g_seen\n\t" + "add x7, x7, x6\n\t" + "li x6, %[done_base]\n\t" + "or x6, x6, x5\n\t" + "sw x6, 0(x7)\n\t" + ".insn r %[custom0], 0, 0, x0, x0, x0\n\t" + "4: j 4b\n\t" + : + : [csr_wid] "i"(VX_CSR_WARP_ID), [custom0] "i"(RISCV_CUSTOM0), + [custom3] "i"(RISCV_CUSTOM3), + [num_scalar_warps] "i"(NUM_SCALAR_WARPS), + [c_offset] "i"(WU_BW_TMEM_C_BYTE_OFFSET), + [tile_bytes] "i"(WU_BW_TMEM_TILE_BYTES), + [smem_base] "i"(WU_BW_DEV_SMEM_START_ADDR), + [done_base] "i"(WU_CASE22_DONE_BASE) + : "memory"); +} + +extern "C" int wu_main() { + if (vx_core_id() != 0 || vx_warp_id() != 0) { + return 0; + } + + const uint32_t tid = wu_tid(); + const uint32_t tensor_mask = 1u << NUM_SCALAR_WARPS; + if (tid == 0) { + wu_case_reset(); + for (uint32_t i = 0; i < WU_BW_OUT_WORDS; ++i) { + g_case22_out[i] = 0; + } + for (uint32_t i = 0; i < 4; ++i) { + g_case22_silu_bits[i] = 0; + } + wu_bw_fill_smem_tile( + reinterpret_cast(WU_BW_DEV_SMEM_START_ADDR), + WU_BW_FP16_ONE_PACKED); + vx_spawn_tensor(tensor_mask, tensor_case22_worker); + if (wu_wait_seen_mask(tensor_mask, WU_CASE22_DONE_BASE) != 0) { + g_case_mem[1] = 0x51u; + } + } + asm volatile("fence rw, rw" ::: "memory"); + + const uint32_t c_frag = + wu_bw_tmem_c_byte_base(0) / WU_BW_TMEM_FRAGMENT_BYTES; + const uint32_t observed_bits = wu_bw_scalar_tmem_ld(c_frag); + const float observed = wu_case22_bits_to_f32(observed_bits); + const float silu = observed / (1.0f + wu_fexp_s(-observed)); + + if (tid == 0) { + g_case22_silu_bits[0] = wu_case22_f32_to_bits(silu); + } + asm volatile("fence rw, rw" ::: "memory"); + + if (tid == 0) { + if (g_case_mem[1] == 0 && observed_bits != WU_CASE22_FP32_FOUR) { + g_aux[0] = observed_bits; + g_case_mem[1] = 0x52u; + } + if (g_case_mem[1] == 0) { + const float expected = 3.9280550480f; + if (wu_case22_absf(silu - expected) > 0.004f) { + g_aux[0] = g_case22_silu_bits[0]; + g_case_mem[1] = 0x53u; + } + } + if (g_case_mem[1] == 0) { + volatile uint32_t bad_actual = 0; + const uint32_t bad = + wu_bw_verify_constant(g_case22_out, WU_BW_OUT_WORDS, + WU_CASE22_FP32_FOUR, &bad_actual); + if (bad != WU_BW_OUT_WORDS) { + g_aux[0] = bad; + g_aux[1] = bad_actual; + g_case_mem[1] = 0x54u; + } + } + if (g_case_mem[1] != 0) { + wu_case_fail(g_case_mem[1]); + return 1; + } + wu_case_pass(); + } + return 0; +} diff --git a/kernels/wu_arch_cases/case23_softmax_only/Makefile b/kernels/wu_arch_cases/case23_softmax_only/Makefile new file mode 100644 index 00000000..04e2835b --- /dev/null +++ b/kernels/wu_arch_cases/case23_softmax_only/Makefile @@ -0,0 +1,3 @@ +PROJECT = case23_softmax_only + +include ../case.mk diff --git a/kernels/wu_arch_cases/case23_softmax_only/README.md b/kernels/wu_arch_cases/case23_softmax_only/README.md new file mode 100644 index 00000000..3d2efc33 --- /dev/null +++ b/kernels/wu_arch_cases/case23_softmax_only/README.md @@ -0,0 +1,9 @@ +# case23_softmax_only + +Scalar softmax-only test. + +This case runs a stable 4-way softmax on scalar warp 0 using `FEXP.S`. The logits +are `log(1), log(3), log(5), log(7)`, giving expected probabilities +`1/16, 3/16, 5/16, 7/16`. + +No tensor warp is spawned in this case. diff --git a/kernels/wu_arch_cases/case23_softmax_only/kernel.cpp b/kernels/wu_arch_cases/case23_softmax_only/kernel.cpp new file mode 100644 index 00000000..e64b5bc5 --- /dev/null +++ b/kernels/wu_arch_cases/case23_softmax_only/kernel.cpp @@ -0,0 +1,83 @@ +#include "../common_wu_min.h" + +extern "C" { +volatile uint32_t g_case23_scores_bits[4] __attribute__((aligned(16))) = { + 0x00000000u, 0x3f8c9f54u, 0x3fcdf854u, 0x3ff91395u}; +volatile uint32_t g_case23_out_bits[4] __attribute__((aligned(16))); +} + +static inline float wu_case23_bits_to_f32(uint32_t bits) { + union { + uint32_t u; + float f; + } v = {bits}; + return v.f; +} + +static inline uint32_t wu_case23_f32_to_bits(float value) { + union { + float f; + uint32_t u; + } v = {value}; + return v.u; +} + +static inline float wu_case23_absf(float value) { + return value < 0.0f ? -value : value; +} + +extern "C" int wu_main() { + if (vx_core_id() != 0 || vx_warp_id() != 0) { + return 0; + } + + const uint32_t tid = wu_tid(); + if (tid == 0) { + wu_case_reset(); + for (uint32_t i = 0; i < 4; ++i) { + g_case23_out_bits[i] = 0; + } + } + asm volatile("fence rw, rw" ::: "memory"); + + float scores[4]; + float row_max = wu_case23_bits_to_f32(g_case23_scores_bits[0]); + for (uint32_t i = 0; i < 4; ++i) { + scores[i] = wu_case23_bits_to_f32(g_case23_scores_bits[i]); + row_max = scores[i] > row_max ? scores[i] : row_max; + } + + float exp_values[4]; + float denom = 0.0f; + for (uint32_t i = 0; i < 4; ++i) { + exp_values[i] = wu_fexp_s(scores[i] - row_max); + denom += exp_values[i]; + } + + float probs[4]; + for (uint32_t i = 0; i < 4; ++i) { + probs[i] = exp_values[i] / denom; + } + + if (tid == 0) { + for (uint32_t i = 0; i < 4; ++i) { + g_case23_out_bits[i] = wu_case23_f32_to_bits(probs[i]); + } + } + asm volatile("fence rw, rw" ::: "memory"); + + if (tid == 0) { + const float expected[4] = {0.0625f, 0.1875f, 0.3125f, 0.4375f}; + const float tolerance = 0.0015f; + for (uint32_t i = 0; i < 4; ++i) { + if (wu_case23_absf(probs[i] - expected[i]) > tolerance) { + g_aux[0] = i; + g_aux[1] = g_case23_out_bits[i]; + wu_case_fail(0x23u); + return 1; + } + } + wu_case_pass(); + } + return 0; +} diff --git a/kernels/wu_arch_cases/case24_flash_sw_pipeline/Makefile b/kernels/wu_arch_cases/case24_flash_sw_pipeline/Makefile new file mode 100644 index 00000000..08525394 --- /dev/null +++ b/kernels/wu_arch_cases/case24_flash_sw_pipeline/Makefile @@ -0,0 +1,3 @@ +PROJECT = case24_flash_sw_pipeline + +include ../case.mk diff --git a/kernels/wu_arch_cases/case24_flash_sw_pipeline/README.md b/kernels/wu_arch_cases/case24_flash_sw_pipeline/README.md new file mode 100644 index 00000000..6eac8e72 --- /dev/null +++ b/kernels/wu_arch_cases/case24_flash_sw_pipeline/README.md @@ -0,0 +1,24 @@ +# case24_flash_sw_pipeline + +Software-pipelined FlashAttention-style multi-iteration case. + +This case keeps `case16_flash_full_pipeline` as the single-tile +producer/consumer baseline and adds a four-iteration ping-pong pipeline: + +```text +tensor warp 2 / slot 0: QK(0) -> wait P(0) -> PV(0) -> QK(2) -> wait P(2) -> PV(2) +tensor warp 3 / slot 1: QK(1) -> wait P(1) -> PV(1) -> QK(3) -> wait P(3) -> PV(3) +scalar warp 0: softmax(0) -> softmax(1) -> softmax(2) -> softmax(3) +``` + +Each tensor warp owns one TMEM slot. The tensor warp writes `S = Q @ K` into +TMEM C for its slot, marks `score_ready[iter]`, waits for scalar-generated +`P`, then computes `O = P @ V`. Scalar warp 0 waits on each score in order, +uses scalar-only `FEXP.S` for stable softmax, writes packed fp16 probabilities +back to the same slot's TMEM A, and marks `p_ready[iter]`. + +The first version intentionally uses constant `Q`, `K`, and `V` so the expected +numeric result is simple: every score is fp32 `32.0`, every softmax row is +uniform `1/32`, and every output word is fp32 `1.0`. The test objective is the +multi-iteration overlap structure and per-slot handoff, not non-uniform +FlashAttention numerics. diff --git a/kernels/wu_arch_cases/case24_flash_sw_pipeline/kernel.cpp b/kernels/wu_arch_cases/case24_flash_sw_pipeline/kernel.cpp new file mode 100644 index 00000000..6b1defde --- /dev/null +++ b/kernels/wu_arch_cases/case24_flash_sw_pipeline/kernel.cpp @@ -0,0 +1,318 @@ +#define WU_CASE_WAIT_SPIN 16384u + +#include "../common_wu_blackwell_fa.h" + +#define WU_CASE24_ITER_N 4u +#define WU_CASE24_ROW_N 32u +#define WU_CASE24_SCORE_READY_BASE 0xb000u +#define WU_CASE24_P_READY_BASE 0xb100u +#define WU_CASE24_DONE_BASE 0xb200u +#define WU_CASE24_FP32_ZERO 0x00000000u +#define WU_CASE24_FP32_THIRTY_TWO 0x42000000u +#define WU_CASE24_OUT_WORDS (WU_CASE24_ITER_N * WU_BW_OUT_WORDS) + +extern "C" { +volatile uint32_t g_case24_q_row[4] __attribute__((aligned(16))) = { + WU_BW_FP16_ONE_PACKED, WU_BW_FP16_ONE_PACKED, WU_BW_FP16_ONE_PACKED, + WU_BW_FP16_ONE_PACKED}; +volatile uint32_t g_case24_zero_row[4] __attribute__((aligned(16))) = { + WU_CASE24_FP32_ZERO, WU_CASE24_FP32_ZERO, WU_CASE24_FP32_ZERO, + WU_CASE24_FP32_ZERO}; +volatile uint32_t g_case24_score_ready[WU_CASE24_ITER_N] + __attribute__((aligned(16))); +volatile uint32_t g_case24_p_ready[WU_CASE24_ITER_N] + __attribute__((aligned(16))); +volatile uint32_t g_case24_done[WU_CASE24_ITER_N] __attribute__((aligned(16))); +volatile uint32_t g_case24_score_bits[WU_CASE24_ITER_N * NUM_THREADS] + __attribute__((aligned(16))); +volatile uint32_t g_case24_p_bits[WU_CASE24_ITER_N * NUM_THREADS] + __attribute__((aligned(16))); +volatile uint32_t g_case24_overlap_hint __attribute__((aligned(16))); +volatile uint32_t g_case24_out[WU_CASE24_OUT_WORDS] + __attribute__((aligned(16))); +} + +static inline float wu_case24_bits_to_f32(uint32_t bits) { + union { + uint32_t u; + float f; + } v = {bits}; + return v.f; +} + +static inline uint32_t wu_case24_f32_to_bits(float value) { + union { + float f; + uint32_t u; + } v = {value}; + return v.u; +} + +static inline uint16_t wu_case24_f32_to_f16_positive(float value) { + const uint32_t bits = wu_case24_f32_to_bits(value); + const uint32_t exp = (bits >> 23) & 0xffu; + uint32_t mant = bits & 0x7fffffu; + + if (exp == 0 || value <= 0.0f) { + return 0; + } + if (exp >= 143u) { + return 0x7c00u; + } + if (exp <= 112u) { + return 0; + } + + uint32_t half_exp = exp - 112u; + mant += 0x1000u; + if (mant & 0x800000u) { + mant = 0; + ++half_exp; + } + if (half_exp >= 31u) { + return 0x7c00u; + } + return static_cast((half_exp << 10) | (mant >> 13)); +} + +static inline uint32_t wu_case24_pack_f16x2(float value) { + const uint32_t h = wu_case24_f32_to_f16_positive(value); + return h | (h << 16); +} + +static inline int wu_case24_wait_status(volatile uint32_t *status, + uint32_t iter, uint32_t base) { + const uint32_t expected = base | iter; + for (uint32_t spin = 0; spin < WU_CASE_WAIT_SPIN; ++spin) { + if (status[iter] == expected) { + return 0; + } + } + return 1; +} + +static inline void wu_case24_softmax_tmem_row_to_p(uint32_t iter, + uint32_t score_frag_base, + uint32_t p_byte_base) { + float row_max = wu_case24_bits_to_f32(wu_bw_scalar_tmem_ld(score_frag_base)); + for (uint32_t i = 1; i < WU_CASE24_ROW_N; ++i) { + const float score = + wu_case24_bits_to_f32(wu_bw_scalar_tmem_ld(score_frag_base + i)); + row_max = score > row_max ? score : row_max; + } + + float denom = 0.0f; + for (uint32_t i = 0; i < WU_CASE24_ROW_N; ++i) { + const float score = + wu_case24_bits_to_f32(wu_bw_scalar_tmem_ld(score_frag_base + i)); + denom += wu_fexp_s(score - row_max); + } + + const uint32_t p_frag_base = p_byte_base / WU_BW_TMEM_FRAGMENT_BYTES; + for (uint32_t frag = 0; frag < WU_BW_TMEM_FRAGMENTS; ++frag) { + const uint32_t row_idx = frag % WU_CASE24_ROW_N; + const float score = + wu_case24_bits_to_f32(wu_bw_scalar_tmem_ld(score_frag_base + row_idx)); + const float p = wu_fexp_s(score - row_max) / denom; + if (frag == 0) { + g_case24_p_bits[iter * NUM_THREADS + wu_tid()] = + wu_case24_f32_to_bits(p); + } + wu_bw_scalar_tmem_st(p_frag_base + frag, wu_case24_pack_f16x2(p)); + } + asm volatile("fence rw, rw" ::: "memory"); +} + +extern "C" void __attribute__((naked, noinline, used)) tensor_case24_worker() { + asm volatile( + "csrr x5, %[csr_wid]\n\t" + "addi x8, x5, -%[num_scalar_warps]\n\t" + "slli x1, x8, 11\n\t" + "addi x2, x1, %[c_offset]\n\t" + "mv x9, x8\n\t" + "1:\n\t" + "li x10, %[iter_n]\n\t" + "bge x9, x10, 9f\n\t" + "la x3, g_case24_q_row\n\t" + "li x7, 0\n\t" + "2:\n\t" + "add x4, x1, x7\n\t" + ".insn r %[custom3], 2, 0, x0, x4, x3\n\t" + "addi x7, x7, 16\n\t" + "li x4, %[tile_bytes]\n\t" + "blt x7, x4, 2b\n\t" + "la x3, g_case24_zero_row\n\t" + "li x7, 0\n\t" + "3:\n\t" + "add x4, x2, x7\n\t" + ".insn r %[custom3], 2, 0, x0, x4, x3\n\t" + "addi x7, x7, 16\n\t" + "li x4, %[tile_bytes]\n\t" + "blt x7, x4, 3b\n\t" + ".insn r %[custom3], 3, 0, x0, x0, x0\n\t" + "li x4, %[smem_base]\n\t" + ".insn r %[custom3], 0, 0, x2, x1, x4\n\t" + ".insn r %[custom3], 1, 0, x0, x0, x0\n\t" + "slli x6, x9, 2\n\t" + "la x7, g_case24_score_ready\n\t" + "add x7, x7, x6\n\t" + "li x6, %[score_ready_base]\n\t" + "or x6, x6, x9\n\t" + "sw x6, 0(x7)\n\t" + "slli x6, x9, 2\n\t" + "la x7, g_case24_p_ready\n\t" + "add x7, x7, x6\n\t" + "li x4, %[p_ready_base]\n\t" + "or x4, x4, x9\n\t" + "4:\n\t" + "lw x6, 0(x7)\n\t" + "bne x6, x4, 4b\n\t" + "la x3, g_case24_zero_row\n\t" + "li x7, 0\n\t" + "5:\n\t" + "add x4, x2, x7\n\t" + ".insn r %[custom3], 2, 0, x0, x4, x3\n\t" + "addi x7, x7, 16\n\t" + "li x4, %[tile_bytes]\n\t" + "blt x7, x4, 5b\n\t" + ".insn r %[custom3], 3, 0, x0, x0, x0\n\t" + "li x4, %[smem_base]\n\t" + ".insn r %[custom3], 0, 0, x2, x1, x4\n\t" + ".insn r %[custom3], 1, 0, x0, x0, x0\n\t" + "slli x6, x9, 10\n\t" + "la x3, g_case24_out\n\t" + "add x3, x3, x6\n\t" + "li x7, 0\n\t" + "6:\n\t" + "add x4, x2, x7\n\t" + "add x6, x3, x7\n\t" + ".insn r %[custom3], 6, 0, x0, x4, x6\n\t" + "addi x7, x7, 16\n\t" + "li x4, %[tile_bytes]\n\t" + "blt x7, x4, 6b\n\t" + ".insn r %[custom3], 3, 0, x0, x0, x0\n\t" + "slli x6, x9, 2\n\t" + "la x7, g_case24_done\n\t" + "add x7, x7, x6\n\t" + "li x6, %[done_base]\n\t" + "or x6, x6, x9\n\t" + "sw x6, 0(x7)\n\t" + "addi x9, x9, 2\n\t" + "j 1b\n\t" + "9:\n\t" + ".insn r %[custom0], 0, 0, x0, x0, x0\n\t" + "10: j 10b\n\t" + : + : [csr_wid] "i"(VX_CSR_WARP_ID), [custom0] "i"(RISCV_CUSTOM0), + [custom3] "i"(RISCV_CUSTOM3), + [num_scalar_warps] "i"(NUM_SCALAR_WARPS), + [c_offset] "i"(WU_BW_TMEM_C_BYTE_OFFSET), + [tile_bytes] "i"(WU_BW_TMEM_TILE_BYTES), + [smem_base] "i"(WU_BW_DEV_SMEM_START_ADDR), + [iter_n] "i"(WU_CASE24_ITER_N), + [score_ready_base] "i"(WU_CASE24_SCORE_READY_BASE), + [p_ready_base] "i"(WU_CASE24_P_READY_BASE), + [done_base] "i"(WU_CASE24_DONE_BASE) + : "memory"); +} + +extern "C" int wu_main() { + if (vx_core_id() != 0 || vx_warp_id() != 0) { + return 0; + } + + const uint32_t tid = wu_tid(); + const uint32_t tensor_mask = vx_tensor_warp_mask(); + if (tid == 0) { + wu_case_reset(); + for (uint32_t i = 0; i < WU_CASE24_ITER_N; ++i) { + g_case24_score_ready[i] = 0; + g_case24_p_ready[i] = 0; + g_case24_done[i] = 0; + } + for (uint32_t i = 0; i < WU_CASE24_ITER_N * NUM_THREADS; ++i) { + g_case24_score_bits[i] = 0; + g_case24_p_bits[i] = 0; + } + for (uint32_t i = 0; i < WU_CASE24_OUT_WORDS; ++i) { + g_case24_out[i] = 0; + } + g_case24_overlap_hint = 0; + wu_bw_fill_smem_tile( + reinterpret_cast(WU_BW_DEV_SMEM_START_ADDR), + WU_BW_FP16_ONE_PACKED); + vx_spawn_tensor(tensor_mask, tensor_case24_worker); + } + asm volatile("fence rw, rw" ::: "memory"); + + for (uint32_t iter = 0; iter < WU_CASE24_ITER_N; ++iter) { + if (wu_case24_wait_status(g_case24_score_ready, iter, + WU_CASE24_SCORE_READY_BASE) != 0) { + if (tid == 0) { + g_case_mem[1] = 0x81u; + g_aux[0] = iter; + } + } + asm volatile("fence rw, rw" ::: "memory"); + + const uint32_t slot = iter & 1u; + const uint32_t c_frag = + wu_bw_tmem_c_byte_base(slot) / WU_BW_TMEM_FRAGMENT_BYTES; + const uint32_t observed = wu_bw_scalar_tmem_ld(c_frag); + g_case24_score_bits[iter * NUM_THREADS + tid] = observed; + asm volatile("fence rw, rw" ::: "memory"); + + if (tid == 0 && g_case_mem[1] == 0 && + observed != WU_CASE24_FP32_THIRTY_TWO) { + g_aux[0] = iter; + g_aux[1] = observed; + g_case_mem[1] = 0x82u; + } + asm volatile("fence rw, rw" ::: "memory"); + + if (g_case_mem[1] == 0) { + vx_tmc(wu_bw_all_lanes_mask()); + wu_case24_softmax_tmem_row_to_p(iter, c_frag, + wu_bw_tmem_a_byte_base(slot)); + vx_tmc_one(); + } + asm volatile("fence rw, rw" ::: "memory"); + + if (tid == 0 && g_case_mem[1] == 0) { + if (iter == 0 && + g_case24_score_ready[1] == (WU_CASE24_SCORE_READY_BASE | 1u)) { + g_case24_overlap_hint = 1; + } + g_case24_p_ready[iter] = WU_CASE24_P_READY_BASE | iter; + } + asm volatile("fence rw, rw" ::: "memory"); + } + + if (tid == 0) { + for (uint32_t iter = 0; iter < WU_CASE24_ITER_N; ++iter) { + if (g_case_mem[1] == 0 && + wu_case24_wait_status(g_case24_done, iter, WU_CASE24_DONE_BASE) != + 0) { + g_aux[0] = iter; + g_case_mem[1] = 0x83u; + } + } + if (g_case_mem[1] == 0) { + volatile uint32_t bad_actual = 0; + const uint32_t bad = + wu_bw_verify_constant(g_case24_out, WU_CASE24_OUT_WORDS, + WU_BW_FP32_ONE, &bad_actual); + if (bad != WU_CASE24_OUT_WORDS) { + g_aux[0] = bad; + g_aux[1] = bad_actual; + g_case_mem[1] = 0x84u; + } + } + if (g_case_mem[1] != 0) { + wu_case_fail(g_case_mem[1]); + return 1; + } + wu_case_pass(); + } + return 0; +} diff --git a/kernels/wu_arch_cases/common_wu_min.h b/kernels/wu_arch_cases/common_wu_min.h index eb01b4f7..e77409e8 100644 --- a/kernels/wu_arch_cases/common_wu_min.h +++ b/kernels/wu_arch_cases/common_wu_min.h @@ -134,6 +134,14 @@ static inline int wu_is_leader() { return vx_core_id() == 0 && vx_warp_id() == 0 && vx_thread_id() == 0; } +static inline float wu_fexp_s(float value) { + float result; + asm volatile(".insn r %[custom1], 2, 0x30, %[rd], %[rs1], x0" + : [rd] "=f"(result) + : [rs1] "f"(value), [custom1] "i"(RISCV_CUSTOM1)); + return result; +} + static inline void wu_report_tohost(uint32_t exit_code) { asm volatile("fence rw, rw" ::: "memory"); tohost = (static_cast(exit_code) << 1) | 1u;