feat: add fp8 tensor dot product path
This commit is contained in:
@@ -38,17 +38,37 @@ object FP8E4M3 {
|
||||
}
|
||||
}
|
||||
|
||||
object TensorInputType extends Enumeration {
|
||||
val FP16, FP32, FP8E4M3 = Value
|
||||
|
||||
def fromHalf(half: Boolean): Value = {
|
||||
if (half) FP16 else FP32
|
||||
}
|
||||
}
|
||||
|
||||
// Implements the four-element dot product (FEDP) unit in Volta Tensor Cores.
|
||||
// `half`: if True, generate fp16 MACs; if False fp32.
|
||||
class TensorDotProductUnit(
|
||||
val dim: Int = 4,
|
||||
val half: Boolean
|
||||
val dim: Int,
|
||||
val half: Boolean,
|
||||
val inputType: TensorInputType.Value
|
||||
) extends Module with tile.HasFPUParameters {
|
||||
val tIn = if (half) tile.FType.H else tile.FType.S
|
||||
def this(dim: Int = 4, half: Boolean) = {
|
||||
this(dim, half, TensorInputType.fromHalf(half))
|
||||
}
|
||||
|
||||
val tIn = inputType match {
|
||||
case TensorInputType.FP16 => tile.FType.H
|
||||
case TensorInputType.FP32 => tile.FType.S
|
||||
case TensorInputType.FP8E4M3 => tile.FType.S
|
||||
}
|
||||
// output datatype fixed to single-precision
|
||||
val tOut = tile.FType.S
|
||||
|
||||
val inFLen = tIn.ieeeWidth
|
||||
val inFLen = inputType match {
|
||||
case TensorInputType.FP8E4M3 => 8
|
||||
case _ => tIn.ieeeWidth
|
||||
}
|
||||
val outFLen = tOut.ieeeWidth
|
||||
val fLen = outFLen // needed for HasFPUParameters
|
||||
val minFLen = 16 // fp16
|
||||
@@ -71,9 +91,21 @@ class TensorDotProductUnit(
|
||||
// [IEEE] -> recode() -> unbox() -> [Hardfloat] -> box() -> ieee() -> [IEEE]
|
||||
// make sure recoding/uncoding happens only at the edge, not at every
|
||||
// pipeline stage inside the dpu
|
||||
val tag = if (half) H else S
|
||||
val in1 = io.in.bits.a.map(x => unbox(recode(x, tag), tag, Some(tIn)))
|
||||
val in2 = io.in.bits.b.map(x => unbox(recode(x, tag), tag, Some(tIn)))
|
||||
val tag = inputType match {
|
||||
case TensorInputType.FP16 => H
|
||||
case TensorInputType.FP32 => S
|
||||
case TensorInputType.FP8E4M3 => S
|
||||
}
|
||||
private def recodeInput(x: Bits): UInt = {
|
||||
inputType match {
|
||||
case TensorInputType.FP8E4M3 =>
|
||||
unbox(recode(FP8E4M3.toFloat32(x.asUInt), S), S, Some(tIn))
|
||||
case _ =>
|
||||
unbox(recode(x.asUInt, tag), tag, Some(tIn))
|
||||
}
|
||||
}
|
||||
val in1 = io.in.bits.a.map(recodeInput)
|
||||
val in2 = io.in.bits.b.map(recodeInput)
|
||||
val in3 = unbox(recode(io.in.bits.c, S), S, Some(tOut))
|
||||
|
||||
val dpu = Module(new DotProductPipe(dim, tIn, tOut))
|
||||
|
||||
@@ -32,4 +32,52 @@ class FP8E4M3Test extends AnyFlatSpec with ChiselScalatestTester {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
it should "run an 8-wide FP8 dot product with FP32 accumulation" in {
|
||||
test(new TensorDotProductUnit(8, half = false, inputType = TensorInputType.FP8E4M3)) { c =>
|
||||
c.io.in.valid.poke(true.B)
|
||||
c.io.stall.poke(false.B)
|
||||
for (i <- 0 until 8) {
|
||||
c.io.in.bits.a(i).poke(0x38.U(8.W))
|
||||
c.io.in.bits.b(i).poke(0x40.U(8.W))
|
||||
}
|
||||
c.io.in.bits.c.poke(0x3f800000L.U(32.W))
|
||||
|
||||
c.io.out.valid.expect(false.B)
|
||||
c.clock.step()
|
||||
c.io.in.valid.poke(false.B)
|
||||
c.io.out.valid.expect(false.B)
|
||||
|
||||
c.clock.step()
|
||||
c.clock.step()
|
||||
c.clock.step()
|
||||
c.clock.step()
|
||||
c.io.out.valid.expect(true.B)
|
||||
c.io.out.bits.data.expect(0x41880000L.U)
|
||||
}
|
||||
}
|
||||
|
||||
it should "run an 8-wide fractional FP8 dot product with FP32 accumulation" in {
|
||||
test(new TensorDotProductUnit(8, half = false, inputType = TensorInputType.FP8E4M3)) { c =>
|
||||
c.io.in.valid.poke(true.B)
|
||||
c.io.stall.poke(false.B)
|
||||
for (i <- 0 until 8) {
|
||||
c.io.in.bits.a(i).poke(0x30.U(8.W))
|
||||
c.io.in.bits.b(i).poke(0x3c.U(8.W))
|
||||
}
|
||||
c.io.in.bits.c.poke(0x40000000L.U(32.W))
|
||||
|
||||
c.io.out.valid.expect(false.B)
|
||||
c.clock.step()
|
||||
c.io.in.valid.poke(false.B)
|
||||
c.io.out.valid.expect(false.B)
|
||||
|
||||
c.clock.step()
|
||||
c.clock.step()
|
||||
c.clock.step()
|
||||
c.clock.step()
|
||||
c.io.out.valid.expect(true.B)
|
||||
c.io.out.bits.data.expect(0x41000000L.U)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user