84 lines
1.9 KiB
C++
84 lines
1.9 KiB
C++
#include "../common_wu_min.h"
|
|
|
|
extern "C" {
|
|
volatile uint32_t g_case23_scores_bits[4] __attribute__((aligned(16))) = {
|
|
0x00000000u, 0x3f8c9f54u, 0x3fcdf854u, 0x3ff91395u};
|
|
volatile uint32_t g_case23_out_bits[4] __attribute__((aligned(16)));
|
|
}
|
|
|
|
static inline float wu_case23_bits_to_f32(uint32_t bits) {
|
|
union {
|
|
uint32_t u;
|
|
float f;
|
|
} v = {bits};
|
|
return v.f;
|
|
}
|
|
|
|
static inline uint32_t wu_case23_f32_to_bits(float value) {
|
|
union {
|
|
float f;
|
|
uint32_t u;
|
|
} v = {value};
|
|
return v.u;
|
|
}
|
|
|
|
static inline float wu_case23_absf(float value) {
|
|
return value < 0.0f ? -value : value;
|
|
}
|
|
|
|
extern "C" int wu_main() {
|
|
if (vx_core_id() != 0 || vx_warp_id() != 0) {
|
|
return 0;
|
|
}
|
|
|
|
const uint32_t tid = wu_tid();
|
|
if (tid == 0) {
|
|
wu_case_reset();
|
|
for (uint32_t i = 0; i < 4; ++i) {
|
|
g_case23_out_bits[i] = 0;
|
|
}
|
|
}
|
|
asm volatile("fence rw, rw" ::: "memory");
|
|
|
|
float scores[4];
|
|
float row_max = wu_case23_bits_to_f32(g_case23_scores_bits[0]);
|
|
for (uint32_t i = 0; i < 4; ++i) {
|
|
scores[i] = wu_case23_bits_to_f32(g_case23_scores_bits[i]);
|
|
row_max = scores[i] > row_max ? scores[i] : row_max;
|
|
}
|
|
|
|
float exp_values[4];
|
|
float denom = 0.0f;
|
|
for (uint32_t i = 0; i < 4; ++i) {
|
|
exp_values[i] = wu_fexp_s(scores[i] - row_max);
|
|
denom += exp_values[i];
|
|
}
|
|
|
|
float probs[4];
|
|
for (uint32_t i = 0; i < 4; ++i) {
|
|
probs[i] = exp_values[i] / denom;
|
|
}
|
|
|
|
if (tid == 0) {
|
|
for (uint32_t i = 0; i < 4; ++i) {
|
|
g_case23_out_bits[i] = wu_case23_f32_to_bits(probs[i]);
|
|
}
|
|
}
|
|
asm volatile("fence rw, rw" ::: "memory");
|
|
|
|
if (tid == 0) {
|
|
const float expected[4] = {0.0625f, 0.1875f, 0.3125f, 0.4375f};
|
|
const float tolerance = 0.0015f;
|
|
for (uint32_t i = 0; i < 4; ++i) {
|
|
if (wu_case23_absf(probs[i] - expected[i]) > tolerance) {
|
|
g_aux[0] = i;
|
|
g_aux[1] = g_case23_out_bits[i];
|
|
wu_case_fail(0x23u);
|
|
return 1;
|
|
}
|
|
}
|
|
wu_case_pass();
|
|
}
|
|
return 0;
|
|
}
|