feat: use fp8 multiply for blackwell bwgmma
This commit is contained in:
@@ -38,6 +38,54 @@ object FP8E4M3 {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
object FP8E4M3MulToFloat32 {
|
||||||
|
private val Bias = 7
|
||||||
|
|
||||||
|
def apply(a: UInt, b: UInt): UInt = {
|
||||||
|
val sign = a(7) ^ b(7)
|
||||||
|
val expA = a(6, 3)
|
||||||
|
val expB = b(6, 3)
|
||||||
|
val fracA = a(2, 0)
|
||||||
|
val fracB = b(2, 0)
|
||||||
|
val zeroA = expA === 0.U && fracA === 0.U
|
||||||
|
val zeroB = expB === 0.U && fracB === 0.U
|
||||||
|
val isZero = zeroA || zeroB
|
||||||
|
|
||||||
|
val sigA = Mux(expA === 0.U, Cat(0.U(1.W), fracA), Cat(1.U(1.W), fracA))
|
||||||
|
val sigB = Mux(expB === 0.U, Cat(0.U(1.W), fracB), Cat(1.U(1.W), fracB))
|
||||||
|
val prodSig = sigA * sigB
|
||||||
|
|
||||||
|
val scaleA = Wire(SInt(6.W))
|
||||||
|
val scaleB = Wire(SInt(6.W))
|
||||||
|
scaleA := Mux(expA === 0.U, -9.S(6.W), expA.zext - (Bias + 3).S(6.W))
|
||||||
|
scaleB := Mux(expB === 0.U, -9.S(6.W), expB.zext - (Bias + 3).S(6.W))
|
||||||
|
|
||||||
|
val msb = Wire(UInt(3.W))
|
||||||
|
when(prodSig(7)) {
|
||||||
|
msb := 7.U
|
||||||
|
}.elsewhen(prodSig(6)) {
|
||||||
|
msb := 6.U
|
||||||
|
}.elsewhen(prodSig(5)) {
|
||||||
|
msb := 5.U
|
||||||
|
}.elsewhen(prodSig(4)) {
|
||||||
|
msb := 4.U
|
||||||
|
}.elsewhen(prodSig(3)) {
|
||||||
|
msb := 3.U
|
||||||
|
}.elsewhen(prodSig(2)) {
|
||||||
|
msb := 2.U
|
||||||
|
}.elsewhen(prodSig(1)) {
|
||||||
|
msb := 1.U
|
||||||
|
}.otherwise {
|
||||||
|
msb := 0.U
|
||||||
|
}
|
||||||
|
|
||||||
|
val normalized = (prodSig << (7.U - msb))(7, 0)
|
||||||
|
val exponent = (scaleA + scaleB + msb.zext + 127.S(10.W)).asUInt(7, 0)
|
||||||
|
val fraction = Cat(normalized(6, 0), 0.U(16.W))
|
||||||
|
Mux(isZero, Cat(sign, 0.U(31.W)), Cat(sign, exponent, fraction))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
object TensorInputType extends Enumeration {
|
object TensorInputType extends Enumeration {
|
||||||
val FP16, FP32, FP8E4M3 = Value
|
val FP16, FP32, FP8E4M3 = Value
|
||||||
|
|
||||||
@@ -96,14 +144,20 @@ class TensorDotProductUnit(
|
|||||||
case TensorInputType.FP32 => S
|
case TensorInputType.FP32 => S
|
||||||
case TensorInputType.FP8E4M3 => S
|
case TensorInputType.FP8E4M3 => S
|
||||||
}
|
}
|
||||||
private def recodeInput(x: Bits): UInt = {
|
if (inputType == TensorInputType.FP8E4M3) {
|
||||||
inputType match {
|
val dpu = Module(new DotProductPipeFP8E4M3(dim))
|
||||||
case TensorInputType.FP8E4M3 =>
|
dpu.io.in.valid := io.in.valid
|
||||||
unbox(recode(FP8E4M3.toFloat32(x.asUInt), S), S, Some(tIn))
|
dpu.io.in.bits.a := io.in.bits.a
|
||||||
case _ =>
|
dpu.io.in.bits.b := io.in.bits.b
|
||||||
|
dpu.io.in.bits.c := io.in.bits.c
|
||||||
|
dpu.io.stall := io.stall
|
||||||
|
|
||||||
|
io.out.valid := dpu.io.out.valid
|
||||||
|
io.out.bits.data := dpu.io.out.bits.data
|
||||||
|
} else {
|
||||||
|
def recodeInput(x: Bits): UInt = {
|
||||||
unbox(recode(x.asUInt, tag), tag, Some(tIn))
|
unbox(recode(x.asUInt, tag), tag, Some(tIn))
|
||||||
}
|
}
|
||||||
}
|
|
||||||
val in1 = io.in.bits.a.map(recodeInput)
|
val in1 = io.in.bits.a.map(recodeInput)
|
||||||
val in2 = io.in.bits.b.map(recodeInput)
|
val in2 = io.in.bits.b.map(recodeInput)
|
||||||
val in3 = unbox(recode(io.in.bits.c, S), S, Some(tOut))
|
val in3 = unbox(recode(io.in.bits.c, S), S, Some(tOut))
|
||||||
@@ -117,6 +171,7 @@ class TensorDotProductUnit(
|
|||||||
|
|
||||||
io.out.valid := dpu.io.out.valid
|
io.out.valid := dpu.io.out.valid
|
||||||
io.out.bits.data := ieee(box(dpu.io.out.bits.data, S))
|
io.out.bits.data := ieee(box(dpu.io.out.bits.data, S))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// An implementation of chisel3.util.Pipe that supports stalls.
|
// An implementation of chisel3.util.Pipe that supports stalls.
|
||||||
@@ -299,6 +354,89 @@ class DotProductPipe(dim: Int, inputType: tile.FType, outputType: tile.FType) ex
|
|||||||
io.out.bits.data := accStageOut.bits
|
io.out.bits.data := accStageOut.bits
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class DotProductPipeFP8E4M3(dim: Int) extends Module with tile.HasFPUParameters {
|
||||||
|
val tOut = tile.FType.S
|
||||||
|
val outExpWidth = tOut.exp
|
||||||
|
val outSigWidth = tOut.sig
|
||||||
|
val recOutFLen = outExpWidth + outSigWidth + 1
|
||||||
|
val fLen = tOut.ieeeWidth
|
||||||
|
val minFLen = 16
|
||||||
|
def xLen = 32
|
||||||
|
|
||||||
|
val io = IO(new Bundle {
|
||||||
|
val in = Flipped(Valid(new Bundle {
|
||||||
|
val a = Vec(dim, Bits(8.W))
|
||||||
|
val b = Vec(dim, Bits(8.W))
|
||||||
|
val c = Bits(32.W)
|
||||||
|
}))
|
||||||
|
val stall = Input(Bool())
|
||||||
|
val out = Valid(new Bundle {
|
||||||
|
val data = Bits(32.W)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
val productRecoded = io.in.bits.a.zip(io.in.bits.b).map { case (a, b) =>
|
||||||
|
unbox(recode(FP8E4M3MulToFloat32(a.asUInt, b.asUInt), S), S, Some(tOut))
|
||||||
|
}
|
||||||
|
val inC = unbox(recode(io.in.bits.c, S), S, Some(tOut))
|
||||||
|
|
||||||
|
val productStageOut = StallingPipe(io.stall, io.in.valid, VecInit(productRecoded))
|
||||||
|
val productStageC = StallingPipe(io.stall, io.in.valid, inC)
|
||||||
|
|
||||||
|
val log2Dim = log2Ceil(dim)
|
||||||
|
require(dim == (1 << log2Dim), s"dim (${dim}) is not power of two!")
|
||||||
|
|
||||||
|
val interim = (log2Dim to 0 by -1).map { i =>
|
||||||
|
Wire(Valid(Vec(1 << i, Bits(recOutFLen.W))))
|
||||||
|
}
|
||||||
|
val interimC = (log2Dim to 0 by -1).map(_ => Wire(Valid(Bits(recOutFLen.W))))
|
||||||
|
interim(0) := productStageOut
|
||||||
|
interimC(0) := productStageC
|
||||||
|
|
||||||
|
val (addStageOut, addStageC) = (interim zip interimC).reduce {
|
||||||
|
(inputsAndC, outputsAndC) => {
|
||||||
|
val (inputs, inC) = inputsAndC
|
||||||
|
val (outputs, outC) = outputsAndC
|
||||||
|
|
||||||
|
require(inputs.bits.length == 2 * outputs.bits.length)
|
||||||
|
val thisDim = inputs.bits.length
|
||||||
|
val adders = Seq.fill(thisDim / 2)(
|
||||||
|
Module(new hardfloat.AddRecFN(outExpWidth, outSigWidth))
|
||||||
|
)
|
||||||
|
val addOuts = adders.zipWithIndex.map { case (a, i) =>
|
||||||
|
a.io.subOp := 0.U
|
||||||
|
a.io.a := inputs.bits(2 * i + 0)
|
||||||
|
a.io.b := inputs.bits(2 * i + 1)
|
||||||
|
a.io.roundingMode := hardfloat.consts.round_near_even
|
||||||
|
a.io.detectTininess := hardfloat.consts.tininess_afterRounding
|
||||||
|
a.io.out
|
||||||
|
}
|
||||||
|
|
||||||
|
outputs := StallingPipe(io.stall, inputs.valid, VecInit(addOuts))
|
||||||
|
outC := StallingPipe(io.stall, inputs.valid, inC.bits)
|
||||||
|
when(inputs.valid =/= inC.valid) {
|
||||||
|
printf("WARN: DotProductPipeFP8E4M3 input/C valid mismatch: inputs=%d c=%d\n",
|
||||||
|
inputs.valid, inC.valid)
|
||||||
|
}
|
||||||
|
|
||||||
|
(outputs, outC)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
require(addStageOut.bits.length == 1)
|
||||||
|
|
||||||
|
val acc = Module(new hardfloat.AddRecFN(outExpWidth, outSigWidth))
|
||||||
|
acc.io.subOp := 0.U
|
||||||
|
acc.io.a := addStageOut.bits(0)
|
||||||
|
acc.io.b := addStageC.bits
|
||||||
|
acc.io.roundingMode := hardfloat.consts.round_near_even
|
||||||
|
acc.io.detectTininess := hardfloat.consts.tininess_afterRounding
|
||||||
|
|
||||||
|
val accStageOut = StallingPipe(io.stall, addStageOut.valid, acc.io.out)
|
||||||
|
|
||||||
|
io.out.valid := accStageOut.valid
|
||||||
|
io.out.bits.data := ieee(box(accStageOut.bits, S))
|
||||||
|
}
|
||||||
|
|
||||||
class MulAddRecFNPipe(latency: Int, expWidth: Int, sigWidth: Int) extends Module {
|
class MulAddRecFNPipe(latency: Int, expWidth: Int, sigWidth: Int) extends Module {
|
||||||
require(latency <= 2)
|
require(latency <= 2)
|
||||||
|
|
||||||
|
|||||||
@@ -13,6 +13,16 @@ class FP8E4M3DecodeHarness extends Module {
|
|||||||
io.out := FP8E4M3.toFloat32(io.in)
|
io.out := FP8E4M3.toFloat32(io.in)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class FP8E4M3MulHarness extends Module {
|
||||||
|
val io = IO(new Bundle {
|
||||||
|
val a = Input(UInt(8.W))
|
||||||
|
val b = Input(UInt(8.W))
|
||||||
|
val out = Output(UInt(32.W))
|
||||||
|
})
|
||||||
|
|
||||||
|
io.out := FP8E4M3MulToFloat32(io.a, io.b)
|
||||||
|
}
|
||||||
|
|
||||||
class FP8E4M3Test extends AnyFlatSpec with ChiselScalatestTester {
|
class FP8E4M3Test extends AnyFlatSpec with ChiselScalatestTester {
|
||||||
behavior of "FP8E4M3"
|
behavior of "FP8E4M3"
|
||||||
|
|
||||||
@@ -33,6 +43,23 @@ class FP8E4M3Test extends AnyFlatSpec with ChiselScalatestTester {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
it should "multiply E4M3 operands with FP8-width significands and return FP32 bits" in {
|
||||||
|
test(new FP8E4M3MulHarness) { c =>
|
||||||
|
Seq(
|
||||||
|
(0x38, 0x40, 0x40000000L), // 1.0 * 2.0 = 2.0
|
||||||
|
(0x30, 0x3c, 0x3f400000L), // 0.5 * 1.5 = 0.75
|
||||||
|
(0xb8, 0x40, 0xc0000000L), // -1.0 * 2.0 = -2.0
|
||||||
|
(0x00, 0x40, 0x00000000L), // 0.0 * 2.0 = 0.0
|
||||||
|
(0x80, 0x40, 0x80000000L) // -0.0 * 2.0 = -0.0
|
||||||
|
).foreach { case (a, b, out) =>
|
||||||
|
c.io.a.poke(a.U)
|
||||||
|
c.io.b.poke(b.U)
|
||||||
|
c.clock.step()
|
||||||
|
c.io.out.expect(out.U)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
it should "run an 8-wide FP8 dot product with FP32 accumulation" in {
|
it should "run an 8-wide FP8 dot product with FP32 accumulation" in {
|
||||||
test(new TensorDotProductUnit(8, half = false, inputType = TensorInputType.FP8E4M3)) { c =>
|
test(new TensorDotProductUnit(8, half = false, inputType = TensorInputType.FP8E4M3)) { c =>
|
||||||
c.io.in.valid.poke(true.B)
|
c.io.in.valid.poke(true.B)
|
||||||
|
|||||||
@@ -68,22 +68,22 @@ class TensorCoreBlackwellExtendedTest extends AnyFlatSpec with ChiselScalatestTe
|
|||||||
val cBase = BigInt(0x600) // row 48, C tile rows 48~79 (no overlap with A)
|
val cBase = BigInt(0x600) // row 48, C tile rows 48~79 (no overlap with A)
|
||||||
val bBase = BigInt(0x800)
|
val bBase = BigInt(0x800)
|
||||||
|
|
||||||
val fp16One = BigInt(0x3c00)
|
val fp8One = BigInt(0x38)
|
||||||
val fp32Zero = BigInt(0)
|
val fp32Zero = BigInt(0)
|
||||||
// 4 sets × 8 dot products × (1.0 × 2.0) = 64.0f
|
// 4 sets × 8 dot products × (1.0 × 2.0) = 64.0f
|
||||||
val fp32SixtyFour = BigInt(0x42800000L)
|
val fp32SixtyFour = BigInt(0x42800000L)
|
||||||
|
|
||||||
// Populate TMEM A at offset aBase (all 32 frags)
|
// Populate TMEM A at offset aBase (all 32 frags)
|
||||||
val aFrag = packWords(Seq.fill(16)(fp16One), 16)
|
val aFrag = packWords(Seq.fill(32)(fp8One), 8)
|
||||||
val cFrag = packWords(Seq.fill(numLanes)(fp32Zero), 32)
|
val cFrag = packWords(Seq.fill(numLanes)(fp32Zero), 32)
|
||||||
for (i <- 0 until 32) {
|
for (i <- 0 until 32) {
|
||||||
tmem(aBase / fragBytes + i) = aFrag
|
tmem(aBase / fragBytes + i) = aFrag
|
||||||
tmem(cBase / fragBytes + i) = cFrag
|
tmem(cBase / fragBytes + i) = cFrag
|
||||||
}
|
}
|
||||||
|
|
||||||
// SMEM B with fp16 2.0
|
// SMEM B with packed FP8 E4M3 2.0
|
||||||
val fp16Two = BigInt(0x4000)
|
val fp8Two = BigInt(0x40)
|
||||||
val bFrag = packWords(Seq.fill(16)(fp16Two), 16)
|
val bFrag = packWords(Seq.fill(32)(fp8Two), 8)
|
||||||
val bMem = mutable.Map[BigInt, BigInt]().withDefaultValue(bFrag)
|
val bMem = mutable.Map[BigInt, BigInt]().withDefaultValue(bFrag)
|
||||||
for (i <- 0 until 32) bMem(bBase + i * fragBytes) = bFrag
|
for (i <- 0 until 32) bMem(bBase + i * fragBytes) = bFrag
|
||||||
|
|
||||||
|
|||||||
@@ -249,13 +249,13 @@ class TensorCoreBlackwellTest extends AnyFlatSpec with ChiselScalatestTester {
|
|||||||
val bBase = BigInt(0x800)
|
val bBase = BigInt(0x800)
|
||||||
val cBase = BigInt(0x1000)
|
val cBase = BigInt(0x1000)
|
||||||
|
|
||||||
// A: all fp16 1.0 (0x3c00), 16 halves per frag
|
// A/B: packed FP8 E4M3 bytes, 32 elements per 256-bit frag
|
||||||
val fp16One = BigInt(0x3c00)
|
val fp8One = BigInt(0x38)
|
||||||
val fp16Two = BigInt(0x4000)
|
val fp8Two = BigInt(0x40)
|
||||||
val fp32One = BigInt(0x3f800000)
|
val fp32One = BigInt(0x3f800000)
|
||||||
val fp32SixtyFive = BigInt(0x42820000)
|
val fp32SixtyFive = BigInt(0x42820000)
|
||||||
val aFrag = packWords(Seq.fill(16)(fp16One), 16)
|
val aFrag = packWords(Seq.fill(32)(fp8One), 8)
|
||||||
val bFrag = packWords(Seq.fill(16)(fp16Two), 16)
|
val bFrag = packWords(Seq.fill(32)(fp8Two), 8)
|
||||||
val cFrag = packWords(Seq.fill(numLanes)(fp32One), 32)
|
val cFrag = packWords(Seq.fill(numLanes)(fp32One), 32)
|
||||||
val expectedCFrag = packWords(Seq.fill(numLanes)(fp32SixtyFive), 32)
|
val expectedCFrag = packWords(Seq.fill(numLanes)(fp32SixtyFive), 32)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user