feat: add flash pipeline kernel cases
This commit is contained in:
@@ -19,7 +19,13 @@ CASES := \
|
||||
case14_flash_pv_k64 \
|
||||
case15_flash_softmax_pv_stage \
|
||||
case16_flash_full_pipeline \
|
||||
case17_flash_exp_softmax_probe
|
||||
case17_flash_exp_softmax_probe \
|
||||
case18_scalar_fexp \
|
||||
case20_flash_bwd_fused \
|
||||
case21_moe_gating \
|
||||
case22_gemm_silu \
|
||||
case23_softmax_only \
|
||||
case24_flash_sw_pipeline
|
||||
|
||||
SMOKE_CASES := \
|
||||
case00_boot_scalar \
|
||||
|
||||
@@ -25,6 +25,12 @@ This directory contains small bare-metal kernels for incremental Wu architecture
|
||||
- `case15_flash_softmax_pv_stage`: scalar reads TMEM C, writes softmax-like `P`, and tensor consumes it in PV.
|
||||
- `case16_flash_full_pipeline`: compact `QK -> scalar softmax handoff -> PV` end-to-end FlashAttention-style pipeline.
|
||||
- `case17_flash_exp_softmax_probe`: scalar non-uniform `e^x` softmax probe for generalized FlashAttention.
|
||||
- `case18_scalar_fexp`: scalar `FEXP.S` numerical probe.
|
||||
- `case20_flash_bwd_fused`: FlashAttention backward-style fused 5xMMA plus scalar softmax/dsoftmax handoff.
|
||||
- `case21_moe_gating`: scalar `softmax -> Top-K -> scatter` MoE gating pipeline.
|
||||
- `case22_gemm_silu`: tensor GEMM followed by scalar SiLU activation.
|
||||
- `case23_softmax_only`: scalar-only stable softmax probe.
|
||||
- `case24_flash_sw_pipeline`: four-iteration ping-pong FlashAttention-style software pipeline.
|
||||
|
||||
Each case has its own `README.md` describing the test objective, RTL surface, and expected pass marker.
|
||||
|
||||
|
||||
@@ -4,10 +4,12 @@ Validates a compact end-to-end FlashAttention-style pipeline on the current Wu
|
||||
Blackwell path.
|
||||
|
||||
The tensor warp first computes `S = Q @ K` into TMEM C with `Q=1.0`, `K=1.0`,
|
||||
and `O_init=0.0`, producing a constant fp32 score of `32.0`. Scalar warp 0 reads
|
||||
the score through scalar TMEM load, records it, writes the uniform softmax result
|
||||
`P=1/32` into TMEM A, refills SMEM with `V=2.0`, and releases the tensor warp.
|
||||
The tensor warp reloads TMEM C with `O=0.0`, then computes `O = P @ V`.
|
||||
and `O_init=0.0`, producing a constant fp32 score of `32.0`. The scalar warp
|
||||
reads the score row through scalar TMEM loads, scans the row maximum and
|
||||
normalization denominator with scalar-only `FEXP.S`, converts each probability
|
||||
to packed fp16, writes `P` into TMEM A, refills SMEM with `V=2.0`, and releases
|
||||
the tensor warp. The tensor warp reloads TMEM C with `O=0.0`, then computes
|
||||
`O = P @ V`.
|
||||
|
||||
Every output word is expected to be fp32 `2.0`. This case covers the staged
|
||||
`QK -> scalar softmax handoff -> PV` loop without using the legacy
|
||||
|
||||
@@ -4,10 +4,10 @@
|
||||
#define WU_CASE16_P_READY 0x9700u
|
||||
#define WU_CASE16_DONE_BASE 0x9800u
|
||||
|
||||
#define WU_BW_FP16_ONE_OVER_32_PACKED 0x28002800u
|
||||
#define WU_BW_FP32_ZERO 0x00000000u
|
||||
#define WU_BW_FP32_TWO 0x40000000u
|
||||
#define WU_BW_FP32_THIRTY_TWO 0x42000000u
|
||||
#define WU_CASE16_ROW_N 32u
|
||||
|
||||
extern "C" {
|
||||
volatile uint32_t g_case16_q_row[4] __attribute__((aligned(16))) = {
|
||||
@@ -17,6 +17,85 @@ volatile uint32_t g_case16_zero_row[4] __attribute__((aligned(16))) = {
|
||||
WU_BW_FP32_ZERO, WU_BW_FP32_ZERO, WU_BW_FP32_ZERO, WU_BW_FP32_ZERO};
|
||||
volatile uint32_t g_case16_out[WU_BW_OUT_WORDS] __attribute__((aligned(16)));
|
||||
volatile uint32_t g_case16_scalar_seen[4] __attribute__((aligned(16)));
|
||||
volatile uint32_t g_case16_p_bits[4] __attribute__((aligned(16)));
|
||||
}
|
||||
|
||||
static inline float wu_case16_bits_to_f32(uint32_t bits) {
|
||||
union {
|
||||
uint32_t u;
|
||||
float f;
|
||||
} v = {bits};
|
||||
return v.f;
|
||||
}
|
||||
|
||||
static inline uint32_t wu_case16_f32_to_bits(float value) {
|
||||
union {
|
||||
float f;
|
||||
uint32_t u;
|
||||
} v = {value};
|
||||
return v.u;
|
||||
}
|
||||
|
||||
static inline uint16_t wu_case16_f32_to_f16_positive(float value) {
|
||||
const uint32_t bits = wu_case16_f32_to_bits(value);
|
||||
const uint32_t exp = (bits >> 23) & 0xffu;
|
||||
uint32_t mant = bits & 0x7fffffu;
|
||||
|
||||
if (exp == 0 || value <= 0.0f) {
|
||||
return 0;
|
||||
}
|
||||
if (exp >= 143u) {
|
||||
return 0x7c00u;
|
||||
}
|
||||
if (exp <= 112u) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
uint32_t half_exp = exp - 112u;
|
||||
mant += 0x1000u;
|
||||
if (mant & 0x800000u) {
|
||||
mant = 0;
|
||||
++half_exp;
|
||||
}
|
||||
if (half_exp >= 31u) {
|
||||
return 0x7c00u;
|
||||
}
|
||||
return static_cast<uint16_t>((half_exp << 10) | (mant >> 13));
|
||||
}
|
||||
|
||||
static inline uint32_t wu_case16_pack_f16x2(float value) {
|
||||
const uint32_t h = wu_case16_f32_to_f16_positive(value);
|
||||
return h | (h << 16);
|
||||
}
|
||||
|
||||
static inline void wu_case16_softmax_tmem_row_to_p(uint32_t score_frag_base,
|
||||
uint32_t p_byte_base) {
|
||||
float row_max = wu_case16_bits_to_f32(wu_bw_scalar_tmem_ld(score_frag_base));
|
||||
for (uint32_t i = 1; i < WU_CASE16_ROW_N; ++i) {
|
||||
const float score =
|
||||
wu_case16_bits_to_f32(wu_bw_scalar_tmem_ld(score_frag_base + i));
|
||||
row_max = score > row_max ? score : row_max;
|
||||
}
|
||||
|
||||
float denom = 0.0f;
|
||||
for (uint32_t i = 0; i < WU_CASE16_ROW_N; ++i) {
|
||||
const float score =
|
||||
wu_case16_bits_to_f32(wu_bw_scalar_tmem_ld(score_frag_base + i));
|
||||
denom += wu_fexp_s(score - row_max);
|
||||
}
|
||||
|
||||
const uint32_t p_frag_base = p_byte_base / WU_BW_TMEM_FRAGMENT_BYTES;
|
||||
for (uint32_t frag = 0; frag < WU_BW_TMEM_FRAGMENTS; ++frag) {
|
||||
const uint32_t row_idx = frag % WU_CASE16_ROW_N;
|
||||
const float score =
|
||||
wu_case16_bits_to_f32(wu_bw_scalar_tmem_ld(score_frag_base + row_idx));
|
||||
const float p = wu_fexp_s(score - row_max) / denom;
|
||||
if (frag == 0) {
|
||||
g_case16_p_bits[wu_tid()] = wu_case16_f32_to_bits(p);
|
||||
}
|
||||
wu_bw_scalar_tmem_st(p_frag_base + frag, wu_case16_pack_f16x2(p));
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
}
|
||||
|
||||
extern "C" void __attribute__((naked, noinline, used)) tensor_case16_worker() {
|
||||
@@ -114,6 +193,7 @@ extern "C" int wu_main() {
|
||||
}
|
||||
for (uint32_t i = 0; i < 4; ++i) {
|
||||
g_case16_scalar_seen[i] = 0;
|
||||
g_case16_p_bits[i] = 0;
|
||||
}
|
||||
wu_bw_fill_smem_tile(
|
||||
reinterpret_cast<volatile uint32_t *>(WU_BW_DEV_SMEM_START_ADDR),
|
||||
@@ -141,8 +221,7 @@ extern "C" int wu_main() {
|
||||
|
||||
if (g_case_mem[1] == 0) {
|
||||
vx_tmc(wu_bw_all_lanes_mask());
|
||||
wu_bw_fill_tmem_tile(wu_bw_tmem_a_byte_base(0),
|
||||
WU_BW_FP16_ONE_OVER_32_PACKED);
|
||||
wu_case16_softmax_tmem_row_to_p(c_frag, wu_bw_tmem_a_byte_base(0));
|
||||
vx_tmc_one();
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
# case17_flash_exp_softmax_probe
|
||||
|
||||
Validates that the current scalar Wu path can execute the `e^x` work needed by
|
||||
non-uniform FlashAttention softmax. The current ISA/RTL does not expose a
|
||||
dedicated exp or exp2 instruction, so this case uses scalar fp32 arithmetic to
|
||||
approximate exp.
|
||||
Validates that the scalar Wu path can execute the `e^x` work needed by
|
||||
non-uniform FlashAttention softmax through the custom scalar-only `FEXP.S`
|
||||
instruction.
|
||||
|
||||
The case evaluates a two-element row with scores `{0, ln(2)}`. A numerically
|
||||
stable softmax computes `exp(score - row_max)`, so the exp inputs are
|
||||
@@ -12,5 +11,5 @@ inputs are loaded from volatile memory so the compiler cannot fold the result
|
||||
into constants.
|
||||
|
||||
This is intentionally separate from the tensor PV path. If this case fails, the
|
||||
problem is in scalar fp32 arithmetic, exp approximation, or normalization rather
|
||||
than TMEM handoff or BWGMMA.
|
||||
problem is in scalar fp32 `FEXP.S` execution or normalization rather than TMEM
|
||||
handoff or BWGMMA.
|
||||
|
||||
@@ -26,15 +26,6 @@ static inline float wu_case17_absf(float value) {
|
||||
return value < 0.0f ? -value : value;
|
||||
}
|
||||
|
||||
static inline float wu_case17_exp_neg_ln2_to_0(float x) {
|
||||
const float x2 = x * x;
|
||||
const float x3 = x2 * x;
|
||||
const float x4 = x2 * x2;
|
||||
const float x5 = x4 * x;
|
||||
return 1.0f + x + (0.5f * x2) + (0.1666666716f * x3) +
|
||||
(0.0416666679f * x4) + (0.0083333338f * x5);
|
||||
}
|
||||
|
||||
extern "C" int wu_main() {
|
||||
if (vx_core_id() != 0 || vx_warp_id() != 0) {
|
||||
return 0;
|
||||
@@ -51,8 +42,8 @@ extern "C" int wu_main() {
|
||||
const float score0 = wu_case17_bits_to_f32(g_case17_scores_bits[0]);
|
||||
const float score1 = wu_case17_bits_to_f32(g_case17_scores_bits[1]);
|
||||
const float row_max = score0 > score1 ? score0 : score1;
|
||||
const float e0 = wu_case17_exp_neg_ln2_to_0(score0 - row_max);
|
||||
const float e1 = wu_case17_exp_neg_ln2_to_0(score1 - row_max);
|
||||
const float e0 = wu_fexp_s(score0 - row_max);
|
||||
const float e1 = wu_fexp_s(score1 - row_max);
|
||||
const float inv_sum = 1.0f / (e0 + e1);
|
||||
const float p0 = e0 * inv_sum;
|
||||
const float p1 = e1 * inv_sum;
|
||||
|
||||
3
kernels/wu_arch_cases/case18_scalar_fexp/Makefile
Normal file
3
kernels/wu_arch_cases/case18_scalar_fexp/Makefile
Normal file
@@ -0,0 +1,3 @@
|
||||
PROJECT = case18_scalar_fexp
|
||||
|
||||
include ../case.mk
|
||||
5
kernels/wu_arch_cases/case18_scalar_fexp/README.md
Normal file
5
kernels/wu_arch_cases/case18_scalar_fexp/README.md
Normal file
@@ -0,0 +1,5 @@
|
||||
# case18_scalar_fexp
|
||||
|
||||
Verifies scalar-warp execution of the custom `FEXP.S` instruction. The test
|
||||
checks representative fp32 inputs used by softmax-style code paths and confirms
|
||||
the result is close to `expf`.
|
||||
68
kernels/wu_arch_cases/case18_scalar_fexp/kernel.cpp
Normal file
68
kernels/wu_arch_cases/case18_scalar_fexp/kernel.cpp
Normal file
@@ -0,0 +1,68 @@
|
||||
#include "../common_wu_min.h"
|
||||
|
||||
extern "C" {
|
||||
volatile uint32_t g_case18_out_bits[4] __attribute__((aligned(16)));
|
||||
}
|
||||
|
||||
static inline uint32_t f32_to_bits(float value) {
|
||||
union {
|
||||
float f;
|
||||
uint32_t u;
|
||||
} v = {value};
|
||||
return v.u;
|
||||
}
|
||||
|
||||
static inline float absf_local(float value) {
|
||||
return value < 0.0f ? -value : value;
|
||||
}
|
||||
|
||||
extern "C" int wu_main() {
|
||||
if (vx_core_id() != 0 || vx_warp_id() != 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const uint32_t tid = wu_tid();
|
||||
if (tid == 0) {
|
||||
wu_case_reset();
|
||||
for (uint32_t i = 0; i < 4; ++i) {
|
||||
g_case18_out_bits[i] = 0;
|
||||
}
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
const float input0 = 0.0f;
|
||||
const float input1 = 1.0f;
|
||||
const float input2 = -0.6931471805599453f;
|
||||
const float input3 = -10.0f;
|
||||
|
||||
const float out0 = wu_fexp_s(input0);
|
||||
const float out1 = wu_fexp_s(input1);
|
||||
const float out2 = wu_fexp_s(input2);
|
||||
const float out3 = wu_fexp_s(input3);
|
||||
|
||||
if (tid == 0) {
|
||||
g_case18_out_bits[0] = f32_to_bits(out0);
|
||||
g_case18_out_bits[1] = f32_to_bits(out1);
|
||||
g_case18_out_bits[2] = f32_to_bits(out2);
|
||||
g_case18_out_bits[3] = f32_to_bits(out3);
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
if (tid == 0) {
|
||||
const float err0 = absf_local(out0 - 1.0f);
|
||||
const float err1 = absf_local(out1 - 2.7182817459f);
|
||||
const float err2 = absf_local(out2 - 0.5f);
|
||||
const float err3 = absf_local(out3 - 0.00004539993f);
|
||||
if (err0 > 0.00001f || err1 > 0.0002f || err2 > 0.00001f ||
|
||||
err3 > 0.000001f) {
|
||||
g_aux[0] = g_case18_out_bits[0];
|
||||
g_aux[1] = g_case18_out_bits[1];
|
||||
g_aux[2] = g_case18_out_bits[2];
|
||||
g_aux[3] = g_case18_out_bits[3];
|
||||
wu_case_fail(0x18u);
|
||||
return 1;
|
||||
}
|
||||
wu_case_pass();
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
@@ -0,0 +1,3 @@
|
||||
PROJECT = case19_tensor_fexp_illegal
|
||||
|
||||
include ../case.mk
|
||||
@@ -0,0 +1,5 @@
|
||||
# case19_tensor_fexp_illegal
|
||||
|
||||
Negative test for `FEXP.S`: tensor warps must not execute this scalar FPU
|
||||
instruction. Running this case is expected to trip the existing tensor-FPU
|
||||
illegal-instruction path in decode/dispatch rather than complete normally.
|
||||
25
kernels/wu_arch_cases/case19_tensor_fexp_illegal/kernel.cpp
Normal file
25
kernels/wu_arch_cases/case19_tensor_fexp_illegal/kernel.cpp
Normal file
@@ -0,0 +1,25 @@
|
||||
#include "../common_wu_min.h"
|
||||
|
||||
extern "C" void __attribute__((naked, noinline, used)) tensor_worker() {
|
||||
asm volatile(
|
||||
"fmv.w.x f1, x0\n\t"
|
||||
".insn r %[custom1], 2, 0x30, f2, f1, x0\n\t"
|
||||
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
|
||||
"1: j 1b\n\t"
|
||||
:
|
||||
: [custom0] "i"(RISCV_CUSTOM0),
|
||||
[custom1] "i"(RISCV_CUSTOM1)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
extern "C" int wu_main() {
|
||||
if (!wu_is_leader()) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
wu_case_reset();
|
||||
vx_spawn_tensor(vx_tensor_warp_mask(), tensor_worker);
|
||||
|
||||
wu_case_fail(0x19u);
|
||||
return 1;
|
||||
}
|
||||
3
kernels/wu_arch_cases/case20_flash_bwd_fused/Makefile
Normal file
3
kernels/wu_arch_cases/case20_flash_bwd_fused/Makefile
Normal file
@@ -0,0 +1,3 @@
|
||||
PROJECT = case20_flash_bwd_fused
|
||||
|
||||
include ../case.mk
|
||||
19
kernels/wu_arch_cases/case20_flash_bwd_fused/README.md
Normal file
19
kernels/wu_arch_cases/case20_flash_bwd_fused/README.md
Normal file
@@ -0,0 +1,19 @@
|
||||
# case20_flash_bwd_fused
|
||||
|
||||
FlashAttention backward-style fused pipeline smoke test.
|
||||
|
||||
The tensor warp performs one score MMA, then waits for the scalar warp to run
|
||||
softmax plus dsoftmax on the TMEM C row. The scalar warp writes the dS row back
|
||||
to TMEM A using signed fp16 values. The tensor warp then performs four more
|
||||
MMA steps, for five MMA operations total in this case.
|
||||
|
||||
This case verifies:
|
||||
|
||||
- tensor warp MMA sequencing around a scalar TMEM handoff;
|
||||
- scalar-only `FEXP.S` use for stable softmax;
|
||||
- dsoftmax shape `dS = P * (dP - sum(P * dP))`;
|
||||
- signed scalar TMEM stores feeding later tensor MMA operations.
|
||||
|
||||
The input score row is uniform, so `P = 1/32`. The synthetic upstream gradient
|
||||
uses two buckets, producing exact dS values `-1/32` for row entries 0..15 and
|
||||
`+1/32` for row entries 16..31.
|
||||
289
kernels/wu_arch_cases/case20_flash_bwd_fused/kernel.cpp
Normal file
289
kernels/wu_arch_cases/case20_flash_bwd_fused/kernel.cpp
Normal file
@@ -0,0 +1,289 @@
|
||||
#include "../common_wu_blackwell_fa.h"
|
||||
|
||||
#define WU_CASE20_SCORE_READY 0xa000u
|
||||
#define WU_CASE20_DSOFTMAX_READY 0xa100u
|
||||
#define WU_CASE20_DONE_BASE 0xa200u
|
||||
#define WU_CASE20_ROW_N 32u
|
||||
#define WU_CASE20_FP32_ZERO 0x00000000u
|
||||
#define WU_CASE20_FP32_THIRTY_TWO 0x42000000u
|
||||
|
||||
extern "C" {
|
||||
volatile uint32_t g_case20_q_row[4] __attribute__((aligned(16))) = {
|
||||
WU_BW_FP16_ONE_PACKED, WU_BW_FP16_ONE_PACKED, WU_BW_FP16_ONE_PACKED,
|
||||
WU_BW_FP16_ONE_PACKED};
|
||||
volatile uint32_t g_case20_zero_row[4] __attribute__((aligned(16))) = {
|
||||
WU_CASE20_FP32_ZERO, WU_CASE20_FP32_ZERO, WU_CASE20_FP32_ZERO,
|
||||
WU_CASE20_FP32_ZERO};
|
||||
volatile uint32_t g_case20_out[WU_BW_OUT_WORDS] __attribute__((aligned(16)));
|
||||
volatile uint32_t g_case20_score_bits[4] __attribute__((aligned(16)));
|
||||
volatile uint32_t g_case20_dsoftmax_bits[4] __attribute__((aligned(16)));
|
||||
}
|
||||
|
||||
static inline float wu_case20_bits_to_f32(uint32_t bits) {
|
||||
union {
|
||||
uint32_t u;
|
||||
float f;
|
||||
} v = {bits};
|
||||
return v.f;
|
||||
}
|
||||
|
||||
static inline uint32_t wu_case20_f32_to_bits(float value) {
|
||||
union {
|
||||
float f;
|
||||
uint32_t u;
|
||||
} v = {value};
|
||||
return v.u;
|
||||
}
|
||||
|
||||
static inline float wu_case20_absf(float value) {
|
||||
return value < 0.0f ? -value : value;
|
||||
}
|
||||
|
||||
static inline uint16_t wu_case20_f32_to_f16(float value) {
|
||||
const uint32_t bits = wu_case20_f32_to_bits(value);
|
||||
const uint32_t sign = (bits >> 16) & 0x8000u;
|
||||
const uint32_t exp = (bits >> 23) & 0xffu;
|
||||
uint32_t mant = bits & 0x7fffffu;
|
||||
|
||||
if ((bits & 0x7fffffffu) == 0 || exp == 0) {
|
||||
return static_cast<uint16_t>(sign);
|
||||
}
|
||||
if (exp >= 143u) {
|
||||
return static_cast<uint16_t>(sign | 0x7c00u);
|
||||
}
|
||||
if (exp <= 112u) {
|
||||
return static_cast<uint16_t>(sign);
|
||||
}
|
||||
|
||||
uint32_t half_exp = exp - 112u;
|
||||
mant += 0x1000u;
|
||||
if (mant & 0x800000u) {
|
||||
mant = 0;
|
||||
++half_exp;
|
||||
}
|
||||
if (half_exp >= 31u) {
|
||||
return static_cast<uint16_t>(sign | 0x7c00u);
|
||||
}
|
||||
return static_cast<uint16_t>(sign | (half_exp << 10) | (mant >> 13));
|
||||
}
|
||||
|
||||
static inline uint32_t wu_case20_pack_f16x2(float value) {
|
||||
const uint32_t h = wu_case20_f32_to_f16(value);
|
||||
return h | (h << 16);
|
||||
}
|
||||
|
||||
static inline float wu_case20_dp(uint32_t row_idx) {
|
||||
return row_idx < 16u ? 0.0f : 2.0f;
|
||||
}
|
||||
|
||||
static inline void wu_case20_dsoftmax_tmem_row(uint32_t score_frag_base,
|
||||
uint32_t ds_byte_base) {
|
||||
float row_max = wu_case20_bits_to_f32(wu_bw_scalar_tmem_ld(score_frag_base));
|
||||
for (uint32_t i = 1; i < WU_CASE20_ROW_N; ++i) {
|
||||
const float score =
|
||||
wu_case20_bits_to_f32(wu_bw_scalar_tmem_ld(score_frag_base + i));
|
||||
row_max = score > row_max ? score : row_max;
|
||||
}
|
||||
|
||||
float denom = 0.0f;
|
||||
for (uint32_t i = 0; i < WU_CASE20_ROW_N; ++i) {
|
||||
const float score =
|
||||
wu_case20_bits_to_f32(wu_bw_scalar_tmem_ld(score_frag_base + i));
|
||||
denom += wu_fexp_s(score - row_max);
|
||||
}
|
||||
|
||||
float dot = 0.0f;
|
||||
for (uint32_t i = 0; i < WU_CASE20_ROW_N; ++i) {
|
||||
const float score =
|
||||
wu_case20_bits_to_f32(wu_bw_scalar_tmem_ld(score_frag_base + i));
|
||||
const float p = wu_fexp_s(score - row_max) / denom;
|
||||
dot += p * wu_case20_dp(i);
|
||||
}
|
||||
|
||||
const uint32_t ds_frag_base = ds_byte_base / WU_BW_TMEM_FRAGMENT_BYTES;
|
||||
for (uint32_t frag = 0; frag < WU_BW_TMEM_FRAGMENTS; ++frag) {
|
||||
const uint32_t row_idx = frag % WU_CASE20_ROW_N;
|
||||
const float score =
|
||||
wu_case20_bits_to_f32(wu_bw_scalar_tmem_ld(score_frag_base + row_idx));
|
||||
const float p = wu_fexp_s(score - row_max) / denom;
|
||||
const float ds = p * (wu_case20_dp(row_idx) - dot);
|
||||
if (wu_tid() == 0 && row_idx == 0) {
|
||||
g_case20_dsoftmax_bits[0] = wu_case20_f32_to_bits(ds);
|
||||
}
|
||||
if (wu_tid() == 0 && row_idx == 16u) {
|
||||
g_case20_dsoftmax_bits[1] = wu_case20_f32_to_bits(ds);
|
||||
}
|
||||
wu_bw_scalar_tmem_st(ds_frag_base + frag, wu_case20_pack_f16x2(ds));
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
}
|
||||
|
||||
extern "C" void __attribute__((naked, noinline, used)) tensor_case20_worker() {
|
||||
asm volatile(
|
||||
"csrr x5, %[csr_wid]\n\t"
|
||||
"addi x1, x5, -%[num_scalar_warps]\n\t"
|
||||
"slli x1, x1, 11\n\t"
|
||||
"addi x2, x1, %[c_offset]\n\t"
|
||||
"la x3, g_case20_q_row\n\t"
|
||||
"li x7, 0\n\t"
|
||||
"1:\n\t"
|
||||
"add x4, x1, x7\n\t"
|
||||
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
|
||||
"addi x7, x7, 16\n\t"
|
||||
"li x4, %[tile_bytes]\n\t"
|
||||
"blt x7, x4, 1b\n\t"
|
||||
"la x3, g_case20_zero_row\n\t"
|
||||
"li x7, 0\n\t"
|
||||
"2:\n\t"
|
||||
"add x4, x2, x7\n\t"
|
||||
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
|
||||
"addi x7, x7, 16\n\t"
|
||||
"li x4, %[tile_bytes]\n\t"
|
||||
"blt x7, x4, 2b\n\t"
|
||||
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||
"li x4, %[smem_base]\n\t"
|
||||
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
|
||||
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
|
||||
"csrr x5, %[csr_wid]\n\t"
|
||||
"slli x6, x5, 2\n\t"
|
||||
"la x7, g_seen\n\t"
|
||||
"add x7, x7, x6\n\t"
|
||||
"li x6, %[score_ready]\n\t"
|
||||
"or x6, x6, x5\n\t"
|
||||
"sw x6, 0(x7)\n\t"
|
||||
"3:\n\t"
|
||||
"la x6, g_case_mem\n\t"
|
||||
"lw x7, 0(x6)\n\t"
|
||||
"li x4, %[dsoftmax_ready]\n\t"
|
||||
"bne x7, x4, 3b\n\t"
|
||||
"la x3, g_case20_zero_row\n\t"
|
||||
"li x7, 0\n\t"
|
||||
"4:\n\t"
|
||||
"add x4, x2, x7\n\t"
|
||||
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
|
||||
"addi x7, x7, 16\n\t"
|
||||
"li x4, %[tile_bytes]\n\t"
|
||||
"blt x7, x4, 4b\n\t"
|
||||
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||
"li x4, %[smem_base]\n\t"
|
||||
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
|
||||
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
|
||||
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
|
||||
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
|
||||
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
|
||||
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
|
||||
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
|
||||
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
|
||||
"la x3, g_case20_out\n\t"
|
||||
"li x7, 0\n\t"
|
||||
"5:\n\t"
|
||||
"add x4, x2, x7\n\t"
|
||||
"add x6, x3, x7\n\t"
|
||||
".insn r %[custom3], 6, 0, x0, x4, x6\n\t"
|
||||
"addi x7, x7, 16\n\t"
|
||||
"li x4, %[tile_bytes]\n\t"
|
||||
"blt x7, x4, 5b\n\t"
|
||||
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||
"csrr x5, %[csr_wid]\n\t"
|
||||
"slli x6, x5, 2\n\t"
|
||||
"la x7, g_seen\n\t"
|
||||
"add x7, x7, x6\n\t"
|
||||
"li x6, %[done_base]\n\t"
|
||||
"or x6, x6, x5\n\t"
|
||||
"sw x6, 0(x7)\n\t"
|
||||
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
|
||||
"6: j 6b\n\t"
|
||||
:
|
||||
: [csr_wid] "i"(VX_CSR_WARP_ID), [custom0] "i"(RISCV_CUSTOM0),
|
||||
[custom3] "i"(RISCV_CUSTOM3),
|
||||
[num_scalar_warps] "i"(NUM_SCALAR_WARPS),
|
||||
[c_offset] "i"(WU_BW_TMEM_C_BYTE_OFFSET),
|
||||
[tile_bytes] "i"(WU_BW_TMEM_TILE_BYTES),
|
||||
[smem_base] "i"(WU_BW_DEV_SMEM_START_ADDR),
|
||||
[score_ready] "i"(WU_CASE20_SCORE_READY),
|
||||
[dsoftmax_ready] "i"(WU_CASE20_DSOFTMAX_READY),
|
||||
[done_base] "i"(WU_CASE20_DONE_BASE)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
extern "C" int wu_main() {
|
||||
if (vx_core_id() != 0 || vx_warp_id() != 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const uint32_t tid = wu_tid();
|
||||
const uint32_t tensor_mask = 1u << NUM_SCALAR_WARPS;
|
||||
if (tid == 0) {
|
||||
wu_case_reset();
|
||||
for (uint32_t i = 0; i < WU_BW_OUT_WORDS; ++i) {
|
||||
g_case20_out[i] = 0xffffffffu;
|
||||
}
|
||||
for (uint32_t i = 0; i < 4; ++i) {
|
||||
g_case20_score_bits[i] = 0;
|
||||
g_case20_dsoftmax_bits[i] = 0;
|
||||
}
|
||||
wu_bw_fill_smem_tile(
|
||||
reinterpret_cast<volatile uint32_t *>(WU_BW_DEV_SMEM_START_ADDR),
|
||||
WU_BW_FP16_ONE_PACKED);
|
||||
vx_spawn_tensor(tensor_mask, tensor_case20_worker);
|
||||
if (wu_wait_seen_mask(tensor_mask, WU_CASE20_SCORE_READY) != 0) {
|
||||
g_case_mem[1] = 0x41u;
|
||||
}
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
const uint32_t c_frag =
|
||||
wu_bw_tmem_c_byte_base(0) / WU_BW_TMEM_FRAGMENT_BYTES;
|
||||
const uint32_t observed = wu_bw_scalar_tmem_ld(c_frag);
|
||||
if (tid == 0) {
|
||||
g_case20_score_bits[0] = observed;
|
||||
if (g_case_mem[1] == 0 && observed != WU_CASE20_FP32_THIRTY_TWO) {
|
||||
g_aux[0] = observed;
|
||||
g_case_mem[1] = 0x42u;
|
||||
}
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
if (g_case_mem[1] == 0) {
|
||||
vx_tmc(wu_bw_all_lanes_mask());
|
||||
wu_case20_dsoftmax_tmem_row(c_frag, wu_bw_tmem_a_byte_base(0));
|
||||
vx_tmc_one();
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
if (tid == 0) {
|
||||
if (g_case_mem[1] == 0) {
|
||||
const float neg = wu_case20_bits_to_f32(g_case20_dsoftmax_bits[0]);
|
||||
const float pos = wu_case20_bits_to_f32(g_case20_dsoftmax_bits[1]);
|
||||
if (wu_case20_absf(neg + 0.03125f) > 0.0002f ||
|
||||
wu_case20_absf(pos - 0.03125f) > 0.0002f) {
|
||||
g_aux[0] = g_case20_dsoftmax_bits[0];
|
||||
g_aux[1] = g_case20_dsoftmax_bits[1];
|
||||
g_case_mem[1] = 0x43u;
|
||||
}
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
g_case_mem[0] = WU_CASE20_DSOFTMAX_READY;
|
||||
if (g_case_mem[1] == 0 &&
|
||||
wu_wait_seen_mask(tensor_mask, WU_CASE20_DONE_BASE) != 0) {
|
||||
g_case_mem[1] = 0x44u;
|
||||
}
|
||||
if (g_case_mem[1] == 0) {
|
||||
volatile uint32_t bad_actual = 0;
|
||||
const uint32_t bad =
|
||||
wu_bw_verify_constant(g_case20_out, WU_BW_OUT_WORDS,
|
||||
WU_CASE20_FP32_ZERO, &bad_actual);
|
||||
if (bad != WU_BW_OUT_WORDS) {
|
||||
g_aux[0] = bad;
|
||||
g_aux[1] = bad_actual;
|
||||
g_case_mem[1] = 0x45u;
|
||||
}
|
||||
}
|
||||
if (g_case_mem[1] != 0) {
|
||||
wu_case_fail(g_case_mem[1]);
|
||||
return 1;
|
||||
}
|
||||
wu_case_pass();
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
3
kernels/wu_arch_cases/case21_moe_gating/Makefile
Normal file
3
kernels/wu_arch_cases/case21_moe_gating/Makefile
Normal file
@@ -0,0 +1,3 @@
|
||||
PROJECT = case21_moe_gating
|
||||
|
||||
include ../case.mk
|
||||
10
kernels/wu_arch_cases/case21_moe_gating/README.md
Normal file
10
kernels/wu_arch_cases/case21_moe_gating/README.md
Normal file
@@ -0,0 +1,10 @@
|
||||
# case21_moe_gating
|
||||
|
||||
MoE gating scalar pipeline test.
|
||||
|
||||
This case runs `softmax -> Top-K -> scatter` on scalar warp 0 using `FEXP.S`.
|
||||
The logits are `log(1), log(2), log(4), log(8)`, so the expected probabilities
|
||||
are `1/15, 2/15, 4/15, 8/15`. Top-2 should select experts 3 and 2, then scatter
|
||||
the token id and weight into the selected expert slots.
|
||||
|
||||
No tensor warp is spawned in this case.
|
||||
128
kernels/wu_arch_cases/case21_moe_gating/kernel.cpp
Normal file
128
kernels/wu_arch_cases/case21_moe_gating/kernel.cpp
Normal file
@@ -0,0 +1,128 @@
|
||||
#include "../common_wu_min.h"
|
||||
|
||||
#define WU_CASE21_TOKEN_ID 0x21u
|
||||
#define WU_CASE21_EMPTY 0xffffffffu
|
||||
|
||||
extern "C" {
|
||||
volatile uint32_t g_case21_logits_bits[4] __attribute__((aligned(16))) = {
|
||||
0x00000000u, 0x3f317218u, 0x3fb17218u, 0x40051d8fu};
|
||||
volatile uint32_t g_case21_prob_bits[4] __attribute__((aligned(16)));
|
||||
volatile uint32_t g_case21_top_idx[2] __attribute__((aligned(16)));
|
||||
volatile uint32_t g_case21_expert_token[4] __attribute__((aligned(16)));
|
||||
volatile uint32_t g_case21_expert_weight_bits[4] __attribute__((aligned(16)));
|
||||
}
|
||||
|
||||
static inline float wu_case21_bits_to_f32(uint32_t bits) {
|
||||
union {
|
||||
uint32_t u;
|
||||
float f;
|
||||
} v = {bits};
|
||||
return v.f;
|
||||
}
|
||||
|
||||
static inline uint32_t wu_case21_f32_to_bits(float value) {
|
||||
union {
|
||||
float f;
|
||||
uint32_t u;
|
||||
} v = {value};
|
||||
return v.u;
|
||||
}
|
||||
|
||||
static inline float wu_case21_absf(float value) {
|
||||
return value < 0.0f ? -value : value;
|
||||
}
|
||||
|
||||
extern "C" int wu_main() {
|
||||
if (vx_core_id() != 0 || vx_warp_id() != 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const uint32_t tid = wu_tid();
|
||||
if (tid == 0) {
|
||||
wu_case_reset();
|
||||
for (uint32_t i = 0; i < 4; ++i) {
|
||||
g_case21_prob_bits[i] = 0;
|
||||
g_case21_expert_token[i] = WU_CASE21_EMPTY;
|
||||
g_case21_expert_weight_bits[i] = 0;
|
||||
}
|
||||
g_case21_top_idx[0] = WU_CASE21_EMPTY;
|
||||
g_case21_top_idx[1] = WU_CASE21_EMPTY;
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
float logits[4];
|
||||
float row_max = wu_case21_bits_to_f32(g_case21_logits_bits[0]);
|
||||
for (uint32_t i = 0; i < 4; ++i) {
|
||||
logits[i] = wu_case21_bits_to_f32(g_case21_logits_bits[i]);
|
||||
row_max = logits[i] > row_max ? logits[i] : row_max;
|
||||
}
|
||||
|
||||
float exp_values[4];
|
||||
float denom = 0.0f;
|
||||
for (uint32_t i = 0; i < 4; ++i) {
|
||||
exp_values[i] = wu_fexp_s(logits[i] - row_max);
|
||||
denom += exp_values[i];
|
||||
}
|
||||
|
||||
float probs[4];
|
||||
for (uint32_t i = 0; i < 4; ++i) {
|
||||
probs[i] = exp_values[i] / denom;
|
||||
}
|
||||
|
||||
uint32_t top0 = 0;
|
||||
uint32_t top1 = 1;
|
||||
if (probs[top1] > probs[top0]) {
|
||||
const uint32_t tmp = top0;
|
||||
top0 = top1;
|
||||
top1 = tmp;
|
||||
}
|
||||
for (uint32_t i = 2; i < 4; ++i) {
|
||||
if (probs[i] > probs[top0]) {
|
||||
top1 = top0;
|
||||
top0 = i;
|
||||
} else if (probs[i] > probs[top1]) {
|
||||
top1 = i;
|
||||
}
|
||||
}
|
||||
|
||||
if (tid == 0) {
|
||||
for (uint32_t i = 0; i < 4; ++i) {
|
||||
g_case21_prob_bits[i] = wu_case21_f32_to_bits(probs[i]);
|
||||
}
|
||||
g_case21_top_idx[0] = top0;
|
||||
g_case21_top_idx[1] = top1;
|
||||
g_case21_expert_token[top0] = WU_CASE21_TOKEN_ID;
|
||||
g_case21_expert_token[top1] = WU_CASE21_TOKEN_ID;
|
||||
g_case21_expert_weight_bits[top0] = wu_case21_f32_to_bits(probs[top0]);
|
||||
g_case21_expert_weight_bits[top1] = wu_case21_f32_to_bits(probs[top1]);
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
if (tid == 0) {
|
||||
const float expected[4] = {0.0666666701f, 0.1333333403f,
|
||||
0.2666666806f, 0.5333333611f};
|
||||
const float tolerance = 0.0015f;
|
||||
for (uint32_t i = 0; i < 4; ++i) {
|
||||
if (wu_case21_absf(probs[i] - expected[i]) > tolerance) {
|
||||
g_aux[0] = i;
|
||||
g_aux[1] = g_case21_prob_bits[i];
|
||||
wu_case_fail(0x21u);
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
if (top0 != 3u || top1 != 2u ||
|
||||
g_case21_expert_token[3] != WU_CASE21_TOKEN_ID ||
|
||||
g_case21_expert_token[2] != WU_CASE21_TOKEN_ID ||
|
||||
g_case21_expert_token[0] != WU_CASE21_EMPTY ||
|
||||
g_case21_expert_token[1] != WU_CASE21_EMPTY) {
|
||||
g_aux[0] = top0;
|
||||
g_aux[1] = top1;
|
||||
g_aux[2] = g_case21_expert_token[3];
|
||||
g_aux[3] = g_case21_expert_token[2];
|
||||
wu_case_fail(0x22u);
|
||||
return 1;
|
||||
}
|
||||
wu_case_pass();
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
3
kernels/wu_arch_cases/case22_gemm_silu/Makefile
Normal file
3
kernels/wu_arch_cases/case22_gemm_silu/Makefile
Normal file
@@ -0,0 +1,3 @@
|
||||
PROJECT = case22_gemm_silu
|
||||
|
||||
include ../case.mk
|
||||
10
kernels/wu_arch_cases/case22_gemm_silu/README.md
Normal file
10
kernels/wu_arch_cases/case22_gemm_silu/README.md
Normal file
@@ -0,0 +1,10 @@
|
||||
# case22_gemm_silu
|
||||
|
||||
GEMM plus SiLU fusion smoke test.
|
||||
|
||||
The tensor warp computes a compact GEMM with fp16 `A = 0.125` and fp16 `B = 1`,
|
||||
producing fp32 `C = 4`. Scalar warp 0 reads TMEM C and applies
|
||||
`SiLU(x) = x / (1 + exp(-x))` using scalar-only `FEXP.S`.
|
||||
|
||||
This case verifies the common `matmul -> nonlinear activation` fusion path
|
||||
without allowing tensor warp FPU execution.
|
||||
160
kernels/wu_arch_cases/case22_gemm_silu/kernel.cpp
Normal file
160
kernels/wu_arch_cases/case22_gemm_silu/kernel.cpp
Normal file
@@ -0,0 +1,160 @@
|
||||
#include "../common_wu_blackwell_fa.h"
|
||||
|
||||
#define WU_CASE22_INIT_BASE 0xa400u
|
||||
#define WU_CASE22_DONE_BASE 0xa500u
|
||||
#define WU_CASE22_FP16_ONE_EIGHTH_PACKED 0x30003000u
|
||||
#define WU_CASE22_FP32_FOUR 0x40800000u
|
||||
|
||||
extern "C" {
|
||||
volatile uint32_t g_case22_a_row[4] __attribute__((aligned(16))) = {
|
||||
WU_CASE22_FP16_ONE_EIGHTH_PACKED, WU_CASE22_FP16_ONE_EIGHTH_PACKED,
|
||||
WU_CASE22_FP16_ONE_EIGHTH_PACKED, WU_CASE22_FP16_ONE_EIGHTH_PACKED};
|
||||
volatile uint32_t g_case22_zero_row[4] __attribute__((aligned(16))) = {
|
||||
0x00000000u, 0x00000000u, 0x00000000u, 0x00000000u};
|
||||
volatile uint32_t g_case22_out[WU_BW_OUT_WORDS] __attribute__((aligned(16)));
|
||||
volatile uint32_t g_case22_silu_bits[4] __attribute__((aligned(16)));
|
||||
}
|
||||
|
||||
static inline float wu_case22_bits_to_f32(uint32_t bits) {
|
||||
union {
|
||||
uint32_t u;
|
||||
float f;
|
||||
} v = {bits};
|
||||
return v.f;
|
||||
}
|
||||
|
||||
static inline uint32_t wu_case22_f32_to_bits(float value) {
|
||||
union {
|
||||
float f;
|
||||
uint32_t u;
|
||||
} v = {value};
|
||||
return v.u;
|
||||
}
|
||||
|
||||
static inline float wu_case22_absf(float value) {
|
||||
return value < 0.0f ? -value : value;
|
||||
}
|
||||
|
||||
extern "C" void __attribute__((naked, noinline, used)) tensor_case22_worker() {
|
||||
asm volatile(
|
||||
"csrr x5, %[csr_wid]\n\t"
|
||||
"addi x1, x5, -%[num_scalar_warps]\n\t"
|
||||
"slli x1, x1, 11\n\t"
|
||||
"addi x2, x1, %[c_offset]\n\t"
|
||||
"la x3, g_case22_a_row\n\t"
|
||||
"li x7, 0\n\t"
|
||||
"1:\n\t"
|
||||
"add x4, x1, x7\n\t"
|
||||
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
|
||||
"addi x7, x7, 16\n\t"
|
||||
"li x4, %[tile_bytes]\n\t"
|
||||
"blt x7, x4, 1b\n\t"
|
||||
"la x3, g_case22_zero_row\n\t"
|
||||
"li x7, 0\n\t"
|
||||
"2:\n\t"
|
||||
"add x4, x2, x7\n\t"
|
||||
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
|
||||
"addi x7, x7, 16\n\t"
|
||||
"li x4, %[tile_bytes]\n\t"
|
||||
"blt x7, x4, 2b\n\t"
|
||||
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||
"li x4, %[smem_base]\n\t"
|
||||
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
|
||||
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
|
||||
"la x3, g_case22_out\n\t"
|
||||
"li x7, 0\n\t"
|
||||
"3:\n\t"
|
||||
"add x4, x2, x7\n\t"
|
||||
"add x6, x3, x7\n\t"
|
||||
".insn r %[custom3], 6, 0, x0, x4, x6\n\t"
|
||||
"addi x7, x7, 16\n\t"
|
||||
"li x4, %[tile_bytes]\n\t"
|
||||
"blt x7, x4, 3b\n\t"
|
||||
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||
"csrr x5, %[csr_wid]\n\t"
|
||||
"slli x6, x5, 2\n\t"
|
||||
"la x7, g_seen\n\t"
|
||||
"add x7, x7, x6\n\t"
|
||||
"li x6, %[done_base]\n\t"
|
||||
"or x6, x6, x5\n\t"
|
||||
"sw x6, 0(x7)\n\t"
|
||||
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
|
||||
"4: j 4b\n\t"
|
||||
:
|
||||
: [csr_wid] "i"(VX_CSR_WARP_ID), [custom0] "i"(RISCV_CUSTOM0),
|
||||
[custom3] "i"(RISCV_CUSTOM3),
|
||||
[num_scalar_warps] "i"(NUM_SCALAR_WARPS),
|
||||
[c_offset] "i"(WU_BW_TMEM_C_BYTE_OFFSET),
|
||||
[tile_bytes] "i"(WU_BW_TMEM_TILE_BYTES),
|
||||
[smem_base] "i"(WU_BW_DEV_SMEM_START_ADDR),
|
||||
[done_base] "i"(WU_CASE22_DONE_BASE)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
extern "C" int wu_main() {
|
||||
if (vx_core_id() != 0 || vx_warp_id() != 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const uint32_t tid = wu_tid();
|
||||
const uint32_t tensor_mask = 1u << NUM_SCALAR_WARPS;
|
||||
if (tid == 0) {
|
||||
wu_case_reset();
|
||||
for (uint32_t i = 0; i < WU_BW_OUT_WORDS; ++i) {
|
||||
g_case22_out[i] = 0;
|
||||
}
|
||||
for (uint32_t i = 0; i < 4; ++i) {
|
||||
g_case22_silu_bits[i] = 0;
|
||||
}
|
||||
wu_bw_fill_smem_tile(
|
||||
reinterpret_cast<volatile uint32_t *>(WU_BW_DEV_SMEM_START_ADDR),
|
||||
WU_BW_FP16_ONE_PACKED);
|
||||
vx_spawn_tensor(tensor_mask, tensor_case22_worker);
|
||||
if (wu_wait_seen_mask(tensor_mask, WU_CASE22_DONE_BASE) != 0) {
|
||||
g_case_mem[1] = 0x51u;
|
||||
}
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
const uint32_t c_frag =
|
||||
wu_bw_tmem_c_byte_base(0) / WU_BW_TMEM_FRAGMENT_BYTES;
|
||||
const uint32_t observed_bits = wu_bw_scalar_tmem_ld(c_frag);
|
||||
const float observed = wu_case22_bits_to_f32(observed_bits);
|
||||
const float silu = observed / (1.0f + wu_fexp_s(-observed));
|
||||
|
||||
if (tid == 0) {
|
||||
g_case22_silu_bits[0] = wu_case22_f32_to_bits(silu);
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
if (tid == 0) {
|
||||
if (g_case_mem[1] == 0 && observed_bits != WU_CASE22_FP32_FOUR) {
|
||||
g_aux[0] = observed_bits;
|
||||
g_case_mem[1] = 0x52u;
|
||||
}
|
||||
if (g_case_mem[1] == 0) {
|
||||
const float expected = 3.9280550480f;
|
||||
if (wu_case22_absf(silu - expected) > 0.004f) {
|
||||
g_aux[0] = g_case22_silu_bits[0];
|
||||
g_case_mem[1] = 0x53u;
|
||||
}
|
||||
}
|
||||
if (g_case_mem[1] == 0) {
|
||||
volatile uint32_t bad_actual = 0;
|
||||
const uint32_t bad =
|
||||
wu_bw_verify_constant(g_case22_out, WU_BW_OUT_WORDS,
|
||||
WU_CASE22_FP32_FOUR, &bad_actual);
|
||||
if (bad != WU_BW_OUT_WORDS) {
|
||||
g_aux[0] = bad;
|
||||
g_aux[1] = bad_actual;
|
||||
g_case_mem[1] = 0x54u;
|
||||
}
|
||||
}
|
||||
if (g_case_mem[1] != 0) {
|
||||
wu_case_fail(g_case_mem[1]);
|
||||
return 1;
|
||||
}
|
||||
wu_case_pass();
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
3
kernels/wu_arch_cases/case23_softmax_only/Makefile
Normal file
3
kernels/wu_arch_cases/case23_softmax_only/Makefile
Normal file
@@ -0,0 +1,3 @@
|
||||
PROJECT = case23_softmax_only
|
||||
|
||||
include ../case.mk
|
||||
9
kernels/wu_arch_cases/case23_softmax_only/README.md
Normal file
9
kernels/wu_arch_cases/case23_softmax_only/README.md
Normal file
@@ -0,0 +1,9 @@
|
||||
# case23_softmax_only
|
||||
|
||||
Scalar softmax-only test.
|
||||
|
||||
This case runs a stable 4-way softmax on scalar warp 0 using `FEXP.S`. The logits
|
||||
are `log(1), log(3), log(5), log(7)`, giving expected probabilities
|
||||
`1/16, 3/16, 5/16, 7/16`.
|
||||
|
||||
No tensor warp is spawned in this case.
|
||||
83
kernels/wu_arch_cases/case23_softmax_only/kernel.cpp
Normal file
83
kernels/wu_arch_cases/case23_softmax_only/kernel.cpp
Normal file
@@ -0,0 +1,83 @@
|
||||
#include "../common_wu_min.h"
|
||||
|
||||
extern "C" {
|
||||
volatile uint32_t g_case23_scores_bits[4] __attribute__((aligned(16))) = {
|
||||
0x00000000u, 0x3f8c9f54u, 0x3fcdf854u, 0x3ff91395u};
|
||||
volatile uint32_t g_case23_out_bits[4] __attribute__((aligned(16)));
|
||||
}
|
||||
|
||||
static inline float wu_case23_bits_to_f32(uint32_t bits) {
|
||||
union {
|
||||
uint32_t u;
|
||||
float f;
|
||||
} v = {bits};
|
||||
return v.f;
|
||||
}
|
||||
|
||||
static inline uint32_t wu_case23_f32_to_bits(float value) {
|
||||
union {
|
||||
float f;
|
||||
uint32_t u;
|
||||
} v = {value};
|
||||
return v.u;
|
||||
}
|
||||
|
||||
static inline float wu_case23_absf(float value) {
|
||||
return value < 0.0f ? -value : value;
|
||||
}
|
||||
|
||||
extern "C" int wu_main() {
|
||||
if (vx_core_id() != 0 || vx_warp_id() != 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const uint32_t tid = wu_tid();
|
||||
if (tid == 0) {
|
||||
wu_case_reset();
|
||||
for (uint32_t i = 0; i < 4; ++i) {
|
||||
g_case23_out_bits[i] = 0;
|
||||
}
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
float scores[4];
|
||||
float row_max = wu_case23_bits_to_f32(g_case23_scores_bits[0]);
|
||||
for (uint32_t i = 0; i < 4; ++i) {
|
||||
scores[i] = wu_case23_bits_to_f32(g_case23_scores_bits[i]);
|
||||
row_max = scores[i] > row_max ? scores[i] : row_max;
|
||||
}
|
||||
|
||||
float exp_values[4];
|
||||
float denom = 0.0f;
|
||||
for (uint32_t i = 0; i < 4; ++i) {
|
||||
exp_values[i] = wu_fexp_s(scores[i] - row_max);
|
||||
denom += exp_values[i];
|
||||
}
|
||||
|
||||
float probs[4];
|
||||
for (uint32_t i = 0; i < 4; ++i) {
|
||||
probs[i] = exp_values[i] / denom;
|
||||
}
|
||||
|
||||
if (tid == 0) {
|
||||
for (uint32_t i = 0; i < 4; ++i) {
|
||||
g_case23_out_bits[i] = wu_case23_f32_to_bits(probs[i]);
|
||||
}
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
if (tid == 0) {
|
||||
const float expected[4] = {0.0625f, 0.1875f, 0.3125f, 0.4375f};
|
||||
const float tolerance = 0.0015f;
|
||||
for (uint32_t i = 0; i < 4; ++i) {
|
||||
if (wu_case23_absf(probs[i] - expected[i]) > tolerance) {
|
||||
g_aux[0] = i;
|
||||
g_aux[1] = g_case23_out_bits[i];
|
||||
wu_case_fail(0x23u);
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
wu_case_pass();
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
3
kernels/wu_arch_cases/case24_flash_sw_pipeline/Makefile
Normal file
3
kernels/wu_arch_cases/case24_flash_sw_pipeline/Makefile
Normal file
@@ -0,0 +1,3 @@
|
||||
PROJECT = case24_flash_sw_pipeline
|
||||
|
||||
include ../case.mk
|
||||
24
kernels/wu_arch_cases/case24_flash_sw_pipeline/README.md
Normal file
24
kernels/wu_arch_cases/case24_flash_sw_pipeline/README.md
Normal file
@@ -0,0 +1,24 @@
|
||||
# case24_flash_sw_pipeline
|
||||
|
||||
Software-pipelined FlashAttention-style multi-iteration case.
|
||||
|
||||
This case keeps `case16_flash_full_pipeline` as the single-tile
|
||||
producer/consumer baseline and adds a four-iteration ping-pong pipeline:
|
||||
|
||||
```text
|
||||
tensor warp 2 / slot 0: QK(0) -> wait P(0) -> PV(0) -> QK(2) -> wait P(2) -> PV(2)
|
||||
tensor warp 3 / slot 1: QK(1) -> wait P(1) -> PV(1) -> QK(3) -> wait P(3) -> PV(3)
|
||||
scalar warp 0: softmax(0) -> softmax(1) -> softmax(2) -> softmax(3)
|
||||
```
|
||||
|
||||
Each tensor warp owns one TMEM slot. The tensor warp writes `S = Q @ K` into
|
||||
TMEM C for its slot, marks `score_ready[iter]`, waits for scalar-generated
|
||||
`P`, then computes `O = P @ V`. Scalar warp 0 waits on each score in order,
|
||||
uses scalar-only `FEXP.S` for stable softmax, writes packed fp16 probabilities
|
||||
back to the same slot's TMEM A, and marks `p_ready[iter]`.
|
||||
|
||||
The first version intentionally uses constant `Q`, `K`, and `V` so the expected
|
||||
numeric result is simple: every score is fp32 `32.0`, every softmax row is
|
||||
uniform `1/32`, and every output word is fp32 `1.0`. The test objective is the
|
||||
multi-iteration overlap structure and per-slot handoff, not non-uniform
|
||||
FlashAttention numerics.
|
||||
318
kernels/wu_arch_cases/case24_flash_sw_pipeline/kernel.cpp
Normal file
318
kernels/wu_arch_cases/case24_flash_sw_pipeline/kernel.cpp
Normal file
@@ -0,0 +1,318 @@
|
||||
#define WU_CASE_WAIT_SPIN 16384u
|
||||
|
||||
#include "../common_wu_blackwell_fa.h"
|
||||
|
||||
#define WU_CASE24_ITER_N 4u
|
||||
#define WU_CASE24_ROW_N 32u
|
||||
#define WU_CASE24_SCORE_READY_BASE 0xb000u
|
||||
#define WU_CASE24_P_READY_BASE 0xb100u
|
||||
#define WU_CASE24_DONE_BASE 0xb200u
|
||||
#define WU_CASE24_FP32_ZERO 0x00000000u
|
||||
#define WU_CASE24_FP32_THIRTY_TWO 0x42000000u
|
||||
#define WU_CASE24_OUT_WORDS (WU_CASE24_ITER_N * WU_BW_OUT_WORDS)
|
||||
|
||||
extern "C" {
|
||||
volatile uint32_t g_case24_q_row[4] __attribute__((aligned(16))) = {
|
||||
WU_BW_FP16_ONE_PACKED, WU_BW_FP16_ONE_PACKED, WU_BW_FP16_ONE_PACKED,
|
||||
WU_BW_FP16_ONE_PACKED};
|
||||
volatile uint32_t g_case24_zero_row[4] __attribute__((aligned(16))) = {
|
||||
WU_CASE24_FP32_ZERO, WU_CASE24_FP32_ZERO, WU_CASE24_FP32_ZERO,
|
||||
WU_CASE24_FP32_ZERO};
|
||||
volatile uint32_t g_case24_score_ready[WU_CASE24_ITER_N]
|
||||
__attribute__((aligned(16)));
|
||||
volatile uint32_t g_case24_p_ready[WU_CASE24_ITER_N]
|
||||
__attribute__((aligned(16)));
|
||||
volatile uint32_t g_case24_done[WU_CASE24_ITER_N] __attribute__((aligned(16)));
|
||||
volatile uint32_t g_case24_score_bits[WU_CASE24_ITER_N * NUM_THREADS]
|
||||
__attribute__((aligned(16)));
|
||||
volatile uint32_t g_case24_p_bits[WU_CASE24_ITER_N * NUM_THREADS]
|
||||
__attribute__((aligned(16)));
|
||||
volatile uint32_t g_case24_overlap_hint __attribute__((aligned(16)));
|
||||
volatile uint32_t g_case24_out[WU_CASE24_OUT_WORDS]
|
||||
__attribute__((aligned(16)));
|
||||
}
|
||||
|
||||
static inline float wu_case24_bits_to_f32(uint32_t bits) {
|
||||
union {
|
||||
uint32_t u;
|
||||
float f;
|
||||
} v = {bits};
|
||||
return v.f;
|
||||
}
|
||||
|
||||
static inline uint32_t wu_case24_f32_to_bits(float value) {
|
||||
union {
|
||||
float f;
|
||||
uint32_t u;
|
||||
} v = {value};
|
||||
return v.u;
|
||||
}
|
||||
|
||||
static inline uint16_t wu_case24_f32_to_f16_positive(float value) {
|
||||
const uint32_t bits = wu_case24_f32_to_bits(value);
|
||||
const uint32_t exp = (bits >> 23) & 0xffu;
|
||||
uint32_t mant = bits & 0x7fffffu;
|
||||
|
||||
if (exp == 0 || value <= 0.0f) {
|
||||
return 0;
|
||||
}
|
||||
if (exp >= 143u) {
|
||||
return 0x7c00u;
|
||||
}
|
||||
if (exp <= 112u) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
uint32_t half_exp = exp - 112u;
|
||||
mant += 0x1000u;
|
||||
if (mant & 0x800000u) {
|
||||
mant = 0;
|
||||
++half_exp;
|
||||
}
|
||||
if (half_exp >= 31u) {
|
||||
return 0x7c00u;
|
||||
}
|
||||
return static_cast<uint16_t>((half_exp << 10) | (mant >> 13));
|
||||
}
|
||||
|
||||
static inline uint32_t wu_case24_pack_f16x2(float value) {
|
||||
const uint32_t h = wu_case24_f32_to_f16_positive(value);
|
||||
return h | (h << 16);
|
||||
}
|
||||
|
||||
static inline int wu_case24_wait_status(volatile uint32_t *status,
|
||||
uint32_t iter, uint32_t base) {
|
||||
const uint32_t expected = base | iter;
|
||||
for (uint32_t spin = 0; spin < WU_CASE_WAIT_SPIN; ++spin) {
|
||||
if (status[iter] == expected) {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
|
||||
static inline void wu_case24_softmax_tmem_row_to_p(uint32_t iter,
|
||||
uint32_t score_frag_base,
|
||||
uint32_t p_byte_base) {
|
||||
float row_max = wu_case24_bits_to_f32(wu_bw_scalar_tmem_ld(score_frag_base));
|
||||
for (uint32_t i = 1; i < WU_CASE24_ROW_N; ++i) {
|
||||
const float score =
|
||||
wu_case24_bits_to_f32(wu_bw_scalar_tmem_ld(score_frag_base + i));
|
||||
row_max = score > row_max ? score : row_max;
|
||||
}
|
||||
|
||||
float denom = 0.0f;
|
||||
for (uint32_t i = 0; i < WU_CASE24_ROW_N; ++i) {
|
||||
const float score =
|
||||
wu_case24_bits_to_f32(wu_bw_scalar_tmem_ld(score_frag_base + i));
|
||||
denom += wu_fexp_s(score - row_max);
|
||||
}
|
||||
|
||||
const uint32_t p_frag_base = p_byte_base / WU_BW_TMEM_FRAGMENT_BYTES;
|
||||
for (uint32_t frag = 0; frag < WU_BW_TMEM_FRAGMENTS; ++frag) {
|
||||
const uint32_t row_idx = frag % WU_CASE24_ROW_N;
|
||||
const float score =
|
||||
wu_case24_bits_to_f32(wu_bw_scalar_tmem_ld(score_frag_base + row_idx));
|
||||
const float p = wu_fexp_s(score - row_max) / denom;
|
||||
if (frag == 0) {
|
||||
g_case24_p_bits[iter * NUM_THREADS + wu_tid()] =
|
||||
wu_case24_f32_to_bits(p);
|
||||
}
|
||||
wu_bw_scalar_tmem_st(p_frag_base + frag, wu_case24_pack_f16x2(p));
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
}
|
||||
|
||||
extern "C" void __attribute__((naked, noinline, used)) tensor_case24_worker() {
|
||||
asm volatile(
|
||||
"csrr x5, %[csr_wid]\n\t"
|
||||
"addi x8, x5, -%[num_scalar_warps]\n\t"
|
||||
"slli x1, x8, 11\n\t"
|
||||
"addi x2, x1, %[c_offset]\n\t"
|
||||
"mv x9, x8\n\t"
|
||||
"1:\n\t"
|
||||
"li x10, %[iter_n]\n\t"
|
||||
"bge x9, x10, 9f\n\t"
|
||||
"la x3, g_case24_q_row\n\t"
|
||||
"li x7, 0\n\t"
|
||||
"2:\n\t"
|
||||
"add x4, x1, x7\n\t"
|
||||
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
|
||||
"addi x7, x7, 16\n\t"
|
||||
"li x4, %[tile_bytes]\n\t"
|
||||
"blt x7, x4, 2b\n\t"
|
||||
"la x3, g_case24_zero_row\n\t"
|
||||
"li x7, 0\n\t"
|
||||
"3:\n\t"
|
||||
"add x4, x2, x7\n\t"
|
||||
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
|
||||
"addi x7, x7, 16\n\t"
|
||||
"li x4, %[tile_bytes]\n\t"
|
||||
"blt x7, x4, 3b\n\t"
|
||||
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||
"li x4, %[smem_base]\n\t"
|
||||
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
|
||||
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
|
||||
"slli x6, x9, 2\n\t"
|
||||
"la x7, g_case24_score_ready\n\t"
|
||||
"add x7, x7, x6\n\t"
|
||||
"li x6, %[score_ready_base]\n\t"
|
||||
"or x6, x6, x9\n\t"
|
||||
"sw x6, 0(x7)\n\t"
|
||||
"slli x6, x9, 2\n\t"
|
||||
"la x7, g_case24_p_ready\n\t"
|
||||
"add x7, x7, x6\n\t"
|
||||
"li x4, %[p_ready_base]\n\t"
|
||||
"or x4, x4, x9\n\t"
|
||||
"4:\n\t"
|
||||
"lw x6, 0(x7)\n\t"
|
||||
"bne x6, x4, 4b\n\t"
|
||||
"la x3, g_case24_zero_row\n\t"
|
||||
"li x7, 0\n\t"
|
||||
"5:\n\t"
|
||||
"add x4, x2, x7\n\t"
|
||||
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
|
||||
"addi x7, x7, 16\n\t"
|
||||
"li x4, %[tile_bytes]\n\t"
|
||||
"blt x7, x4, 5b\n\t"
|
||||
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||
"li x4, %[smem_base]\n\t"
|
||||
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
|
||||
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
|
||||
"slli x6, x9, 10\n\t"
|
||||
"la x3, g_case24_out\n\t"
|
||||
"add x3, x3, x6\n\t"
|
||||
"li x7, 0\n\t"
|
||||
"6:\n\t"
|
||||
"add x4, x2, x7\n\t"
|
||||
"add x6, x3, x7\n\t"
|
||||
".insn r %[custom3], 6, 0, x0, x4, x6\n\t"
|
||||
"addi x7, x7, 16\n\t"
|
||||
"li x4, %[tile_bytes]\n\t"
|
||||
"blt x7, x4, 6b\n\t"
|
||||
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||
"slli x6, x9, 2\n\t"
|
||||
"la x7, g_case24_done\n\t"
|
||||
"add x7, x7, x6\n\t"
|
||||
"li x6, %[done_base]\n\t"
|
||||
"or x6, x6, x9\n\t"
|
||||
"sw x6, 0(x7)\n\t"
|
||||
"addi x9, x9, 2\n\t"
|
||||
"j 1b\n\t"
|
||||
"9:\n\t"
|
||||
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
|
||||
"10: j 10b\n\t"
|
||||
:
|
||||
: [csr_wid] "i"(VX_CSR_WARP_ID), [custom0] "i"(RISCV_CUSTOM0),
|
||||
[custom3] "i"(RISCV_CUSTOM3),
|
||||
[num_scalar_warps] "i"(NUM_SCALAR_WARPS),
|
||||
[c_offset] "i"(WU_BW_TMEM_C_BYTE_OFFSET),
|
||||
[tile_bytes] "i"(WU_BW_TMEM_TILE_BYTES),
|
||||
[smem_base] "i"(WU_BW_DEV_SMEM_START_ADDR),
|
||||
[iter_n] "i"(WU_CASE24_ITER_N),
|
||||
[score_ready_base] "i"(WU_CASE24_SCORE_READY_BASE),
|
||||
[p_ready_base] "i"(WU_CASE24_P_READY_BASE),
|
||||
[done_base] "i"(WU_CASE24_DONE_BASE)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
extern "C" int wu_main() {
|
||||
if (vx_core_id() != 0 || vx_warp_id() != 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const uint32_t tid = wu_tid();
|
||||
const uint32_t tensor_mask = vx_tensor_warp_mask();
|
||||
if (tid == 0) {
|
||||
wu_case_reset();
|
||||
for (uint32_t i = 0; i < WU_CASE24_ITER_N; ++i) {
|
||||
g_case24_score_ready[i] = 0;
|
||||
g_case24_p_ready[i] = 0;
|
||||
g_case24_done[i] = 0;
|
||||
}
|
||||
for (uint32_t i = 0; i < WU_CASE24_ITER_N * NUM_THREADS; ++i) {
|
||||
g_case24_score_bits[i] = 0;
|
||||
g_case24_p_bits[i] = 0;
|
||||
}
|
||||
for (uint32_t i = 0; i < WU_CASE24_OUT_WORDS; ++i) {
|
||||
g_case24_out[i] = 0;
|
||||
}
|
||||
g_case24_overlap_hint = 0;
|
||||
wu_bw_fill_smem_tile(
|
||||
reinterpret_cast<volatile uint32_t *>(WU_BW_DEV_SMEM_START_ADDR),
|
||||
WU_BW_FP16_ONE_PACKED);
|
||||
vx_spawn_tensor(tensor_mask, tensor_case24_worker);
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
for (uint32_t iter = 0; iter < WU_CASE24_ITER_N; ++iter) {
|
||||
if (wu_case24_wait_status(g_case24_score_ready, iter,
|
||||
WU_CASE24_SCORE_READY_BASE) != 0) {
|
||||
if (tid == 0) {
|
||||
g_case_mem[1] = 0x81u;
|
||||
g_aux[0] = iter;
|
||||
}
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
const uint32_t slot = iter & 1u;
|
||||
const uint32_t c_frag =
|
||||
wu_bw_tmem_c_byte_base(slot) / WU_BW_TMEM_FRAGMENT_BYTES;
|
||||
const uint32_t observed = wu_bw_scalar_tmem_ld(c_frag);
|
||||
g_case24_score_bits[iter * NUM_THREADS + tid] = observed;
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
if (tid == 0 && g_case_mem[1] == 0 &&
|
||||
observed != WU_CASE24_FP32_THIRTY_TWO) {
|
||||
g_aux[0] = iter;
|
||||
g_aux[1] = observed;
|
||||
g_case_mem[1] = 0x82u;
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
if (g_case_mem[1] == 0) {
|
||||
vx_tmc(wu_bw_all_lanes_mask());
|
||||
wu_case24_softmax_tmem_row_to_p(iter, c_frag,
|
||||
wu_bw_tmem_a_byte_base(slot));
|
||||
vx_tmc_one();
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
if (tid == 0 && g_case_mem[1] == 0) {
|
||||
if (iter == 0 &&
|
||||
g_case24_score_ready[1] == (WU_CASE24_SCORE_READY_BASE | 1u)) {
|
||||
g_case24_overlap_hint = 1;
|
||||
}
|
||||
g_case24_p_ready[iter] = WU_CASE24_P_READY_BASE | iter;
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
}
|
||||
|
||||
if (tid == 0) {
|
||||
for (uint32_t iter = 0; iter < WU_CASE24_ITER_N; ++iter) {
|
||||
if (g_case_mem[1] == 0 &&
|
||||
wu_case24_wait_status(g_case24_done, iter, WU_CASE24_DONE_BASE) !=
|
||||
0) {
|
||||
g_aux[0] = iter;
|
||||
g_case_mem[1] = 0x83u;
|
||||
}
|
||||
}
|
||||
if (g_case_mem[1] == 0) {
|
||||
volatile uint32_t bad_actual = 0;
|
||||
const uint32_t bad =
|
||||
wu_bw_verify_constant(g_case24_out, WU_CASE24_OUT_WORDS,
|
||||
WU_BW_FP32_ONE, &bad_actual);
|
||||
if (bad != WU_CASE24_OUT_WORDS) {
|
||||
g_aux[0] = bad;
|
||||
g_aux[1] = bad_actual;
|
||||
g_case_mem[1] = 0x84u;
|
||||
}
|
||||
}
|
||||
if (g_case_mem[1] != 0) {
|
||||
wu_case_fail(g_case_mem[1]);
|
||||
return 1;
|
||||
}
|
||||
wu_case_pass();
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
@@ -134,6 +134,14 @@ static inline int wu_is_leader() {
|
||||
return vx_core_id() == 0 && vx_warp_id() == 0 && vx_thread_id() == 0;
|
||||
}
|
||||
|
||||
static inline float wu_fexp_s(float value) {
|
||||
float result;
|
||||
asm volatile(".insn r %[custom1], 2, 0x30, %[rd], %[rs1], x0"
|
||||
: [rd] "=f"(result)
|
||||
: [rs1] "f"(value), [custom1] "i"(RISCV_CUSTOM1));
|
||||
return result;
|
||||
}
|
||||
|
||||
static inline void wu_report_tohost(uint32_t exit_code) {
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
tohost = (static_cast<uint64_t>(exit_code) << 1) | 1u;
|
||||
|
||||
Reference in New Issue
Block a user