From 12f5c6d92d4dde26ff8a896ea7ce548480fb2868 Mon Sep 17 00:00:00 2001 From: Zhongdi LUO Date: Thu, 2 Jul 2026 08:57:59 +0000 Subject: [PATCH] feat: switch blackwell bwgmma inputs to fp8 --- .../radiance/core/TensorCoreBlackwell.scala | 47 +++++++----- .../radiance/TensorCoreBlackwellFP8Test.scala | 76 +++++++++++++++++++ 2 files changed, 104 insertions(+), 19 deletions(-) create mode 100644 src/test/scala/radiance/TensorCoreBlackwellFP8Test.scala diff --git a/src/main/scala/radiance/core/TensorCoreBlackwell.scala b/src/main/scala/radiance/core/TensorCoreBlackwell.scala index 30f2f48..6c80682 100644 --- a/src/main/scala/radiance/core/TensorCoreBlackwell.scala +++ b/src/main/scala/radiance/core/TensorCoreBlackwell.scala @@ -6,6 +6,29 @@ package radiance.core import chisel3._ import chisel3.util._ +object TensorCoreBlackwellFP8Packing { + def fp8Byte(x: UInt, idx: Int): UInt = { + x((idx + 1) * 8 - 1, idx * 8) + } + + def selectA(operandA: UInt, k: Int, elemM: UInt, numLanes: Int): UInt = { + if (numLanes == 4) { + Mux(elemM.asBool, fp8Byte(operandA, 8 + k), fp8Byte(operandA, k)) + } else { + MuxLookup(elemM, fp8Byte(operandA, k))(Seq( + 0.U -> fp8Byte(operandA, k), + 1.U -> fp8Byte(operandA, 8 + k), + 2.U -> fp8Byte(operandA, 16 + k), + 3.U -> fp8Byte(operandA, 24 + k) + )) + } + } + + def selectB(operandB: UInt, k: Int, elemN: UInt): UInt = { + Mux(elemN.asBool, fp8Byte(operandB, 8 + k), fp8Byte(operandB, k)) + } +} + class TensorCoreBlackwell( val numWarps: Int, val numLanes: Int, @@ -13,7 +36,7 @@ class TensorCoreBlackwell( val numSourceIds: Int = 16, val numFPRegs: Int = 32 ) extends Module { - require(half, "Blackwell MMA currently supports FP16 inputs only") + require(half, "Blackwell MMA compatibility flag must remain true; BWGMMA inputs are FP8 E4M3 on this branch") require(numLanes == 4 || numLanes == 8, s"Blackwell MMA currently supports 4 or 8 lanes, got ${numLanes}") @@ -198,30 +221,16 @@ class TensorCoreBlackwell( val dpuInValid = WireDefault(false.B) val dpu = Module(new TensorDotProductUnit( dim = 8, - half = true + half = false, + inputType = TensorInputType.FP8E4M3 )) - private def halfWord(x: UInt, idx: Int): UInt = { - x((idx + 1) * 16 - 1, idx * 16) - } - val elemM = if (numLanes == 4) elemReg(0, 0) else elemReg(1, 0) val elemN = if (numLanes == 4) elemReg(1) else elemReg(2) dpu.io.in.valid := dpuInValid for (k <- 0 until 8) { - dpu.io.in.bits.a(k) := ( - if (numLanes == 4) { - Mux(elemM.asBool, halfWord(operandA, 8 + k), halfWord(operandA, k)) - } else { - MuxLookup(elemM, halfWord(operandA, k))(Seq( - 0.U -> halfWord(operandA, k), - 1.U -> halfWord(operandA, 8 + k), - 2.U -> halfWord(operandA, 16 + k), - 3.U -> halfWord(operandA, 24 + k) - )) - } - ) - dpu.io.in.bits.b(k) := Mux(elemN.asBool, halfWord(operandB, 8 + k), halfWord(operandB, k)) + dpu.io.in.bits.a(k) := TensorCoreBlackwellFP8Packing.selectA(operandA, k, elemM, numLanes) + dpu.io.in.bits.b(k) := TensorCoreBlackwellFP8Packing.selectB(operandB, k, elemN) } dpu.io.in.bits.c := cWords(elemReg) dpu.io.stall := false.B diff --git a/src/test/scala/radiance/TensorCoreBlackwellFP8Test.scala b/src/test/scala/radiance/TensorCoreBlackwellFP8Test.scala new file mode 100644 index 0000000..c2aeb72 --- /dev/null +++ b/src/test/scala/radiance/TensorCoreBlackwellFP8Test.scala @@ -0,0 +1,76 @@ +package radiance.core + +import chisel3._ +import chiseltest._ +import org.scalatest.flatspec.AnyFlatSpec + +class TensorCoreBlackwellFP8MicrostepHarness extends Module { + val io = IO(new Bundle { + val valid = Input(Bool()) + val operandA = Input(UInt(512.W)) + val operandB = Input(UInt(256.W)) + val c = Input(UInt(32.W)) + val elemM = Input(UInt(2.W)) + val elemN = Input(UInt(1.W)) + val outValid = Output(Bool()) + val out = Output(UInt(32.W)) + }) + + val dpu = Module(new TensorDotProductUnit( + dim = 8, + half = false, + inputType = TensorInputType.FP8E4M3 + )) + + dpu.io.in.valid := io.valid + for (k <- 0 until 8) { + dpu.io.in.bits.a(k) := TensorCoreBlackwellFP8Packing.selectA(io.operandA, k, io.elemM, numLanes = 8) + dpu.io.in.bits.b(k) := TensorCoreBlackwellFP8Packing.selectB(io.operandB, k, io.elemN) + } + dpu.io.in.bits.c := io.c + dpu.io.stall := false.B + + io.outValid := dpu.io.out.valid + io.out := dpu.io.out.bits.data +} + +class TensorCoreBlackwellFP8Test extends AnyFlatSpec with ChiselScalatestTester { + behavior of "TensorCoreBlackwell FP8 operand microstep" + + private def packWords(words: Seq[BigInt], width: Int): BigInt = { + val mask = (BigInt(1) << width) - 1 + words.zipWithIndex.foldLeft(BigInt(0)) { + case (acc, (word, i)) => acc | ((word & mask) << (i * width)) + } + } + + it should "select packed FP8 E4M3 operands and accumulate into FP32" in { + test(new TensorCoreBlackwellFP8MicrostepHarness) { c => + val fp8One = BigInt(0x38) + val fp8Two = BigInt(0x40) + val fp32One = BigInt(0x3f800000L) + val fp32Seventeen = BigInt(0x41880000L) + val operandA = packWords(Seq.fill(64)(fp8One), 8) + val operandB = packWords(Seq.fill(32)(fp8Two), 8) + + c.io.valid.poke(true.B) + c.io.operandA.poke(operandA.U) + c.io.operandB.poke(operandB.U) + c.io.c.poke(fp32One.U) + c.io.elemM.poke(0.U) + c.io.elemN.poke(0.U) + c.io.outValid.expect(false.B) + + c.clock.step() + c.io.valid.poke(false.B) + c.io.outValid.expect(false.B) + + c.clock.step() + c.clock.step() + c.clock.step() + c.clock.step() + c.io.outValid.expect(true.B) + c.io.out.expect(fp32Seventeen.U) + } + } +}