Files
kernels/kernels/wu_arch_cases/case23_softmax_only/kernel.cpp
2026-07-02 07:24:59 +00:00

84 lines
1.9 KiB
C++

#include "../common_wu_min.h"
extern "C" {
volatile uint32_t g_case23_scores_bits[4] __attribute__((aligned(16))) = {
0x00000000u, 0x3f8c9f54u, 0x3fcdf854u, 0x3ff91395u};
volatile uint32_t g_case23_out_bits[4] __attribute__((aligned(16)));
}
static inline float wu_case23_bits_to_f32(uint32_t bits) {
union {
uint32_t u;
float f;
} v = {bits};
return v.f;
}
static inline uint32_t wu_case23_f32_to_bits(float value) {
union {
float f;
uint32_t u;
} v = {value};
return v.u;
}
static inline float wu_case23_absf(float value) {
return value < 0.0f ? -value : value;
}
extern "C" int wu_main() {
if (vx_core_id() != 0 || vx_warp_id() != 0) {
return 0;
}
const uint32_t tid = wu_tid();
if (tid == 0) {
wu_case_reset();
for (uint32_t i = 0; i < 4; ++i) {
g_case23_out_bits[i] = 0;
}
}
asm volatile("fence rw, rw" ::: "memory");
float scores[4];
float row_max = wu_case23_bits_to_f32(g_case23_scores_bits[0]);
for (uint32_t i = 0; i < 4; ++i) {
scores[i] = wu_case23_bits_to_f32(g_case23_scores_bits[i]);
row_max = scores[i] > row_max ? scores[i] : row_max;
}
float exp_values[4];
float denom = 0.0f;
for (uint32_t i = 0; i < 4; ++i) {
exp_values[i] = wu_fexp_s(scores[i] - row_max);
denom += exp_values[i];
}
float probs[4];
for (uint32_t i = 0; i < 4; ++i) {
probs[i] = exp_values[i] / denom;
}
if (tid == 0) {
for (uint32_t i = 0; i < 4; ++i) {
g_case23_out_bits[i] = wu_case23_f32_to_bits(probs[i]);
}
}
asm volatile("fence rw, rw" ::: "memory");
if (tid == 0) {
const float expected[4] = {0.0625f, 0.1875f, 0.3125f, 0.4375f};
const float tolerance = 0.0015f;
for (uint32_t i = 0; i < 4; ++i) {
if (wu_case23_absf(probs[i] - expected[i]) > tolerance) {
g_aux[0] = i;
g_aux[1] = g_case23_out_bits[i];
wu_case_fail(0x23u);
return 1;
}
}
wu_case_pass();
}
return 0;
}