feat: add fp8 e4m3 decode support
This commit is contained in:
@@ -7,6 +7,37 @@ import chisel3._
|
|||||||
import chisel3.util._
|
import chisel3.util._
|
||||||
import freechips.rocketchip.tile
|
import freechips.rocketchip.tile
|
||||||
|
|
||||||
|
object FP8E4M3 {
|
||||||
|
private val Bias = 7
|
||||||
|
|
||||||
|
private def decodeToFloat(bits: Int): Float = {
|
||||||
|
val sign = (bits >> 7) & 0x1
|
||||||
|
val exp = (bits >> 3) & 0xf
|
||||||
|
val frac = bits & 0x7
|
||||||
|
|
||||||
|
val magnitude =
|
||||||
|
if (exp == 0) {
|
||||||
|
if (frac == 0) 0.0
|
||||||
|
else (frac.toDouble / 8.0) * Math.pow(2.0, 1 - Bias)
|
||||||
|
} else {
|
||||||
|
(1.0 + frac.toDouble / 8.0) * Math.pow(2.0, exp - Bias)
|
||||||
|
}
|
||||||
|
|
||||||
|
val value = if (sign == 1) -magnitude else magnitude
|
||||||
|
value.toFloat
|
||||||
|
}
|
||||||
|
|
||||||
|
private def fp32Bits(bits: Int): BigInt = {
|
||||||
|
BigInt(java.lang.Float.floatToRawIntBits(decodeToFloat(bits)).toLong & 0xffffffffL)
|
||||||
|
}
|
||||||
|
|
||||||
|
def toFloat32(x: UInt): UInt = {
|
||||||
|
MuxLookup(x, 0.U(32.W))((0 until 256).map { bits =>
|
||||||
|
bits.U(8.W) -> fp32Bits(bits).U(32.W)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Implements the four-element dot product (FEDP) unit in Volta Tensor Cores.
|
// Implements the four-element dot product (FEDP) unit in Volta Tensor Cores.
|
||||||
// `half`: if True, generate fp16 MACs; if False fp32.
|
// `half`: if True, generate fp16 MACs; if False fp32.
|
||||||
class TensorDotProductUnit(
|
class TensorDotProductUnit(
|
||||||
|
|||||||
35
src/test/scala/radiance/FP8E4M3Test.scala
Normal file
35
src/test/scala/radiance/FP8E4M3Test.scala
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
package radiance.core
|
||||||
|
|
||||||
|
import chisel3._
|
||||||
|
import chiseltest._
|
||||||
|
import org.scalatest.flatspec.AnyFlatSpec
|
||||||
|
|
||||||
|
class FP8E4M3DecodeHarness extends Module {
|
||||||
|
val io = IO(new Bundle {
|
||||||
|
val in = Input(UInt(8.W))
|
||||||
|
val out = Output(UInt(32.W))
|
||||||
|
})
|
||||||
|
|
||||||
|
io.out := FP8E4M3.toFloat32(io.in)
|
||||||
|
}
|
||||||
|
|
||||||
|
class FP8E4M3Test extends AnyFlatSpec with ChiselScalatestTester {
|
||||||
|
behavior of "FP8E4M3"
|
||||||
|
|
||||||
|
it should "decode representative E4M3 values to FP32 bits" in {
|
||||||
|
test(new FP8E4M3DecodeHarness) { c =>
|
||||||
|
Seq(
|
||||||
|
0x00 -> 0x00000000L,
|
||||||
|
0x80 -> 0x80000000L,
|
||||||
|
0x38 -> 0x3f800000L,
|
||||||
|
0x40 -> 0x40000000L,
|
||||||
|
0x30 -> 0x3f000000L,
|
||||||
|
0x3c -> 0x3fc00000L
|
||||||
|
).foreach { case (fp8, fp32) =>
|
||||||
|
c.io.in.poke(fp8.U)
|
||||||
|
c.clock.step()
|
||||||
|
c.io.out.expect(fp32.U)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user