From 2afb96bb145afb180c006c012504dc4d9da3dd46 Mon Sep 17 00:00:00 2001 From: Zhongdi LUO Date: Thu, 2 Jul 2026 07:59:01 +0000 Subject: [PATCH] feat: add fp8 e4m3 decode support --- src/main/scala/radiance/core/TensorDPU.scala | 31 +++++++++++++++++ src/test/scala/radiance/FP8E4M3Test.scala | 35 ++++++++++++++++++++ 2 files changed, 66 insertions(+) create mode 100644 src/test/scala/radiance/FP8E4M3Test.scala diff --git a/src/main/scala/radiance/core/TensorDPU.scala b/src/main/scala/radiance/core/TensorDPU.scala index d1a2377..5666dc6 100644 --- a/src/main/scala/radiance/core/TensorDPU.scala +++ b/src/main/scala/radiance/core/TensorDPU.scala @@ -7,6 +7,37 @@ import chisel3._ import chisel3.util._ import freechips.rocketchip.tile +object FP8E4M3 { + private val Bias = 7 + + private def decodeToFloat(bits: Int): Float = { + val sign = (bits >> 7) & 0x1 + val exp = (bits >> 3) & 0xf + val frac = bits & 0x7 + + val magnitude = + if (exp == 0) { + if (frac == 0) 0.0 + else (frac.toDouble / 8.0) * Math.pow(2.0, 1 - Bias) + } else { + (1.0 + frac.toDouble / 8.0) * Math.pow(2.0, exp - Bias) + } + + val value = if (sign == 1) -magnitude else magnitude + value.toFloat + } + + private def fp32Bits(bits: Int): BigInt = { + BigInt(java.lang.Float.floatToRawIntBits(decodeToFloat(bits)).toLong & 0xffffffffL) + } + + def toFloat32(x: UInt): UInt = { + MuxLookup(x, 0.U(32.W))((0 until 256).map { bits => + bits.U(8.W) -> fp32Bits(bits).U(32.W) + }) + } +} + // Implements the four-element dot product (FEDP) unit in Volta Tensor Cores. // `half`: if True, generate fp16 MACs; if False fp32. class TensorDotProductUnit( diff --git a/src/test/scala/radiance/FP8E4M3Test.scala b/src/test/scala/radiance/FP8E4M3Test.scala new file mode 100644 index 0000000..eba0a35 --- /dev/null +++ b/src/test/scala/radiance/FP8E4M3Test.scala @@ -0,0 +1,35 @@ +package radiance.core + +import chisel3._ +import chiseltest._ +import org.scalatest.flatspec.AnyFlatSpec + +class FP8E4M3DecodeHarness extends Module { + val io = IO(new Bundle { + val in = Input(UInt(8.W)) + val out = Output(UInt(32.W)) + }) + + io.out := FP8E4M3.toFloat32(io.in) +} + +class FP8E4M3Test extends AnyFlatSpec with ChiselScalatestTester { + behavior of "FP8E4M3" + + it should "decode representative E4M3 values to FP32 bits" in { + test(new FP8E4M3DecodeHarness) { c => + Seq( + 0x00 -> 0x00000000L, + 0x80 -> 0x80000000L, + 0x38 -> 0x3f800000L, + 0x40 -> 0x40000000L, + 0x30 -> 0x3f000000L, + 0x3c -> 0x3fc00000L + ).foreach { case (fp8, fp32) => + c.io.in.poke(fp8.U) + c.clock.step() + c.io.out.expect(fp32.U) + } + } + } +}