feat: add fp8 tensor dot product path

This commit is contained in:
Zhongdi LUO
2026-07-02 08:06:42 +00:00
parent 2afb96bb14
commit 68a7f66046
2 changed files with 87 additions and 7 deletions

View File

@@ -38,17 +38,37 @@ object FP8E4M3 {
}
}
object TensorInputType extends Enumeration {
val FP16, FP32, FP8E4M3 = Value
def fromHalf(half: Boolean): Value = {
if (half) FP16 else FP32
}
}
// Implements the four-element dot product (FEDP) unit in Volta Tensor Cores.
// `half`: if True, generate fp16 MACs; if False fp32.
class TensorDotProductUnit(
val dim: Int = 4,
val half: Boolean
val dim: Int,
val half: Boolean,
val inputType: TensorInputType.Value
) extends Module with tile.HasFPUParameters {
val tIn = if (half) tile.FType.H else tile.FType.S
def this(dim: Int = 4, half: Boolean) = {
this(dim, half, TensorInputType.fromHalf(half))
}
val tIn = inputType match {
case TensorInputType.FP16 => tile.FType.H
case TensorInputType.FP32 => tile.FType.S
case TensorInputType.FP8E4M3 => tile.FType.S
}
// output datatype fixed to single-precision
val tOut = tile.FType.S
val inFLen = tIn.ieeeWidth
val inFLen = inputType match {
case TensorInputType.FP8E4M3 => 8
case _ => tIn.ieeeWidth
}
val outFLen = tOut.ieeeWidth
val fLen = outFLen // needed for HasFPUParameters
val minFLen = 16 // fp16
@@ -71,9 +91,21 @@ class TensorDotProductUnit(
// [IEEE] -> recode() -> unbox() -> [Hardfloat] -> box() -> ieee() -> [IEEE]
// make sure recoding/uncoding happens only at the edge, not at every
// pipeline stage inside the dpu
val tag = if (half) H else S
val in1 = io.in.bits.a.map(x => unbox(recode(x, tag), tag, Some(tIn)))
val in2 = io.in.bits.b.map(x => unbox(recode(x, tag), tag, Some(tIn)))
val tag = inputType match {
case TensorInputType.FP16 => H
case TensorInputType.FP32 => S
case TensorInputType.FP8E4M3 => S
}
private def recodeInput(x: Bits): UInt = {
inputType match {
case TensorInputType.FP8E4M3 =>
unbox(recode(FP8E4M3.toFloat32(x.asUInt), S), S, Some(tIn))
case _ =>
unbox(recode(x.asUInt, tag), tag, Some(tIn))
}
}
val in1 = io.in.bits.a.map(recodeInput)
val in2 = io.in.bits.b.map(recodeInput)
val in3 = unbox(recode(io.in.bits.c, S), S, Some(tOut))
val dpu = Module(new DotProductPipe(dim, tIn, tOut))

View File

@@ -32,4 +32,52 @@ class FP8E4M3Test extends AnyFlatSpec with ChiselScalatestTester {
}
}
}
it should "run an 8-wide FP8 dot product with FP32 accumulation" in {
test(new TensorDotProductUnit(8, half = false, inputType = TensorInputType.FP8E4M3)) { c =>
c.io.in.valid.poke(true.B)
c.io.stall.poke(false.B)
for (i <- 0 until 8) {
c.io.in.bits.a(i).poke(0x38.U(8.W))
c.io.in.bits.b(i).poke(0x40.U(8.W))
}
c.io.in.bits.c.poke(0x3f800000L.U(32.W))
c.io.out.valid.expect(false.B)
c.clock.step()
c.io.in.valid.poke(false.B)
c.io.out.valid.expect(false.B)
c.clock.step()
c.clock.step()
c.clock.step()
c.clock.step()
c.io.out.valid.expect(true.B)
c.io.out.bits.data.expect(0x41880000L.U)
}
}
it should "run an 8-wide fractional FP8 dot product with FP32 accumulation" in {
test(new TensorDotProductUnit(8, half = false, inputType = TensorInputType.FP8E4M3)) { c =>
c.io.in.valid.poke(true.B)
c.io.stall.poke(false.B)
for (i <- 0 until 8) {
c.io.in.bits.a(i).poke(0x30.U(8.W))
c.io.in.bits.b(i).poke(0x3c.U(8.W))
}
c.io.in.bits.c.poke(0x40000000L.U(32.W))
c.io.out.valid.expect(false.B)
c.clock.step()
c.io.in.valid.poke(false.B)
c.io.out.valid.expect(false.B)
c.clock.step()
c.clock.step()
c.clock.step()
c.clock.step()
c.io.out.valid.expect(true.B)
c.io.out.bits.data.expect(0x41000000L.U)
}
}
}