11 Commits

Author SHA1 Message Date
Zhongdi LUO
ca4c48251d feat: use fp8 multiply for blackwell bwgmma 2026-07-03 08:40:10 +00:00
Zhongdi LUO
12f5c6d92d feat: switch blackwell bwgmma inputs to fp8 2026-07-02 08:57:59 +00:00
Zhongdi LUO
68a7f66046 feat: add fp8 tensor dot product path 2026-07-02 08:06:42 +00:00
Zhongdi LUO
2afb96bb14 feat: add fp8 e4m3 decode support 2026-07-02 07:59:01 +00:00
Zhongdi LUO
007350fd5a feat: include vortex fexp RTL 2026-07-02 07:25:32 +00:00
Zhongdi LUO
47d6585896 Wire scalar TMEM through Radiance tile 2026-06-24 06:25:10 +00:00
Zhongdi LUO
f88085331e Save pre-TMEM-bank Radiance changes 2026-06-21 08:20:21 +00:00
Zhongdi LUO
1e78574113 Fix Blackwell SMEM fragment alignment 2026-05-27 08:43:36 +00:00
Zhongdi LUO
c6c30ec0dc Add 4-lane Blackwell tensor core support 2026-05-27 05:54:39 +00:00
126523c5d2 Track vortex masked barrier fix 2026-05-27 09:08:03 +08:00
87a4bbc757 Integrate WU architecture in Radiance 2026-05-25 19:25:59 +08:00
12 changed files with 1081 additions and 296 deletions

2
.gitmodules vendored
View File

@@ -1,3 +1,3 @@
[submodule "src/main/resources/vsrc/vortex"] [submodule "src/main/resources/vsrc/vortex"]
path = src/main/resources/vsrc/vortex path = src/main/resources/vsrc/vortex
url = https://github.com/hansungk/vortex.git url = https://git.nudt.space/wu-arch/vortex.git

View File

@@ -24,7 +24,7 @@ ifeq ($(shell echo $(CONFIG) | grep -E "HopperConfig$$"),$(CONFIG))
EXTRA_SIM_PREPROC_DEFINES += +define+NUM_CORES=4 +define+EXT_T_HOPPER EXTRA_SIM_PREPROC_DEFINES += +define+NUM_CORES=4 +define+EXT_T_HOPPER
endif endif
ifeq ($(shell echo $(CONFIG) | grep -E "BlackwellConfig$$"),$(CONFIG)) ifeq ($(shell echo $(CONFIG) | grep -E "BlackwellConfig$$"),$(CONFIG))
EXTRA_SIM_PREPROC_DEFINES += +define+NUM_CORES=4 +define+EXT_T_BLACKWELL EXTRA_SIM_PREPROC_DEFINES += +define+NUM_CORES=1 +define+NUM_WARPS=4 +define+NUM_THREADS=4 +define+NUM_TENSOR_WARPS=2 +define+EXT_T_BLACKWELL
endif endif
ifeq ($(shell echo $(CONFIG) | grep -E "FlashConfig$$"),$(CONFIG)) ifeq ($(shell echo $(CONFIG) | grep -E "FlashConfig$$"),$(CONFIG))
EXTRA_SIM_PREPROC_DEFINES += +define+NUM_CORES=4 EXTRA_SIM_PREPROC_DEFINES += +define+NUM_CORES=4

View File

@@ -6,6 +6,29 @@ package radiance.core
import chisel3._ import chisel3._
import chisel3.util._ import chisel3.util._
object TensorCoreBlackwellFP8Packing {
def fp8Byte(x: UInt, idx: Int): UInt = {
x((idx + 1) * 8 - 1, idx * 8)
}
def selectA(operandA: UInt, k: Int, elemM: UInt, numLanes: Int): UInt = {
if (numLanes == 4) {
Mux(elemM.asBool, fp8Byte(operandA, 8 + k), fp8Byte(operandA, k))
} else {
MuxLookup(elemM, fp8Byte(operandA, k))(Seq(
0.U -> fp8Byte(operandA, k),
1.U -> fp8Byte(operandA, 8 + k),
2.U -> fp8Byte(operandA, 16 + k),
3.U -> fp8Byte(operandA, 24 + k)
))
}
}
def selectB(operandB: UInt, k: Int, elemN: UInt): UInt = {
Mux(elemN.asBool, fp8Byte(operandB, 8 + k), fp8Byte(operandB, k))
}
}
class TensorCoreBlackwell( class TensorCoreBlackwell(
val numWarps: Int, val numWarps: Int,
val numLanes: Int, val numLanes: Int,
@@ -13,8 +36,9 @@ class TensorCoreBlackwell(
val numSourceIds: Int = 16, val numSourceIds: Int = 16,
val numFPRegs: Int = 32 val numFPRegs: Int = 32
) extends Module { ) extends Module {
require(half, "Blackwell MMA currently supports FP16 inputs only") require(half, "Blackwell MMA compatibility flag must remain true; BWGMMA inputs are FP8 E4M3 on this branch")
require(numLanes == 8, "Blackwell MMA currently assumes 8 lanes") require(numLanes == 4 || numLanes == 8,
s"Blackwell MMA currently supports 4 or 8 lanes, got ${numLanes}")
val numWarpBits = log2Ceil(numWarps) val numWarpBits = log2Ceil(numWarps)
val sourceWidth = log2Ceil(numSourceIds) val sourceWidth = log2Ceil(numSourceIds)
@@ -26,11 +50,16 @@ class TensorCoreBlackwell(
val fragOffsetBits = log2Ceil(memWidth / 8) val fragOffsetBits = log2Ceil(memWidth / 8)
val numSets = 4 val numSets = 4
val numAFragsPerSet = 8
val numBGroups = 4 val numBGroups = 4
val numBFragsPerGroup = 2 val numSubsteps = 2
val numMGroups = 4 val mElemsPerFrag = if (numLanes == 4) 2 else 4
val numCFrags = 32 val numMGroups = 16 / mElemsPerFrag
val numAFragsPerMGroup = 2
val numAFragsPerSet = numMGroups * numAFragsPerMGroup
val numBFragsPerSubstep = if (numLanes == 4) 2 else 1
val numBFragsPerGroup = numSubsteps * numBFragsPerSubstep
val numBFragsPerSet = numBGroups * numBFragsPerGroup
val numCFrags = numBGroups * numMGroups * numSubsteps
object Ops { object Ops {
val bwgmma :: bwgmmaWait :: tcgen05Cp :: tcgen05CpWait :: tcgen05Ld :: tcgen05St :: tcgen05Cb :: Nil = Enum(7) val bwgmma :: bwgmmaWait :: tcgen05Cp :: tcgen05CpWait :: tcgen05Ld :: tcgen05St :: tcgen05Cb :: Nil = Enum(7)
@@ -57,13 +86,21 @@ class TensorCoreBlackwell(
// Direct SRAM port for TMEM (no TileLink overhead) // Direct SRAM port for TMEM (no TileLink overhead)
class TmemSramPort extends Bundle { class TmemSramPort extends Bundle {
val wen = Output(Bool()) val aRen = Output(Bool())
val ren = Output(Bool()) val aRready = Input(Bool())
val waddr = Output(UInt(log2Ceil(numWarps * numCFrags * 2).W)) val aRaddr = Output(UInt(log2Ceil(numWarps * numCFrags * 2).W))
val raddr = Output(UInt(log2Ceil(numWarps * numCFrags * 2).W)) val aRdata = Input(UInt(memWidth.W))
val wdata = Output(UInt(memWidth.W))
val mask = Output(UInt(maskWidth.W)) val cRen = Output(Bool())
val rdata = Input(UInt(memWidth.W)) val cRready = Input(Bool())
val cRaddr = Output(UInt(log2Ceil(numWarps * numCFrags * 2).W))
val cRdata = Input(UInt(memWidth.W))
val cWen = Output(Bool())
val cWready = Input(Bool())
val cWaddr = Output(UInt(log2Ceil(numWarps * numCFrags * 2).W))
val cWdata = Output(UInt(memWidth.W))
val cMask = Output(UInt(maskWidth.W))
} }
val io = IO(new Bundle { val io = IO(new Bundle {
@@ -94,7 +131,7 @@ class TensorCoreBlackwell(
val idle, bwLoadAReq, bwLoadAResp, bwLoadBReq, bwLoadBResp, val idle, bwLoadAReq, bwLoadAResp, bwLoadBReq, bwLoadBResp,
bwReadCReq, bwReadCResp, bwCompute, bwDpuResp, bwWriteCReq, bwReadCReq, bwReadCResp, bwCompute, bwDpuResp, bwWriteCReq,
bwWriteCWait, bwDone, cpRead, cpWrite, ldReq, stReq, stWrite, waitWb, bwWriteCWait, bwDone, cpRead, cpWrite, ldReq, stReq, stWrite, waitWb,
cbRead, cbWrite = Value cbRead, cbCapture, cbWrite = Value
} }
val state = RegInit(State.idle) val state = RegInit(State.idle)
@@ -128,10 +165,11 @@ class TensorCoreBlackwell(
base + (fragIndex << fragOffsetBits).asUInt base + (fragIndex << fragOffsetBits).asUInt
} }
val aFragIndex = (setReg << 3) + aIndexReg val aFragIndex = (setReg * numAFragsPerSet.U) + aIndexReg
val bFragIndex = (setReg << 3) + (bGroupReg << 1) + bIndexReg val bFragIndex =
val stepIndex = Cat(bGroupReg, mGroupReg) (setReg * numBFragsPerSet.U) + (bGroupReg * numBFragsPerGroup.U) + bIndexReg
val cFragIndex = (stepIndex << 1) + substepReg val cFragIndex =
(((bGroupReg * numMGroups.U) + mGroupReg) * numSubsteps.U) + substepReg
val aReqAddress = byteAddress(addrAReg, aFragIndex) val aReqAddress = byteAddress(addrAReg, aFragIndex)
val bReqAddress = byteAddress(addrBReg, bFragIndex) val bReqAddress = byteAddress(addrBReg, bFragIndex)
val cReqAddress = byteAddress(addrCReg, cFragIndex) val cReqAddress = byteAddress(addrCReg, cFragIndex)
@@ -147,12 +185,14 @@ class TensorCoreBlackwell(
io.reqA <> reqA io.reqA <> reqA
io.reqB <> reqB io.reqB <> reqB
io.tmemC.wen := false.B io.tmemC.aRen := false.B
io.tmemC.ren := false.B io.tmemC.aRaddr := 0.U
io.tmemC.waddr := 0.U io.tmemC.cRen := false.B
io.tmemC.raddr := 0.U io.tmemC.cRaddr := 0.U
io.tmemC.wdata := 0.U io.tmemC.cWen := false.B
io.tmemC.mask := 0.U io.tmemC.cWaddr := 0.U
io.tmemC.cWdata := 0.U
io.tmemC.cMask := 0.U
val wbValid = RegInit(false.B) val wbValid = RegInit(false.B)
val wbData = Reg(Vec(numLanes, UInt(laneWidth.W))) val wbData = Reg(Vec(numLanes, UInt(laneWidth.W)))
@@ -171,29 +211,26 @@ class TensorCoreBlackwell(
io.initiate.ready := state === State.idle && !wbValid io.initiate.ready := state === State.idle && !wbValid
val operandA = Cat(aBuf((mGroupReg << 1) + 1.U), aBuf(mGroupReg << 1)) val operandA = Cat(aBuf((mGroupReg << 1) + 1.U), aBuf(mGroupReg << 1))
val operandB = bBuf(substepReg) val operandB =
if (numLanes == 4) {
Cat(bBuf((substepReg << 1) + 1.U), bBuf(substepReg << 1))
} else {
bBuf(substepReg)
}
val cWords = cDataReg.asTypeOf(Vec(numLanes, UInt(laneWidth.W))) val cWords = cDataReg.asTypeOf(Vec(numLanes, UInt(laneWidth.W)))
val dpuInValid = WireDefault(false.B) val dpuInValid = WireDefault(false.B)
val dpu = Module(new TensorDotProductUnit( val dpu = Module(new TensorDotProductUnit(
dim = 8, dim = 8,
half = true half = false,
inputType = TensorInputType.FP8E4M3
)) ))
private def halfWord(x: UInt, idx: Int): UInt = { val elemM = if (numLanes == 4) elemReg(0, 0) else elemReg(1, 0)
x((idx + 1) * 16 - 1, idx * 16) val elemN = if (numLanes == 4) elemReg(1) else elemReg(2)
}
val elemM = elemReg(1, 0)
val elemN = elemReg(2)
dpu.io.in.valid := dpuInValid dpu.io.in.valid := dpuInValid
for (k <- 0 until 8) { for (k <- 0 until 8) {
dpu.io.in.bits.a(k) := MuxLookup(elemM, halfWord(operandA, k))(Seq( dpu.io.in.bits.a(k) := TensorCoreBlackwellFP8Packing.selectA(operandA, k, elemM, numLanes)
0.U -> halfWord(operandA, k), dpu.io.in.bits.b(k) := TensorCoreBlackwellFP8Packing.selectB(operandB, k, elemN)
1.U -> halfWord(operandA, 8 + k),
2.U -> halfWord(operandA, 16 + k),
3.U -> halfWord(operandA, 24 + k)
))
dpu.io.in.bits.b(k) := Mux(elemN.asBool, halfWord(operandB, 8 + k), halfWord(operandB, k))
} }
dpu.io.in.bits.c := cWords(elemReg) dpu.io.in.bits.c := cWords(elemReg)
dpu.io.stall := false.B dpu.io.stall := false.B
@@ -229,13 +266,15 @@ class TensorCoreBlackwell(
} }
when(state === State.bwLoadAReq) { when(state === State.bwLoadAReq) {
io.tmemC.ren := true.B io.tmemC.aRen := true.B
io.tmemC.raddr := tmemABase + aFragIndex io.tmemC.aRaddr := tmemABase + aFragIndex
when(io.tmemC.aRready) {
state := State.bwLoadAResp state := State.bwLoadAResp
} }
}
when(state === State.bwLoadAResp) { when(state === State.bwLoadAResp) {
aBuf(aIndexReg) := io.tmemC.rdata aBuf(aIndexReg) := io.tmemC.aRdata
when(aIndexReg === (numAFragsPerSet - 1).U) { when(aIndexReg === (numAFragsPerSet - 1).U) {
bGroupReg := 0.U bGroupReg := 0.U
bIndexReg := 0.U bIndexReg := 0.U
@@ -274,13 +313,15 @@ class TensorCoreBlackwell(
} }
when(state === State.bwReadCReq) { when(state === State.bwReadCReq) {
io.tmemC.ren := true.B io.tmemC.cRen := true.B
io.tmemC.raddr := tmemCBase + cFragIndex io.tmemC.cRaddr := tmemCBase + cFragIndex
when(io.tmemC.cRready) {
state := State.bwReadCResp state := State.bwReadCResp
} }
}
when(state === State.bwReadCResp) { when(state === State.bwReadCResp) {
cDataReg := io.tmemC.rdata cDataReg := io.tmemC.cRdata
elemReg := 0.U elemReg := 0.U
state := State.bwCompute state := State.bwCompute
} }
@@ -303,10 +344,11 @@ class TensorCoreBlackwell(
} }
when(state === State.bwWriteCReq) { when(state === State.bwWriteCReq) {
io.tmemC.wen := true.B io.tmemC.cWen := true.B
io.tmemC.waddr := tmemCBase + cFragIndex io.tmemC.cWaddr := tmemCBase + cFragIndex
io.tmemC.wdata := mmaDataReg.asUInt io.tmemC.cWdata := mmaDataReg.asUInt
io.tmemC.mask := Fill(maskWidth, 1.U(1.W)) io.tmemC.cMask := Fill(maskWidth, 1.U(1.W))
when(io.tmemC.cWready) {
when(substepReg === 0.U) { when(substepReg === 0.U) {
substepReg := 1.U substepReg := 1.U
state := State.bwReadCReq state := State.bwReadCReq
@@ -333,6 +375,7 @@ class TensorCoreBlackwell(
state := State.bwWriteCWait state := State.bwWriteCWait
} }
} }
}
when(state === State.bwWriteCWait) { when(state === State.bwWriteCWait) {
when(waitCounter === 0.U) { when(waitCounter === 0.U) {
@@ -361,24 +404,26 @@ class TensorCoreBlackwell(
} }
when(state === State.cpWrite) { when(state === State.cpWrite) {
io.respA.ready := true.B io.respA.ready := io.tmemC.cWready
io.tmemC.cWen := io.respA.valid
io.tmemC.cWaddr := (addrAReg >> fragOffsetBits.U).asUInt
io.tmemC.cWdata := io.respA.bits.data
io.tmemC.cMask := Fill(maskWidth, 1.U(1.W))
when(io.respA.fire) { when(io.respA.fire) {
io.tmemC.wen := true.B
io.tmemC.waddr := (addrAReg >> fragOffsetBits.U).asUInt
io.tmemC.wdata := io.respA.bits.data
io.tmemC.mask := Fill(maskWidth, 1.U(1.W))
state := State.idle state := State.idle
} }
} }
when(state === State.ldReq) { when(state === State.ldReq) {
io.tmemC.ren := true.B io.tmemC.cRen := true.B
io.tmemC.raddr := (addrAReg >> fragOffsetBits.U).asUInt io.tmemC.cRaddr := (addrAReg >> fragOffsetBits.U).asUInt
when(io.tmemC.cRready) {
state := State.waitWb state := State.waitWb
} }
}
when(state === State.waitWb && opReg === Ops.tcgen05Ld) { when(state === State.waitWb && opReg === Ops.tcgen05Ld) {
wbData := io.tmemC.rdata.asTypeOf(Vec(numLanes, UInt(laneWidth.W))) wbData := io.tmemC.cRdata.asTypeOf(Vec(numLanes, UInt(laneWidth.W)))
wbValid := true.B wbValid := true.B
state := State.idle state := State.idle
} }
@@ -389,16 +434,25 @@ class TensorCoreBlackwell(
} }
when(state === State.stWrite) { when(state === State.stWrite) {
io.tmemC.wen := true.B io.tmemC.cWen := true.B
io.tmemC.waddr := (addrAReg >> fragOffsetBits.U).asUInt io.tmemC.cWaddr := (addrAReg >> fragOffsetBits.U).asUInt
io.tmemC.wdata := io.respC io.tmemC.cWdata := io.respC
io.tmemC.mask := Fill(maskWidth, 1.U(1.W)) io.tmemC.cMask := Fill(maskWidth, 1.U(1.W))
when(io.tmemC.cWready) {
state := State.idle state := State.idle
} }
}
when(state === State.cbRead) { when(state === State.cbRead) {
io.tmemC.ren := true.B io.tmemC.cRen := true.B
io.tmemC.raddr := (addrAReg >> fragOffsetBits.U).asUInt io.tmemC.cRaddr := (addrAReg >> fragOffsetBits.U).asUInt
when(io.tmemC.cRready) {
state := State.cbCapture
}
}
when(state === State.cbCapture) {
cDataReg := io.tmemC.cRdata
state := State.cbWrite state := State.cbWrite
} }
@@ -408,7 +462,7 @@ class TensorCoreBlackwell(
reqA.bits.byteen := Fill(maskWidth, 1.U(1.W)) reqA.bits.byteen := Fill(maskWidth, 1.U(1.W))
reqA.bits.address := addrBReg reqA.bits.address := addrBReg
reqA.bits.source := sourceCounter reqA.bits.source := sourceCounter
reqA.bits.data := io.tmemC.rdata reqA.bits.data := cDataReg
when(reqA.fire) { when(reqA.fire) {
bumpSource() bumpSource()
state := State.waitWb state := State.waitWb

View File

@@ -7,17 +7,116 @@ import chisel3._
import chisel3.util._ import chisel3.util._
import freechips.rocketchip.tile import freechips.rocketchip.tile
object FP8E4M3 {
private val Bias = 7
private def decodeToFloat(bits: Int): Float = {
val sign = (bits >> 7) & 0x1
val exp = (bits >> 3) & 0xf
val frac = bits & 0x7
val magnitude =
if (exp == 0) {
if (frac == 0) 0.0
else (frac.toDouble / 8.0) * Math.pow(2.0, 1 - Bias)
} else {
(1.0 + frac.toDouble / 8.0) * Math.pow(2.0, exp - Bias)
}
val value = if (sign == 1) -magnitude else magnitude
value.toFloat
}
private def fp32Bits(bits: Int): BigInt = {
BigInt(java.lang.Float.floatToRawIntBits(decodeToFloat(bits)).toLong & 0xffffffffL)
}
def toFloat32(x: UInt): UInt = {
MuxLookup(x, 0.U(32.W))((0 until 256).map { bits =>
bits.U(8.W) -> fp32Bits(bits).U(32.W)
})
}
}
object FP8E4M3MulToFloat32 {
private val Bias = 7
def apply(a: UInt, b: UInt): UInt = {
val sign = a(7) ^ b(7)
val expA = a(6, 3)
val expB = b(6, 3)
val fracA = a(2, 0)
val fracB = b(2, 0)
val zeroA = expA === 0.U && fracA === 0.U
val zeroB = expB === 0.U && fracB === 0.U
val isZero = zeroA || zeroB
val sigA = Mux(expA === 0.U, Cat(0.U(1.W), fracA), Cat(1.U(1.W), fracA))
val sigB = Mux(expB === 0.U, Cat(0.U(1.W), fracB), Cat(1.U(1.W), fracB))
val prodSig = sigA * sigB
val scaleA = Wire(SInt(6.W))
val scaleB = Wire(SInt(6.W))
scaleA := Mux(expA === 0.U, -9.S(6.W), expA.zext - (Bias + 3).S(6.W))
scaleB := Mux(expB === 0.U, -9.S(6.W), expB.zext - (Bias + 3).S(6.W))
val msb = Wire(UInt(3.W))
when(prodSig(7)) {
msb := 7.U
}.elsewhen(prodSig(6)) {
msb := 6.U
}.elsewhen(prodSig(5)) {
msb := 5.U
}.elsewhen(prodSig(4)) {
msb := 4.U
}.elsewhen(prodSig(3)) {
msb := 3.U
}.elsewhen(prodSig(2)) {
msb := 2.U
}.elsewhen(prodSig(1)) {
msb := 1.U
}.otherwise {
msb := 0.U
}
val normalized = (prodSig << (7.U - msb))(7, 0)
val exponent = (scaleA + scaleB + msb.zext + 127.S(10.W)).asUInt(7, 0)
val fraction = Cat(normalized(6, 0), 0.U(16.W))
Mux(isZero, Cat(sign, 0.U(31.W)), Cat(sign, exponent, fraction))
}
}
object TensorInputType extends Enumeration {
val FP16, FP32, FP8E4M3 = Value
def fromHalf(half: Boolean): Value = {
if (half) FP16 else FP32
}
}
// Implements the four-element dot product (FEDP) unit in Volta Tensor Cores. // Implements the four-element dot product (FEDP) unit in Volta Tensor Cores.
// `half`: if True, generate fp16 MACs; if False fp32. // `half`: if True, generate fp16 MACs; if False fp32.
class TensorDotProductUnit( class TensorDotProductUnit(
val dim: Int = 4, val dim: Int,
val half: Boolean val half: Boolean,
val inputType: TensorInputType.Value
) extends Module with tile.HasFPUParameters { ) extends Module with tile.HasFPUParameters {
val tIn = if (half) tile.FType.H else tile.FType.S def this(dim: Int = 4, half: Boolean) = {
this(dim, half, TensorInputType.fromHalf(half))
}
val tIn = inputType match {
case TensorInputType.FP16 => tile.FType.H
case TensorInputType.FP32 => tile.FType.S
case TensorInputType.FP8E4M3 => tile.FType.S
}
// output datatype fixed to single-precision // output datatype fixed to single-precision
val tOut = tile.FType.S val tOut = tile.FType.S
val inFLen = tIn.ieeeWidth val inFLen = inputType match {
case TensorInputType.FP8E4M3 => 8
case _ => tIn.ieeeWidth
}
val outFLen = tOut.ieeeWidth val outFLen = tOut.ieeeWidth
val fLen = outFLen // needed for HasFPUParameters val fLen = outFLen // needed for HasFPUParameters
val minFLen = 16 // fp16 val minFLen = 16 // fp16
@@ -40,9 +139,27 @@ class TensorDotProductUnit(
// [IEEE] -> recode() -> unbox() -> [Hardfloat] -> box() -> ieee() -> [IEEE] // [IEEE] -> recode() -> unbox() -> [Hardfloat] -> box() -> ieee() -> [IEEE]
// make sure recoding/uncoding happens only at the edge, not at every // make sure recoding/uncoding happens only at the edge, not at every
// pipeline stage inside the dpu // pipeline stage inside the dpu
val tag = if (half) H else S val tag = inputType match {
val in1 = io.in.bits.a.map(x => unbox(recode(x, tag), tag, Some(tIn))) case TensorInputType.FP16 => H
val in2 = io.in.bits.b.map(x => unbox(recode(x, tag), tag, Some(tIn))) case TensorInputType.FP32 => S
case TensorInputType.FP8E4M3 => S
}
if (inputType == TensorInputType.FP8E4M3) {
val dpu = Module(new DotProductPipeFP8E4M3(dim))
dpu.io.in.valid := io.in.valid
dpu.io.in.bits.a := io.in.bits.a
dpu.io.in.bits.b := io.in.bits.b
dpu.io.in.bits.c := io.in.bits.c
dpu.io.stall := io.stall
io.out.valid := dpu.io.out.valid
io.out.bits.data := dpu.io.out.bits.data
} else {
def recodeInput(x: Bits): UInt = {
unbox(recode(x.asUInt, tag), tag, Some(tIn))
}
val in1 = io.in.bits.a.map(recodeInput)
val in2 = io.in.bits.b.map(recodeInput)
val in3 = unbox(recode(io.in.bits.c, S), S, Some(tOut)) val in3 = unbox(recode(io.in.bits.c, S), S, Some(tOut))
val dpu = Module(new DotProductPipe(dim, tIn, tOut)) val dpu = Module(new DotProductPipe(dim, tIn, tOut))
@@ -54,6 +171,7 @@ class TensorDotProductUnit(
io.out.valid := dpu.io.out.valid io.out.valid := dpu.io.out.valid
io.out.bits.data := ieee(box(dpu.io.out.bits.data, S)) io.out.bits.data := ieee(box(dpu.io.out.bits.data, S))
}
} }
// An implementation of chisel3.util.Pipe that supports stalls. // An implementation of chisel3.util.Pipe that supports stalls.
@@ -236,6 +354,89 @@ class DotProductPipe(dim: Int, inputType: tile.FType, outputType: tile.FType) ex
io.out.bits.data := accStageOut.bits io.out.bits.data := accStageOut.bits
} }
class DotProductPipeFP8E4M3(dim: Int) extends Module with tile.HasFPUParameters {
val tOut = tile.FType.S
val outExpWidth = tOut.exp
val outSigWidth = tOut.sig
val recOutFLen = outExpWidth + outSigWidth + 1
val fLen = tOut.ieeeWidth
val minFLen = 16
def xLen = 32
val io = IO(new Bundle {
val in = Flipped(Valid(new Bundle {
val a = Vec(dim, Bits(8.W))
val b = Vec(dim, Bits(8.W))
val c = Bits(32.W)
}))
val stall = Input(Bool())
val out = Valid(new Bundle {
val data = Bits(32.W)
})
})
val productRecoded = io.in.bits.a.zip(io.in.bits.b).map { case (a, b) =>
unbox(recode(FP8E4M3MulToFloat32(a.asUInt, b.asUInt), S), S, Some(tOut))
}
val inC = unbox(recode(io.in.bits.c, S), S, Some(tOut))
val productStageOut = StallingPipe(io.stall, io.in.valid, VecInit(productRecoded))
val productStageC = StallingPipe(io.stall, io.in.valid, inC)
val log2Dim = log2Ceil(dim)
require(dim == (1 << log2Dim), s"dim (${dim}) is not power of two!")
val interim = (log2Dim to 0 by -1).map { i =>
Wire(Valid(Vec(1 << i, Bits(recOutFLen.W))))
}
val interimC = (log2Dim to 0 by -1).map(_ => Wire(Valid(Bits(recOutFLen.W))))
interim(0) := productStageOut
interimC(0) := productStageC
val (addStageOut, addStageC) = (interim zip interimC).reduce {
(inputsAndC, outputsAndC) => {
val (inputs, inC) = inputsAndC
val (outputs, outC) = outputsAndC
require(inputs.bits.length == 2 * outputs.bits.length)
val thisDim = inputs.bits.length
val adders = Seq.fill(thisDim / 2)(
Module(new hardfloat.AddRecFN(outExpWidth, outSigWidth))
)
val addOuts = adders.zipWithIndex.map { case (a, i) =>
a.io.subOp := 0.U
a.io.a := inputs.bits(2 * i + 0)
a.io.b := inputs.bits(2 * i + 1)
a.io.roundingMode := hardfloat.consts.round_near_even
a.io.detectTininess := hardfloat.consts.tininess_afterRounding
a.io.out
}
outputs := StallingPipe(io.stall, inputs.valid, VecInit(addOuts))
outC := StallingPipe(io.stall, inputs.valid, inC.bits)
when(inputs.valid =/= inC.valid) {
printf("WARN: DotProductPipeFP8E4M3 input/C valid mismatch: inputs=%d c=%d\n",
inputs.valid, inC.valid)
}
(outputs, outC)
}
}
require(addStageOut.bits.length == 1)
val acc = Module(new hardfloat.AddRecFN(outExpWidth, outSigWidth))
acc.io.subOp := 0.U
acc.io.a := addStageOut.bits(0)
acc.io.b := addStageC.bits
acc.io.roundingMode := hardfloat.consts.round_near_even
acc.io.detectTininess := hardfloat.consts.tininess_afterRounding
val accStageOut = StallingPipe(io.stall, addStageOut.valid, acc.io.out)
io.out.valid := accStageOut.valid
io.out.bits.data := ieee(box(accStageOut.bits, S))
}
class MulAddRecFNPipe(latency: Int, expWidth: Int, sigWidth: Int) extends Module { class MulAddRecFNPipe(latency: Int, expWidth: Int, sigWidth: Int) extends Module {
require(latency <= 2) require(latency <= 2)

View File

@@ -51,6 +51,7 @@ class WithRadianceCores(
tensorCoreFP16: Boolean, tensorCoreFP16: Boolean,
tensorCoreDecoupled: Boolean, tensorCoreDecoupled: Boolean,
tensorCoreBlackwell: Boolean, tensorCoreBlackwell: Boolean,
numTensorWarps: Int,
startupAddress: BigInt, startupAddress: BigInt,
useVxCache: Boolean useVxCache: Boolean
) extends Config((site, _, up) => { ) extends Config((site, _, up) => {
@@ -63,6 +64,7 @@ class WithRadianceCores(
tensorCoreFP16 = tensorCoreFP16, tensorCoreFP16 = tensorCoreFP16,
tensorCoreDecoupled = tensorCoreDecoupled, tensorCoreDecoupled = tensorCoreDecoupled,
tensorCoreBlackwell = tensorCoreBlackwell, tensorCoreBlackwell = tensorCoreBlackwell,
numTensorWarps = numTensorWarps,
startupAddress = startupAddress startupAddress = startupAddress
), ),
btb = None, btb = None,
@@ -101,6 +103,7 @@ class WithRadianceCores(
def this(n: Int, location: HierarchicalLocation = InSubsystem, def this(n: Int, location: HierarchicalLocation = InSubsystem,
tensorCoreFP16: Boolean = false, tensorCoreDecoupled: Boolean = false, tensorCoreFP16: Boolean = false, tensorCoreDecoupled: Boolean = false,
tensorCoreBlackwell: Boolean = false, tensorCoreBlackwell: Boolean = false,
numTensorWarps: Int = 4,
startupAddress: BigInt = BigInt("10100", 16), startupAddress: BigInt = BigInt("10100", 16),
useVxCache: Boolean = false) useVxCache: Boolean = false)
= this(n, location, RocketCrossingParams( = this(n, location, RocketCrossingParams(
@@ -110,7 +113,7 @@ class WithRadianceCores(
case InSubsystem => CBUS case InSubsystem => CBUS
case InCluster(clusterId) => CCBUS(clusterId) case InCluster(clusterId) => CCBUS(clusterId)
} }
), tensorCoreFP16, tensorCoreDecoupled, tensorCoreBlackwell, startupAddress, useVxCache) ), tensorCoreFP16, tensorCoreDecoupled, tensorCoreBlackwell, numTensorWarps, startupAddress, useVxCache)
} }
class WithBlackwellTensorCore(location: HierarchicalLocation = InSubsystem) extends Config((site, _, up) => { class WithBlackwellTensorCore(location: HierarchicalLocation = InSubsystem) extends Config((site, _, up) => {

View File

@@ -21,7 +21,7 @@ import midas.targetutils.SynthesizePrintf
import org.chipsalliance.cde.config._ import org.chipsalliance.cde.config._
import radiance.core._ import radiance.core._
import radiance.memory._ import radiance.memory._
import radiance.subsystem.{GPUMemParams, GPUMemory, RadianceSimArgs} import radiance.subsystem.{GPUMemParams, GPUMemory, RadianceSharedMemKey, RadianceSimArgs}
/** For determining radiance core id. This may be different from /** For determining radiance core id. This may be different from
* RadianceTileParams.tileId, when a cluster contains non-core tiles */ * RadianceTileParams.tileId, when a cluster contains non-core tiles */
@@ -102,6 +102,7 @@ case class VortexCoreParams(
tensorCoreFP16: Boolean = false, // FP16 if true, FP32 if false tensorCoreFP16: Boolean = false, // FP16 if true, FP32 if false
tensorCoreDecoupled: Boolean = false, // hopper-style SMEM operand decoupling tensorCoreDecoupled: Boolean = false, // hopper-style SMEM operand decoupling
tensorCoreBlackwell: Boolean = false, // blackwell-style TMEM + SMEM tensor core tensorCoreBlackwell: Boolean = false, // blackwell-style TMEM + SMEM tensor core
numTensorWarps: Int = 4,
startupAddress: BigInt = BigInt("10100", 16), // initial warp PC programmed through startup DCRs startupAddress: BigInt = BigInt("10100", 16), // initial warp PC programmed through startup DCRs
debugROB: Boolean = false, // if enabled, uses a C++ debug ROB to generate trace-with-wdata debugROB: Boolean = false, // if enabled, uses a C++ debug ROB to generate trace-with-wdata
haveCease: Boolean = true, // non-standard CEASE instruction haveCease: Boolean = true, // non-standard CEASE instruction
@@ -210,7 +211,9 @@ class RadianceTile private (
case Some(false) => 1 case Some(false) => 1
case None => 1 case None => 1
} }
val imemTagWidth = UUID_WIDTH + NW_WIDTH // Must match VX_gpu_pkg.sv: ICACHE_TAG_WIDTH = domain + UUID + wid.
val imemDomainWidth = 1
val imemTagWidth = imemDomainWidth + UUID_WIDTH + NW_WIDTH
require(numWarps >= numLsuLanes, require(numWarps >= numLsuLanes,
s"Vortex core requires numWarps (${numWarps}) >= numLsuLanes (${numLsuLanes})") s"Vortex core requires numWarps (${numWarps}) >= numLsuLanes (${numLsuLanes})")
@@ -285,18 +288,37 @@ class RadianceTile private (
) )
} }
val tcSmemSize = 32 val tcSmemSize = numLsuLanes * 4
val tcSmemLineSize = p(RadianceSharedMemKey)
.map(k => k.numWords * k.wordSize)
.getOrElse(tcSmemSize)
val tcSmemClientMaxSize =
if (radianceParams.core.tensorCoreBlackwell) math.max(tcSmemSize, tcSmemLineSize) else tcSmemSize
val numTensorWarps = radianceParams.core.numTensorWarps
val numScalarWarps = numWarps - numTensorWarps
require(numTensorWarps > 0 && numTensorWarps < numWarps,
s"Wu requires 0 < numTensorWarps (${numTensorWarps}) < numWarps (${numWarps})")
val numTensorCores = if (radianceParams.core.tensorCoreBlackwell) numTensorWarps else 1
if (radianceParams.core.tensorCoreBlackwell) {
require(numCoreLanes == numLsuLanes,
s"Wu Blackwell binding requires matching core lanes (${numCoreLanes}) and memory lanes (${numLsuLanes})")
require(numLsuLanes == 4 || numLsuLanes == 8,
s"Wu Blackwell binding supports 4 or 8 lanes, got ${numLsuLanes}")
require(numTensorCores == numTensorWarps, "Wu Blackwell binding requires one Tensor Core per Tensor warp")
require(isPow2(tcSmemLineSize) && tcSmemLineSize >= tcSmemSize && (tcSmemLineSize % tcSmemSize) == 0,
s"Wu Blackwell SMEM line size (${tcSmemLineSize}) must be a power-of-two multiple of TC fragment size (${tcSmemSize})")
}
val tensorUsesAsyncMem = radianceParams.core.tensorCoreDecoupled || radianceParams.core.tensorCoreBlackwell val tensorUsesAsyncMem = radianceParams.core.tensorCoreDecoupled || radianceParams.core.tensorCoreBlackwell
val tcSmemNodeCount = if (radianceParams.core.tensorCoreDecoupled) 2 else if (radianceParams.core.tensorCoreBlackwell) 1 else 0 val tcSmemNodeCount = if (radianceParams.core.tensorCoreDecoupled) 2 else if (radianceParams.core.tensorCoreBlackwell) numTensorCores else 0
val tcSmemNodes = Seq.tabulate(tcSmemNodeCount) { i => val tcSmemNodes = Seq.tabulate(tcSmemNodeCount) { i =>
TLClientNode(Seq(TLMasterPortParameters.v2( TLClientNode(Seq(TLMasterPortParameters.v2(
masters = Seq(TLMasterParameters.v2( masters = Seq(TLMasterParameters.v2(
name = s"rad_tc_${radianceParams.coreId}_$i", name = s"rad_tc_${radianceParams.coreId}_$i",
sourceId = IdRange(0, 1 << smemSourceWidth), sourceId = IdRange(0, 1 << smemSourceWidth),
supports = TLSlaveToMasterTransferSizes( supports = TLSlaveToMasterTransferSizes(
probe = TransferSizes(1, tcSmemSize), probe = TransferSizes(1, tcSmemClientMaxSize),
get = TransferSizes(1, tcSmemSize), get = TransferSizes(1, tcSmemClientMaxSize),
putFull = TransferSizes(1, tcSmemSize), putFull = TransferSizes(1, tcSmemClientMaxSize),
), ),
requestFifo = true requestFifo = true
)) ))
@@ -304,10 +326,11 @@ class RadianceTile private (
} }
// For Blackwell, tcSmemNodes accesses SMEM (bwgmma B operand) // For Blackwell, tcSmemNodes accesses SMEM (bwgmma B operand)
// tcGmemNode provides global memory access for cp (global→tmem) and cb (tmem→global) // tcGmemNodes provide global memory access for cp (global→tmem) and cb (tmem→global)
val tcGmemNode = if (radianceParams.core.tensorCoreBlackwell) Some(TLClientNode(Seq( val tcGmemNodes = if (radianceParams.core.tensorCoreBlackwell) {
TLMasterPortParameters.v2(masters = Seq(TLMasterParameters.v2( Seq.tabulate(numTensorCores) { i =>
name = s"rad_tc_gmem_${radianceParams.coreId}", TLClientNode(Seq(TLMasterPortParameters.v2(masters = Seq(TLMasterParameters.v2(
name = s"rad_tc_gmem_${radianceParams.coreId}_$i",
sourceId = IdRange(0, 1 << dmemSourceWidth), sourceId = IdRange(0, 1 << dmemSourceWidth),
supports = TLSlaveToMasterTransferSizes( supports = TLSlaveToMasterTransferSizes(
probe = TransferSizes(1, tcSmemSize), probe = TransferSizes(1, tcSmemSize),
@@ -315,8 +338,9 @@ class RadianceTile private (
putFull = TransferSizes(1, tcSmemSize), putFull = TransferSizes(1, tcSmemSize),
), ),
requestFifo = true requestFifo = true
))) )))))
))) else None }
} else Seq.empty
// combine outgoing per-lane dmemNode into 1 idenity node // combine outgoing per-lane dmemNode into 1 idenity node
// //
@@ -406,7 +430,7 @@ class RadianceTile private (
// imemNodes.foreach { tlMasterXbar.node := TLWidthWidget(4) := _ } // imemNodes.foreach { tlMasterXbar.node := TLWidthWidget(4) := _ }
tlMasterXbar.node :=* AddressOrNode(base) :=* icacheNode tlMasterXbar.node :=* AddressOrNode(base) :=* icacheNode
tlMasterXbar.node :=* AddressOrNode(base) :=* dcacheNode tlMasterXbar.node :=* AddressOrNode(base) :=* dcacheNode
tcGmemNode.foreach { n => tlMasterXbar.node := AddressOrNode(base) := n } tcGmemNodes.foreach { n => tlMasterXbar.node := AddressOrNode(base) := n }
} }
/* below are copied from rocket */ /* below are copied from rocket */
@@ -822,86 +846,298 @@ class RadianceTileModuleImp(outer: RadianceTile)
core.io.tc_d_bits_data := DontCare core.io.tc_d_bits_data := DontCare
core.io.tc_d_bits_tag := DontCare core.io.tc_d_bits_tag := DontCare
} }
core.io.tc_tmem_A_rready := DontCare
core.io.tc_tmem_A_rdata := DontCare
core.io.tc_tmem_C_rready := DontCare
core.io.tc_tmem_C_rdata := DontCare
core.io.tc_tmem_C_wready := DontCare
core.io.sc_tmem_rready := DontCare
core.io.sc_tmem_rdata := DontCare
core.io.sc_tmem_wready := DontCare
} }
def connectTensorBlackwell = { def connectTensorBlackwell = {
if (outer.radianceParams.core.tensorCoreBlackwell) { if (outer.radianceParams.core.tensorCoreBlackwell) {
require(outer.tcSmemNodes.nonEmpty) require(outer.tcSmemNodes.nonEmpty)
require(outer.tcSmemNodes.length == outer.numTensorCores)
require(outer.tcGmemNodes.length == outer.numTensorCores)
// TMEM C matrix: direct SRAM (no TileLink), connected via VortexCore IO val nTC = outer.numTensorCores
// Each warp needs 2 tiles (A + C), each tile = 32 frags × 32B = 1KB val tcPorts = 3
val tmemDepth = outer.numWarps * outer.tcSmemSize * 2 // numWarps × 64 rows val tcCoreDataBits = 32 * 8
val tmem = Module(new radiance.memory.TwoReadOneWriteSyncMem( val tcDataBits = outer.tcSmemSize * 8
tmemDepth, UInt((outer.tcSmemSize * 8).W))) val tcSmemLineBits = outer.tcSmemLineSize * 8
tmem.io.ren0 := core.io.tc_tmem_C_ren val tmemAddrBits = 9
tmem.io.raddr0 := core.io.tc_tmem_C_raddr val tmemDataBits = tcDataBits
core.io.tc_tmem_C_rdata := tmem.io.rdata0 val tmemMaskBits = outer.tcSmemSize
tmem.io.ren1 := false.B val tcTlSize = log2Ceil(outer.tcSmemSize)
tmem.io.raddr1 := 0.U val tcSmemLineTlSize = log2Ceil(outer.tcSmemLineSize)
tmem.io.wen := core.io.tc_tmem_C_wen
tmem.io.waddr := core.io.tc_tmem_C_waddr
tmem.io.wdata := core.io.tc_tmem_C_wdata
tmem.io.mask := core.io.tc_tmem_C_mask
// smem_B (port 2): Global Memory via TileLink def slice(u: UInt, width: Int, idx: Int): UInt = u(width * (idx + 1) - 1, width * idx)
val smemBBundle = new { def port(tc: Int, p: Int): Int = tc * tcPorts + p
val addr = core.io.tc_a_bits_address(95, 64) def padToCoreData(u: UInt): UInt = {
val tag = core.io.tc_a_bits_tag(8 + outer.tensorTagWidth - 1, 8) if (u.getWidth == tcCoreDataBits) u else Cat(0.U((tcCoreDataBits - u.getWidth).W), u)
val write = core.io.tc_a_bits_write(2)
val mask = core.io.tc_a_bits_mask(95, 64)
val data = core.io.tc_a_bits_data(767, 512)
val aValid = core.io.tc_a_valid(2)
val dReady = core.io.tc_d_ready(2)
} }
val client = outer.tcSmemNodes.head.out.head
val tcAReady = Wire(Vec(nTC * tcPorts, Bool()))
val tcDValid = Wire(Vec(nTC * tcPorts, Bool()))
val tcDData = Wire(Vec(nTC * tcPorts, UInt(tcCoreDataBits.W)))
val tcDTag = Wire(Vec(nTC * tcPorts, UInt(outer.tensorTagWidth.W)))
tcAReady.foreach(_ := false.B)
tcDValid.foreach(_ := false.B)
tcDData.foreach(_ := 0.U)
tcDTag.foreach(_ := 0.U)
// TMEM matrix: four banked 2R1W SRAMs. Tensor A/C reads and scalar
// reads can proceed together when bank placement avoids conflicts.
// Each warp owns 2KB: A tile and C tile are 1KB each. The row count
// scales with the physical fragment width (16B for 4 lanes, 32B for 8).
val tmemBytesPerWarp = 2048
val tmemDepth = outer.numWarps * (tmemBytesPerWarp / outer.tcSmemSize)
val tmemBanks = 4
val tmemBankBits = log2Ceil(tmemBanks)
val tmemBankDepth = tmemDepth / tmemBanks
require(isPow2(tmemBanks))
require(tmemDepth % tmemBanks == 0)
val tmem = Seq.fill(tmemBanks) {
Module(new radiance.memory.TwoReadOneWriteSyncMem(
tmemBankDepth, UInt((outer.tcSmemSize * 8).W)))
}
class TmemReadReq extends Bundle {
val addr = UInt(tmemAddrBits.W)
val src = UInt(2.W)
val tc = UInt(log2Ceil(nTC max 2).W)
}
class TmemWriteReq extends Bundle {
val addr = UInt(tmemAddrBits.W)
val data = UInt(tmemDataBits.W)
val mask = UInt(tmemMaskBits.W)
val src = UInt(1.W)
val tc = UInt(log2Ceil(nTC max 2).W)
}
def bank(addr: UInt): UInt = addr(tmemBankBits - 1, 0)
def row(addr: UInt): UInt = addr(tmemAddrBits - 1, tmemBankBits)
val aReady = Wire(Vec(nTC, Bool()))
val cReady = Wire(Vec(nTC, Bool()))
val wReady = Wire(Vec(nTC, Bool()))
val scReadReady = Wire(Bool())
val scWriteReady = Wire(Bool())
aReady.foreach(_ := false.B)
cReady.foreach(_ := false.B)
wReady.foreach(_ := false.B)
scReadReady := false.B
scWriteReady := false.B
val read0Grant = Wire(Vec(tmemBanks, new TmemReadReq))
val read1Grant = Wire(Vec(tmemBanks, new TmemReadReq))
val read0Valid = Wire(Vec(tmemBanks, Bool()))
val read1Valid = Wire(Vec(tmemBanks, Bool()))
val writeGrant = Wire(Vec(tmemBanks, new TmemWriteReq))
val writeValid = Wire(Vec(tmemBanks, Bool()))
read0Grant.foreach(_ := 0.U.asTypeOf(new TmemReadReq))
read1Grant.foreach(_ := 0.U.asTypeOf(new TmemReadReq))
read0Valid.foreach(_ := false.B)
read1Valid.foreach(_ := false.B)
writeGrant.foreach(_ := 0.U.asTypeOf(new TmemWriteReq))
writeValid.foreach(_ := false.B)
(0 until tmemBanks).foreach { b =>
val requests = (0 until nTC).flatMap { tc =>
val aAddr = slice(core.io.tc_tmem_A_raddr, tmemAddrBits, tc)
val cAddr = slice(core.io.tc_tmem_C_raddr, tmemAddrBits, tc)
Seq(
(core.io.tc_tmem_A_ren(tc).asBool && bank(aAddr) === b.U, aAddr, 0.U(2.W), tc.U),
(core.io.tc_tmem_C_ren(tc).asBool && bank(cAddr) === b.U, cAddr, 1.U(2.W), tc.U)
)
} ++ Seq(
(core.io.sc_tmem_ren.asBool && bank(core.io.sc_tmem_raddr) === b.U,
core.io.sc_tmem_raddr, 2.U(2.W), 0.U)
)
var used0 = false.B
var used1 = false.B
requests.foreach { case (valid, addr, src, tc) =>
val grant0 = valid && !used0
val grant1 = valid && used0 && !used1
when(grant0) {
read0Grant(b).addr := addr
read0Grant(b).src := src
read0Grant(b).tc := tc
}
when(grant1) {
read1Grant(b).addr := addr
read1Grant(b).src := src
read1Grant(b).tc := tc
}
used0 = used0 || grant0
used1 = used1 || grant1
when(grant0 || grant1) {
when(src === 0.U) { aReady(tc) := true.B }
when(src === 1.U) { cReady(tc) := true.B }
when(src === 2.U) { scReadReady := true.B }
}
}
read0Valid(b) := used0
read1Valid(b) := used1
var writeUsed = false.B
(0 until nTC).foreach { tc =>
val addr = slice(core.io.tc_tmem_C_waddr, tmemAddrBits, tc)
val valid = core.io.tc_tmem_C_wen(tc).asBool && bank(addr) === b.U
val grant = valid && !writeUsed
when(grant) {
writeValid(b) := true.B
writeGrant(b).addr := addr
writeGrant(b).data := slice(core.io.tc_tmem_C_wdata, tmemDataBits, tc)
writeGrant(b).mask := slice(core.io.tc_tmem_C_mask, tmemMaskBits, tc)
writeGrant(b).src := 0.U
writeGrant(b).tc := tc.U
wReady(tc) := true.B
}
writeUsed = writeUsed || grant
}
val scWValid = core.io.sc_tmem_wen.asBool && bank(core.io.sc_tmem_waddr) === b.U
val scWGrant = scWValid && !writeUsed
when(scWGrant) {
writeValid(b) := true.B
writeGrant(b).addr := core.io.sc_tmem_waddr
writeGrant(b).data := core.io.sc_tmem_wdata
writeGrant(b).mask := core.io.sc_tmem_mask
writeGrant(b).src := 1.U
writeGrant(b).tc := 0.U
scWriteReady := true.B
}
tmem(b).io.ren0 := read0Valid(b)
tmem(b).io.raddr0 := row(read0Grant(b).addr)
tmem(b).io.ren1 := read1Valid(b)
tmem(b).io.raddr1 := row(read1Grant(b).addr)
tmem(b).io.wen := writeValid(b)
tmem(b).io.waddr := row(writeGrant(b).addr)
tmem(b).io.wdata := writeGrant(b).data
tmem(b).io.mask := writeGrant(b).mask
}
val read0GrantReg = RegNext(read0Grant)
val read1GrantReg = RegNext(read1Grant)
val read0ValidReg = RegNext(read0Valid)
val read1ValidReg = RegNext(read1Valid)
core.io.tc_tmem_A_rready := aReady.asUInt
core.io.tc_tmem_C_rready := cReady.asUInt
core.io.tc_tmem_C_wready := wReady.asUInt
core.io.sc_tmem_rready := scReadReady.asUInt
core.io.sc_tmem_wready := scWriteReady.asUInt
core.io.tc_tmem_A_rdata := VecInit((0 until nTC).map { tc =>
VecInit((0 until tmemBanks).map { b =>
Mux(read0ValidReg(b) && read0GrantReg(b).src === 0.U && read0GrantReg(b).tc === tc.U, tmem(b).io.rdata0,
Mux(read1ValidReg(b) && read1GrantReg(b).src === 0.U && read1GrantReg(b).tc === tc.U, tmem(b).io.rdata1, 0.U(tmemDataBits.W)))
}).reduce(_ | _)
}).asUInt
core.io.tc_tmem_C_rdata := VecInit((0 until nTC).map { tc =>
VecInit((0 until tmemBanks).map { b =>
Mux(read0ValidReg(b) && read0GrantReg(b).src === 1.U && read0GrantReg(b).tc === tc.U, tmem(b).io.rdata0,
Mux(read1ValidReg(b) && read1GrantReg(b).src === 1.U && read1GrantReg(b).tc === tc.U, tmem(b).io.rdata1, 0.U(tmemDataBits.W)))
}).reduce(_ | _)
}).asUInt
core.io.sc_tmem_rdata := VecInit((0 until tmemBanks).map { b =>
Mux(read0ValidReg(b) && read0GrantReg(b).src === 2.U, tmem(b).io.rdata0,
Mux(read1ValidReg(b) && read1GrantReg(b).src === 2.U, tmem(b).io.rdata1, 0.U(tmemDataBits.W)))
}).reduce(_ | _)
// port 2: SMEM B, one TL client per tensor core. RadianceSharedMem arbitrates them.
(0 until nTC).foreach { tc =>
val p2 = port(tc, 2)
val client = outer.tcSmemNodes(tc).out.head
val rawAddress = slice(core.io.tc_a_bits_address, 32, p2)
val lineAddress = rawAddress & (~((outer.tcSmemLineSize - 1).U(32.W))).asUInt
val adapter = Module(new VortexTLAdapter( val adapter = Module(new VortexTLAdapter(
outer.smemSourceWidth, outer.smemSourceWidth,
new VortexBundleA(tagWidth = outer.tensorTagWidth, dataWidth = 32 * 8), new VortexBundleA(tagWidth = outer.tensorTagWidth, dataWidth = tcSmemLineBits),
new VortexBundleD(tagWidth = outer.tensorTagWidth, dataWidth = 32 * 8), new VortexBundleD(tagWidth = outer.tensorTagWidth, dataWidth = tcSmemLineBits),
client client
)) ))
adapter.io.inReq.bits <> DontCare adapter.io.inReq.bits <> DontCare
adapter.io.inReq.valid := smemBBundle.aValid adapter.io.inReq.valid := core.io.tc_a_valid(p2)
adapter.io.inReq.bits.address := smemBBundle.addr adapter.io.inReq.bits.address := lineAddress
adapter.io.inReq.bits.source := smemBBundle.tag adapter.io.inReq.bits.source := slice(core.io.tc_a_bits_tag, outer.tensorTagWidth, p2)
adapter.io.inReq.bits.size := 5.U adapter.io.inReq.bits.size := tcSmemLineTlSize.U
adapter.io.inReq.bits.opcode := Mux(smemBBundle.write.asBool, TLMessages.PutFullData, TLMessages.Get) adapter.io.inReq.bits.opcode := Mux(core.io.tc_a_bits_write(p2).asBool, TLMessages.PutFullData, TLMessages.Get)
adapter.io.inReq.bits.mask := smemBBundle.mask adapter.io.inReq.bits.mask := Fill(outer.tcSmemLineSize, 1.U(1.W))
adapter.io.inReq.bits.data := smemBBundle.data adapter.io.inReq.bits.data := slice(core.io.tc_a_bits_data, tcCoreDataBits, p2)(tcSmemLineBits - 1, 0)
adapter.io.inResp.ready := smemBBundle.dReady adapter.io.inResp.ready := core.io.tc_d_ready(p2)
client._1.a <> adapter.io.outReq client._1.a <> adapter.io.outReq
adapter.io.outResp <> client._1.d adapter.io.outResp <> client._1.d
val lineData = adapter.io.inResp.bits.data
val fragmentData = if (outer.tcSmemLineSize == outer.tcSmemSize) {
lineData
} else {
val fragmentsPerLine = outer.tcSmemLineSize / outer.tcSmemSize
val fragmentIndex = RegInit(0.U(log2Ceil(fragmentsPerLine).W))
val requestFragmentIndex = ((rawAddress & (outer.tcSmemLineSize - 1).U) >>
log2Ceil(outer.tcSmemSize)).asUInt
val lineFragments = lineData.asTypeOf(Vec(fragmentsPerLine, UInt(tcDataBits.W)))
when(adapter.io.inReq.fire) {
fragmentIndex := requestFragmentIndex
}
lineFragments(fragmentIndex)
}
// port 0: global memory (cp/cb) tcAReady(p2) := adapter.io.inReq.ready
val gmemClient = outer.tcGmemNode.get.out.head tcDValid(p2) := adapter.io.inResp.valid
tcDData(p2) := padToCoreData(fragmentData)
tcDTag(p2) := adapter.io.inResp.bits.source
}
// port 0: global memory (cp/cb), one TL client per tensor core.
(0 until nTC).foreach { tc =>
val p0 = port(tc, 0)
val gmemClient = outer.tcGmemNodes(tc).out.head
val gmemAdapter = Module(new VortexTLAdapter( val gmemAdapter = Module(new VortexTLAdapter(
outer.dmemSourceWidth, outer.dmemSourceWidth,
new VortexBundleA(tagWidth = outer.tensorTagWidth, dataWidth = 32 * 8), new VortexBundleA(tagWidth = outer.tensorTagWidth, dataWidth = tcDataBits),
new VortexBundleD(tagWidth = outer.tensorTagWidth, dataWidth = 32 * 8), new VortexBundleD(tagWidth = outer.tensorTagWidth, dataWidth = tcDataBits),
gmemClient gmemClient
)) ))
gmemAdapter.io.inReq.bits <> DontCare gmemAdapter.io.inReq.bits <> DontCare
gmemAdapter.io.inReq.valid := core.io.tc_a_valid(0) gmemAdapter.io.inReq.valid := core.io.tc_a_valid(p0)
gmemAdapter.io.inReq.bits.address := core.io.tc_a_bits_address(31, 0) gmemAdapter.io.inReq.bits.address := slice(core.io.tc_a_bits_address, 32, p0)
gmemAdapter.io.inReq.bits.source := core.io.tc_a_bits_tag(outer.tensorTagWidth - 1, 0) gmemAdapter.io.inReq.bits.source := slice(core.io.tc_a_bits_tag, outer.tensorTagWidth, p0)
gmemAdapter.io.inReq.bits.size := 5.U gmemAdapter.io.inReq.bits.size := tcTlSize.U
gmemAdapter.io.inReq.bits.opcode := Mux(core.io.tc_a_bits_write(0).asBool, TLMessages.PutFullData, TLMessages.Get) gmemAdapter.io.inReq.bits.opcode := Mux(core.io.tc_a_bits_write(p0).asBool, TLMessages.PutFullData, TLMessages.Get)
gmemAdapter.io.inReq.bits.mask := core.io.tc_a_bits_mask(31, 0) gmemAdapter.io.inReq.bits.mask := slice(core.io.tc_a_bits_mask, 32, p0)(outer.tcSmemSize - 1, 0)
gmemAdapter.io.inReq.bits.data := core.io.tc_a_bits_data(255, 0) gmemAdapter.io.inReq.bits.data := slice(core.io.tc_a_bits_data, tcCoreDataBits, p0)(tcDataBits - 1, 0)
gmemAdapter.io.inResp.ready := core.io.tc_d_ready(0) gmemAdapter.io.inResp.ready := core.io.tc_d_ready(p0)
gmemClient._1.a <> gmemAdapter.io.outReq gmemClient._1.a <> gmemAdapter.io.outReq
gmemAdapter.io.outResp <> gmemClient._1.d gmemAdapter.io.outResp <> gmemClient._1.d
core.io.tc_a_ready := Cat(adapter.io.inReq.ready, 0.U(1.W), gmemAdapter.io.inReq.ready) tcAReady(p0) := gmemAdapter.io.inReq.ready
core.io.tc_d_valid := Cat(adapter.io.inResp.valid, 0.U(1.W), gmemAdapter.io.inResp.valid) tcDValid(p0) := gmemAdapter.io.inResp.valid
core.io.tc_d_bits_data := Cat(adapter.io.inResp.bits.data, 0.U((outer.tcSmemSize * 8).W), gmemAdapter.io.inResp.bits.data) tcDData(p0) := padToCoreData(gmemAdapter.io.inResp.bits.data)
core.io.tc_d_bits_tag := Cat(adapter.io.inResp.bits.source, 0.U(outer.tensorTagWidth.W), gmemAdapter.io.inResp.bits.source) tcDTag(p0) := gmemAdapter.io.inResp.bits.source
}
core.io.tc_a_ready := tcAReady.asUInt
core.io.tc_d_valid := tcDValid.asUInt
core.io.tc_d_bits_data := tcDData.asUInt
core.io.tc_d_bits_tag := tcDTag.asUInt
} else { } else {
core.io.tc_a_ready := false.B core.io.tc_a_ready := false.B
core.io.tc_d_valid := false.B core.io.tc_d_valid := false.B
core.io.tc_d_bits_data := DontCare core.io.tc_d_bits_data := DontCare
core.io.tc_d_bits_tag := DontCare core.io.tc_d_bits_tag := DontCare
core.io.tc_tmem_A_rready := DontCare
core.io.tc_tmem_A_rdata := DontCare
core.io.tc_tmem_C_rready := DontCare
core.io.tc_tmem_C_rdata := DontCare core.io.tc_tmem_C_rdata := DontCare
core.io.tc_tmem_C_wready := DontCare
core.io.sc_tmem_rready := DontCare
core.io.sc_tmem_rdata := DontCare
core.io.sc_tmem_wready := DontCare
} }
} }
@@ -995,7 +1231,7 @@ class RadianceTileModuleImp(outer: RadianceTile)
} else if (outer.radianceParams.core.tensorCoreBlackwell) { } else if (outer.radianceParams.core.tensorCoreBlackwell) {
val tensorNumSourceIds = (1 << outer.tensorTagWidth) val tensorNumSourceIds = (1 << outer.tensorTagWidth)
val tensor = Module(new radiance.core.TensorCoreBlackwell( val tensor = Module(new radiance.core.TensorCoreBlackwell(
8, 8, half = true, tensorNumSourceIds)) outer.numWarps, outer.numLsuLanes, half = true, tensorNumSourceIds))
tensor.io.initiate.valid := false.B tensor.io.initiate.valid := false.B
tensor.io.initiate.bits := DontCare tensor.io.initiate.bits := DontCare
tensor.io.respA.valid := false.B tensor.io.respA.valid := false.B
@@ -1006,7 +1242,11 @@ class RadianceTileModuleImp(outer: RadianceTile)
tensor.io.reqA.ready := false.B tensor.io.reqA.ready := false.B
tensor.io.reqB.ready := false.B tensor.io.reqB.ready := false.B
tensor.io.writeback.ready := false.B tensor.io.writeback.ready := false.B
tensor.io.tmemC.rdata := DontCare tensor.io.tmemC.aRready := false.B
tensor.io.tmemC.aRdata := DontCare
tensor.io.tmemC.cRready := false.B
tensor.io.tmemC.cRdata := DontCare
tensor.io.tmemC.cWready := false.B
dontTouch(tensor.io) dontTouch(tensor.io)
} else { } else {
if (outer.radianceParams.core.tensorCoreFP16) { if (outer.radianceParams.core.tensorCoreFP16) {
@@ -1070,18 +1310,6 @@ class VortexTLAdapter(
val outResp = chiselTypeOf(outTL._1.d) val outResp = chiselTypeOf(outTL._1.d)
}) })
val (bundle, edge) = outTL val (bundle, edge) = outTL
val sourceGen = Module(
new SourceGenerator(
newSourceWidth,
Some(inReqT.source),
ignoreInUse = false
)
)
sourceGen.io.gen := io.outReq.fire // use up a source ID only when request is created
sourceGen.io.reclaim.valid := io.outResp.fire
sourceGen.io.reclaim.bits := io.outResp.bits.source
sourceGen.io.meta := io.inReq.bits.source
// io passthrough logic // io passthrough logic
// TLBundleA <> VortexBundleA // TLBundleA <> VortexBundleA
io.outReq.valid := io.inReq.valid io.outReq.valid := io.inReq.valid
@@ -1090,29 +1318,70 @@ class VortexTLAdapter(
io.outReq.bits.size := io.inReq.bits.size io.outReq.bits.size := io.inReq.bits.size
io.outReq.bits.source := io.inReq.bits.source io.outReq.bits.source := io.inReq.bits.source
io.outReq.bits.address := io.inReq.bits.address io.outReq.bits.address := io.inReq.bits.address
// Get requires contiguous mask; only copy core's potentially-partial mask val outMaskWidth = io.outReq.bits.mask.getWidth
// when writing val inMaskWidth = io.inReq.bits.mask.getWidth
val outDataWidth = io.outReq.bits.data.getWidth
val inDataWidth = io.inReq.bits.data.getWidth
val byteOffset = io.inReq.bits.address(log2Ceil(outMaskWidth) - 1, 0)
val responseOffsetWidth = log2Ceil(outMaskWidth)
val responseSourceWidth = inReqT.source.getWidth
val sourceGen = Module(
new SourceGenerator(
newSourceWidth,
Some(UInt((responseSourceWidth + responseOffsetWidth).W)),
ignoreInUse = false
)
)
sourceGen.io.gen := io.outReq.fire // use up a source ID only when request is created
sourceGen.io.reclaim.valid := io.outResp.fire
sourceGen.io.reclaim.bits := io.outResp.bits.source
sourceGen.io.meta := Cat(byteOffset, io.inReq.bits.source)
val alignedMask = Wire(UInt(outMaskWidth.W))
val alignedData = Wire(UInt(outDataWidth.W))
if (outMaskWidth == inMaskWidth) {
alignedMask := io.inReq.bits.mask
} else {
val paddedMask = Wire(UInt(outMaskWidth.W))
paddedMask := io.inReq.bits.mask
alignedMask := (paddedMask << byteOffset)(outMaskWidth - 1, 0)
}
if (outDataWidth == inDataWidth) {
alignedData := io.inReq.bits.data
} else {
val paddedData = Wire(UInt(outDataWidth.W))
paddedData := io.inReq.bits.data
alignedData := (paddedData << (byteOffset << 3))(outDataWidth - 1, 0)
}
// PutFull requires the TL-canonical full mask for address+size; PutPartial
// can carry the core-provided byte mask.
io.outReq.bits.mask := Mux( io.outReq.bits.mask := Mux(
edge.hasData(io.outReq.bits), io.outReq.bits.opcode === TLMessages.PutPartialData,
io.inReq.bits.mask, alignedMask,
// generate TL-correct mask
edge.mask(io.inReq.bits.address, io.inReq.bits.size) edge.mask(io.inReq.bits.address, io.inReq.bits.size)
) )
io.outReq.bits.data := io.inReq.bits.data io.outReq.bits.data := alignedData
io.outReq.bits.corrupt := 0.U io.outReq.bits.corrupt := 0.U
io.inReq.ready := io.outReq.ready io.inReq.ready := io.outReq.ready
// VortexBundleD <> TLBundleD // VortexBundleD <> TLBundleD
io.inResp.valid := io.outResp.valid io.inResp.valid := io.outResp.valid
io.inResp.bits.opcode := io.outResp.bits.opcode io.inResp.bits.opcode := io.outResp.bits.opcode
io.inResp.bits.size := io.outResp.bits.size io.inResp.bits.size := io.outResp.bits.size
io.inResp.bits.source := io.outResp.bits.source val responseMeta = sourceGen.io.peek.asUInt
val responseSource = responseMeta(responseSourceWidth - 1, 0)
val responseByteOffset =
responseMeta(responseSourceWidth + responseOffsetWidth - 1, responseSourceWidth)
io.inResp.bits.source := responseSource
if (outDataWidth == inDataWidth) {
io.inResp.bits.data := io.outResp.bits.data io.inResp.bits.data := io.outResp.bits.data
} else {
io.inResp.bits.data := (io.outResp.bits.data >> (responseByteOffset << 3))(inDataWidth - 1, 0)
}
io.outResp.ready := io.inResp.ready io.outResp.ready := io.inResp.ready
// "man-in-the-middle" // "man-in-the-middle"
io.inReq.ready := io.outReq.ready && sourceGen.io.id.valid io.inReq.ready := io.outReq.ready && sourceGen.io.id.valid
io.outReq.valid := io.inReq.valid && sourceGen.io.id.valid io.outReq.valid := io.inReq.valid && sourceGen.io.id.valid
io.outReq.bits.source := sourceGen.io.id.bits io.outReq.bits.source := sourceGen.io.id.bits
// translate upstream response back to its old sourceId
io.inResp.bits.source := sourceGen.io.peek
} }

View File

@@ -90,28 +90,45 @@ class VortexBundle(tile: RadianceTile)(implicit p: Parameters) extends CoreBundl
val smem_d_bits_data = Input(UInt((tile.numLsuLanes * 32).W)) val smem_d_bits_data = Input(UInt((tile.numLsuLanes * 32).W))
val smem_d_ready = Output(UInt((tile.numLsuLanes * 1).W)) val smem_d_ready = Output(UInt((tile.numLsuLanes * 1).W))
val numTensorCores = if (tile.radianceParams.core.tensorCoreBlackwell) tile.numTensorCores else 1
val tcPortCount = 3 val tcPortCount = 3
val tc_a_valid = Output(UInt(tcPortCount.W)) val tcFlatPortCount = tcPortCount * numTensorCores
val tc_a_bits_write = Output(UInt(tcPortCount.W)) val tc_a_valid = Output(UInt(tcFlatPortCount.W))
val tc_a_bits_address = Output(UInt((tcPortCount * 32).W)) val tc_a_bits_write = Output(UInt(tcFlatPortCount.W))
val tc_a_bits_tag = Output(UInt((tcPortCount * 4).W)) val tc_a_bits_address = Output(UInt((tcFlatPortCount * 32).W))
val tc_a_bits_mask = Output(UInt((tcPortCount * 32).W)) val tc_a_bits_tag = Output(UInt((tcFlatPortCount * 4).W))
val tc_a_bits_data = Output(UInt((tcPortCount * 32 * 8).W)) val tc_a_bits_mask = Output(UInt((tcFlatPortCount * 32).W))
val tc_a_ready = Input(UInt(tcPortCount.W)) val tc_a_bits_data = Output(UInt((tcFlatPortCount * 32 * 8).W))
val tc_d_valid = Input(UInt(tcPortCount.W)) val tc_a_ready = Input(UInt(tcFlatPortCount.W))
val tc_d_bits_data = Input(UInt((tcPortCount * 32 * 8).W)) val tc_d_valid = Input(UInt(tcFlatPortCount.W))
val tc_d_bits_tag = Input(UInt((tcPortCount * 4).W)) val tc_d_bits_data = Input(UInt((tcFlatPortCount * 32 * 8).W))
val tc_d_ready = Output(UInt(tcPortCount.W)) val tc_d_bits_tag = Input(UInt((tcFlatPortCount * 4).W))
val tc_d_ready = Output(UInt(tcFlatPortCount.W))
// Direct SRAM port for TMEM C (bypasses TileLink) // Direct SRAM ports for shared TMEM (bypasses TileLink)
val numLanes = tile.numLsuLanes val numLanes = tile.numLsuLanes
val tc_tmem_C_wen = Output(Bool()) val tc_tmem_A_ren = Output(UInt(numTensorCores.W))
val tc_tmem_C_ren = Output(Bool()) val tc_tmem_A_rready = Input(UInt(numTensorCores.W))
val tc_tmem_C_waddr = Output(UInt(9.W)) val tc_tmem_A_raddr = Output(UInt((numTensorCores * 9).W))
val tc_tmem_C_raddr = Output(UInt(9.W)) val tc_tmem_A_rdata = Input(UInt((numTensorCores * numLanes * 32).W))
val tc_tmem_C_wdata = Output(UInt((numLanes * 32).W)) val tc_tmem_C_ren = Output(UInt(numTensorCores.W))
val tc_tmem_C_mask = Output(UInt((numLanes * 4).W)) val tc_tmem_C_rready = Input(UInt(numTensorCores.W))
val tc_tmem_C_rdata = Input(UInt((numLanes * 32).W)) val tc_tmem_C_raddr = Output(UInt((numTensorCores * 9).W))
val tc_tmem_C_rdata = Input(UInt((numTensorCores * numLanes * 32).W))
val tc_tmem_C_wen = Output(UInt(numTensorCores.W))
val tc_tmem_C_wready = Input(UInt(numTensorCores.W))
val tc_tmem_C_waddr = Output(UInt((numTensorCores * 9).W))
val tc_tmem_C_wdata = Output(UInt((numTensorCores * numLanes * 32).W))
val tc_tmem_C_mask = Output(UInt((numTensorCores * numLanes * 4).W))
val sc_tmem_ren = Output(UInt(1.W))
val sc_tmem_rready = Input(UInt(1.W))
val sc_tmem_raddr = Output(UInt(9.W))
val sc_tmem_rdata = Input(UInt((numLanes * 32).W))
val sc_tmem_wen = Output(UInt(1.W))
val sc_tmem_wready = Input(UInt(1.W))
val sc_tmem_waddr = Output(UInt(9.W))
val sc_tmem_wdata = Output(UInt((numLanes * 32).W))
val sc_tmem_mask = Output(UInt((numLanes * 4).W))
// FIXME: hardcoded // FIXME: hardcoded
val barrierIdBits = tile.barrierMasterNode.out(0)._2.barrierIdBits val barrierIdBits = tile.barrierMasterNode.out(0)._2.barrierIdBits
@@ -147,7 +164,8 @@ class Vortex(tile: RadianceTile)(implicit p: Parameters)
"CORE_ID" -> tile.radianceParams.coreId, "CORE_ID" -> tile.radianceParams.coreId,
"TENSOR_FP16" -> (if (tile.radianceParams.core.tensorCoreFP16) 1 else 0), "TENSOR_FP16" -> (if (tile.radianceParams.core.tensorCoreFP16) 1 else 0),
"STARTUP_ADDR" -> tile.radianceParams.core.startupAddress, "STARTUP_ADDR" -> tile.radianceParams.core.startupAddress,
"NUM_THREADS" -> tile.numLsuLanes "NUM_THREADS" -> tile.numLsuLanes,
"NUM_TENSOR_CORES" -> (if (tile.radianceParams.core.tensorCoreBlackwell) tile.numTensorCores else 1)
) )
) )
with HasBlackBoxResource with HasBlackBoxPath { with HasBlackBoxResource with HasBlackBoxPath {
@@ -211,6 +229,7 @@ class Vortex(tile: RadianceTile)(implicit p: Parameters)
addResource("/vsrc/vortex/hw/rtl/core/VX_scoreboard.sv") addResource("/vsrc/vortex/hw/rtl/core/VX_scoreboard.sv")
addResource("/vsrc/vortex/hw/rtl/core/VX_sfu_unit.sv") addResource("/vsrc/vortex/hw/rtl/core/VX_sfu_unit.sv")
addResource("/vsrc/vortex/hw/rtl/core/VX_smem_unit.sv") addResource("/vsrc/vortex/hw/rtl/core/VX_smem_unit.sv")
addResource("/vsrc/vortex/hw/rtl/core/VX_tensor_ctrl_unit.sv")
addResource("/vsrc/vortex/hw/rtl/core/VX_split_join.sv") addResource("/vsrc/vortex/hw/rtl/core/VX_split_join.sv")
addResource("/vsrc/vortex/hw/rtl/core/VX_trace.vh") addResource("/vsrc/vortex/hw/rtl/core/VX_trace.vh")
addResource("/vsrc/vortex/hw/rtl/core/VX_wctl_unit.sv") addResource("/vsrc/vortex/hw/rtl/core/VX_wctl_unit.sv")
@@ -341,6 +360,7 @@ class Vortex(tile: RadianceTile)(implicit p: Parameters)
addResource("/vsrc/vortex/hw/rtl/fpu/VX_fpu_div.sv") addResource("/vsrc/vortex/hw/rtl/fpu/VX_fpu_div.sv")
addResource("/vsrc/vortex/hw/rtl/fpu/VX_fpu_dpi.sv") addResource("/vsrc/vortex/hw/rtl/fpu/VX_fpu_dpi.sv")
addResource("/vsrc/vortex/hw/rtl/fpu/VX_fpu_dsp.sv") addResource("/vsrc/vortex/hw/rtl/fpu/VX_fpu_dsp.sv")
addResource("/vsrc/vortex/hw/rtl/fpu/VX_fpu_exp.sv")
addResource("/vsrc/vortex/hw/rtl/fpu/VX_fpu_fma.sv") addResource("/vsrc/vortex/hw/rtl/fpu/VX_fpu_fma.sv")
addResource("/vsrc/vortex/hw/rtl/fpu/VX_fpu_ncomp.sv") addResource("/vsrc/vortex/hw/rtl/fpu/VX_fpu_ncomp.sv")
addResource("/vsrc/vortex/hw/rtl/fpu/VX_fpu_rounding.sv") addResource("/vsrc/vortex/hw/rtl/fpu/VX_fpu_rounding.sv")

View File

@@ -0,0 +1,110 @@
package radiance.core
import chisel3._
import chiseltest._
import org.scalatest.flatspec.AnyFlatSpec
class FP8E4M3DecodeHarness extends Module {
val io = IO(new Bundle {
val in = Input(UInt(8.W))
val out = Output(UInt(32.W))
})
io.out := FP8E4M3.toFloat32(io.in)
}
class FP8E4M3MulHarness extends Module {
val io = IO(new Bundle {
val a = Input(UInt(8.W))
val b = Input(UInt(8.W))
val out = Output(UInt(32.W))
})
io.out := FP8E4M3MulToFloat32(io.a, io.b)
}
class FP8E4M3Test extends AnyFlatSpec with ChiselScalatestTester {
behavior of "FP8E4M3"
it should "decode representative E4M3 values to FP32 bits" in {
test(new FP8E4M3DecodeHarness) { c =>
Seq(
0x00 -> 0x00000000L,
0x80 -> 0x80000000L,
0x38 -> 0x3f800000L,
0x40 -> 0x40000000L,
0x30 -> 0x3f000000L,
0x3c -> 0x3fc00000L
).foreach { case (fp8, fp32) =>
c.io.in.poke(fp8.U)
c.clock.step()
c.io.out.expect(fp32.U)
}
}
}
it should "multiply E4M3 operands with FP8-width significands and return FP32 bits" in {
test(new FP8E4M3MulHarness) { c =>
Seq(
(0x38, 0x40, 0x40000000L), // 1.0 * 2.0 = 2.0
(0x30, 0x3c, 0x3f400000L), // 0.5 * 1.5 = 0.75
(0xb8, 0x40, 0xc0000000L), // -1.0 * 2.0 = -2.0
(0x00, 0x40, 0x00000000L), // 0.0 * 2.0 = 0.0
(0x80, 0x40, 0x80000000L) // -0.0 * 2.0 = -0.0
).foreach { case (a, b, out) =>
c.io.a.poke(a.U)
c.io.b.poke(b.U)
c.clock.step()
c.io.out.expect(out.U)
}
}
}
it should "run an 8-wide FP8 dot product with FP32 accumulation" in {
test(new TensorDotProductUnit(8, half = false, inputType = TensorInputType.FP8E4M3)) { c =>
c.io.in.valid.poke(true.B)
c.io.stall.poke(false.B)
for (i <- 0 until 8) {
c.io.in.bits.a(i).poke(0x38.U(8.W))
c.io.in.bits.b(i).poke(0x40.U(8.W))
}
c.io.in.bits.c.poke(0x3f800000L.U(32.W))
c.io.out.valid.expect(false.B)
c.clock.step()
c.io.in.valid.poke(false.B)
c.io.out.valid.expect(false.B)
c.clock.step()
c.clock.step()
c.clock.step()
c.clock.step()
c.io.out.valid.expect(true.B)
c.io.out.bits.data.expect(0x41880000L.U)
}
}
it should "run an 8-wide fractional FP8 dot product with FP32 accumulation" in {
test(new TensorDotProductUnit(8, half = false, inputType = TensorInputType.FP8E4M3)) { c =>
c.io.in.valid.poke(true.B)
c.io.stall.poke(false.B)
for (i <- 0 until 8) {
c.io.in.bits.a(i).poke(0x30.U(8.W))
c.io.in.bits.b(i).poke(0x3c.U(8.W))
}
c.io.in.bits.c.poke(0x40000000L.U(32.W))
c.io.out.valid.expect(false.B)
c.clock.step()
c.io.in.valid.poke(false.B)
c.io.out.valid.expect(false.B)
c.clock.step()
c.clock.step()
c.clock.step()
c.clock.step()
c.io.out.valid.expect(true.B)
c.io.out.bits.data.expect(0x41000000L.U)
}
}
}

View File

@@ -26,7 +26,11 @@ class TensorCoreBlackwellExtendedTest extends AnyFlatSpec with ChiselScalatestTe
c.io.reqB.ready.poke(false.B) c.io.reqB.ready.poke(false.B)
c.io.respC.poke(0.U) c.io.respC.poke(0.U)
c.io.writeback.ready.poke(false.B) c.io.writeback.ready.poke(false.B)
c.io.tmemC.rdata.poke(0.U) c.io.tmemC.aRready.poke(true.B)
c.io.tmemC.aRdata.poke(0.U)
c.io.tmemC.cRready.poke(true.B)
c.io.tmemC.cRdata.poke(0.U)
c.io.tmemC.cWready.poke(true.B)
} }
private def packWords(words: Seq[BigInt], width: Int): BigInt = { private def packWords(words: Seq[BigInt], width: Int): BigInt = {
@@ -39,13 +43,17 @@ class TensorCoreBlackwellExtendedTest extends AnyFlatSpec with ChiselScalatestTe
private def makeTmem() = mutable.Map[BigInt, BigInt]().withDefaultValue(BigInt(0)) private def makeTmem() = mutable.Map[BigInt, BigInt]().withDefaultValue(BigInt(0))
private def stepTmem(c: TensorCoreBlackwell, tmem: mutable.Map[BigInt, BigInt]): Unit = { private def stepTmem(c: TensorCoreBlackwell, tmem: mutable.Map[BigInt, BigInt]): Unit = {
if (c.io.tmemC.ren.peek().litToBoolean) { if (c.io.tmemC.aRen.peek().litToBoolean) {
val addr = c.io.tmemC.raddr.peek().litValue val addr = c.io.tmemC.aRaddr.peek().litValue
c.io.tmemC.rdata.poke(tmem(addr).U) c.io.tmemC.aRdata.poke(tmem(addr).U)
} }
if (c.io.tmemC.wen.peek().litToBoolean) { if (c.io.tmemC.cRen.peek().litToBoolean) {
val addr = c.io.tmemC.waddr.peek().litValue val addr = c.io.tmemC.cRaddr.peek().litValue
tmem(addr) = c.io.tmemC.wdata.peek().litValue c.io.tmemC.cRdata.poke(tmem(addr).U)
}
if (c.io.tmemC.cWen.peek().litToBoolean) {
val addr = c.io.tmemC.cWaddr.peek().litValue
tmem(addr) = c.io.tmemC.cWdata.peek().litValue
} }
} }
@@ -60,22 +68,22 @@ class TensorCoreBlackwellExtendedTest extends AnyFlatSpec with ChiselScalatestTe
val cBase = BigInt(0x600) // row 48, C tile rows 48~79 (no overlap with A) val cBase = BigInt(0x600) // row 48, C tile rows 48~79 (no overlap with A)
val bBase = BigInt(0x800) val bBase = BigInt(0x800)
val fp16One = BigInt(0x3c00) val fp8One = BigInt(0x38)
val fp32Zero = BigInt(0) val fp32Zero = BigInt(0)
// 4 sets × 8 dot products × (1.0 × 2.0) = 64.0f // 4 sets × 8 dot products × (1.0 × 2.0) = 64.0f
val fp32SixtyFour = BigInt(0x42800000L) val fp32SixtyFour = BigInt(0x42800000L)
// Populate TMEM A at offset aBase (all 32 frags) // Populate TMEM A at offset aBase (all 32 frags)
val aFrag = packWords(Seq.fill(16)(fp16One), 16) val aFrag = packWords(Seq.fill(32)(fp8One), 8)
val cFrag = packWords(Seq.fill(numLanes)(fp32Zero), 32) val cFrag = packWords(Seq.fill(numLanes)(fp32Zero), 32)
for (i <- 0 until 32) { for (i <- 0 until 32) {
tmem(aBase / fragBytes + i) = aFrag tmem(aBase / fragBytes + i) = aFrag
tmem(cBase / fragBytes + i) = cFrag tmem(cBase / fragBytes + i) = cFrag
} }
// SMEM B with fp16 2.0 // SMEM B with packed FP8 E4M3 2.0
val fp16Two = BigInt(0x4000) val fp8Two = BigInt(0x40)
val bFrag = packWords(Seq.fill(16)(fp16Two), 16) val bFrag = packWords(Seq.fill(32)(fp8Two), 8)
val bMem = mutable.Map[BigInt, BigInt]().withDefaultValue(bFrag) val bMem = mutable.Map[BigInt, BigInt]().withDefaultValue(bFrag)
for (i <- 0 until 32) bMem(bBase + i * fragBytes) = bFrag for (i <- 0 until 32) bMem(bBase + i * fragBytes) = bFrag
@@ -154,9 +162,9 @@ class TensorCoreBlackwellExtendedTest extends AnyFlatSpec with ChiselScalatestTe
// cpWrite: respA fires, tmemC written // cpWrite: respA fires, tmemC written
c.io.respA.valid.poke(true.B) c.io.respA.valid.poke(true.B)
c.io.respA.bits.data.poke(cpData.U) c.io.respA.bits.data.poke(cpData.U)
c.io.tmemC.wen.expect(true.B) c.io.tmemC.cWen.expect(true.B)
c.io.tmemC.waddr.expect((tmemAddr / fragBytes).U) c.io.tmemC.cWaddr.expect((tmemAddr / fragBytes).U)
c.io.tmemC.wdata.expect(cpData.U) c.io.tmemC.cWdata.expect(cpData.U)
stepTmem(c, tmem) stepTmem(c, tmem)
c.clock.step() c.clock.step()
c.io.respA.valid.poke(false.B) c.io.respA.valid.poke(false.B)
@@ -171,10 +179,10 @@ class TensorCoreBlackwellExtendedTest extends AnyFlatSpec with ChiselScalatestTe
c.io.initiate.valid.poke(false.B) c.io.initiate.valid.poke(false.B)
// ldReq: ren asserted, serve from tmem model // ldReq: ren asserted, serve from tmem model
c.io.tmemC.ren.expect(true.B) c.io.tmemC.cRen.expect(true.B)
c.io.tmemC.rdata.poke(tmem(tmemAddr / fragBytes).U) c.io.tmemC.cRdata.poke(tmem(tmemAddr / fragBytes).U)
c.clock.step() c.clock.step()
c.io.tmemC.rdata.poke(tmem(tmemAddr / fragBytes).U) c.io.tmemC.cRdata.poke(tmem(tmemAddr / fragBytes).U)
c.clock.step() c.clock.step()
// writeback should carry cpData // writeback should carry cpData
@@ -206,8 +214,8 @@ class TensorCoreBlackwellExtendedTest extends AnyFlatSpec with ChiselScalatestTe
c.clock.step() c.clock.step()
// stWrite: tmemC written // stWrite: tmemC written
c.io.tmemC.wen.expect(true.B) c.io.tmemC.cWen.expect(true.B)
c.io.tmemC.wdata.expect(stData.U) c.io.tmemC.cWdata.expect(stData.U)
stepTmem(c, tmem) stepTmem(c, tmem)
c.clock.step() c.clock.step()
@@ -217,13 +225,15 @@ class TensorCoreBlackwellExtendedTest extends AnyFlatSpec with ChiselScalatestTe
c.io.initiate.bits.addressA.poke(tmemAddr.U) c.io.initiate.bits.addressA.poke(tmemAddr.U)
c.io.initiate.bits.addressB.poke("h20000000".U) c.io.initiate.bits.addressB.poke("h20000000".U)
c.io.reqA.ready.poke(true.B) c.io.reqA.ready.poke(true.B)
c.io.tmemC.rdata.poke(tmem(tmemAddr / fragBytes).U) c.io.tmemC.cRdata.poke(tmem(tmemAddr / fragBytes).U)
c.clock.step() c.clock.step()
c.io.initiate.valid.poke(false.B) c.io.initiate.valid.poke(false.B)
// cbRead: ren asserted // cbRead: ren asserted
c.io.tmemC.ren.expect(true.B) c.io.tmemC.cRen.expect(true.B)
c.io.tmemC.rdata.poke(tmem(tmemAddr / fragBytes).U) c.io.tmemC.cRdata.poke(tmem(tmemAddr / fragBytes).U)
c.clock.step()
c.io.tmemC.cRdata.poke(tmem(tmemAddr / fragBytes).U)
c.clock.step() c.clock.step()
// cbWrite: reqA write with stData // cbWrite: reqA write with stData
@@ -280,7 +290,7 @@ class TensorCoreBlackwellExtendedTest extends AnyFlatSpec with ChiselScalatestTe
c.clock.step() c.clock.step()
c.io.initiate.ready.expect(false.B) c.io.initiate.ready.expect(false.B)
c.io.tmemC.wen.expect(true.B) c.io.tmemC.cWen.expect(true.B)
c.clock.step() c.clock.step()
c.io.initiate.ready.expect(true.B) c.io.initiate.ready.expect(true.B)
} }
@@ -309,8 +319,8 @@ class TensorCoreBlackwellExtendedTest extends AnyFlatSpec with ChiselScalatestTe
c.io.initiate.valid.poke(false.B) c.io.initiate.valid.poke(false.B)
c.io.reqC.valid.expect(true.B) c.io.reqC.valid.expect(true.B)
c.clock.step() c.clock.step()
c.io.tmemC.wen.expect(true.B) c.io.tmemC.cWen.expect(true.B)
c.io.tmemC.waddr.expect((warp0TmemA / fragBytes).U) c.io.tmemC.cWaddr.expect((warp0TmemA / fragBytes).U)
stepTmem(c, tmem) stepTmem(c, tmem)
c.clock.step() c.clock.step()
@@ -324,8 +334,8 @@ class TensorCoreBlackwellExtendedTest extends AnyFlatSpec with ChiselScalatestTe
c.io.initiate.valid.poke(false.B) c.io.initiate.valid.poke(false.B)
c.io.reqC.valid.expect(true.B) c.io.reqC.valid.expect(true.B)
c.clock.step() c.clock.step()
c.io.tmemC.wen.expect(true.B) c.io.tmemC.cWen.expect(true.B)
c.io.tmemC.waddr.expect((warp3TmemA / fragBytes).U) c.io.tmemC.cWaddr.expect((warp3TmemA / fragBytes).U)
stepTmem(c, tmem) stepTmem(c, tmem)
c.clock.step() c.clock.step()

View File

@@ -0,0 +1,76 @@
package radiance.core
import chisel3._
import chiseltest._
import org.scalatest.flatspec.AnyFlatSpec
class TensorCoreBlackwellFP8MicrostepHarness extends Module {
val io = IO(new Bundle {
val valid = Input(Bool())
val operandA = Input(UInt(512.W))
val operandB = Input(UInt(256.W))
val c = Input(UInt(32.W))
val elemM = Input(UInt(2.W))
val elemN = Input(UInt(1.W))
val outValid = Output(Bool())
val out = Output(UInt(32.W))
})
val dpu = Module(new TensorDotProductUnit(
dim = 8,
half = false,
inputType = TensorInputType.FP8E4M3
))
dpu.io.in.valid := io.valid
for (k <- 0 until 8) {
dpu.io.in.bits.a(k) := TensorCoreBlackwellFP8Packing.selectA(io.operandA, k, io.elemM, numLanes = 8)
dpu.io.in.bits.b(k) := TensorCoreBlackwellFP8Packing.selectB(io.operandB, k, io.elemN)
}
dpu.io.in.bits.c := io.c
dpu.io.stall := false.B
io.outValid := dpu.io.out.valid
io.out := dpu.io.out.bits.data
}
class TensorCoreBlackwellFP8Test extends AnyFlatSpec with ChiselScalatestTester {
behavior of "TensorCoreBlackwell FP8 operand microstep"
private def packWords(words: Seq[BigInt], width: Int): BigInt = {
val mask = (BigInt(1) << width) - 1
words.zipWithIndex.foldLeft(BigInt(0)) {
case (acc, (word, i)) => acc | ((word & mask) << (i * width))
}
}
it should "select packed FP8 E4M3 operands and accumulate into FP32" in {
test(new TensorCoreBlackwellFP8MicrostepHarness) { c =>
val fp8One = BigInt(0x38)
val fp8Two = BigInt(0x40)
val fp32One = BigInt(0x3f800000L)
val fp32Seventeen = BigInt(0x41880000L)
val operandA = packWords(Seq.fill(64)(fp8One), 8)
val operandB = packWords(Seq.fill(32)(fp8Two), 8)
c.io.valid.poke(true.B)
c.io.operandA.poke(operandA.U)
c.io.operandB.poke(operandB.U)
c.io.c.poke(fp32One.U)
c.io.elemM.poke(0.U)
c.io.elemN.poke(0.U)
c.io.outValid.expect(false.B)
c.clock.step()
c.io.valid.poke(false.B)
c.io.outValid.expect(false.B)
c.clock.step()
c.clock.step()
c.clock.step()
c.clock.step()
c.io.outValid.expect(true.B)
c.io.out.expect(fp32Seventeen.U)
}
}
}

View File

@@ -25,7 +25,11 @@ class TensorCoreBlackwellTest extends AnyFlatSpec with ChiselScalatestTester {
c.io.reqB.ready.poke(false.B) c.io.reqB.ready.poke(false.B)
c.io.respC.poke(0.U) c.io.respC.poke(0.U)
c.io.writeback.ready.poke(false.B) c.io.writeback.ready.poke(false.B)
c.io.tmemC.rdata.poke(0.U) c.io.tmemC.aRready.poke(true.B)
c.io.tmemC.aRdata.poke(0.U)
c.io.tmemC.cRready.poke(true.B)
c.io.tmemC.cRdata.poke(0.U)
c.io.tmemC.cWready.poke(true.B)
} }
private def packWords(words: Seq[BigInt], width: Int): BigInt = { private def packWords(words: Seq[BigInt], width: Int): BigInt = {
@@ -38,15 +42,19 @@ class TensorCoreBlackwellTest extends AnyFlatSpec with ChiselScalatestTester {
// Simple TMEM model: address → 256-bit row // Simple TMEM model: address → 256-bit row
private def makeTmem() = mutable.Map[BigInt, BigInt]().withDefaultValue(BigInt(0)) private def makeTmem() = mutable.Map[BigInt, BigInt]().withDefaultValue(BigInt(0))
// Drive tmemC read response from model, handle write // Drive TMEM read responses from model, handle C-port writes.
private def stepTmem(c: TensorCoreBlackwell, tmem: mutable.Map[BigInt, BigInt]): Unit = { private def stepTmem(c: TensorCoreBlackwell, tmem: mutable.Map[BigInt, BigInt]): Unit = {
if (c.io.tmemC.ren.peek().litToBoolean) { if (c.io.tmemC.aRen.peek().litToBoolean) {
val addr = c.io.tmemC.raddr.peek().litValue val addr = c.io.tmemC.aRaddr.peek().litValue
c.io.tmemC.rdata.poke(tmem(addr).U) c.io.tmemC.aRdata.poke(tmem(addr).U)
} }
if (c.io.tmemC.wen.peek().litToBoolean) { if (c.io.tmemC.cRen.peek().litToBoolean) {
val addr = c.io.tmemC.waddr.peek().litValue val addr = c.io.tmemC.cRaddr.peek().litValue
tmem(addr) = c.io.tmemC.wdata.peek().litValue c.io.tmemC.cRdata.poke(tmem(addr).U)
}
if (c.io.tmemC.cWen.peek().litToBoolean) {
val addr = c.io.tmemC.cWaddr.peek().litValue
tmem(addr) = c.io.tmemC.cWdata.peek().litValue
} }
} }
@@ -65,19 +73,19 @@ class TensorCoreBlackwellTest extends AnyFlatSpec with ChiselScalatestTester {
c.io.initiate.bits.rd.poke(3.U) c.io.initiate.bits.rd.poke(3.U)
c.io.initiate.bits.addressA.poke(tmemAddr.U) c.io.initiate.bits.addressA.poke(tmemAddr.U)
c.io.writeback.ready.poke(true.B) c.io.writeback.ready.poke(true.B)
c.io.tmemC.rdata.poke(testData.U) c.io.tmemC.cRdata.poke(testData.U)
c.clock.step() c.clock.step()
c.io.initiate.valid.poke(false.B) c.io.initiate.valid.poke(false.B)
c.io.initiate.ready.expect(false.B) c.io.initiate.ready.expect(false.B)
// ldReq: tmemC.ren asserted; rdata must be valid before next step // ldReq: tmemC.ren asserted; rdata must be valid before next step
c.io.tmemC.ren.expect(true.B) c.io.tmemC.cRen.expect(true.B)
c.io.tmemC.raddr.expect((tmemAddr / fragBytes).U) c.io.tmemC.cRaddr.expect((tmemAddr / fragBytes).U)
c.io.tmemC.rdata.poke(testData.U) c.io.tmemC.cRdata.poke(testData.U)
c.clock.step() c.clock.step()
// waitWb: wbValid gets set this cycle, step to let it register // waitWb: wbValid gets set this cycle, step to let it register
c.io.tmemC.rdata.poke(testData.U) c.io.tmemC.cRdata.poke(testData.U)
c.clock.step() c.clock.step()
// idle: writeback.valid now true // idle: writeback.valid now true
@@ -91,6 +99,38 @@ class TensorCoreBlackwellTest extends AnyFlatSpec with ChiselScalatestTester {
} }
} }
it should "tcgen05_ld: support 4-lane 16-byte fragments" in {
val lanes = 4
test(new TensorCoreBlackwell(numWarps, lanes, half = true, numSourceIds = 4)) { c =>
idleIO(c)
val fragBytes = 16
val tmemAddr = BigInt(0x40)
val testData = packWords(Seq.tabulate(lanes)(i => BigInt(0x2000 + i)), 32)
c.io.initiate.valid.poke(true.B)
c.io.initiate.bits.op.poke(4.U) // tcgen05Ld
c.io.initiate.bits.wid.poke(0.U)
c.io.initiate.bits.rd.poke(3.U)
c.io.initiate.bits.addressA.poke(tmemAddr.U)
c.io.writeback.ready.poke(true.B)
c.clock.step()
c.io.initiate.valid.poke(false.B)
c.io.tmemC.cRen.expect(true.B)
c.io.tmemC.cRaddr.expect((tmemAddr / fragBytes).U)
c.io.tmemC.cRdata.poke(testData.U)
c.clock.step()
c.io.tmemC.cRdata.poke(testData.U)
c.clock.step()
c.io.writeback.valid.expect(true.B)
c.io.writeback.bits.rd.expect(3.U)
for (i <- 0 until lanes) {
c.io.writeback.bits.data(i).expect((0x2000 + i).U)
}
}
}
it should "tcgen05_st: write from respC to TMEM" in { it should "tcgen05_st: write from respC to TMEM" in {
test(new TensorCoreBlackwell(numWarps, numLanes, half = true, numSourceIds = 4)) { c => test(new TensorCoreBlackwell(numWarps, numLanes, half = true, numSourceIds = 4)) { c =>
idleIO(c) idleIO(c)
@@ -114,9 +154,9 @@ class TensorCoreBlackwellTest extends AnyFlatSpec with ChiselScalatestTester {
c.clock.step() c.clock.step()
// stWrite: tmemC.wen asserted with storeData // stWrite: tmemC.wen asserted with storeData
c.io.tmemC.wen.expect(true.B) c.io.tmemC.cWen.expect(true.B)
c.io.tmemC.waddr.expect((tmemAddr / fragBytes).U) c.io.tmemC.cWaddr.expect((tmemAddr / fragBytes).U)
c.io.tmemC.wdata.expect(storeData.U) c.io.tmemC.cWdata.expect(storeData.U)
c.clock.step() c.clock.step()
c.io.initiate.ready.expect(true.B) c.io.initiate.ready.expect(true.B)
} }
@@ -151,9 +191,9 @@ class TensorCoreBlackwellTest extends AnyFlatSpec with ChiselScalatestTester {
c.io.respA.bits.data.poke(cpData.U) c.io.respA.bits.data.poke(cpData.U)
// tmemC write happens combinatorially when respA fires // tmemC write happens combinatorially when respA fires
c.io.tmemC.wen.expect(true.B) c.io.tmemC.cWen.expect(true.B)
c.io.tmemC.waddr.expect((tmemAddr / fragBytes).U) c.io.tmemC.cWaddr.expect((tmemAddr / fragBytes).U)
c.io.tmemC.wdata.expect(cpData.U) c.io.tmemC.cWdata.expect(cpData.U)
c.clock.step() c.clock.step()
c.io.initiate.ready.expect(true.B) c.io.initiate.ready.expect(true.B)
} }
@@ -172,14 +212,16 @@ class TensorCoreBlackwellTest extends AnyFlatSpec with ChiselScalatestTester {
c.io.initiate.bits.addressA.poke(tmemAddr.U) c.io.initiate.bits.addressA.poke(tmemAddr.U)
c.io.initiate.bits.addressB.poke(gmemAddr.U) c.io.initiate.bits.addressB.poke(gmemAddr.U)
c.io.reqA.ready.poke(true.B) c.io.reqA.ready.poke(true.B)
c.io.tmemC.rdata.poke(cbData.U) c.io.tmemC.cRdata.poke(cbData.U)
c.clock.step() c.clock.step()
c.io.initiate.valid.poke(false.B) c.io.initiate.valid.poke(false.B)
c.io.initiate.ready.expect(false.B) c.io.initiate.ready.expect(false.B)
// cbRead: tmemC.ren asserted // cbRead: tmemC.ren asserted
c.io.tmemC.ren.expect(true.B) c.io.tmemC.cRen.expect(true.B)
c.io.tmemC.raddr.expect((tmemAddr / fragBytes).U) c.io.tmemC.cRaddr.expect((tmemAddr / fragBytes).U)
c.clock.step()
c.io.tmemC.cRdata.poke(cbData.U)
c.clock.step() c.clock.step()
c.io.initiate.ready.expect(false.B) c.io.initiate.ready.expect(false.B)
@@ -207,13 +249,13 @@ class TensorCoreBlackwellTest extends AnyFlatSpec with ChiselScalatestTester {
val bBase = BigInt(0x800) val bBase = BigInt(0x800)
val cBase = BigInt(0x1000) val cBase = BigInt(0x1000)
// A: all fp16 1.0 (0x3c00), 16 halves per frag // A/B: packed FP8 E4M3 bytes, 32 elements per 256-bit frag
val fp16One = BigInt(0x3c00) val fp8One = BigInt(0x38)
val fp16Two = BigInt(0x4000) val fp8Two = BigInt(0x40)
val fp32One = BigInt(0x3f800000) val fp32One = BigInt(0x3f800000)
val fp32SixtyFive = BigInt(0x42820000) val fp32SixtyFive = BigInt(0x42820000)
val aFrag = packWords(Seq.fill(16)(fp16One), 16) val aFrag = packWords(Seq.fill(32)(fp8One), 8)
val bFrag = packWords(Seq.fill(16)(fp16Two), 16) val bFrag = packWords(Seq.fill(32)(fp8Two), 8)
val cFrag = packWords(Seq.fill(numLanes)(fp32One), 32) val cFrag = packWords(Seq.fill(numLanes)(fp32One), 32)
val expectedCFrag = packWords(Seq.fill(numLanes)(fp32SixtyFive), 32) val expectedCFrag = packWords(Seq.fill(numLanes)(fp32SixtyFive), 32)