Compare commits
6 Commits
wu-blackwe
...
wu-blackwe
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ca4c48251d | ||
|
|
12f5c6d92d | ||
|
|
68a7f66046 | ||
|
|
2afb96bb14 | ||
|
|
007350fd5a | ||
|
|
47d6585896 |
Submodule src/main/resources/vsrc/vortex updated: abee301b6e...9251ba0a24
@@ -6,6 +6,29 @@ package radiance.core
|
||||
import chisel3._
|
||||
import chisel3.util._
|
||||
|
||||
object TensorCoreBlackwellFP8Packing {
|
||||
def fp8Byte(x: UInt, idx: Int): UInt = {
|
||||
x((idx + 1) * 8 - 1, idx * 8)
|
||||
}
|
||||
|
||||
def selectA(operandA: UInt, k: Int, elemM: UInt, numLanes: Int): UInt = {
|
||||
if (numLanes == 4) {
|
||||
Mux(elemM.asBool, fp8Byte(operandA, 8 + k), fp8Byte(operandA, k))
|
||||
} else {
|
||||
MuxLookup(elemM, fp8Byte(operandA, k))(Seq(
|
||||
0.U -> fp8Byte(operandA, k),
|
||||
1.U -> fp8Byte(operandA, 8 + k),
|
||||
2.U -> fp8Byte(operandA, 16 + k),
|
||||
3.U -> fp8Byte(operandA, 24 + k)
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
def selectB(operandB: UInt, k: Int, elemN: UInt): UInt = {
|
||||
Mux(elemN.asBool, fp8Byte(operandB, 8 + k), fp8Byte(operandB, k))
|
||||
}
|
||||
}
|
||||
|
||||
class TensorCoreBlackwell(
|
||||
val numWarps: Int,
|
||||
val numLanes: Int,
|
||||
@@ -13,7 +36,7 @@ class TensorCoreBlackwell(
|
||||
val numSourceIds: Int = 16,
|
||||
val numFPRegs: Int = 32
|
||||
) extends Module {
|
||||
require(half, "Blackwell MMA currently supports FP16 inputs only")
|
||||
require(half, "Blackwell MMA compatibility flag must remain true; BWGMMA inputs are FP8 E4M3 on this branch")
|
||||
require(numLanes == 4 || numLanes == 8,
|
||||
s"Blackwell MMA currently supports 4 or 8 lanes, got ${numLanes}")
|
||||
|
||||
@@ -198,30 +221,16 @@ class TensorCoreBlackwell(
|
||||
val dpuInValid = WireDefault(false.B)
|
||||
val dpu = Module(new TensorDotProductUnit(
|
||||
dim = 8,
|
||||
half = true
|
||||
half = false,
|
||||
inputType = TensorInputType.FP8E4M3
|
||||
))
|
||||
|
||||
private def halfWord(x: UInt, idx: Int): UInt = {
|
||||
x((idx + 1) * 16 - 1, idx * 16)
|
||||
}
|
||||
|
||||
val elemM = if (numLanes == 4) elemReg(0, 0) else elemReg(1, 0)
|
||||
val elemN = if (numLanes == 4) elemReg(1) else elemReg(2)
|
||||
dpu.io.in.valid := dpuInValid
|
||||
for (k <- 0 until 8) {
|
||||
dpu.io.in.bits.a(k) := (
|
||||
if (numLanes == 4) {
|
||||
Mux(elemM.asBool, halfWord(operandA, 8 + k), halfWord(operandA, k))
|
||||
} else {
|
||||
MuxLookup(elemM, halfWord(operandA, k))(Seq(
|
||||
0.U -> halfWord(operandA, k),
|
||||
1.U -> halfWord(operandA, 8 + k),
|
||||
2.U -> halfWord(operandA, 16 + k),
|
||||
3.U -> halfWord(operandA, 24 + k)
|
||||
))
|
||||
}
|
||||
)
|
||||
dpu.io.in.bits.b(k) := Mux(elemN.asBool, halfWord(operandB, 8 + k), halfWord(operandB, k))
|
||||
dpu.io.in.bits.a(k) := TensorCoreBlackwellFP8Packing.selectA(operandA, k, elemM, numLanes)
|
||||
dpu.io.in.bits.b(k) := TensorCoreBlackwellFP8Packing.selectB(operandB, k, elemN)
|
||||
}
|
||||
dpu.io.in.bits.c := cWords(elemReg)
|
||||
dpu.io.stall := false.B
|
||||
|
||||
@@ -7,17 +7,116 @@ import chisel3._
|
||||
import chisel3.util._
|
||||
import freechips.rocketchip.tile
|
||||
|
||||
object FP8E4M3 {
|
||||
private val Bias = 7
|
||||
|
||||
private def decodeToFloat(bits: Int): Float = {
|
||||
val sign = (bits >> 7) & 0x1
|
||||
val exp = (bits >> 3) & 0xf
|
||||
val frac = bits & 0x7
|
||||
|
||||
val magnitude =
|
||||
if (exp == 0) {
|
||||
if (frac == 0) 0.0
|
||||
else (frac.toDouble / 8.0) * Math.pow(2.0, 1 - Bias)
|
||||
} else {
|
||||
(1.0 + frac.toDouble / 8.0) * Math.pow(2.0, exp - Bias)
|
||||
}
|
||||
|
||||
val value = if (sign == 1) -magnitude else magnitude
|
||||
value.toFloat
|
||||
}
|
||||
|
||||
private def fp32Bits(bits: Int): BigInt = {
|
||||
BigInt(java.lang.Float.floatToRawIntBits(decodeToFloat(bits)).toLong & 0xffffffffL)
|
||||
}
|
||||
|
||||
def toFloat32(x: UInt): UInt = {
|
||||
MuxLookup(x, 0.U(32.W))((0 until 256).map { bits =>
|
||||
bits.U(8.W) -> fp32Bits(bits).U(32.W)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
object FP8E4M3MulToFloat32 {
|
||||
private val Bias = 7
|
||||
|
||||
def apply(a: UInt, b: UInt): UInt = {
|
||||
val sign = a(7) ^ b(7)
|
||||
val expA = a(6, 3)
|
||||
val expB = b(6, 3)
|
||||
val fracA = a(2, 0)
|
||||
val fracB = b(2, 0)
|
||||
val zeroA = expA === 0.U && fracA === 0.U
|
||||
val zeroB = expB === 0.U && fracB === 0.U
|
||||
val isZero = zeroA || zeroB
|
||||
|
||||
val sigA = Mux(expA === 0.U, Cat(0.U(1.W), fracA), Cat(1.U(1.W), fracA))
|
||||
val sigB = Mux(expB === 0.U, Cat(0.U(1.W), fracB), Cat(1.U(1.W), fracB))
|
||||
val prodSig = sigA * sigB
|
||||
|
||||
val scaleA = Wire(SInt(6.W))
|
||||
val scaleB = Wire(SInt(6.W))
|
||||
scaleA := Mux(expA === 0.U, -9.S(6.W), expA.zext - (Bias + 3).S(6.W))
|
||||
scaleB := Mux(expB === 0.U, -9.S(6.W), expB.zext - (Bias + 3).S(6.W))
|
||||
|
||||
val msb = Wire(UInt(3.W))
|
||||
when(prodSig(7)) {
|
||||
msb := 7.U
|
||||
}.elsewhen(prodSig(6)) {
|
||||
msb := 6.U
|
||||
}.elsewhen(prodSig(5)) {
|
||||
msb := 5.U
|
||||
}.elsewhen(prodSig(4)) {
|
||||
msb := 4.U
|
||||
}.elsewhen(prodSig(3)) {
|
||||
msb := 3.U
|
||||
}.elsewhen(prodSig(2)) {
|
||||
msb := 2.U
|
||||
}.elsewhen(prodSig(1)) {
|
||||
msb := 1.U
|
||||
}.otherwise {
|
||||
msb := 0.U
|
||||
}
|
||||
|
||||
val normalized = (prodSig << (7.U - msb))(7, 0)
|
||||
val exponent = (scaleA + scaleB + msb.zext + 127.S(10.W)).asUInt(7, 0)
|
||||
val fraction = Cat(normalized(6, 0), 0.U(16.W))
|
||||
Mux(isZero, Cat(sign, 0.U(31.W)), Cat(sign, exponent, fraction))
|
||||
}
|
||||
}
|
||||
|
||||
object TensorInputType extends Enumeration {
|
||||
val FP16, FP32, FP8E4M3 = Value
|
||||
|
||||
def fromHalf(half: Boolean): Value = {
|
||||
if (half) FP16 else FP32
|
||||
}
|
||||
}
|
||||
|
||||
// Implements the four-element dot product (FEDP) unit in Volta Tensor Cores.
|
||||
// `half`: if True, generate fp16 MACs; if False fp32.
|
||||
class TensorDotProductUnit(
|
||||
val dim: Int = 4,
|
||||
val half: Boolean
|
||||
val dim: Int,
|
||||
val half: Boolean,
|
||||
val inputType: TensorInputType.Value
|
||||
) extends Module with tile.HasFPUParameters {
|
||||
val tIn = if (half) tile.FType.H else tile.FType.S
|
||||
def this(dim: Int = 4, half: Boolean) = {
|
||||
this(dim, half, TensorInputType.fromHalf(half))
|
||||
}
|
||||
|
||||
val tIn = inputType match {
|
||||
case TensorInputType.FP16 => tile.FType.H
|
||||
case TensorInputType.FP32 => tile.FType.S
|
||||
case TensorInputType.FP8E4M3 => tile.FType.S
|
||||
}
|
||||
// output datatype fixed to single-precision
|
||||
val tOut = tile.FType.S
|
||||
|
||||
val inFLen = tIn.ieeeWidth
|
||||
val inFLen = inputType match {
|
||||
case TensorInputType.FP8E4M3 => 8
|
||||
case _ => tIn.ieeeWidth
|
||||
}
|
||||
val outFLen = tOut.ieeeWidth
|
||||
val fLen = outFLen // needed for HasFPUParameters
|
||||
val minFLen = 16 // fp16
|
||||
@@ -40,20 +139,39 @@ class TensorDotProductUnit(
|
||||
// [IEEE] -> recode() -> unbox() -> [Hardfloat] -> box() -> ieee() -> [IEEE]
|
||||
// make sure recoding/uncoding happens only at the edge, not at every
|
||||
// pipeline stage inside the dpu
|
||||
val tag = if (half) H else S
|
||||
val in1 = io.in.bits.a.map(x => unbox(recode(x, tag), tag, Some(tIn)))
|
||||
val in2 = io.in.bits.b.map(x => unbox(recode(x, tag), tag, Some(tIn)))
|
||||
val in3 = unbox(recode(io.in.bits.c, S), S, Some(tOut))
|
||||
val tag = inputType match {
|
||||
case TensorInputType.FP16 => H
|
||||
case TensorInputType.FP32 => S
|
||||
case TensorInputType.FP8E4M3 => S
|
||||
}
|
||||
if (inputType == TensorInputType.FP8E4M3) {
|
||||
val dpu = Module(new DotProductPipeFP8E4M3(dim))
|
||||
dpu.io.in.valid := io.in.valid
|
||||
dpu.io.in.bits.a := io.in.bits.a
|
||||
dpu.io.in.bits.b := io.in.bits.b
|
||||
dpu.io.in.bits.c := io.in.bits.c
|
||||
dpu.io.stall := io.stall
|
||||
|
||||
val dpu = Module(new DotProductPipe(dim, tIn, tOut))
|
||||
dpu.io.in.valid := io.in.valid
|
||||
dpu.io.in.bits.a := in1
|
||||
dpu.io.in.bits.b := in2
|
||||
dpu.io.in.bits.c := in3
|
||||
dpu.io.stall := io.stall
|
||||
io.out.valid := dpu.io.out.valid
|
||||
io.out.bits.data := dpu.io.out.bits.data
|
||||
} else {
|
||||
def recodeInput(x: Bits): UInt = {
|
||||
unbox(recode(x.asUInt, tag), tag, Some(tIn))
|
||||
}
|
||||
val in1 = io.in.bits.a.map(recodeInput)
|
||||
val in2 = io.in.bits.b.map(recodeInput)
|
||||
val in3 = unbox(recode(io.in.bits.c, S), S, Some(tOut))
|
||||
|
||||
io.out.valid := dpu.io.out.valid
|
||||
io.out.bits.data := ieee(box(dpu.io.out.bits.data, S))
|
||||
val dpu = Module(new DotProductPipe(dim, tIn, tOut))
|
||||
dpu.io.in.valid := io.in.valid
|
||||
dpu.io.in.bits.a := in1
|
||||
dpu.io.in.bits.b := in2
|
||||
dpu.io.in.bits.c := in3
|
||||
dpu.io.stall := io.stall
|
||||
|
||||
io.out.valid := dpu.io.out.valid
|
||||
io.out.bits.data := ieee(box(dpu.io.out.bits.data, S))
|
||||
}
|
||||
}
|
||||
|
||||
// An implementation of chisel3.util.Pipe that supports stalls.
|
||||
@@ -236,6 +354,89 @@ class DotProductPipe(dim: Int, inputType: tile.FType, outputType: tile.FType) ex
|
||||
io.out.bits.data := accStageOut.bits
|
||||
}
|
||||
|
||||
class DotProductPipeFP8E4M3(dim: Int) extends Module with tile.HasFPUParameters {
|
||||
val tOut = tile.FType.S
|
||||
val outExpWidth = tOut.exp
|
||||
val outSigWidth = tOut.sig
|
||||
val recOutFLen = outExpWidth + outSigWidth + 1
|
||||
val fLen = tOut.ieeeWidth
|
||||
val minFLen = 16
|
||||
def xLen = 32
|
||||
|
||||
val io = IO(new Bundle {
|
||||
val in = Flipped(Valid(new Bundle {
|
||||
val a = Vec(dim, Bits(8.W))
|
||||
val b = Vec(dim, Bits(8.W))
|
||||
val c = Bits(32.W)
|
||||
}))
|
||||
val stall = Input(Bool())
|
||||
val out = Valid(new Bundle {
|
||||
val data = Bits(32.W)
|
||||
})
|
||||
})
|
||||
|
||||
val productRecoded = io.in.bits.a.zip(io.in.bits.b).map { case (a, b) =>
|
||||
unbox(recode(FP8E4M3MulToFloat32(a.asUInt, b.asUInt), S), S, Some(tOut))
|
||||
}
|
||||
val inC = unbox(recode(io.in.bits.c, S), S, Some(tOut))
|
||||
|
||||
val productStageOut = StallingPipe(io.stall, io.in.valid, VecInit(productRecoded))
|
||||
val productStageC = StallingPipe(io.stall, io.in.valid, inC)
|
||||
|
||||
val log2Dim = log2Ceil(dim)
|
||||
require(dim == (1 << log2Dim), s"dim (${dim}) is not power of two!")
|
||||
|
||||
val interim = (log2Dim to 0 by -1).map { i =>
|
||||
Wire(Valid(Vec(1 << i, Bits(recOutFLen.W))))
|
||||
}
|
||||
val interimC = (log2Dim to 0 by -1).map(_ => Wire(Valid(Bits(recOutFLen.W))))
|
||||
interim(0) := productStageOut
|
||||
interimC(0) := productStageC
|
||||
|
||||
val (addStageOut, addStageC) = (interim zip interimC).reduce {
|
||||
(inputsAndC, outputsAndC) => {
|
||||
val (inputs, inC) = inputsAndC
|
||||
val (outputs, outC) = outputsAndC
|
||||
|
||||
require(inputs.bits.length == 2 * outputs.bits.length)
|
||||
val thisDim = inputs.bits.length
|
||||
val adders = Seq.fill(thisDim / 2)(
|
||||
Module(new hardfloat.AddRecFN(outExpWidth, outSigWidth))
|
||||
)
|
||||
val addOuts = adders.zipWithIndex.map { case (a, i) =>
|
||||
a.io.subOp := 0.U
|
||||
a.io.a := inputs.bits(2 * i + 0)
|
||||
a.io.b := inputs.bits(2 * i + 1)
|
||||
a.io.roundingMode := hardfloat.consts.round_near_even
|
||||
a.io.detectTininess := hardfloat.consts.tininess_afterRounding
|
||||
a.io.out
|
||||
}
|
||||
|
||||
outputs := StallingPipe(io.stall, inputs.valid, VecInit(addOuts))
|
||||
outC := StallingPipe(io.stall, inputs.valid, inC.bits)
|
||||
when(inputs.valid =/= inC.valid) {
|
||||
printf("WARN: DotProductPipeFP8E4M3 input/C valid mismatch: inputs=%d c=%d\n",
|
||||
inputs.valid, inC.valid)
|
||||
}
|
||||
|
||||
(outputs, outC)
|
||||
}
|
||||
}
|
||||
require(addStageOut.bits.length == 1)
|
||||
|
||||
val acc = Module(new hardfloat.AddRecFN(outExpWidth, outSigWidth))
|
||||
acc.io.subOp := 0.U
|
||||
acc.io.a := addStageOut.bits(0)
|
||||
acc.io.b := addStageC.bits
|
||||
acc.io.roundingMode := hardfloat.consts.round_near_even
|
||||
acc.io.detectTininess := hardfloat.consts.tininess_afterRounding
|
||||
|
||||
val accStageOut = StallingPipe(io.stall, addStageOut.valid, acc.io.out)
|
||||
|
||||
io.out.valid := accStageOut.valid
|
||||
io.out.bits.data := ieee(box(accStageOut.bits, S))
|
||||
}
|
||||
|
||||
class MulAddRecFNPipe(latency: Int, expWidth: Int, sigWidth: Int) extends Module {
|
||||
require(latency <= 2)
|
||||
|
||||
|
||||
@@ -851,6 +851,9 @@ class RadianceTileModuleImp(outer: RadianceTile)
|
||||
core.io.tc_tmem_C_rready := DontCare
|
||||
core.io.tc_tmem_C_rdata := DontCare
|
||||
core.io.tc_tmem_C_wready := DontCare
|
||||
core.io.sc_tmem_rready := DontCare
|
||||
core.io.sc_tmem_rdata := DontCare
|
||||
core.io.sc_tmem_wready := DontCare
|
||||
}
|
||||
|
||||
def connectTensorBlackwell = {
|
||||
@@ -885,59 +888,166 @@ class RadianceTileModuleImp(outer: RadianceTile)
|
||||
tcDData.foreach(_ := 0.U)
|
||||
tcDTag.foreach(_ := 0.U)
|
||||
|
||||
// TMEM matrix: one shared 2R1W SRAM. read0 is operand A, read1 is C.
|
||||
// TMEM matrix: four banked 2R1W SRAMs. Tensor A/C reads and scalar
|
||||
// reads can proceed together when bank placement avoids conflicts.
|
||||
// Each warp owns 2KB: A tile and C tile are 1KB each. The row count
|
||||
// scales with the physical fragment width (16B for 4 lanes, 32B for 8).
|
||||
val tmemBytesPerWarp = 2048
|
||||
val tmemDepth = outer.numWarps * (tmemBytesPerWarp / outer.tcSmemSize)
|
||||
val tmem = Module(new radiance.memory.TwoReadOneWriteSyncMem(
|
||||
tmemDepth, UInt((outer.tcSmemSize * 8).W)))
|
||||
val tmemBanks = 4
|
||||
val tmemBankBits = log2Ceil(tmemBanks)
|
||||
val tmemBankDepth = tmemDepth / tmemBanks
|
||||
require(isPow2(tmemBanks))
|
||||
require(tmemDepth % tmemBanks == 0)
|
||||
val tmem = Seq.fill(tmemBanks) {
|
||||
Module(new radiance.memory.TwoReadOneWriteSyncMem(
|
||||
tmemBankDepth, UInt((outer.tcSmemSize * 8).W)))
|
||||
}
|
||||
|
||||
val aReadArb = Module(new RRArbiter(UInt(tmemAddrBits.W), nTC))
|
||||
val cReadArb = Module(new RRArbiter(UInt(tmemAddrBits.W), nTC))
|
||||
class TmemReadReq extends Bundle {
|
||||
val addr = UInt(tmemAddrBits.W)
|
||||
val src = UInt(2.W)
|
||||
val tc = UInt(log2Ceil(nTC max 2).W)
|
||||
}
|
||||
|
||||
class TmemWriteReq extends Bundle {
|
||||
val addr = UInt(tmemAddrBits.W)
|
||||
val data = UInt(tmemDataBits.W)
|
||||
val mask = UInt(tmemMaskBits.W)
|
||||
}
|
||||
val cWriteArb = Module(new RRArbiter(new TmemWriteReq, nTC))
|
||||
|
||||
(0 until nTC).foreach { tc =>
|
||||
aReadArb.io.in(tc).valid := core.io.tc_tmem_A_ren(tc)
|
||||
aReadArb.io.in(tc).bits := slice(core.io.tc_tmem_A_raddr, tmemAddrBits, tc)
|
||||
cReadArb.io.in(tc).valid := core.io.tc_tmem_C_ren(tc)
|
||||
cReadArb.io.in(tc).bits := slice(core.io.tc_tmem_C_raddr, tmemAddrBits, tc)
|
||||
cWriteArb.io.in(tc).valid := core.io.tc_tmem_C_wen(tc)
|
||||
cWriteArb.io.in(tc).bits.addr := slice(core.io.tc_tmem_C_waddr, tmemAddrBits, tc)
|
||||
cWriteArb.io.in(tc).bits.data := slice(core.io.tc_tmem_C_wdata, tmemDataBits, tc)
|
||||
cWriteArb.io.in(tc).bits.mask := slice(core.io.tc_tmem_C_mask, tmemMaskBits, tc)
|
||||
val src = UInt(1.W)
|
||||
val tc = UInt(log2Ceil(nTC max 2).W)
|
||||
}
|
||||
|
||||
aReadArb.io.out.ready := true.B
|
||||
cReadArb.io.out.ready := true.B
|
||||
cWriteArb.io.out.ready := true.B
|
||||
def bank(addr: UInt): UInt = addr(tmemBankBits - 1, 0)
|
||||
def row(addr: UInt): UInt = addr(tmemAddrBits - 1, tmemBankBits)
|
||||
|
||||
tmem.io.ren0 := aReadArb.io.out.fire
|
||||
tmem.io.raddr0 := aReadArb.io.out.bits
|
||||
tmem.io.ren1 := cReadArb.io.out.fire
|
||||
tmem.io.raddr1 := cReadArb.io.out.bits
|
||||
tmem.io.wen := cWriteArb.io.out.fire
|
||||
tmem.io.waddr := cWriteArb.io.out.bits.addr
|
||||
tmem.io.wdata := cWriteArb.io.out.bits.data
|
||||
tmem.io.mask := cWriteArb.io.out.bits.mask
|
||||
val aReady = Wire(Vec(nTC, Bool()))
|
||||
val cReady = Wire(Vec(nTC, Bool()))
|
||||
val wReady = Wire(Vec(nTC, Bool()))
|
||||
val scReadReady = Wire(Bool())
|
||||
val scWriteReady = Wire(Bool())
|
||||
aReady.foreach(_ := false.B)
|
||||
cReady.foreach(_ := false.B)
|
||||
wReady.foreach(_ := false.B)
|
||||
scReadReady := false.B
|
||||
scWriteReady := false.B
|
||||
|
||||
val aReadGrant = RegNext(Mux(aReadArb.io.out.fire, UIntToOH(aReadArb.io.chosen, nTC), 0.U(nTC.W)))
|
||||
val cReadGrant = RegNext(Mux(cReadArb.io.out.fire, UIntToOH(cReadArb.io.chosen, nTC), 0.U(nTC.W)))
|
||||
core.io.tc_tmem_A_rready := VecInit(aReadArb.io.in.map(_.fire)).asUInt
|
||||
core.io.tc_tmem_C_rready := VecInit(cReadArb.io.in.map(_.fire)).asUInt
|
||||
core.io.tc_tmem_C_wready := VecInit(cWriteArb.io.in.map(_.fire)).asUInt
|
||||
val read0Grant = Wire(Vec(tmemBanks, new TmemReadReq))
|
||||
val read1Grant = Wire(Vec(tmemBanks, new TmemReadReq))
|
||||
val read0Valid = Wire(Vec(tmemBanks, Bool()))
|
||||
val read1Valid = Wire(Vec(tmemBanks, Bool()))
|
||||
val writeGrant = Wire(Vec(tmemBanks, new TmemWriteReq))
|
||||
val writeValid = Wire(Vec(tmemBanks, Bool()))
|
||||
read0Grant.foreach(_ := 0.U.asTypeOf(new TmemReadReq))
|
||||
read1Grant.foreach(_ := 0.U.asTypeOf(new TmemReadReq))
|
||||
read0Valid.foreach(_ := false.B)
|
||||
read1Valid.foreach(_ := false.B)
|
||||
writeGrant.foreach(_ := 0.U.asTypeOf(new TmemWriteReq))
|
||||
writeValid.foreach(_ := false.B)
|
||||
|
||||
(0 until tmemBanks).foreach { b =>
|
||||
val requests = (0 until nTC).flatMap { tc =>
|
||||
val aAddr = slice(core.io.tc_tmem_A_raddr, tmemAddrBits, tc)
|
||||
val cAddr = slice(core.io.tc_tmem_C_raddr, tmemAddrBits, tc)
|
||||
Seq(
|
||||
(core.io.tc_tmem_A_ren(tc).asBool && bank(aAddr) === b.U, aAddr, 0.U(2.W), tc.U),
|
||||
(core.io.tc_tmem_C_ren(tc).asBool && bank(cAddr) === b.U, cAddr, 1.U(2.W), tc.U)
|
||||
)
|
||||
} ++ Seq(
|
||||
(core.io.sc_tmem_ren.asBool && bank(core.io.sc_tmem_raddr) === b.U,
|
||||
core.io.sc_tmem_raddr, 2.U(2.W), 0.U)
|
||||
)
|
||||
|
||||
var used0 = false.B
|
||||
var used1 = false.B
|
||||
requests.foreach { case (valid, addr, src, tc) =>
|
||||
val grant0 = valid && !used0
|
||||
val grant1 = valid && used0 && !used1
|
||||
when(grant0) {
|
||||
read0Grant(b).addr := addr
|
||||
read0Grant(b).src := src
|
||||
read0Grant(b).tc := tc
|
||||
}
|
||||
when(grant1) {
|
||||
read1Grant(b).addr := addr
|
||||
read1Grant(b).src := src
|
||||
read1Grant(b).tc := tc
|
||||
}
|
||||
used0 = used0 || grant0
|
||||
used1 = used1 || grant1
|
||||
when(grant0 || grant1) {
|
||||
when(src === 0.U) { aReady(tc) := true.B }
|
||||
when(src === 1.U) { cReady(tc) := true.B }
|
||||
when(src === 2.U) { scReadReady := true.B }
|
||||
}
|
||||
}
|
||||
read0Valid(b) := used0
|
||||
read1Valid(b) := used1
|
||||
|
||||
var writeUsed = false.B
|
||||
(0 until nTC).foreach { tc =>
|
||||
val addr = slice(core.io.tc_tmem_C_waddr, tmemAddrBits, tc)
|
||||
val valid = core.io.tc_tmem_C_wen(tc).asBool && bank(addr) === b.U
|
||||
val grant = valid && !writeUsed
|
||||
when(grant) {
|
||||
writeValid(b) := true.B
|
||||
writeGrant(b).addr := addr
|
||||
writeGrant(b).data := slice(core.io.tc_tmem_C_wdata, tmemDataBits, tc)
|
||||
writeGrant(b).mask := slice(core.io.tc_tmem_C_mask, tmemMaskBits, tc)
|
||||
writeGrant(b).src := 0.U
|
||||
writeGrant(b).tc := tc.U
|
||||
wReady(tc) := true.B
|
||||
}
|
||||
writeUsed = writeUsed || grant
|
||||
}
|
||||
|
||||
val scWValid = core.io.sc_tmem_wen.asBool && bank(core.io.sc_tmem_waddr) === b.U
|
||||
val scWGrant = scWValid && !writeUsed
|
||||
when(scWGrant) {
|
||||
writeValid(b) := true.B
|
||||
writeGrant(b).addr := core.io.sc_tmem_waddr
|
||||
writeGrant(b).data := core.io.sc_tmem_wdata
|
||||
writeGrant(b).mask := core.io.sc_tmem_mask
|
||||
writeGrant(b).src := 1.U
|
||||
writeGrant(b).tc := 0.U
|
||||
scWriteReady := true.B
|
||||
}
|
||||
|
||||
tmem(b).io.ren0 := read0Valid(b)
|
||||
tmem(b).io.raddr0 := row(read0Grant(b).addr)
|
||||
tmem(b).io.ren1 := read1Valid(b)
|
||||
tmem(b).io.raddr1 := row(read1Grant(b).addr)
|
||||
tmem(b).io.wen := writeValid(b)
|
||||
tmem(b).io.waddr := row(writeGrant(b).addr)
|
||||
tmem(b).io.wdata := writeGrant(b).data
|
||||
tmem(b).io.mask := writeGrant(b).mask
|
||||
}
|
||||
|
||||
val read0GrantReg = RegNext(read0Grant)
|
||||
val read1GrantReg = RegNext(read1Grant)
|
||||
val read0ValidReg = RegNext(read0Valid)
|
||||
val read1ValidReg = RegNext(read1Valid)
|
||||
core.io.tc_tmem_A_rready := aReady.asUInt
|
||||
core.io.tc_tmem_C_rready := cReady.asUInt
|
||||
core.io.tc_tmem_C_wready := wReady.asUInt
|
||||
core.io.sc_tmem_rready := scReadReady.asUInt
|
||||
core.io.sc_tmem_wready := scWriteReady.asUInt
|
||||
core.io.tc_tmem_A_rdata := VecInit((0 until nTC).map { tc =>
|
||||
Mux(aReadGrant(tc), tmem.io.rdata0, 0.U(tmemDataBits.W))
|
||||
VecInit((0 until tmemBanks).map { b =>
|
||||
Mux(read0ValidReg(b) && read0GrantReg(b).src === 0.U && read0GrantReg(b).tc === tc.U, tmem(b).io.rdata0,
|
||||
Mux(read1ValidReg(b) && read1GrantReg(b).src === 0.U && read1GrantReg(b).tc === tc.U, tmem(b).io.rdata1, 0.U(tmemDataBits.W)))
|
||||
}).reduce(_ | _)
|
||||
}).asUInt
|
||||
core.io.tc_tmem_C_rdata := VecInit((0 until nTC).map { tc =>
|
||||
Mux(cReadGrant(tc), tmem.io.rdata1, 0.U(tmemDataBits.W))
|
||||
VecInit((0 until tmemBanks).map { b =>
|
||||
Mux(read0ValidReg(b) && read0GrantReg(b).src === 1.U && read0GrantReg(b).tc === tc.U, tmem(b).io.rdata0,
|
||||
Mux(read1ValidReg(b) && read1GrantReg(b).src === 1.U && read1GrantReg(b).tc === tc.U, tmem(b).io.rdata1, 0.U(tmemDataBits.W)))
|
||||
}).reduce(_ | _)
|
||||
}).asUInt
|
||||
core.io.sc_tmem_rdata := VecInit((0 until tmemBanks).map { b =>
|
||||
Mux(read0ValidReg(b) && read0GrantReg(b).src === 2.U, tmem(b).io.rdata0,
|
||||
Mux(read1ValidReg(b) && read1GrantReg(b).src === 2.U, tmem(b).io.rdata1, 0.U(tmemDataBits.W)))
|
||||
}).reduce(_ | _)
|
||||
|
||||
// port 2: SMEM B, one TL client per tensor core. RadianceSharedMem arbitrates them.
|
||||
(0 until nTC).foreach { tc =>
|
||||
@@ -1025,6 +1135,9 @@ class RadianceTileModuleImp(outer: RadianceTile)
|
||||
core.io.tc_tmem_C_rready := DontCare
|
||||
core.io.tc_tmem_C_rdata := DontCare
|
||||
core.io.tc_tmem_C_wready := DontCare
|
||||
core.io.sc_tmem_rready := DontCare
|
||||
core.io.sc_tmem_rdata := DontCare
|
||||
core.io.sc_tmem_wready := DontCare
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -120,6 +120,15 @@ class VortexBundle(tile: RadianceTile)(implicit p: Parameters) extends CoreBundl
|
||||
val tc_tmem_C_waddr = Output(UInt((numTensorCores * 9).W))
|
||||
val tc_tmem_C_wdata = Output(UInt((numTensorCores * numLanes * 32).W))
|
||||
val tc_tmem_C_mask = Output(UInt((numTensorCores * numLanes * 4).W))
|
||||
val sc_tmem_ren = Output(UInt(1.W))
|
||||
val sc_tmem_rready = Input(UInt(1.W))
|
||||
val sc_tmem_raddr = Output(UInt(9.W))
|
||||
val sc_tmem_rdata = Input(UInt((numLanes * 32).W))
|
||||
val sc_tmem_wen = Output(UInt(1.W))
|
||||
val sc_tmem_wready = Input(UInt(1.W))
|
||||
val sc_tmem_waddr = Output(UInt(9.W))
|
||||
val sc_tmem_wdata = Output(UInt((numLanes * 32).W))
|
||||
val sc_tmem_mask = Output(UInt((numLanes * 4).W))
|
||||
|
||||
// FIXME: hardcoded
|
||||
val barrierIdBits = tile.barrierMasterNode.out(0)._2.barrierIdBits
|
||||
@@ -351,6 +360,7 @@ class Vortex(tile: RadianceTile)(implicit p: Parameters)
|
||||
addResource("/vsrc/vortex/hw/rtl/fpu/VX_fpu_div.sv")
|
||||
addResource("/vsrc/vortex/hw/rtl/fpu/VX_fpu_dpi.sv")
|
||||
addResource("/vsrc/vortex/hw/rtl/fpu/VX_fpu_dsp.sv")
|
||||
addResource("/vsrc/vortex/hw/rtl/fpu/VX_fpu_exp.sv")
|
||||
addResource("/vsrc/vortex/hw/rtl/fpu/VX_fpu_fma.sv")
|
||||
addResource("/vsrc/vortex/hw/rtl/fpu/VX_fpu_ncomp.sv")
|
||||
addResource("/vsrc/vortex/hw/rtl/fpu/VX_fpu_rounding.sv")
|
||||
|
||||
110
src/test/scala/radiance/FP8E4M3Test.scala
Normal file
110
src/test/scala/radiance/FP8E4M3Test.scala
Normal file
@@ -0,0 +1,110 @@
|
||||
package radiance.core
|
||||
|
||||
import chisel3._
|
||||
import chiseltest._
|
||||
import org.scalatest.flatspec.AnyFlatSpec
|
||||
|
||||
class FP8E4M3DecodeHarness extends Module {
|
||||
val io = IO(new Bundle {
|
||||
val in = Input(UInt(8.W))
|
||||
val out = Output(UInt(32.W))
|
||||
})
|
||||
|
||||
io.out := FP8E4M3.toFloat32(io.in)
|
||||
}
|
||||
|
||||
class FP8E4M3MulHarness extends Module {
|
||||
val io = IO(new Bundle {
|
||||
val a = Input(UInt(8.W))
|
||||
val b = Input(UInt(8.W))
|
||||
val out = Output(UInt(32.W))
|
||||
})
|
||||
|
||||
io.out := FP8E4M3MulToFloat32(io.a, io.b)
|
||||
}
|
||||
|
||||
class FP8E4M3Test extends AnyFlatSpec with ChiselScalatestTester {
|
||||
behavior of "FP8E4M3"
|
||||
|
||||
it should "decode representative E4M3 values to FP32 bits" in {
|
||||
test(new FP8E4M3DecodeHarness) { c =>
|
||||
Seq(
|
||||
0x00 -> 0x00000000L,
|
||||
0x80 -> 0x80000000L,
|
||||
0x38 -> 0x3f800000L,
|
||||
0x40 -> 0x40000000L,
|
||||
0x30 -> 0x3f000000L,
|
||||
0x3c -> 0x3fc00000L
|
||||
).foreach { case (fp8, fp32) =>
|
||||
c.io.in.poke(fp8.U)
|
||||
c.clock.step()
|
||||
c.io.out.expect(fp32.U)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
it should "multiply E4M3 operands with FP8-width significands and return FP32 bits" in {
|
||||
test(new FP8E4M3MulHarness) { c =>
|
||||
Seq(
|
||||
(0x38, 0x40, 0x40000000L), // 1.0 * 2.0 = 2.0
|
||||
(0x30, 0x3c, 0x3f400000L), // 0.5 * 1.5 = 0.75
|
||||
(0xb8, 0x40, 0xc0000000L), // -1.0 * 2.0 = -2.0
|
||||
(0x00, 0x40, 0x00000000L), // 0.0 * 2.0 = 0.0
|
||||
(0x80, 0x40, 0x80000000L) // -0.0 * 2.0 = -0.0
|
||||
).foreach { case (a, b, out) =>
|
||||
c.io.a.poke(a.U)
|
||||
c.io.b.poke(b.U)
|
||||
c.clock.step()
|
||||
c.io.out.expect(out.U)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
it should "run an 8-wide FP8 dot product with FP32 accumulation" in {
|
||||
test(new TensorDotProductUnit(8, half = false, inputType = TensorInputType.FP8E4M3)) { c =>
|
||||
c.io.in.valid.poke(true.B)
|
||||
c.io.stall.poke(false.B)
|
||||
for (i <- 0 until 8) {
|
||||
c.io.in.bits.a(i).poke(0x38.U(8.W))
|
||||
c.io.in.bits.b(i).poke(0x40.U(8.W))
|
||||
}
|
||||
c.io.in.bits.c.poke(0x3f800000L.U(32.W))
|
||||
|
||||
c.io.out.valid.expect(false.B)
|
||||
c.clock.step()
|
||||
c.io.in.valid.poke(false.B)
|
||||
c.io.out.valid.expect(false.B)
|
||||
|
||||
c.clock.step()
|
||||
c.clock.step()
|
||||
c.clock.step()
|
||||
c.clock.step()
|
||||
c.io.out.valid.expect(true.B)
|
||||
c.io.out.bits.data.expect(0x41880000L.U)
|
||||
}
|
||||
}
|
||||
|
||||
it should "run an 8-wide fractional FP8 dot product with FP32 accumulation" in {
|
||||
test(new TensorDotProductUnit(8, half = false, inputType = TensorInputType.FP8E4M3)) { c =>
|
||||
c.io.in.valid.poke(true.B)
|
||||
c.io.stall.poke(false.B)
|
||||
for (i <- 0 until 8) {
|
||||
c.io.in.bits.a(i).poke(0x30.U(8.W))
|
||||
c.io.in.bits.b(i).poke(0x3c.U(8.W))
|
||||
}
|
||||
c.io.in.bits.c.poke(0x40000000L.U(32.W))
|
||||
|
||||
c.io.out.valid.expect(false.B)
|
||||
c.clock.step()
|
||||
c.io.in.valid.poke(false.B)
|
||||
c.io.out.valid.expect(false.B)
|
||||
|
||||
c.clock.step()
|
||||
c.clock.step()
|
||||
c.clock.step()
|
||||
c.clock.step()
|
||||
c.io.out.valid.expect(true.B)
|
||||
c.io.out.bits.data.expect(0x41000000L.U)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -68,22 +68,22 @@ class TensorCoreBlackwellExtendedTest extends AnyFlatSpec with ChiselScalatestTe
|
||||
val cBase = BigInt(0x600) // row 48, C tile rows 48~79 (no overlap with A)
|
||||
val bBase = BigInt(0x800)
|
||||
|
||||
val fp16One = BigInt(0x3c00)
|
||||
val fp8One = BigInt(0x38)
|
||||
val fp32Zero = BigInt(0)
|
||||
// 4 sets × 8 dot products × (1.0 × 2.0) = 64.0f
|
||||
val fp32SixtyFour = BigInt(0x42800000L)
|
||||
|
||||
// Populate TMEM A at offset aBase (all 32 frags)
|
||||
val aFrag = packWords(Seq.fill(16)(fp16One), 16)
|
||||
val aFrag = packWords(Seq.fill(32)(fp8One), 8)
|
||||
val cFrag = packWords(Seq.fill(numLanes)(fp32Zero), 32)
|
||||
for (i <- 0 until 32) {
|
||||
tmem(aBase / fragBytes + i) = aFrag
|
||||
tmem(cBase / fragBytes + i) = cFrag
|
||||
}
|
||||
|
||||
// SMEM B with fp16 2.0
|
||||
val fp16Two = BigInt(0x4000)
|
||||
val bFrag = packWords(Seq.fill(16)(fp16Two), 16)
|
||||
// SMEM B with packed FP8 E4M3 2.0
|
||||
val fp8Two = BigInt(0x40)
|
||||
val bFrag = packWords(Seq.fill(32)(fp8Two), 8)
|
||||
val bMem = mutable.Map[BigInt, BigInt]().withDefaultValue(bFrag)
|
||||
for (i <- 0 until 32) bMem(bBase + i * fragBytes) = bFrag
|
||||
|
||||
|
||||
76
src/test/scala/radiance/TensorCoreBlackwellFP8Test.scala
Normal file
76
src/test/scala/radiance/TensorCoreBlackwellFP8Test.scala
Normal file
@@ -0,0 +1,76 @@
|
||||
package radiance.core
|
||||
|
||||
import chisel3._
|
||||
import chiseltest._
|
||||
import org.scalatest.flatspec.AnyFlatSpec
|
||||
|
||||
class TensorCoreBlackwellFP8MicrostepHarness extends Module {
|
||||
val io = IO(new Bundle {
|
||||
val valid = Input(Bool())
|
||||
val operandA = Input(UInt(512.W))
|
||||
val operandB = Input(UInt(256.W))
|
||||
val c = Input(UInt(32.W))
|
||||
val elemM = Input(UInt(2.W))
|
||||
val elemN = Input(UInt(1.W))
|
||||
val outValid = Output(Bool())
|
||||
val out = Output(UInt(32.W))
|
||||
})
|
||||
|
||||
val dpu = Module(new TensorDotProductUnit(
|
||||
dim = 8,
|
||||
half = false,
|
||||
inputType = TensorInputType.FP8E4M3
|
||||
))
|
||||
|
||||
dpu.io.in.valid := io.valid
|
||||
for (k <- 0 until 8) {
|
||||
dpu.io.in.bits.a(k) := TensorCoreBlackwellFP8Packing.selectA(io.operandA, k, io.elemM, numLanes = 8)
|
||||
dpu.io.in.bits.b(k) := TensorCoreBlackwellFP8Packing.selectB(io.operandB, k, io.elemN)
|
||||
}
|
||||
dpu.io.in.bits.c := io.c
|
||||
dpu.io.stall := false.B
|
||||
|
||||
io.outValid := dpu.io.out.valid
|
||||
io.out := dpu.io.out.bits.data
|
||||
}
|
||||
|
||||
class TensorCoreBlackwellFP8Test extends AnyFlatSpec with ChiselScalatestTester {
|
||||
behavior of "TensorCoreBlackwell FP8 operand microstep"
|
||||
|
||||
private def packWords(words: Seq[BigInt], width: Int): BigInt = {
|
||||
val mask = (BigInt(1) << width) - 1
|
||||
words.zipWithIndex.foldLeft(BigInt(0)) {
|
||||
case (acc, (word, i)) => acc | ((word & mask) << (i * width))
|
||||
}
|
||||
}
|
||||
|
||||
it should "select packed FP8 E4M3 operands and accumulate into FP32" in {
|
||||
test(new TensorCoreBlackwellFP8MicrostepHarness) { c =>
|
||||
val fp8One = BigInt(0x38)
|
||||
val fp8Two = BigInt(0x40)
|
||||
val fp32One = BigInt(0x3f800000L)
|
||||
val fp32Seventeen = BigInt(0x41880000L)
|
||||
val operandA = packWords(Seq.fill(64)(fp8One), 8)
|
||||
val operandB = packWords(Seq.fill(32)(fp8Two), 8)
|
||||
|
||||
c.io.valid.poke(true.B)
|
||||
c.io.operandA.poke(operandA.U)
|
||||
c.io.operandB.poke(operandB.U)
|
||||
c.io.c.poke(fp32One.U)
|
||||
c.io.elemM.poke(0.U)
|
||||
c.io.elemN.poke(0.U)
|
||||
c.io.outValid.expect(false.B)
|
||||
|
||||
c.clock.step()
|
||||
c.io.valid.poke(false.B)
|
||||
c.io.outValid.expect(false.B)
|
||||
|
||||
c.clock.step()
|
||||
c.clock.step()
|
||||
c.clock.step()
|
||||
c.clock.step()
|
||||
c.io.outValid.expect(true.B)
|
||||
c.io.out.expect(fp32Seventeen.U)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -249,13 +249,13 @@ class TensorCoreBlackwellTest extends AnyFlatSpec with ChiselScalatestTester {
|
||||
val bBase = BigInt(0x800)
|
||||
val cBase = BigInt(0x1000)
|
||||
|
||||
// A: all fp16 1.0 (0x3c00), 16 halves per frag
|
||||
val fp16One = BigInt(0x3c00)
|
||||
val fp16Two = BigInt(0x4000)
|
||||
// A/B: packed FP8 E4M3 bytes, 32 elements per 256-bit frag
|
||||
val fp8One = BigInt(0x38)
|
||||
val fp8Two = BigInt(0x40)
|
||||
val fp32One = BigInt(0x3f800000)
|
||||
val fp32SixtyFive = BigInt(0x42820000)
|
||||
val aFrag = packWords(Seq.fill(16)(fp16One), 16)
|
||||
val bFrag = packWords(Seq.fill(16)(fp16Two), 16)
|
||||
val aFrag = packWords(Seq.fill(32)(fp8One), 8)
|
||||
val bFrag = packWords(Seq.fill(32)(fp8Two), 8)
|
||||
val cFrag = packWords(Seq.fill(numLanes)(fp32One), 32)
|
||||
val expectedCFrag = packWords(Seq.fill(numLanes)(fp32SixtyFive), 32)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user