From ca4c48251d26df971e1dc97d2b5239448e4e7d00 Mon Sep 17 00:00:00 2001 From: Zhongdi LUO Date: Fri, 3 Jul 2026 08:40:10 +0000 Subject: [PATCH] feat: use fp8 multiply for blackwell bwgmma --- src/main/scala/radiance/core/TensorDPU.scala | 176 ++++++++++++++++-- src/test/scala/radiance/FP8E4M3Test.scala | 27 +++ .../TensorCoreBlackwellExtendedTest.scala | 10 +- .../radiance/TensorCoreBlackwellTest.scala | 10 +- 4 files changed, 194 insertions(+), 29 deletions(-) diff --git a/src/main/scala/radiance/core/TensorDPU.scala b/src/main/scala/radiance/core/TensorDPU.scala index 2504974..8bed9b6 100644 --- a/src/main/scala/radiance/core/TensorDPU.scala +++ b/src/main/scala/radiance/core/TensorDPU.scala @@ -38,6 +38,54 @@ object FP8E4M3 { } } +object FP8E4M3MulToFloat32 { + private val Bias = 7 + + def apply(a: UInt, b: UInt): UInt = { + val sign = a(7) ^ b(7) + val expA = a(6, 3) + val expB = b(6, 3) + val fracA = a(2, 0) + val fracB = b(2, 0) + val zeroA = expA === 0.U && fracA === 0.U + val zeroB = expB === 0.U && fracB === 0.U + val isZero = zeroA || zeroB + + val sigA = Mux(expA === 0.U, Cat(0.U(1.W), fracA), Cat(1.U(1.W), fracA)) + val sigB = Mux(expB === 0.U, Cat(0.U(1.W), fracB), Cat(1.U(1.W), fracB)) + val prodSig = sigA * sigB + + val scaleA = Wire(SInt(6.W)) + val scaleB = Wire(SInt(6.W)) + scaleA := Mux(expA === 0.U, -9.S(6.W), expA.zext - (Bias + 3).S(6.W)) + scaleB := Mux(expB === 0.U, -9.S(6.W), expB.zext - (Bias + 3).S(6.W)) + + val msb = Wire(UInt(3.W)) + when(prodSig(7)) { + msb := 7.U + }.elsewhen(prodSig(6)) { + msb := 6.U + }.elsewhen(prodSig(5)) { + msb := 5.U + }.elsewhen(prodSig(4)) { + msb := 4.U + }.elsewhen(prodSig(3)) { + msb := 3.U + }.elsewhen(prodSig(2)) { + msb := 2.U + }.elsewhen(prodSig(1)) { + msb := 1.U + }.otherwise { + msb := 0.U + } + + val normalized = (prodSig << (7.U - msb))(7, 0) + val exponent = (scaleA + scaleB + msb.zext + 127.S(10.W)).asUInt(7, 0) + val fraction = Cat(normalized(6, 0), 0.U(16.W)) + Mux(isZero, Cat(sign, 0.U(31.W)), Cat(sign, exponent, fraction)) + } +} + object TensorInputType extends Enumeration { val FP16, FP32, FP8E4M3 = Value @@ -96,27 +144,34 @@ class TensorDotProductUnit( case TensorInputType.FP32 => S case TensorInputType.FP8E4M3 => S } - private def recodeInput(x: Bits): UInt = { - inputType match { - case TensorInputType.FP8E4M3 => - unbox(recode(FP8E4M3.toFloat32(x.asUInt), S), S, Some(tIn)) - case _ => - unbox(recode(x.asUInt, tag), tag, Some(tIn)) + if (inputType == TensorInputType.FP8E4M3) { + val dpu = Module(new DotProductPipeFP8E4M3(dim)) + dpu.io.in.valid := io.in.valid + dpu.io.in.bits.a := io.in.bits.a + dpu.io.in.bits.b := io.in.bits.b + dpu.io.in.bits.c := io.in.bits.c + dpu.io.stall := io.stall + + io.out.valid := dpu.io.out.valid + io.out.bits.data := dpu.io.out.bits.data + } else { + def recodeInput(x: Bits): UInt = { + unbox(recode(x.asUInt, tag), tag, Some(tIn)) } + val in1 = io.in.bits.a.map(recodeInput) + val in2 = io.in.bits.b.map(recodeInput) + val in3 = unbox(recode(io.in.bits.c, S), S, Some(tOut)) + + val dpu = Module(new DotProductPipe(dim, tIn, tOut)) + dpu.io.in.valid := io.in.valid + dpu.io.in.bits.a := in1 + dpu.io.in.bits.b := in2 + dpu.io.in.bits.c := in3 + dpu.io.stall := io.stall + + io.out.valid := dpu.io.out.valid + io.out.bits.data := ieee(box(dpu.io.out.bits.data, S)) } - val in1 = io.in.bits.a.map(recodeInput) - val in2 = io.in.bits.b.map(recodeInput) - val in3 = unbox(recode(io.in.bits.c, S), S, Some(tOut)) - - val dpu = Module(new DotProductPipe(dim, tIn, tOut)) - dpu.io.in.valid := io.in.valid - dpu.io.in.bits.a := in1 - dpu.io.in.bits.b := in2 - dpu.io.in.bits.c := in3 - dpu.io.stall := io.stall - - io.out.valid := dpu.io.out.valid - io.out.bits.data := ieee(box(dpu.io.out.bits.data, S)) } // An implementation of chisel3.util.Pipe that supports stalls. @@ -299,6 +354,89 @@ class DotProductPipe(dim: Int, inputType: tile.FType, outputType: tile.FType) ex io.out.bits.data := accStageOut.bits } +class DotProductPipeFP8E4M3(dim: Int) extends Module with tile.HasFPUParameters { + val tOut = tile.FType.S + val outExpWidth = tOut.exp + val outSigWidth = tOut.sig + val recOutFLen = outExpWidth + outSigWidth + 1 + val fLen = tOut.ieeeWidth + val minFLen = 16 + def xLen = 32 + + val io = IO(new Bundle { + val in = Flipped(Valid(new Bundle { + val a = Vec(dim, Bits(8.W)) + val b = Vec(dim, Bits(8.W)) + val c = Bits(32.W) + })) + val stall = Input(Bool()) + val out = Valid(new Bundle { + val data = Bits(32.W) + }) + }) + + val productRecoded = io.in.bits.a.zip(io.in.bits.b).map { case (a, b) => + unbox(recode(FP8E4M3MulToFloat32(a.asUInt, b.asUInt), S), S, Some(tOut)) + } + val inC = unbox(recode(io.in.bits.c, S), S, Some(tOut)) + + val productStageOut = StallingPipe(io.stall, io.in.valid, VecInit(productRecoded)) + val productStageC = StallingPipe(io.stall, io.in.valid, inC) + + val log2Dim = log2Ceil(dim) + require(dim == (1 << log2Dim), s"dim (${dim}) is not power of two!") + + val interim = (log2Dim to 0 by -1).map { i => + Wire(Valid(Vec(1 << i, Bits(recOutFLen.W)))) + } + val interimC = (log2Dim to 0 by -1).map(_ => Wire(Valid(Bits(recOutFLen.W)))) + interim(0) := productStageOut + interimC(0) := productStageC + + val (addStageOut, addStageC) = (interim zip interimC).reduce { + (inputsAndC, outputsAndC) => { + val (inputs, inC) = inputsAndC + val (outputs, outC) = outputsAndC + + require(inputs.bits.length == 2 * outputs.bits.length) + val thisDim = inputs.bits.length + val adders = Seq.fill(thisDim / 2)( + Module(new hardfloat.AddRecFN(outExpWidth, outSigWidth)) + ) + val addOuts = adders.zipWithIndex.map { case (a, i) => + a.io.subOp := 0.U + a.io.a := inputs.bits(2 * i + 0) + a.io.b := inputs.bits(2 * i + 1) + a.io.roundingMode := hardfloat.consts.round_near_even + a.io.detectTininess := hardfloat.consts.tininess_afterRounding + a.io.out + } + + outputs := StallingPipe(io.stall, inputs.valid, VecInit(addOuts)) + outC := StallingPipe(io.stall, inputs.valid, inC.bits) + when(inputs.valid =/= inC.valid) { + printf("WARN: DotProductPipeFP8E4M3 input/C valid mismatch: inputs=%d c=%d\n", + inputs.valid, inC.valid) + } + + (outputs, outC) + } + } + require(addStageOut.bits.length == 1) + + val acc = Module(new hardfloat.AddRecFN(outExpWidth, outSigWidth)) + acc.io.subOp := 0.U + acc.io.a := addStageOut.bits(0) + acc.io.b := addStageC.bits + acc.io.roundingMode := hardfloat.consts.round_near_even + acc.io.detectTininess := hardfloat.consts.tininess_afterRounding + + val accStageOut = StallingPipe(io.stall, addStageOut.valid, acc.io.out) + + io.out.valid := accStageOut.valid + io.out.bits.data := ieee(box(accStageOut.bits, S)) +} + class MulAddRecFNPipe(latency: Int, expWidth: Int, sigWidth: Int) extends Module { require(latency <= 2) diff --git a/src/test/scala/radiance/FP8E4M3Test.scala b/src/test/scala/radiance/FP8E4M3Test.scala index e67d49f..7f201c8 100644 --- a/src/test/scala/radiance/FP8E4M3Test.scala +++ b/src/test/scala/radiance/FP8E4M3Test.scala @@ -13,6 +13,16 @@ class FP8E4M3DecodeHarness extends Module { io.out := FP8E4M3.toFloat32(io.in) } +class FP8E4M3MulHarness extends Module { + val io = IO(new Bundle { + val a = Input(UInt(8.W)) + val b = Input(UInt(8.W)) + val out = Output(UInt(32.W)) + }) + + io.out := FP8E4M3MulToFloat32(io.a, io.b) +} + class FP8E4M3Test extends AnyFlatSpec with ChiselScalatestTester { behavior of "FP8E4M3" @@ -33,6 +43,23 @@ class FP8E4M3Test extends AnyFlatSpec with ChiselScalatestTester { } } + it should "multiply E4M3 operands with FP8-width significands and return FP32 bits" in { + test(new FP8E4M3MulHarness) { c => + Seq( + (0x38, 0x40, 0x40000000L), // 1.0 * 2.0 = 2.0 + (0x30, 0x3c, 0x3f400000L), // 0.5 * 1.5 = 0.75 + (0xb8, 0x40, 0xc0000000L), // -1.0 * 2.0 = -2.0 + (0x00, 0x40, 0x00000000L), // 0.0 * 2.0 = 0.0 + (0x80, 0x40, 0x80000000L) // -0.0 * 2.0 = -0.0 + ).foreach { case (a, b, out) => + c.io.a.poke(a.U) + c.io.b.poke(b.U) + c.clock.step() + c.io.out.expect(out.U) + } + } + } + it should "run an 8-wide FP8 dot product with FP32 accumulation" in { test(new TensorDotProductUnit(8, half = false, inputType = TensorInputType.FP8E4M3)) { c => c.io.in.valid.poke(true.B) diff --git a/src/test/scala/radiance/TensorCoreBlackwellExtendedTest.scala b/src/test/scala/radiance/TensorCoreBlackwellExtendedTest.scala index 7073311..72d7309 100644 --- a/src/test/scala/radiance/TensorCoreBlackwellExtendedTest.scala +++ b/src/test/scala/radiance/TensorCoreBlackwellExtendedTest.scala @@ -68,22 +68,22 @@ class TensorCoreBlackwellExtendedTest extends AnyFlatSpec with ChiselScalatestTe val cBase = BigInt(0x600) // row 48, C tile rows 48~79 (no overlap with A) val bBase = BigInt(0x800) - val fp16One = BigInt(0x3c00) + val fp8One = BigInt(0x38) val fp32Zero = BigInt(0) // 4 sets × 8 dot products × (1.0 × 2.0) = 64.0f val fp32SixtyFour = BigInt(0x42800000L) // Populate TMEM A at offset aBase (all 32 frags) - val aFrag = packWords(Seq.fill(16)(fp16One), 16) + val aFrag = packWords(Seq.fill(32)(fp8One), 8) val cFrag = packWords(Seq.fill(numLanes)(fp32Zero), 32) for (i <- 0 until 32) { tmem(aBase / fragBytes + i) = aFrag tmem(cBase / fragBytes + i) = cFrag } - // SMEM B with fp16 2.0 - val fp16Two = BigInt(0x4000) - val bFrag = packWords(Seq.fill(16)(fp16Two), 16) + // SMEM B with packed FP8 E4M3 2.0 + val fp8Two = BigInt(0x40) + val bFrag = packWords(Seq.fill(32)(fp8Two), 8) val bMem = mutable.Map[BigInt, BigInt]().withDefaultValue(bFrag) for (i <- 0 until 32) bMem(bBase + i * fragBytes) = bFrag diff --git a/src/test/scala/radiance/TensorCoreBlackwellTest.scala b/src/test/scala/radiance/TensorCoreBlackwellTest.scala index feb8008..f85b0de 100644 --- a/src/test/scala/radiance/TensorCoreBlackwellTest.scala +++ b/src/test/scala/radiance/TensorCoreBlackwellTest.scala @@ -249,13 +249,13 @@ class TensorCoreBlackwellTest extends AnyFlatSpec with ChiselScalatestTester { val bBase = BigInt(0x800) val cBase = BigInt(0x1000) - // A: all fp16 1.0 (0x3c00), 16 halves per frag - val fp16One = BigInt(0x3c00) - val fp16Two = BigInt(0x4000) + // A/B: packed FP8 E4M3 bytes, 32 elements per 256-bit frag + val fp8One = BigInt(0x38) + val fp8Two = BigInt(0x40) val fp32One = BigInt(0x3f800000) val fp32SixtyFive = BigInt(0x42820000) - val aFrag = packWords(Seq.fill(16)(fp16One), 16) - val bFrag = packWords(Seq.fill(16)(fp16Two), 16) + val aFrag = packWords(Seq.fill(32)(fp8One), 8) + val bFrag = packWords(Seq.fill(32)(fp8Two), 8) val cFrag = packWords(Seq.fill(numLanes)(fp32One), 32) val expectedCFrag = packWords(Seq.fill(numLanes)(fp32SixtyFive), 32)