feat: switch blackwell bwgmma inputs to fp8
This commit is contained in:
@@ -6,6 +6,29 @@ package radiance.core
|
|||||||
import chisel3._
|
import chisel3._
|
||||||
import chisel3.util._
|
import chisel3.util._
|
||||||
|
|
||||||
|
object TensorCoreBlackwellFP8Packing {
|
||||||
|
def fp8Byte(x: UInt, idx: Int): UInt = {
|
||||||
|
x((idx + 1) * 8 - 1, idx * 8)
|
||||||
|
}
|
||||||
|
|
||||||
|
def selectA(operandA: UInt, k: Int, elemM: UInt, numLanes: Int): UInt = {
|
||||||
|
if (numLanes == 4) {
|
||||||
|
Mux(elemM.asBool, fp8Byte(operandA, 8 + k), fp8Byte(operandA, k))
|
||||||
|
} else {
|
||||||
|
MuxLookup(elemM, fp8Byte(operandA, k))(Seq(
|
||||||
|
0.U -> fp8Byte(operandA, k),
|
||||||
|
1.U -> fp8Byte(operandA, 8 + k),
|
||||||
|
2.U -> fp8Byte(operandA, 16 + k),
|
||||||
|
3.U -> fp8Byte(operandA, 24 + k)
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def selectB(operandB: UInt, k: Int, elemN: UInt): UInt = {
|
||||||
|
Mux(elemN.asBool, fp8Byte(operandB, 8 + k), fp8Byte(operandB, k))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
class TensorCoreBlackwell(
|
class TensorCoreBlackwell(
|
||||||
val numWarps: Int,
|
val numWarps: Int,
|
||||||
val numLanes: Int,
|
val numLanes: Int,
|
||||||
@@ -13,7 +36,7 @@ class TensorCoreBlackwell(
|
|||||||
val numSourceIds: Int = 16,
|
val numSourceIds: Int = 16,
|
||||||
val numFPRegs: Int = 32
|
val numFPRegs: Int = 32
|
||||||
) extends Module {
|
) extends Module {
|
||||||
require(half, "Blackwell MMA currently supports FP16 inputs only")
|
require(half, "Blackwell MMA compatibility flag must remain true; BWGMMA inputs are FP8 E4M3 on this branch")
|
||||||
require(numLanes == 4 || numLanes == 8,
|
require(numLanes == 4 || numLanes == 8,
|
||||||
s"Blackwell MMA currently supports 4 or 8 lanes, got ${numLanes}")
|
s"Blackwell MMA currently supports 4 or 8 lanes, got ${numLanes}")
|
||||||
|
|
||||||
@@ -198,30 +221,16 @@ class TensorCoreBlackwell(
|
|||||||
val dpuInValid = WireDefault(false.B)
|
val dpuInValid = WireDefault(false.B)
|
||||||
val dpu = Module(new TensorDotProductUnit(
|
val dpu = Module(new TensorDotProductUnit(
|
||||||
dim = 8,
|
dim = 8,
|
||||||
half = true
|
half = false,
|
||||||
|
inputType = TensorInputType.FP8E4M3
|
||||||
))
|
))
|
||||||
|
|
||||||
private def halfWord(x: UInt, idx: Int): UInt = {
|
|
||||||
x((idx + 1) * 16 - 1, idx * 16)
|
|
||||||
}
|
|
||||||
|
|
||||||
val elemM = if (numLanes == 4) elemReg(0, 0) else elemReg(1, 0)
|
val elemM = if (numLanes == 4) elemReg(0, 0) else elemReg(1, 0)
|
||||||
val elemN = if (numLanes == 4) elemReg(1) else elemReg(2)
|
val elemN = if (numLanes == 4) elemReg(1) else elemReg(2)
|
||||||
dpu.io.in.valid := dpuInValid
|
dpu.io.in.valid := dpuInValid
|
||||||
for (k <- 0 until 8) {
|
for (k <- 0 until 8) {
|
||||||
dpu.io.in.bits.a(k) := (
|
dpu.io.in.bits.a(k) := TensorCoreBlackwellFP8Packing.selectA(operandA, k, elemM, numLanes)
|
||||||
if (numLanes == 4) {
|
dpu.io.in.bits.b(k) := TensorCoreBlackwellFP8Packing.selectB(operandB, k, elemN)
|
||||||
Mux(elemM.asBool, halfWord(operandA, 8 + k), halfWord(operandA, k))
|
|
||||||
} else {
|
|
||||||
MuxLookup(elemM, halfWord(operandA, k))(Seq(
|
|
||||||
0.U -> halfWord(operandA, k),
|
|
||||||
1.U -> halfWord(operandA, 8 + k),
|
|
||||||
2.U -> halfWord(operandA, 16 + k),
|
|
||||||
3.U -> halfWord(operandA, 24 + k)
|
|
||||||
))
|
|
||||||
}
|
|
||||||
)
|
|
||||||
dpu.io.in.bits.b(k) := Mux(elemN.asBool, halfWord(operandB, 8 + k), halfWord(operandB, k))
|
|
||||||
}
|
}
|
||||||
dpu.io.in.bits.c := cWords(elemReg)
|
dpu.io.in.bits.c := cWords(elemReg)
|
||||||
dpu.io.stall := false.B
|
dpu.io.stall := false.B
|
||||||
|
|||||||
76
src/test/scala/radiance/TensorCoreBlackwellFP8Test.scala
Normal file
76
src/test/scala/radiance/TensorCoreBlackwellFP8Test.scala
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
package radiance.core
|
||||||
|
|
||||||
|
import chisel3._
|
||||||
|
import chiseltest._
|
||||||
|
import org.scalatest.flatspec.AnyFlatSpec
|
||||||
|
|
||||||
|
class TensorCoreBlackwellFP8MicrostepHarness extends Module {
|
||||||
|
val io = IO(new Bundle {
|
||||||
|
val valid = Input(Bool())
|
||||||
|
val operandA = Input(UInt(512.W))
|
||||||
|
val operandB = Input(UInt(256.W))
|
||||||
|
val c = Input(UInt(32.W))
|
||||||
|
val elemM = Input(UInt(2.W))
|
||||||
|
val elemN = Input(UInt(1.W))
|
||||||
|
val outValid = Output(Bool())
|
||||||
|
val out = Output(UInt(32.W))
|
||||||
|
})
|
||||||
|
|
||||||
|
val dpu = Module(new TensorDotProductUnit(
|
||||||
|
dim = 8,
|
||||||
|
half = false,
|
||||||
|
inputType = TensorInputType.FP8E4M3
|
||||||
|
))
|
||||||
|
|
||||||
|
dpu.io.in.valid := io.valid
|
||||||
|
for (k <- 0 until 8) {
|
||||||
|
dpu.io.in.bits.a(k) := TensorCoreBlackwellFP8Packing.selectA(io.operandA, k, io.elemM, numLanes = 8)
|
||||||
|
dpu.io.in.bits.b(k) := TensorCoreBlackwellFP8Packing.selectB(io.operandB, k, io.elemN)
|
||||||
|
}
|
||||||
|
dpu.io.in.bits.c := io.c
|
||||||
|
dpu.io.stall := false.B
|
||||||
|
|
||||||
|
io.outValid := dpu.io.out.valid
|
||||||
|
io.out := dpu.io.out.bits.data
|
||||||
|
}
|
||||||
|
|
||||||
|
class TensorCoreBlackwellFP8Test extends AnyFlatSpec with ChiselScalatestTester {
|
||||||
|
behavior of "TensorCoreBlackwell FP8 operand microstep"
|
||||||
|
|
||||||
|
private def packWords(words: Seq[BigInt], width: Int): BigInt = {
|
||||||
|
val mask = (BigInt(1) << width) - 1
|
||||||
|
words.zipWithIndex.foldLeft(BigInt(0)) {
|
||||||
|
case (acc, (word, i)) => acc | ((word & mask) << (i * width))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
it should "select packed FP8 E4M3 operands and accumulate into FP32" in {
|
||||||
|
test(new TensorCoreBlackwellFP8MicrostepHarness) { c =>
|
||||||
|
val fp8One = BigInt(0x38)
|
||||||
|
val fp8Two = BigInt(0x40)
|
||||||
|
val fp32One = BigInt(0x3f800000L)
|
||||||
|
val fp32Seventeen = BigInt(0x41880000L)
|
||||||
|
val operandA = packWords(Seq.fill(64)(fp8One), 8)
|
||||||
|
val operandB = packWords(Seq.fill(32)(fp8Two), 8)
|
||||||
|
|
||||||
|
c.io.valid.poke(true.B)
|
||||||
|
c.io.operandA.poke(operandA.U)
|
||||||
|
c.io.operandB.poke(operandB.U)
|
||||||
|
c.io.c.poke(fp32One.U)
|
||||||
|
c.io.elemM.poke(0.U)
|
||||||
|
c.io.elemN.poke(0.U)
|
||||||
|
c.io.outValid.expect(false.B)
|
||||||
|
|
||||||
|
c.clock.step()
|
||||||
|
c.io.valid.poke(false.B)
|
||||||
|
c.io.outValid.expect(false.B)
|
||||||
|
|
||||||
|
c.clock.step()
|
||||||
|
c.clock.step()
|
||||||
|
c.clock.step()
|
||||||
|
c.clock.step()
|
||||||
|
c.io.outValid.expect(true.B)
|
||||||
|
c.io.out.expect(fp32Seventeen.U)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user