feat: switch blackwell bwgmma inputs to fp8

This commit is contained in:
Zhongdi LUO
2026-07-02 08:57:59 +00:00
parent 68a7f66046
commit 12f5c6d92d
2 changed files with 104 additions and 19 deletions

View File

@@ -6,6 +6,29 @@ package radiance.core
import chisel3._ import chisel3._
import chisel3.util._ import chisel3.util._
object TensorCoreBlackwellFP8Packing {
def fp8Byte(x: UInt, idx: Int): UInt = {
x((idx + 1) * 8 - 1, idx * 8)
}
def selectA(operandA: UInt, k: Int, elemM: UInt, numLanes: Int): UInt = {
if (numLanes == 4) {
Mux(elemM.asBool, fp8Byte(operandA, 8 + k), fp8Byte(operandA, k))
} else {
MuxLookup(elemM, fp8Byte(operandA, k))(Seq(
0.U -> fp8Byte(operandA, k),
1.U -> fp8Byte(operandA, 8 + k),
2.U -> fp8Byte(operandA, 16 + k),
3.U -> fp8Byte(operandA, 24 + k)
))
}
}
def selectB(operandB: UInt, k: Int, elemN: UInt): UInt = {
Mux(elemN.asBool, fp8Byte(operandB, 8 + k), fp8Byte(operandB, k))
}
}
class TensorCoreBlackwell( class TensorCoreBlackwell(
val numWarps: Int, val numWarps: Int,
val numLanes: Int, val numLanes: Int,
@@ -13,7 +36,7 @@ class TensorCoreBlackwell(
val numSourceIds: Int = 16, val numSourceIds: Int = 16,
val numFPRegs: Int = 32 val numFPRegs: Int = 32
) extends Module { ) extends Module {
require(half, "Blackwell MMA currently supports FP16 inputs only") require(half, "Blackwell MMA compatibility flag must remain true; BWGMMA inputs are FP8 E4M3 on this branch")
require(numLanes == 4 || numLanes == 8, require(numLanes == 4 || numLanes == 8,
s"Blackwell MMA currently supports 4 or 8 lanes, got ${numLanes}") s"Blackwell MMA currently supports 4 or 8 lanes, got ${numLanes}")
@@ -198,30 +221,16 @@ class TensorCoreBlackwell(
val dpuInValid = WireDefault(false.B) val dpuInValid = WireDefault(false.B)
val dpu = Module(new TensorDotProductUnit( val dpu = Module(new TensorDotProductUnit(
dim = 8, dim = 8,
half = true half = false,
inputType = TensorInputType.FP8E4M3
)) ))
private def halfWord(x: UInt, idx: Int): UInt = {
x((idx + 1) * 16 - 1, idx * 16)
}
val elemM = if (numLanes == 4) elemReg(0, 0) else elemReg(1, 0) val elemM = if (numLanes == 4) elemReg(0, 0) else elemReg(1, 0)
val elemN = if (numLanes == 4) elemReg(1) else elemReg(2) val elemN = if (numLanes == 4) elemReg(1) else elemReg(2)
dpu.io.in.valid := dpuInValid dpu.io.in.valid := dpuInValid
for (k <- 0 until 8) { for (k <- 0 until 8) {
dpu.io.in.bits.a(k) := ( dpu.io.in.bits.a(k) := TensorCoreBlackwellFP8Packing.selectA(operandA, k, elemM, numLanes)
if (numLanes == 4) { dpu.io.in.bits.b(k) := TensorCoreBlackwellFP8Packing.selectB(operandB, k, elemN)
Mux(elemM.asBool, halfWord(operandA, 8 + k), halfWord(operandA, k))
} else {
MuxLookup(elemM, halfWord(operandA, k))(Seq(
0.U -> halfWord(operandA, k),
1.U -> halfWord(operandA, 8 + k),
2.U -> halfWord(operandA, 16 + k),
3.U -> halfWord(operandA, 24 + k)
))
}
)
dpu.io.in.bits.b(k) := Mux(elemN.asBool, halfWord(operandB, 8 + k), halfWord(operandB, k))
} }
dpu.io.in.bits.c := cWords(elemReg) dpu.io.in.bits.c := cWords(elemReg)
dpu.io.stall := false.B dpu.io.stall := false.B

View File

@@ -0,0 +1,76 @@
package radiance.core
import chisel3._
import chiseltest._
import org.scalatest.flatspec.AnyFlatSpec
class TensorCoreBlackwellFP8MicrostepHarness extends Module {
val io = IO(new Bundle {
val valid = Input(Bool())
val operandA = Input(UInt(512.W))
val operandB = Input(UInt(256.W))
val c = Input(UInt(32.W))
val elemM = Input(UInt(2.W))
val elemN = Input(UInt(1.W))
val outValid = Output(Bool())
val out = Output(UInt(32.W))
})
val dpu = Module(new TensorDotProductUnit(
dim = 8,
half = false,
inputType = TensorInputType.FP8E4M3
))
dpu.io.in.valid := io.valid
for (k <- 0 until 8) {
dpu.io.in.bits.a(k) := TensorCoreBlackwellFP8Packing.selectA(io.operandA, k, io.elemM, numLanes = 8)
dpu.io.in.bits.b(k) := TensorCoreBlackwellFP8Packing.selectB(io.operandB, k, io.elemN)
}
dpu.io.in.bits.c := io.c
dpu.io.stall := false.B
io.outValid := dpu.io.out.valid
io.out := dpu.io.out.bits.data
}
class TensorCoreBlackwellFP8Test extends AnyFlatSpec with ChiselScalatestTester {
behavior of "TensorCoreBlackwell FP8 operand microstep"
private def packWords(words: Seq[BigInt], width: Int): BigInt = {
val mask = (BigInt(1) << width) - 1
words.zipWithIndex.foldLeft(BigInt(0)) {
case (acc, (word, i)) => acc | ((word & mask) << (i * width))
}
}
it should "select packed FP8 E4M3 operands and accumulate into FP32" in {
test(new TensorCoreBlackwellFP8MicrostepHarness) { c =>
val fp8One = BigInt(0x38)
val fp8Two = BigInt(0x40)
val fp32One = BigInt(0x3f800000L)
val fp32Seventeen = BigInt(0x41880000L)
val operandA = packWords(Seq.fill(64)(fp8One), 8)
val operandB = packWords(Seq.fill(32)(fp8Two), 8)
c.io.valid.poke(true.B)
c.io.operandA.poke(operandA.U)
c.io.operandB.poke(operandB.U)
c.io.c.poke(fp32One.U)
c.io.elemM.poke(0.U)
c.io.elemN.poke(0.U)
c.io.outValid.expect(false.B)
c.clock.step()
c.io.valid.poke(false.B)
c.io.outValid.expect(false.B)
c.clock.step()
c.clock.step()
c.clock.step()
c.clock.step()
c.io.outValid.expect(true.B)
c.io.out.expect(fp32Seventeen.U)
}
}
}