feat: add fp8 tensor dot product path
This commit is contained in:
@@ -38,17 +38,37 @@ object FP8E4M3 {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
object TensorInputType extends Enumeration {
|
||||||
|
val FP16, FP32, FP8E4M3 = Value
|
||||||
|
|
||||||
|
def fromHalf(half: Boolean): Value = {
|
||||||
|
if (half) FP16 else FP32
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Implements the four-element dot product (FEDP) unit in Volta Tensor Cores.
|
// Implements the four-element dot product (FEDP) unit in Volta Tensor Cores.
|
||||||
// `half`: if True, generate fp16 MACs; if False fp32.
|
// `half`: if True, generate fp16 MACs; if False fp32.
|
||||||
class TensorDotProductUnit(
|
class TensorDotProductUnit(
|
||||||
val dim: Int = 4,
|
val dim: Int,
|
||||||
val half: Boolean
|
val half: Boolean,
|
||||||
|
val inputType: TensorInputType.Value
|
||||||
) extends Module with tile.HasFPUParameters {
|
) extends Module with tile.HasFPUParameters {
|
||||||
val tIn = if (half) tile.FType.H else tile.FType.S
|
def this(dim: Int = 4, half: Boolean) = {
|
||||||
|
this(dim, half, TensorInputType.fromHalf(half))
|
||||||
|
}
|
||||||
|
|
||||||
|
val tIn = inputType match {
|
||||||
|
case TensorInputType.FP16 => tile.FType.H
|
||||||
|
case TensorInputType.FP32 => tile.FType.S
|
||||||
|
case TensorInputType.FP8E4M3 => tile.FType.S
|
||||||
|
}
|
||||||
// output datatype fixed to single-precision
|
// output datatype fixed to single-precision
|
||||||
val tOut = tile.FType.S
|
val tOut = tile.FType.S
|
||||||
|
|
||||||
val inFLen = tIn.ieeeWidth
|
val inFLen = inputType match {
|
||||||
|
case TensorInputType.FP8E4M3 => 8
|
||||||
|
case _ => tIn.ieeeWidth
|
||||||
|
}
|
||||||
val outFLen = tOut.ieeeWidth
|
val outFLen = tOut.ieeeWidth
|
||||||
val fLen = outFLen // needed for HasFPUParameters
|
val fLen = outFLen // needed for HasFPUParameters
|
||||||
val minFLen = 16 // fp16
|
val minFLen = 16 // fp16
|
||||||
@@ -71,9 +91,21 @@ class TensorDotProductUnit(
|
|||||||
// [IEEE] -> recode() -> unbox() -> [Hardfloat] -> box() -> ieee() -> [IEEE]
|
// [IEEE] -> recode() -> unbox() -> [Hardfloat] -> box() -> ieee() -> [IEEE]
|
||||||
// make sure recoding/uncoding happens only at the edge, not at every
|
// make sure recoding/uncoding happens only at the edge, not at every
|
||||||
// pipeline stage inside the dpu
|
// pipeline stage inside the dpu
|
||||||
val tag = if (half) H else S
|
val tag = inputType match {
|
||||||
val in1 = io.in.bits.a.map(x => unbox(recode(x, tag), tag, Some(tIn)))
|
case TensorInputType.FP16 => H
|
||||||
val in2 = io.in.bits.b.map(x => unbox(recode(x, tag), tag, Some(tIn)))
|
case TensorInputType.FP32 => S
|
||||||
|
case TensorInputType.FP8E4M3 => S
|
||||||
|
}
|
||||||
|
private def recodeInput(x: Bits): UInt = {
|
||||||
|
inputType match {
|
||||||
|
case TensorInputType.FP8E4M3 =>
|
||||||
|
unbox(recode(FP8E4M3.toFloat32(x.asUInt), S), S, Some(tIn))
|
||||||
|
case _ =>
|
||||||
|
unbox(recode(x.asUInt, tag), tag, Some(tIn))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
val in1 = io.in.bits.a.map(recodeInput)
|
||||||
|
val in2 = io.in.bits.b.map(recodeInput)
|
||||||
val in3 = unbox(recode(io.in.bits.c, S), S, Some(tOut))
|
val in3 = unbox(recode(io.in.bits.c, S), S, Some(tOut))
|
||||||
|
|
||||||
val dpu = Module(new DotProductPipe(dim, tIn, tOut))
|
val dpu = Module(new DotProductPipe(dim, tIn, tOut))
|
||||||
|
|||||||
@@ -32,4 +32,52 @@ class FP8E4M3Test extends AnyFlatSpec with ChiselScalatestTester {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
it should "run an 8-wide FP8 dot product with FP32 accumulation" in {
|
||||||
|
test(new TensorDotProductUnit(8, half = false, inputType = TensorInputType.FP8E4M3)) { c =>
|
||||||
|
c.io.in.valid.poke(true.B)
|
||||||
|
c.io.stall.poke(false.B)
|
||||||
|
for (i <- 0 until 8) {
|
||||||
|
c.io.in.bits.a(i).poke(0x38.U(8.W))
|
||||||
|
c.io.in.bits.b(i).poke(0x40.U(8.W))
|
||||||
|
}
|
||||||
|
c.io.in.bits.c.poke(0x3f800000L.U(32.W))
|
||||||
|
|
||||||
|
c.io.out.valid.expect(false.B)
|
||||||
|
c.clock.step()
|
||||||
|
c.io.in.valid.poke(false.B)
|
||||||
|
c.io.out.valid.expect(false.B)
|
||||||
|
|
||||||
|
c.clock.step()
|
||||||
|
c.clock.step()
|
||||||
|
c.clock.step()
|
||||||
|
c.clock.step()
|
||||||
|
c.io.out.valid.expect(true.B)
|
||||||
|
c.io.out.bits.data.expect(0x41880000L.U)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
it should "run an 8-wide fractional FP8 dot product with FP32 accumulation" in {
|
||||||
|
test(new TensorDotProductUnit(8, half = false, inputType = TensorInputType.FP8E4M3)) { c =>
|
||||||
|
c.io.in.valid.poke(true.B)
|
||||||
|
c.io.stall.poke(false.B)
|
||||||
|
for (i <- 0 until 8) {
|
||||||
|
c.io.in.bits.a(i).poke(0x30.U(8.W))
|
||||||
|
c.io.in.bits.b(i).poke(0x3c.U(8.W))
|
||||||
|
}
|
||||||
|
c.io.in.bits.c.poke(0x40000000L.U(32.W))
|
||||||
|
|
||||||
|
c.io.out.valid.expect(false.B)
|
||||||
|
c.clock.step()
|
||||||
|
c.io.in.valid.poke(false.B)
|
||||||
|
c.io.out.valid.expect(false.B)
|
||||||
|
|
||||||
|
c.clock.step()
|
||||||
|
c.clock.step()
|
||||||
|
c.clock.step()
|
||||||
|
c.clock.step()
|
||||||
|
c.io.out.valid.expect(true.B)
|
||||||
|
c.io.out.bits.data.expect(0x41000000L.U)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user