13 Commits

Author SHA1 Message Date
Zhongdi LUO
ca4c48251d feat: use fp8 multiply for blackwell bwgmma 2026-07-03 08:40:10 +00:00
Zhongdi LUO
12f5c6d92d feat: switch blackwell bwgmma inputs to fp8 2026-07-02 08:57:59 +00:00
Zhongdi LUO
68a7f66046 feat: add fp8 tensor dot product path 2026-07-02 08:06:42 +00:00
Zhongdi LUO
2afb96bb14 feat: add fp8 e4m3 decode support 2026-07-02 07:59:01 +00:00
Zhongdi LUO
007350fd5a feat: include vortex fexp RTL 2026-07-02 07:25:32 +00:00
Zhongdi LUO
47d6585896 Wire scalar TMEM through Radiance tile 2026-06-24 06:25:10 +00:00
Zhongdi LUO
f88085331e Save pre-TMEM-bank Radiance changes 2026-06-21 08:20:21 +00:00
Zhongdi LUO
1e78574113 Fix Blackwell SMEM fragment alignment 2026-05-27 08:43:36 +00:00
Zhongdi LUO
c6c30ec0dc Add 4-lane Blackwell tensor core support 2026-05-27 05:54:39 +00:00
126523c5d2 Track vortex masked barrier fix 2026-05-27 09:08:03 +08:00
87a4bbc757 Integrate WU architecture in Radiance 2026-05-25 19:25:59 +08:00
5112f3665a Add Blackwell tensor core implementation and tests
- Implement TensorCoreBlackwell.scala with BWGMMA and TCGEN05 instructions
- Update TensorDPU, RadianceTile, and VortexCore for Blackwell integration
- Add TensorCoreBlackwellExtendedTest for comprehensive testing
- Update vortex submodule with Blackwell ISA support
2026-05-06 14:51:11 +08:00
136cf70a58 Add Blackwell tensor core baseline plumbing 2026-04-25 10:15:31 +08:00
14 changed files with 2077 additions and 80 deletions

5
.gitmodules vendored
View File

@@ -1,6 +1,3 @@
[submodule "src/main/resources/vsrc/vortex"] [submodule "src/main/resources/vsrc/vortex"]
path = src/main/resources/vsrc/vortex path = src/main/resources/vsrc/vortex
url = https://github.com/hansungk/vortex.git url = https://git.nudt.space/wu-arch/vortex.git
[submodule "cyclotron"]
path = cyclotron
url = https://github.com/hansungk/cyclotron.git

Submodule cyclotron deleted from ca6933c4ec

View File

@@ -23,6 +23,9 @@ endif
ifeq ($(shell echo $(CONFIG) | grep -E "HopperConfig$$"),$(CONFIG)) ifeq ($(shell echo $(CONFIG) | grep -E "HopperConfig$$"),$(CONFIG))
EXTRA_SIM_PREPROC_DEFINES += +define+NUM_CORES=4 +define+EXT_T_HOPPER EXTRA_SIM_PREPROC_DEFINES += +define+NUM_CORES=4 +define+EXT_T_HOPPER
endif endif
ifeq ($(shell echo $(CONFIG) | grep -E "BlackwellConfig$$"),$(CONFIG))
EXTRA_SIM_PREPROC_DEFINES += +define+NUM_CORES=1 +define+NUM_WARPS=4 +define+NUM_THREADS=4 +define+NUM_TENSOR_WARPS=2 +define+EXT_T_BLACKWELL
endif
ifeq ($(shell echo $(CONFIG) | grep -E "FlashConfig$$"),$(CONFIG)) ifeq ($(shell echo $(CONFIG) | grep -E "FlashConfig$$"),$(CONFIG))
EXTRA_SIM_PREPROC_DEFINES += +define+NUM_CORES=4 EXTRA_SIM_PREPROC_DEFINES += +define+NUM_CORES=4
endif endif

View File

@@ -0,0 +1,478 @@
// See LICENSE.SiFive for license details.
// See LICENSE.Berkeley for license details.
package radiance.core
import chisel3._
import chisel3.util._
object TensorCoreBlackwellFP8Packing {
def fp8Byte(x: UInt, idx: Int): UInt = {
x((idx + 1) * 8 - 1, idx * 8)
}
def selectA(operandA: UInt, k: Int, elemM: UInt, numLanes: Int): UInt = {
if (numLanes == 4) {
Mux(elemM.asBool, fp8Byte(operandA, 8 + k), fp8Byte(operandA, k))
} else {
MuxLookup(elemM, fp8Byte(operandA, k))(Seq(
0.U -> fp8Byte(operandA, k),
1.U -> fp8Byte(operandA, 8 + k),
2.U -> fp8Byte(operandA, 16 + k),
3.U -> fp8Byte(operandA, 24 + k)
))
}
}
def selectB(operandB: UInt, k: Int, elemN: UInt): UInt = {
Mux(elemN.asBool, fp8Byte(operandB, 8 + k), fp8Byte(operandB, k))
}
}
class TensorCoreBlackwell(
val numWarps: Int,
val numLanes: Int,
val half: Boolean,
val numSourceIds: Int = 16,
val numFPRegs: Int = 32
) extends Module {
require(half, "Blackwell MMA compatibility flag must remain true; BWGMMA inputs are FP8 E4M3 on this branch")
require(numLanes == 4 || numLanes == 8,
s"Blackwell MMA currently supports 4 or 8 lanes, got ${numLanes}")
val numWarpBits = log2Ceil(numWarps)
val sourceWidth = log2Ceil(numSourceIds)
val laneWidth = 4 * 8
val memWidth = numLanes * laneWidth
val numFPRegBits = log2Ceil(numFPRegs)
val addressWidth = 32
val maskWidth = memWidth / 8
val fragOffsetBits = log2Ceil(memWidth / 8)
val numSets = 4
val numBGroups = 4
val numSubsteps = 2
val mElemsPerFrag = if (numLanes == 4) 2 else 4
val numMGroups = 16 / mElemsPerFrag
val numAFragsPerMGroup = 2
val numAFragsPerSet = numMGroups * numAFragsPerMGroup
val numBFragsPerSubstep = if (numLanes == 4) 2 else 1
val numBFragsPerGroup = numSubsteps * numBFragsPerSubstep
val numBFragsPerSet = numBGroups * numBFragsPerGroup
val numCFrags = numBGroups * numMGroups * numSubsteps
object Ops {
val bwgmma :: bwgmmaWait :: tcgen05Cp :: tcgen05CpWait :: tcgen05Ld :: tcgen05St :: tcgen05Cb :: Nil = Enum(7)
}
class TensorMemReq(
sourceWidth: Int,
dataWidth: Int
) extends Bundle {
val rw = Bool()
val byteen = UInt((dataWidth / 8).W)
val source = UInt(sourceWidth.W)
val address = UInt(addressWidth.W)
val data = UInt(dataWidth.W)
}
class TensorMemResp(
sourceWidth: Int,
dataWidth: Int
) extends Bundle {
val source = UInt(sourceWidth.W)
val data = UInt(dataWidth.W)
}
// Direct SRAM port for TMEM (no TileLink overhead)
class TmemSramPort extends Bundle {
val aRen = Output(Bool())
val aRready = Input(Bool())
val aRaddr = Output(UInt(log2Ceil(numWarps * numCFrags * 2).W))
val aRdata = Input(UInt(memWidth.W))
val cRen = Output(Bool())
val cRready = Input(Bool())
val cRaddr = Output(UInt(log2Ceil(numWarps * numCFrags * 2).W))
val cRdata = Input(UInt(memWidth.W))
val cWen = Output(Bool())
val cWready = Input(Bool())
val cWaddr = Output(UInt(log2Ceil(numWarps * numCFrags * 2).W))
val cWdata = Output(UInt(memWidth.W))
val cMask = Output(UInt(maskWidth.W))
}
val io = IO(new Bundle {
val initiate = Flipped(Decoupled(new Bundle {
val op = UInt(3.W)
val wid = UInt(numWarpBits.W)
val rd = UInt(numFPRegBits.W)
val addressA = UInt(addressWidth.W)
val addressB = UInt(addressWidth.W)
val addressC = UInt(addressWidth.W)
}))
val writeback = Decoupled(new Bundle {
val last = Bool()
val wid = UInt(numWarpBits.W)
val rd = UInt(numFPRegBits.W)
val data = Vec(numLanes, UInt(laneWidth.W))
})
val respA = Flipped(Decoupled(new TensorMemResp(sourceWidth, memWidth)))
val respB = Flipped(Decoupled(new TensorMemResp(sourceWidth, memWidth)))
val respC = Input(UInt(memWidth.W))
val reqA = Decoupled(new TensorMemReq(sourceWidth, memWidth))
val reqB = Decoupled(new TensorMemReq(sourceWidth, memWidth))
val reqC = Output(Valid(UInt(numFPRegBits.W)))
val tmemC = new TmemSramPort // direct SRAM for C matrix (replaces reqCmem/respCmem)
})
object State extends ChiselEnum {
val idle, bwLoadAReq, bwLoadAResp, bwLoadBReq, bwLoadBResp,
bwReadCReq, bwReadCResp, bwCompute, bwDpuResp, bwWriteCReq,
bwWriteCWait, bwDone, cpRead, cpWrite, ldReq, stReq, stWrite, waitWb,
cbRead, cbCapture, cbWrite = Value
}
val state = RegInit(State.idle)
val opReg = RegInit(0.U(3.W))
val widReg = RegInit(0.U(numWarpBits.W))
val rdReg = RegInit(0.U(numFPRegBits.W))
val addrAReg = RegInit(0.U(addressWidth.W))
val addrBReg = RegInit(0.U(addressWidth.W))
val addrCReg = RegInit(0.U(addressWidth.W))
val sourceCounter = RegInit(0.U(sourceWidth.W))
val setReg = RegInit(0.U(log2Ceil(numSets).W))
val aIndexReg = RegInit(0.U(log2Ceil(numAFragsPerSet).W))
val bGroupReg = RegInit(0.U(log2Ceil(numBGroups).W))
val bIndexReg = RegInit(0.U(log2Ceil(numBFragsPerGroup).W))
val mGroupReg = RegInit(0.U(log2Ceil(numMGroups).W))
val substepReg = RegInit(0.U(1.W))
val elemReg = RegInit(0.U(log2Ceil(numLanes).W))
val waitCounter = RegInit(0.U(3.W))
val aBuf = Reg(Vec(numAFragsPerSet, UInt(memWidth.W)))
val bBuf = Reg(Vec(numBFragsPerGroup, UInt(memWidth.W)))
val cDataReg = Reg(UInt(memWidth.W))
val mmaDataReg = Reg(Vec(numLanes, UInt(laneWidth.W)))
private def bumpSource(): Unit = {
sourceCounter := sourceCounter + 1.U
}
private def byteAddress(base: UInt, fragIndex: UInt): UInt = {
base + (fragIndex << fragOffsetBits).asUInt
}
val aFragIndex = (setReg * numAFragsPerSet.U) + aIndexReg
val bFragIndex =
(setReg * numBFragsPerSet.U) + (bGroupReg * numBFragsPerGroup.U) + bIndexReg
val cFragIndex =
(((bGroupReg * numMGroups.U) + mGroupReg) * numSubsteps.U) + substepReg
val aReqAddress = byteAddress(addrAReg, aFragIndex)
val bReqAddress = byteAddress(addrBReg, bFragIndex)
val cReqAddress = byteAddress(addrCReg, cFragIndex)
val tmemABase = (addrAReg >> fragOffsetBits.U).asUInt
val tmemCBase = (addrCReg >> fragOffsetBits.U).asUInt
val reqA = Wire(Decoupled(new TensorMemReq(sourceWidth, memWidth)))
val reqB = Wire(Decoupled(new TensorMemReq(sourceWidth, memWidth)))
reqA.valid := false.B
reqA.bits := 0.U.asTypeOf(reqA.bits)
reqB.valid := false.B
reqB.bits := 0.U.asTypeOf(reqB.bits)
io.reqA <> reqA
io.reqB <> reqB
io.tmemC.aRen := false.B
io.tmemC.aRaddr := 0.U
io.tmemC.cRen := false.B
io.tmemC.cRaddr := 0.U
io.tmemC.cWen := false.B
io.tmemC.cWaddr := 0.U
io.tmemC.cWdata := 0.U
io.tmemC.cMask := 0.U
val wbValid = RegInit(false.B)
val wbData = Reg(Vec(numLanes, UInt(laneWidth.W)))
io.writeback.valid := wbValid
io.writeback.bits.last := true.B
io.writeback.bits.wid := widReg
io.writeback.bits.rd := rdReg
io.writeback.bits.data := wbData
io.reqC.valid := false.B
io.reqC.bits := rdReg
// drain stale write-ack from TMEM so TLRAM doesn't stall on r_full
io.respA.ready := state === State.idle
io.respB.ready := false.B
io.initiate.ready := state === State.idle && !wbValid
val operandA = Cat(aBuf((mGroupReg << 1) + 1.U), aBuf(mGroupReg << 1))
val operandB =
if (numLanes == 4) {
Cat(bBuf((substepReg << 1) + 1.U), bBuf(substepReg << 1))
} else {
bBuf(substepReg)
}
val cWords = cDataReg.asTypeOf(Vec(numLanes, UInt(laneWidth.W)))
val dpuInValid = WireDefault(false.B)
val dpu = Module(new TensorDotProductUnit(
dim = 8,
half = false,
inputType = TensorInputType.FP8E4M3
))
val elemM = if (numLanes == 4) elemReg(0, 0) else elemReg(1, 0)
val elemN = if (numLanes == 4) elemReg(1) else elemReg(2)
dpu.io.in.valid := dpuInValid
for (k <- 0 until 8) {
dpu.io.in.bits.a(k) := TensorCoreBlackwellFP8Packing.selectA(operandA, k, elemM, numLanes)
dpu.io.in.bits.b(k) := TensorCoreBlackwellFP8Packing.selectB(operandB, k, elemN)
}
dpu.io.in.bits.c := cWords(elemReg)
dpu.io.stall := false.B
val dpuValid = dpu.io.out.valid
when(io.writeback.fire) {
wbValid := false.B
}
when(io.initiate.fire) {
opReg := io.initiate.bits.op
widReg := io.initiate.bits.wid
rdReg := io.initiate.bits.rd
addrAReg := io.initiate.bits.addressA
addrBReg := io.initiate.bits.addressB
addrCReg := io.initiate.bits.addressC
setReg := 0.U
aIndexReg := 0.U
bGroupReg := 0.U
bIndexReg := 0.U
mGroupReg := 0.U
substepReg := 0.U
elemReg := 0.U
switch(io.initiate.bits.op) {
is(Ops.bwgmma) { state := State.bwLoadAReq }
is(Ops.tcgen05Cp) { state := State.cpRead }
is(Ops.tcgen05Ld) { state := State.ldReq }
is(Ops.tcgen05St) { state := State.stReq }
is(Ops.bwgmmaWait) { state := State.idle }
is(Ops.tcgen05CpWait) { state := State.idle }
is(Ops.tcgen05Cb) { state := State.cbRead }
}
}
when(state === State.bwLoadAReq) {
io.tmemC.aRen := true.B
io.tmemC.aRaddr := tmemABase + aFragIndex
when(io.tmemC.aRready) {
state := State.bwLoadAResp
}
}
when(state === State.bwLoadAResp) {
aBuf(aIndexReg) := io.tmemC.aRdata
when(aIndexReg === (numAFragsPerSet - 1).U) {
bGroupReg := 0.U
bIndexReg := 0.U
state := State.bwLoadBReq
}.otherwise {
aIndexReg := aIndexReg + 1.U
state := State.bwLoadAReq
}
}
when(state === State.bwLoadBReq) {
reqB.valid := true.B
reqB.bits.rw := false.B
reqB.bits.byteen := Fill(maskWidth, 1.U(1.W))
reqB.bits.address := bReqAddress
reqB.bits.source := sourceCounter
when(reqB.fire) {
bumpSource()
state := State.bwLoadBResp
}
}
when(state === State.bwLoadBResp) {
io.respB.ready := true.B
when(io.respB.fire) {
bBuf(bIndexReg) := io.respB.bits.data
when(bIndexReg === (numBFragsPerGroup - 1).U) {
mGroupReg := 0.U
substepReg := 0.U
state := State.bwReadCReq
}.otherwise {
bIndexReg := bIndexReg + 1.U
state := State.bwLoadBReq
}
}
}
when(state === State.bwReadCReq) {
io.tmemC.cRen := true.B
io.tmemC.cRaddr := tmemCBase + cFragIndex
when(io.tmemC.cRready) {
state := State.bwReadCResp
}
}
when(state === State.bwReadCResp) {
cDataReg := io.tmemC.cRdata
elemReg := 0.U
state := State.bwCompute
}
when(state === State.bwCompute) {
dpuInValid := true.B
state := State.bwDpuResp
}
when(state === State.bwDpuResp) {
when(dpuValid) {
mmaDataReg(elemReg) := dpu.io.out.bits.data
when(elemReg === (numLanes - 1).U) {
state := State.bwWriteCReq
}.otherwise {
elemReg := elemReg + 1.U
state := State.bwCompute
}
}
}
when(state === State.bwWriteCReq) {
io.tmemC.cWen := true.B
io.tmemC.cWaddr := tmemCBase + cFragIndex
io.tmemC.cWdata := mmaDataReg.asUInt
io.tmemC.cMask := Fill(maskWidth, 1.U(1.W))
when(io.tmemC.cWready) {
when(substepReg === 0.U) {
substepReg := 1.U
state := State.bwReadCReq
}.elsewhen(mGroupReg =/= (numMGroups - 1).U) {
substepReg := 0.U
mGroupReg := mGroupReg + 1.U
state := State.bwReadCReq
}.elsewhen(bGroupReg =/= (numBGroups - 1).U) {
substepReg := 0.U
mGroupReg := 0.U
bGroupReg := bGroupReg + 1.U
bIndexReg := 0.U
state := State.bwLoadBReq
}.elsewhen(setReg =/= (numSets - 1).U) {
substepReg := 0.U
mGroupReg := 0.U
bGroupReg := 0.U
bIndexReg := 0.U
setReg := setReg + 1.U
aIndexReg := 0.U
state := State.bwLoadAReq
}.otherwise {
waitCounter := 7.U
state := State.bwWriteCWait
}
}
}
when(state === State.bwWriteCWait) {
when(waitCounter === 0.U) {
state := State.bwDone
}.otherwise {
waitCounter := waitCounter - 1.U
}
}
when(state === State.bwDone) {
wbData := mmaDataReg
wbValid := true.B
state := State.idle
}
when(state === State.cpRead) {
reqA.valid := true.B
reqA.bits.rw := false.B
reqA.bits.byteen := Fill(maskWidth, 1.U(1.W))
reqA.bits.address := addrBReg
reqA.bits.source := sourceCounter
when(reqA.fire) {
bumpSource()
state := State.cpWrite
}
}
when(state === State.cpWrite) {
io.respA.ready := io.tmemC.cWready
io.tmemC.cWen := io.respA.valid
io.tmemC.cWaddr := (addrAReg >> fragOffsetBits.U).asUInt
io.tmemC.cWdata := io.respA.bits.data
io.tmemC.cMask := Fill(maskWidth, 1.U(1.W))
when(io.respA.fire) {
state := State.idle
}
}
when(state === State.ldReq) {
io.tmemC.cRen := true.B
io.tmemC.cRaddr := (addrAReg >> fragOffsetBits.U).asUInt
when(io.tmemC.cRready) {
state := State.waitWb
}
}
when(state === State.waitWb && opReg === Ops.tcgen05Ld) {
wbData := io.tmemC.cRdata.asTypeOf(Vec(numLanes, UInt(laneWidth.W)))
wbValid := true.B
state := State.idle
}
when(state === State.stReq) {
io.reqC.valid := true.B
state := State.stWrite
}
when(state === State.stWrite) {
io.tmemC.cWen := true.B
io.tmemC.cWaddr := (addrAReg >> fragOffsetBits.U).asUInt
io.tmemC.cWdata := io.respC
io.tmemC.cMask := Fill(maskWidth, 1.U(1.W))
when(io.tmemC.cWready) {
state := State.idle
}
}
when(state === State.cbRead) {
io.tmemC.cRen := true.B
io.tmemC.cRaddr := (addrAReg >> fragOffsetBits.U).asUInt
when(io.tmemC.cRready) {
state := State.cbCapture
}
}
when(state === State.cbCapture) {
cDataReg := io.tmemC.cRdata
state := State.cbWrite
}
when(state === State.cbWrite) {
reqA.valid := true.B
reqA.bits.rw := true.B
reqA.bits.byteen := Fill(maskWidth, 1.U(1.W))
reqA.bits.address := addrBReg
reqA.bits.source := sourceCounter
reqA.bits.data := cDataReg
when(reqA.fire) {
bumpSource()
state := State.waitWb
}
}
when(state === State.waitWb && opReg === Ops.tcgen05Cb) {
io.respA.ready := true.B
when(io.respA.fire) {
state := State.idle
}
}
}

View File

@@ -7,17 +7,116 @@ import chisel3._
import chisel3.util._ import chisel3.util._
import freechips.rocketchip.tile import freechips.rocketchip.tile
object FP8E4M3 {
private val Bias = 7
private def decodeToFloat(bits: Int): Float = {
val sign = (bits >> 7) & 0x1
val exp = (bits >> 3) & 0xf
val frac = bits & 0x7
val magnitude =
if (exp == 0) {
if (frac == 0) 0.0
else (frac.toDouble / 8.0) * Math.pow(2.0, 1 - Bias)
} else {
(1.0 + frac.toDouble / 8.0) * Math.pow(2.0, exp - Bias)
}
val value = if (sign == 1) -magnitude else magnitude
value.toFloat
}
private def fp32Bits(bits: Int): BigInt = {
BigInt(java.lang.Float.floatToRawIntBits(decodeToFloat(bits)).toLong & 0xffffffffL)
}
def toFloat32(x: UInt): UInt = {
MuxLookup(x, 0.U(32.W))((0 until 256).map { bits =>
bits.U(8.W) -> fp32Bits(bits).U(32.W)
})
}
}
object FP8E4M3MulToFloat32 {
private val Bias = 7
def apply(a: UInt, b: UInt): UInt = {
val sign = a(7) ^ b(7)
val expA = a(6, 3)
val expB = b(6, 3)
val fracA = a(2, 0)
val fracB = b(2, 0)
val zeroA = expA === 0.U && fracA === 0.U
val zeroB = expB === 0.U && fracB === 0.U
val isZero = zeroA || zeroB
val sigA = Mux(expA === 0.U, Cat(0.U(1.W), fracA), Cat(1.U(1.W), fracA))
val sigB = Mux(expB === 0.U, Cat(0.U(1.W), fracB), Cat(1.U(1.W), fracB))
val prodSig = sigA * sigB
val scaleA = Wire(SInt(6.W))
val scaleB = Wire(SInt(6.W))
scaleA := Mux(expA === 0.U, -9.S(6.W), expA.zext - (Bias + 3).S(6.W))
scaleB := Mux(expB === 0.U, -9.S(6.W), expB.zext - (Bias + 3).S(6.W))
val msb = Wire(UInt(3.W))
when(prodSig(7)) {
msb := 7.U
}.elsewhen(prodSig(6)) {
msb := 6.U
}.elsewhen(prodSig(5)) {
msb := 5.U
}.elsewhen(prodSig(4)) {
msb := 4.U
}.elsewhen(prodSig(3)) {
msb := 3.U
}.elsewhen(prodSig(2)) {
msb := 2.U
}.elsewhen(prodSig(1)) {
msb := 1.U
}.otherwise {
msb := 0.U
}
val normalized = (prodSig << (7.U - msb))(7, 0)
val exponent = (scaleA + scaleB + msb.zext + 127.S(10.W)).asUInt(7, 0)
val fraction = Cat(normalized(6, 0), 0.U(16.W))
Mux(isZero, Cat(sign, 0.U(31.W)), Cat(sign, exponent, fraction))
}
}
object TensorInputType extends Enumeration {
val FP16, FP32, FP8E4M3 = Value
def fromHalf(half: Boolean): Value = {
if (half) FP16 else FP32
}
}
// Implements the four-element dot product (FEDP) unit in Volta Tensor Cores. // Implements the four-element dot product (FEDP) unit in Volta Tensor Cores.
// `half`: if True, generate fp16 MACs; if False fp32. // `half`: if True, generate fp16 MACs; if False fp32.
class TensorDotProductUnit( class TensorDotProductUnit(
val dim: Int = 4, val dim: Int,
val half: Boolean val half: Boolean,
val inputType: TensorInputType.Value
) extends Module with tile.HasFPUParameters { ) extends Module with tile.HasFPUParameters {
val tIn = if (half) tile.FType.H else tile.FType.S def this(dim: Int = 4, half: Boolean) = {
this(dim, half, TensorInputType.fromHalf(half))
}
val tIn = inputType match {
case TensorInputType.FP16 => tile.FType.H
case TensorInputType.FP32 => tile.FType.S
case TensorInputType.FP8E4M3 => tile.FType.S
}
// output datatype fixed to single-precision // output datatype fixed to single-precision
val tOut = tile.FType.S val tOut = tile.FType.S
val inFLen = tIn.ieeeWidth val inFLen = inputType match {
case TensorInputType.FP8E4M3 => 8
case _ => tIn.ieeeWidth
}
val outFLen = tOut.ieeeWidth val outFLen = tOut.ieeeWidth
val fLen = outFLen // needed for HasFPUParameters val fLen = outFLen // needed for HasFPUParameters
val minFLen = 16 // fp16 val minFLen = 16 // fp16
@@ -40,9 +139,27 @@ class TensorDotProductUnit(
// [IEEE] -> recode() -> unbox() -> [Hardfloat] -> box() -> ieee() -> [IEEE] // [IEEE] -> recode() -> unbox() -> [Hardfloat] -> box() -> ieee() -> [IEEE]
// make sure recoding/uncoding happens only at the edge, not at every // make sure recoding/uncoding happens only at the edge, not at every
// pipeline stage inside the dpu // pipeline stage inside the dpu
val tag = if (half) H else S val tag = inputType match {
val in1 = io.in.bits.a.map(x => unbox(recode(x, tag), tag, Some(tIn))) case TensorInputType.FP16 => H
val in2 = io.in.bits.b.map(x => unbox(recode(x, tag), tag, Some(tIn))) case TensorInputType.FP32 => S
case TensorInputType.FP8E4M3 => S
}
if (inputType == TensorInputType.FP8E4M3) {
val dpu = Module(new DotProductPipeFP8E4M3(dim))
dpu.io.in.valid := io.in.valid
dpu.io.in.bits.a := io.in.bits.a
dpu.io.in.bits.b := io.in.bits.b
dpu.io.in.bits.c := io.in.bits.c
dpu.io.stall := io.stall
io.out.valid := dpu.io.out.valid
io.out.bits.data := dpu.io.out.bits.data
} else {
def recodeInput(x: Bits): UInt = {
unbox(recode(x.asUInt, tag), tag, Some(tIn))
}
val in1 = io.in.bits.a.map(recodeInput)
val in2 = io.in.bits.b.map(recodeInput)
val in3 = unbox(recode(io.in.bits.c, S), S, Some(tOut)) val in3 = unbox(recode(io.in.bits.c, S), S, Some(tOut))
val dpu = Module(new DotProductPipe(dim, tIn, tOut)) val dpu = Module(new DotProductPipe(dim, tIn, tOut))
@@ -54,6 +171,7 @@ class TensorDotProductUnit(
io.out.valid := dpu.io.out.valid io.out.valid := dpu.io.out.valid
io.out.bits.data := ieee(box(dpu.io.out.bits.data, S)) io.out.bits.data := ieee(box(dpu.io.out.bits.data, S))
}
} }
// An implementation of chisel3.util.Pipe that supports stalls. // An implementation of chisel3.util.Pipe that supports stalls.
@@ -201,8 +319,10 @@ class DotProductPipe(dim: Int, inputType: tile.FType, outputType: tile.FType) ex
// pipeline and connect outputs to the next stage // pipeline and connect outputs to the next stage
outputs := StallingPipe(io.stall, inputs.valid, VecInit(addOuts)) outputs := StallingPipe(io.stall, inputs.valid, VecInit(addOuts))
outC := StallingPipe(io.stall, inputs.valid, inC.bits) outC := StallingPipe(io.stall, inputs.valid, inC.bits)
assert(inputs.valid === inC.valid, when (inputs.valid =/= inC.valid) {
"adder inputs valid and C pipe valid went out-of-sync") printf("WARN: DotProductPipe input/C valid mismatch: inputs=%d c=%d\n",
inputs.valid, inC.valid)
}
(outputs, outC) (outputs, outC)
} }
@@ -234,6 +354,89 @@ class DotProductPipe(dim: Int, inputType: tile.FType, outputType: tile.FType) ex
io.out.bits.data := accStageOut.bits io.out.bits.data := accStageOut.bits
} }
class DotProductPipeFP8E4M3(dim: Int) extends Module with tile.HasFPUParameters {
val tOut = tile.FType.S
val outExpWidth = tOut.exp
val outSigWidth = tOut.sig
val recOutFLen = outExpWidth + outSigWidth + 1
val fLen = tOut.ieeeWidth
val minFLen = 16
def xLen = 32
val io = IO(new Bundle {
val in = Flipped(Valid(new Bundle {
val a = Vec(dim, Bits(8.W))
val b = Vec(dim, Bits(8.W))
val c = Bits(32.W)
}))
val stall = Input(Bool())
val out = Valid(new Bundle {
val data = Bits(32.W)
})
})
val productRecoded = io.in.bits.a.zip(io.in.bits.b).map { case (a, b) =>
unbox(recode(FP8E4M3MulToFloat32(a.asUInt, b.asUInt), S), S, Some(tOut))
}
val inC = unbox(recode(io.in.bits.c, S), S, Some(tOut))
val productStageOut = StallingPipe(io.stall, io.in.valid, VecInit(productRecoded))
val productStageC = StallingPipe(io.stall, io.in.valid, inC)
val log2Dim = log2Ceil(dim)
require(dim == (1 << log2Dim), s"dim (${dim}) is not power of two!")
val interim = (log2Dim to 0 by -1).map { i =>
Wire(Valid(Vec(1 << i, Bits(recOutFLen.W))))
}
val interimC = (log2Dim to 0 by -1).map(_ => Wire(Valid(Bits(recOutFLen.W))))
interim(0) := productStageOut
interimC(0) := productStageC
val (addStageOut, addStageC) = (interim zip interimC).reduce {
(inputsAndC, outputsAndC) => {
val (inputs, inC) = inputsAndC
val (outputs, outC) = outputsAndC
require(inputs.bits.length == 2 * outputs.bits.length)
val thisDim = inputs.bits.length
val adders = Seq.fill(thisDim / 2)(
Module(new hardfloat.AddRecFN(outExpWidth, outSigWidth))
)
val addOuts = adders.zipWithIndex.map { case (a, i) =>
a.io.subOp := 0.U
a.io.a := inputs.bits(2 * i + 0)
a.io.b := inputs.bits(2 * i + 1)
a.io.roundingMode := hardfloat.consts.round_near_even
a.io.detectTininess := hardfloat.consts.tininess_afterRounding
a.io.out
}
outputs := StallingPipe(io.stall, inputs.valid, VecInit(addOuts))
outC := StallingPipe(io.stall, inputs.valid, inC.bits)
when(inputs.valid =/= inC.valid) {
printf("WARN: DotProductPipeFP8E4M3 input/C valid mismatch: inputs=%d c=%d\n",
inputs.valid, inC.valid)
}
(outputs, outC)
}
}
require(addStageOut.bits.length == 1)
val acc = Module(new hardfloat.AddRecFN(outExpWidth, outSigWidth))
acc.io.subOp := 0.U
acc.io.a := addStageOut.bits(0)
acc.io.b := addStageC.bits
acc.io.roundingMode := hardfloat.consts.round_near_even
acc.io.detectTininess := hardfloat.consts.tininess_afterRounding
val accStageOut = StallingPipe(io.stall, addStageOut.valid, acc.io.out)
io.out.valid := accStageOut.valid
io.out.bits.data := ieee(box(accStageOut.bits, S))
}
class MulAddRecFNPipe(latency: Int, expWidth: Int, sigWidth: Int) extends Module { class MulAddRecFNPipe(latency: Int, expWidth: Int, sigWidth: Int) extends Module {
require(latency <= 2) require(latency <= 2)

View File

@@ -50,6 +50,9 @@ class WithRadianceCores(
crossing: RocketCrossingParams, crossing: RocketCrossingParams,
tensorCoreFP16: Boolean, tensorCoreFP16: Boolean,
tensorCoreDecoupled: Boolean, tensorCoreDecoupled: Boolean,
tensorCoreBlackwell: Boolean,
numTensorWarps: Int,
startupAddress: BigInt,
useVxCache: Boolean useVxCache: Boolean
) extends Config((site, _, up) => { ) extends Config((site, _, up) => {
case TilesLocated(`location`) => { case TilesLocated(`location`) => {
@@ -59,7 +62,10 @@ class WithRadianceCores(
val vortex = RadianceTileParams( val vortex = RadianceTileParams(
core = VortexCoreParams( core = VortexCoreParams(
tensorCoreFP16 = tensorCoreFP16, tensorCoreFP16 = tensorCoreFP16,
tensorCoreDecoupled = tensorCoreDecoupled tensorCoreDecoupled = tensorCoreDecoupled,
tensorCoreBlackwell = tensorCoreBlackwell,
numTensorWarps = numTensorWarps,
startupAddress = startupAddress
), ),
btb = None, btb = None,
useVxCache = useVxCache, useVxCache = useVxCache,
@@ -96,6 +102,9 @@ class WithRadianceCores(
// constructor override that omits `crossing` // constructor override that omits `crossing`
def this(n: Int, location: HierarchicalLocation = InSubsystem, def this(n: Int, location: HierarchicalLocation = InSubsystem,
tensorCoreFP16: Boolean = false, tensorCoreDecoupled: Boolean = false, tensorCoreFP16: Boolean = false, tensorCoreDecoupled: Boolean = false,
tensorCoreBlackwell: Boolean = false,
numTensorWarps: Int = 4,
startupAddress: BigInt = BigInt("10100", 16),
useVxCache: Boolean = false) useVxCache: Boolean = false)
= this(n, location, RocketCrossingParams( = this(n, location, RocketCrossingParams(
master = HierarchicalElementMasterPortParams.locationDefault(location), master = HierarchicalElementMasterPortParams.locationDefault(location),
@@ -104,9 +113,23 @@ class WithRadianceCores(
case InSubsystem => CBUS case InSubsystem => CBUS
case InCluster(clusterId) => CCBUS(clusterId) case InCluster(clusterId) => CCBUS(clusterId)
} }
), tensorCoreFP16, tensorCoreDecoupled, useVxCache) ), tensorCoreFP16, tensorCoreDecoupled, tensorCoreBlackwell, numTensorWarps, startupAddress, useVxCache)
} }
class WithBlackwellTensorCore(location: HierarchicalLocation = InSubsystem) extends Config((site, _, up) => {
case TilesLocated(`location`) =>
up(TilesLocated(`location`)).map {
case r: RadianceTileAttachParams =>
r.copy(tileParams = r.tileParams.copy(
core = r.tileParams.core.copy(
tensorCoreBlackwell = true,
tensorCoreDecoupled = false
)
))
case other => other
}
})
class WithEmulatorCores( class WithEmulatorCores(
n: Int, n: Int,
useVxCache: Boolean useVxCache: Boolean

View File

@@ -216,7 +216,6 @@ class GemminiTileModuleImp(outer: GemminiTile) extends BaseTileModuleImp(outer)
val squareBoundsInst = ciscInstT.Lit(_.inst -> 0x1220b07b.U, _.rs1 -> 0.U, val squareBoundsInst = ciscInstT.Lit(_.inst -> 0x1220b07b.U, _.rs1 -> 0.U,
_.rs2 -> (tileSizeM | (tileSizeM << 16) | (BigInt(tileSizeM) << 32)).U) _.rs2 -> (tileSizeM | (tileSizeM << 16) | (BigInt(tileSizeM) << 32)).U)
val boundsInst = Mux(ciscId(7), squareBoundsInst, rectBoundsInst) val boundsInst = Mux(ciscId(7), squareBoundsInst, rectBoundsInst)
val nopInst = ciscInstT.Lit(_.inst -> 0.U, _.rs1 -> 0.U, _.rs2 -> 0.U)
def genStrideInst(tileA: UInt, tileB: UInt) = { def genStrideInst(tileA: UInt, tileB: UInt) = {
val inst = Wire(ciscInstT) val inst = Wire(ciscInstT)
@@ -250,9 +249,7 @@ class GemminiTileModuleImp(outer: GemminiTile) extends BaseTileModuleImp(outer)
val accSkipInst = genAccSkipInst(0.U, ((ciscArgs(23, 16) * spadHexadecile.U) << 32).asUInt | 0x238.U) val accSkipInst = genAccSkipInst(0.U, ((ciscArgs(23, 16) * spadHexadecile.U) << 32).asUInt | 0x238.U)
ciscInst := microcodeEntry(Seq(boundsInst, strideInst, accSkipInst)) ciscInst := microcodeEntry(Seq(boundsInst, strideInst, accSkipInst))
} }
is (2.U) { is (2.U) {} // no actual invocation, fake job placeholder
ciscInst := microcodeEntry(Seq(nopInst))
} // no actual invocation, fake job placeholder
is (8.U) { // set a, b stride is (8.U) { // set a, b stride
val inst = Wire(ciscInstT) val inst = Wire(ciscInstT)
inst.inst := 0x1820b07b.U inst.inst := 0x1820b07b.U
@@ -340,7 +337,7 @@ class GemminiTileModuleImp(outer: GemminiTile) extends BaseTileModuleImp(outer)
gemminiIO.bits.inst := Mux(ciscValid, ciscInst.inst.asTypeOf(gemminiIO.bits.inst), regCommand) gemminiIO.bits.inst := Mux(ciscValid, ciscInst.inst.asTypeOf(gemminiIO.bits.inst), regCommand)
gemminiIO.bits.rs1 := Mux(ciscValid, ciscInst.rs1, Cat(gemminiRs1RegMSB, gemminiRs1RegLSB)) gemminiIO.bits.rs1 := Mux(ciscValid, ciscInst.rs1, Cat(gemminiRs1RegMSB, gemminiRs1RegLSB))
gemminiIO.bits.rs2 := Mux(ciscValid, ciscInst.rs2, Cat(gemminiRs2RegMSB, gemminiRs2RegLSB)) gemminiIO.bits.rs2 := Mux(ciscValid, ciscInst.rs2, Cat(gemminiRs2RegMSB, gemminiRs2RegLSB))
gemminiIO.valid := (ciscValid && (ciscInst.inst =/= 0.U)) || regValid gemminiIO.valid := ciscValid || regValid
assert(gemminiIO.ready || !gemminiIO.valid) assert(gemminiIO.ready || !gemminiIO.valid)
accSlave.status := RegNext(outer.gemmini.module.io.busy).asUInt accSlave.status := RegNext(outer.gemmini.module.io.busy).asUInt

View File

@@ -21,7 +21,7 @@ import midas.targetutils.SynthesizePrintf
import org.chipsalliance.cde.config._ import org.chipsalliance.cde.config._
import radiance.core._ import radiance.core._
import radiance.memory._ import radiance.memory._
import radiance.subsystem.{GPUMemParams, GPUMemory, RadianceSimArgs} import radiance.subsystem.{GPUMemParams, GPUMemory, RadianceSharedMemKey, RadianceSimArgs}
/** For determining radiance core id. This may be different from /** For determining radiance core id. This may be different from
* RadianceTileParams.tileId, when a cluster contains non-core tiles */ * RadianceTileParams.tileId, when a cluster contains non-core tiles */
@@ -101,6 +101,9 @@ case class VortexCoreParams(
fpu: Option[FPUParams] = None, fpu: Option[FPUParams] = None,
tensorCoreFP16: Boolean = false, // FP16 if true, FP32 if false tensorCoreFP16: Boolean = false, // FP16 if true, FP32 if false
tensorCoreDecoupled: Boolean = false, // hopper-style SMEM operand decoupling tensorCoreDecoupled: Boolean = false, // hopper-style SMEM operand decoupling
tensorCoreBlackwell: Boolean = false, // blackwell-style TMEM + SMEM tensor core
numTensorWarps: Int = 4,
startupAddress: BigInt = BigInt("10100", 16), // initial warp PC programmed through startup DCRs
debugROB: Boolean = false, // if enabled, uses a C++ debug ROB to generate trace-with-wdata debugROB: Boolean = false, // if enabled, uses a C++ debug ROB to generate trace-with-wdata
haveCease: Boolean = true, // non-standard CEASE instruction haveCease: Boolean = true, // non-standard CEASE instruction
haveSimTimeout: Boolean = true // add plusarg for simulation timeout haveSimTimeout: Boolean = true // add plusarg for simulation timeout
@@ -152,6 +155,10 @@ class RadianceTile private (
p(SIMTCoreKey).isDefined, p(SIMTCoreKey).isDefined,
"SIMTCoreKey not defined; make sure to use WithSimtConfig when using RadianceTile" "SIMTCoreKey not defined; make sure to use WithSimtConfig when using RadianceTile"
) )
require(
!(radianceParams.core.tensorCoreDecoupled && radianceParams.core.tensorCoreBlackwell),
"tensorCoreDecoupled and tensorCoreBlackwell are mutually exclusive"
)
// NOTE: when changing these, remember to change +define+NUM_CORES/THREADS/WARPS in // NOTE: when changing these, remember to change +define+NUM_CORES/THREADS/WARPS in
// radiance.mk as well! // radiance.mk as well!
@@ -204,7 +211,9 @@ class RadianceTile private (
case Some(false) => 1 case Some(false) => 1
case None => 1 case None => 1
} }
val imemTagWidth = UUID_WIDTH + NW_WIDTH // Must match VX_gpu_pkg.sv: ICACHE_TAG_WIDTH = domain + UUID + wid.
val imemDomainWidth = 1
val imemTagWidth = imemDomainWidth + UUID_WIDTH + NW_WIDTH
require(numWarps >= numLsuLanes, require(numWarps >= numLsuLanes,
s"Vortex core requires numWarps (${numWarps}) >= numLsuLanes (${numLsuLanes})") s"Vortex core requires numWarps (${numWarps}) >= numLsuLanes (${numLsuLanes})")
@@ -279,21 +288,60 @@ class RadianceTile private (
) )
} }
val tcSmemSize = 32 val tcSmemSize = numLsuLanes * 4
val tcSmemNodes = Seq.tabulate(if (radianceParams.core.tensorCoreDecoupled) 2 else 0) { i => val tcSmemLineSize = p(RadianceSharedMemKey)
.map(k => k.numWords * k.wordSize)
.getOrElse(tcSmemSize)
val tcSmemClientMaxSize =
if (radianceParams.core.tensorCoreBlackwell) math.max(tcSmemSize, tcSmemLineSize) else tcSmemSize
val numTensorWarps = radianceParams.core.numTensorWarps
val numScalarWarps = numWarps - numTensorWarps
require(numTensorWarps > 0 && numTensorWarps < numWarps,
s"Wu requires 0 < numTensorWarps (${numTensorWarps}) < numWarps (${numWarps})")
val numTensorCores = if (radianceParams.core.tensorCoreBlackwell) numTensorWarps else 1
if (radianceParams.core.tensorCoreBlackwell) {
require(numCoreLanes == numLsuLanes,
s"Wu Blackwell binding requires matching core lanes (${numCoreLanes}) and memory lanes (${numLsuLanes})")
require(numLsuLanes == 4 || numLsuLanes == 8,
s"Wu Blackwell binding supports 4 or 8 lanes, got ${numLsuLanes}")
require(numTensorCores == numTensorWarps, "Wu Blackwell binding requires one Tensor Core per Tensor warp")
require(isPow2(tcSmemLineSize) && tcSmemLineSize >= tcSmemSize && (tcSmemLineSize % tcSmemSize) == 0,
s"Wu Blackwell SMEM line size (${tcSmemLineSize}) must be a power-of-two multiple of TC fragment size (${tcSmemSize})")
}
val tensorUsesAsyncMem = radianceParams.core.tensorCoreDecoupled || radianceParams.core.tensorCoreBlackwell
val tcSmemNodeCount = if (radianceParams.core.tensorCoreDecoupled) 2 else if (radianceParams.core.tensorCoreBlackwell) numTensorCores else 0
val tcSmemNodes = Seq.tabulate(tcSmemNodeCount) { i =>
TLClientNode(Seq(TLMasterPortParameters.v2( TLClientNode(Seq(TLMasterPortParameters.v2(
masters = Seq(TLMasterParameters.v2( masters = Seq(TLMasterParameters.v2(
name = s"rad_tc_${radianceParams.coreId}_$i", name = s"rad_tc_${radianceParams.coreId}_$i",
sourceId = IdRange(0, 1 << smemSourceWidth), sourceId = IdRange(0, 1 << smemSourceWidth),
supports = TLSlaveToMasterTransferSizes( supports = TLSlaveToMasterTransferSizes(
probe = TransferSizes(1, tcSmemSize), probe = TransferSizes(1, tcSmemClientMaxSize),
get = TransferSizes(1, tcSmemSize), get = TransferSizes(1, tcSmemClientMaxSize),
putFull = TransferSizes(1, tcSmemClientMaxSize),
), ),
requestFifo = true requestFifo = true
)) ))
))) )))
} }
// For Blackwell, tcSmemNodes accesses SMEM (bwgmma B operand)
// tcGmemNodes provide global memory access for cp (global→tmem) and cb (tmem→global)
val tcGmemNodes = if (radianceParams.core.tensorCoreBlackwell) {
Seq.tabulate(numTensorCores) { i =>
TLClientNode(Seq(TLMasterPortParameters.v2(masters = Seq(TLMasterParameters.v2(
name = s"rad_tc_gmem_${radianceParams.coreId}_$i",
sourceId = IdRange(0, 1 << dmemSourceWidth),
supports = TLSlaveToMasterTransferSizes(
probe = TransferSizes(1, tcSmemSize),
get = TransferSizes(1, tcSmemSize),
putFull = TransferSizes(1, tcSmemSize),
),
requestFifo = true
)))))
}
} else Seq.empty
// combine outgoing per-lane dmemNode into 1 idenity node // combine outgoing per-lane dmemNode into 1 idenity node
// //
// NOTE: We need TLWidthWidget here because there might be a data width // NOTE: We need TLWidthWidget here because there might be a data width
@@ -382,6 +430,7 @@ class RadianceTile private (
// imemNodes.foreach { tlMasterXbar.node := TLWidthWidget(4) := _ } // imemNodes.foreach { tlMasterXbar.node := TLWidthWidget(4) := _ }
tlMasterXbar.node :=* AddressOrNode(base) :=* icacheNode tlMasterXbar.node :=* AddressOrNode(base) :=* icacheNode
tlMasterXbar.node :=* AddressOrNode(base) :=* dcacheNode tlMasterXbar.node :=* AddressOrNode(base) :=* dcacheNode
tcGmemNodes.foreach { n => tlMasterXbar.node := AddressOrNode(base) := n }
} }
/* below are copied from rocket */ /* below are copied from rocket */
@@ -743,12 +792,18 @@ class RadianceTileModuleImp(outer: RadianceTile)
val tcb0 = new { val tcb0 = new {
val addr = core.io.tc_a_bits_address(31, 0) val addr = core.io.tc_a_bits_address(31, 0)
val tag = core.io.tc_a_bits_tag(outer.tensorTagWidth - 1, 0) val tag = core.io.tc_a_bits_tag(outer.tensorTagWidth - 1, 0)
val write = core.io.tc_a_bits_write(0)
val mask = core.io.tc_a_bits_mask(31, 0)
val data = core.io.tc_a_bits_data(255, 0)
val aValid = core.io.tc_a_valid(0) val aValid = core.io.tc_a_valid(0)
val dReady = core.io.tc_d_ready(0) val dReady = core.io.tc_d_ready(0)
} }
val tcb1 = new { val tcb1 = new {
val addr = core.io.tc_a_bits_address(63, 32) val addr = core.io.tc_a_bits_address(63, 32)
val tag = core.io.tc_a_bits_tag(4 + outer.tensorTagWidth - 1, 4) val tag = core.io.tc_a_bits_tag(4 + outer.tensorTagWidth - 1, 4)
val write = core.io.tc_a_bits_write(1)
val mask = core.io.tc_a_bits_mask(63, 32)
val data = core.io.tc_a_bits_data(511, 256)
val aValid = core.io.tc_a_valid(1) val aValid = core.io.tc_a_valid(1)
val dReady = core.io.tc_d_ready(1) val dReady = core.io.tc_d_ready(1)
} }
@@ -770,26 +825,320 @@ class RadianceTileModuleImp(outer: RadianceTile)
adapter.io.inReq.bits.address := bundle.addr adapter.io.inReq.bits.address := bundle.addr
adapter.io.inReq.bits.source := bundle.tag adapter.io.inReq.bits.source := bundle.tag
adapter.io.inReq.bits.size := 5.U // 256 bits adapter.io.inReq.bits.size := 5.U // 256 bits
adapter.io.inReq.bits.opcode := TLMessages.Get adapter.io.inReq.bits.opcode := Mux(bundle.write.asBool, TLMessages.PutFullData, TLMessages.Get)
adapter.io.inReq.bits.mask := x"ffffffff".U adapter.io.inReq.bits.mask := bundle.mask
adapter.io.inReq.bits.data := bundle.data
adapter.io.inResp.ready := bundle.dReady adapter.io.inResp.ready := bundle.dReady
client._1.a <> adapter.io.outReq client._1.a <> adapter.io.outReq
adapter.io.outResp <> client._1.d adapter.io.outResp <> client._1.d
adapter adapter
} }
core.io.tc_a_ready := Cat(adapters.last.io.inReq.ready, adapters.head.io.inReq.ready) core.io.tc_a_ready := Cat(0.U(1.W), adapters.last.io.inReq.ready, adapters.head.io.inReq.ready)
core.io.tc_d_valid := Cat(adapters.last.io.inResp.valid, adapters.head.io.inResp.valid) core.io.tc_d_valid := Cat(0.U(1.W), adapters.last.io.inResp.valid, adapters.head.io.inResp.valid)
core.io.tc_d_bits_data := Cat(adapters.last.io.inResp.bits.data, adapters.head.io.inResp.bits.data) core.io.tc_d_bits_data := Cat(0.U((32 * 8).W), adapters.last.io.inResp.bits.data, adapters.head.io.inResp.bits.data)
core.io.tc_d_bits_tag := Cat(adapters.last.io.inResp.bits.source, adapters.head.io.inResp.bits.source) core.io.tc_d_bits_tag := Cat(0.U(outer.tensorTagWidth.W), adapters.last.io.inResp.bits.source, adapters.head.io.inResp.bits.source)
require(core.io.tc_d_bits_data.widthOption.get == adapters.head.io.inResp.bits.data.widthOption.get * 2) require(core.io.tc_d_bits_data.widthOption.get == adapters.head.io.inResp.bits.data.widthOption.get * 3)
require(core.io.tc_d_bits_tag.widthOption.get == adapters.head.io.inResp.bits.source.widthOption.get * 2) require(core.io.tc_d_bits_tag.widthOption.get == adapters.head.io.inResp.bits.source.widthOption.get * 3)
} else { } else {
core.io.tc_a_ready := false.B core.io.tc_a_ready := false.B
core.io.tc_d_valid := false.B core.io.tc_d_valid := false.B
core.io.tc_d_bits_data := DontCare core.io.tc_d_bits_data := DontCare
core.io.tc_d_bits_tag := DontCare core.io.tc_d_bits_tag := DontCare
} }
core.io.tc_tmem_A_rready := DontCare
core.io.tc_tmem_A_rdata := DontCare
core.io.tc_tmem_C_rready := DontCare
core.io.tc_tmem_C_rdata := DontCare
core.io.tc_tmem_C_wready := DontCare
core.io.sc_tmem_rready := DontCare
core.io.sc_tmem_rdata := DontCare
core.io.sc_tmem_wready := DontCare
}
def connectTensorBlackwell = {
if (outer.radianceParams.core.tensorCoreBlackwell) {
require(outer.tcSmemNodes.nonEmpty)
require(outer.tcSmemNodes.length == outer.numTensorCores)
require(outer.tcGmemNodes.length == outer.numTensorCores)
val nTC = outer.numTensorCores
val tcPorts = 3
val tcCoreDataBits = 32 * 8
val tcDataBits = outer.tcSmemSize * 8
val tcSmemLineBits = outer.tcSmemLineSize * 8
val tmemAddrBits = 9
val tmemDataBits = tcDataBits
val tmemMaskBits = outer.tcSmemSize
val tcTlSize = log2Ceil(outer.tcSmemSize)
val tcSmemLineTlSize = log2Ceil(outer.tcSmemLineSize)
def slice(u: UInt, width: Int, idx: Int): UInt = u(width * (idx + 1) - 1, width * idx)
def port(tc: Int, p: Int): Int = tc * tcPorts + p
def padToCoreData(u: UInt): UInt = {
if (u.getWidth == tcCoreDataBits) u else Cat(0.U((tcCoreDataBits - u.getWidth).W), u)
}
val tcAReady = Wire(Vec(nTC * tcPorts, Bool()))
val tcDValid = Wire(Vec(nTC * tcPorts, Bool()))
val tcDData = Wire(Vec(nTC * tcPorts, UInt(tcCoreDataBits.W)))
val tcDTag = Wire(Vec(nTC * tcPorts, UInt(outer.tensorTagWidth.W)))
tcAReady.foreach(_ := false.B)
tcDValid.foreach(_ := false.B)
tcDData.foreach(_ := 0.U)
tcDTag.foreach(_ := 0.U)
// TMEM matrix: four banked 2R1W SRAMs. Tensor A/C reads and scalar
// reads can proceed together when bank placement avoids conflicts.
// Each warp owns 2KB: A tile and C tile are 1KB each. The row count
// scales with the physical fragment width (16B for 4 lanes, 32B for 8).
val tmemBytesPerWarp = 2048
val tmemDepth = outer.numWarps * (tmemBytesPerWarp / outer.tcSmemSize)
val tmemBanks = 4
val tmemBankBits = log2Ceil(tmemBanks)
val tmemBankDepth = tmemDepth / tmemBanks
require(isPow2(tmemBanks))
require(tmemDepth % tmemBanks == 0)
val tmem = Seq.fill(tmemBanks) {
Module(new radiance.memory.TwoReadOneWriteSyncMem(
tmemBankDepth, UInt((outer.tcSmemSize * 8).W)))
}
class TmemReadReq extends Bundle {
val addr = UInt(tmemAddrBits.W)
val src = UInt(2.W)
val tc = UInt(log2Ceil(nTC max 2).W)
}
class TmemWriteReq extends Bundle {
val addr = UInt(tmemAddrBits.W)
val data = UInt(tmemDataBits.W)
val mask = UInt(tmemMaskBits.W)
val src = UInt(1.W)
val tc = UInt(log2Ceil(nTC max 2).W)
}
def bank(addr: UInt): UInt = addr(tmemBankBits - 1, 0)
def row(addr: UInt): UInt = addr(tmemAddrBits - 1, tmemBankBits)
val aReady = Wire(Vec(nTC, Bool()))
val cReady = Wire(Vec(nTC, Bool()))
val wReady = Wire(Vec(nTC, Bool()))
val scReadReady = Wire(Bool())
val scWriteReady = Wire(Bool())
aReady.foreach(_ := false.B)
cReady.foreach(_ := false.B)
wReady.foreach(_ := false.B)
scReadReady := false.B
scWriteReady := false.B
val read0Grant = Wire(Vec(tmemBanks, new TmemReadReq))
val read1Grant = Wire(Vec(tmemBanks, new TmemReadReq))
val read0Valid = Wire(Vec(tmemBanks, Bool()))
val read1Valid = Wire(Vec(tmemBanks, Bool()))
val writeGrant = Wire(Vec(tmemBanks, new TmemWriteReq))
val writeValid = Wire(Vec(tmemBanks, Bool()))
read0Grant.foreach(_ := 0.U.asTypeOf(new TmemReadReq))
read1Grant.foreach(_ := 0.U.asTypeOf(new TmemReadReq))
read0Valid.foreach(_ := false.B)
read1Valid.foreach(_ := false.B)
writeGrant.foreach(_ := 0.U.asTypeOf(new TmemWriteReq))
writeValid.foreach(_ := false.B)
(0 until tmemBanks).foreach { b =>
val requests = (0 until nTC).flatMap { tc =>
val aAddr = slice(core.io.tc_tmem_A_raddr, tmemAddrBits, tc)
val cAddr = slice(core.io.tc_tmem_C_raddr, tmemAddrBits, tc)
Seq(
(core.io.tc_tmem_A_ren(tc).asBool && bank(aAddr) === b.U, aAddr, 0.U(2.W), tc.U),
(core.io.tc_tmem_C_ren(tc).asBool && bank(cAddr) === b.U, cAddr, 1.U(2.W), tc.U)
)
} ++ Seq(
(core.io.sc_tmem_ren.asBool && bank(core.io.sc_tmem_raddr) === b.U,
core.io.sc_tmem_raddr, 2.U(2.W), 0.U)
)
var used0 = false.B
var used1 = false.B
requests.foreach { case (valid, addr, src, tc) =>
val grant0 = valid && !used0
val grant1 = valid && used0 && !used1
when(grant0) {
read0Grant(b).addr := addr
read0Grant(b).src := src
read0Grant(b).tc := tc
}
when(grant1) {
read1Grant(b).addr := addr
read1Grant(b).src := src
read1Grant(b).tc := tc
}
used0 = used0 || grant0
used1 = used1 || grant1
when(grant0 || grant1) {
when(src === 0.U) { aReady(tc) := true.B }
when(src === 1.U) { cReady(tc) := true.B }
when(src === 2.U) { scReadReady := true.B }
}
}
read0Valid(b) := used0
read1Valid(b) := used1
var writeUsed = false.B
(0 until nTC).foreach { tc =>
val addr = slice(core.io.tc_tmem_C_waddr, tmemAddrBits, tc)
val valid = core.io.tc_tmem_C_wen(tc).asBool && bank(addr) === b.U
val grant = valid && !writeUsed
when(grant) {
writeValid(b) := true.B
writeGrant(b).addr := addr
writeGrant(b).data := slice(core.io.tc_tmem_C_wdata, tmemDataBits, tc)
writeGrant(b).mask := slice(core.io.tc_tmem_C_mask, tmemMaskBits, tc)
writeGrant(b).src := 0.U
writeGrant(b).tc := tc.U
wReady(tc) := true.B
}
writeUsed = writeUsed || grant
}
val scWValid = core.io.sc_tmem_wen.asBool && bank(core.io.sc_tmem_waddr) === b.U
val scWGrant = scWValid && !writeUsed
when(scWGrant) {
writeValid(b) := true.B
writeGrant(b).addr := core.io.sc_tmem_waddr
writeGrant(b).data := core.io.sc_tmem_wdata
writeGrant(b).mask := core.io.sc_tmem_mask
writeGrant(b).src := 1.U
writeGrant(b).tc := 0.U
scWriteReady := true.B
}
tmem(b).io.ren0 := read0Valid(b)
tmem(b).io.raddr0 := row(read0Grant(b).addr)
tmem(b).io.ren1 := read1Valid(b)
tmem(b).io.raddr1 := row(read1Grant(b).addr)
tmem(b).io.wen := writeValid(b)
tmem(b).io.waddr := row(writeGrant(b).addr)
tmem(b).io.wdata := writeGrant(b).data
tmem(b).io.mask := writeGrant(b).mask
}
val read0GrantReg = RegNext(read0Grant)
val read1GrantReg = RegNext(read1Grant)
val read0ValidReg = RegNext(read0Valid)
val read1ValidReg = RegNext(read1Valid)
core.io.tc_tmem_A_rready := aReady.asUInt
core.io.tc_tmem_C_rready := cReady.asUInt
core.io.tc_tmem_C_wready := wReady.asUInt
core.io.sc_tmem_rready := scReadReady.asUInt
core.io.sc_tmem_wready := scWriteReady.asUInt
core.io.tc_tmem_A_rdata := VecInit((0 until nTC).map { tc =>
VecInit((0 until tmemBanks).map { b =>
Mux(read0ValidReg(b) && read0GrantReg(b).src === 0.U && read0GrantReg(b).tc === tc.U, tmem(b).io.rdata0,
Mux(read1ValidReg(b) && read1GrantReg(b).src === 0.U && read1GrantReg(b).tc === tc.U, tmem(b).io.rdata1, 0.U(tmemDataBits.W)))
}).reduce(_ | _)
}).asUInt
core.io.tc_tmem_C_rdata := VecInit((0 until nTC).map { tc =>
VecInit((0 until tmemBanks).map { b =>
Mux(read0ValidReg(b) && read0GrantReg(b).src === 1.U && read0GrantReg(b).tc === tc.U, tmem(b).io.rdata0,
Mux(read1ValidReg(b) && read1GrantReg(b).src === 1.U && read1GrantReg(b).tc === tc.U, tmem(b).io.rdata1, 0.U(tmemDataBits.W)))
}).reduce(_ | _)
}).asUInt
core.io.sc_tmem_rdata := VecInit((0 until tmemBanks).map { b =>
Mux(read0ValidReg(b) && read0GrantReg(b).src === 2.U, tmem(b).io.rdata0,
Mux(read1ValidReg(b) && read1GrantReg(b).src === 2.U, tmem(b).io.rdata1, 0.U(tmemDataBits.W)))
}).reduce(_ | _)
// port 2: SMEM B, one TL client per tensor core. RadianceSharedMem arbitrates them.
(0 until nTC).foreach { tc =>
val p2 = port(tc, 2)
val client = outer.tcSmemNodes(tc).out.head
val rawAddress = slice(core.io.tc_a_bits_address, 32, p2)
val lineAddress = rawAddress & (~((outer.tcSmemLineSize - 1).U(32.W))).asUInt
val adapter = Module(new VortexTLAdapter(
outer.smemSourceWidth,
new VortexBundleA(tagWidth = outer.tensorTagWidth, dataWidth = tcSmemLineBits),
new VortexBundleD(tagWidth = outer.tensorTagWidth, dataWidth = tcSmemLineBits),
client
))
adapter.io.inReq.bits <> DontCare
adapter.io.inReq.valid := core.io.tc_a_valid(p2)
adapter.io.inReq.bits.address := lineAddress
adapter.io.inReq.bits.source := slice(core.io.tc_a_bits_tag, outer.tensorTagWidth, p2)
adapter.io.inReq.bits.size := tcSmemLineTlSize.U
adapter.io.inReq.bits.opcode := Mux(core.io.tc_a_bits_write(p2).asBool, TLMessages.PutFullData, TLMessages.Get)
adapter.io.inReq.bits.mask := Fill(outer.tcSmemLineSize, 1.U(1.W))
adapter.io.inReq.bits.data := slice(core.io.tc_a_bits_data, tcCoreDataBits, p2)(tcSmemLineBits - 1, 0)
adapter.io.inResp.ready := core.io.tc_d_ready(p2)
client._1.a <> adapter.io.outReq
adapter.io.outResp <> client._1.d
val lineData = adapter.io.inResp.bits.data
val fragmentData = if (outer.tcSmemLineSize == outer.tcSmemSize) {
lineData
} else {
val fragmentsPerLine = outer.tcSmemLineSize / outer.tcSmemSize
val fragmentIndex = RegInit(0.U(log2Ceil(fragmentsPerLine).W))
val requestFragmentIndex = ((rawAddress & (outer.tcSmemLineSize - 1).U) >>
log2Ceil(outer.tcSmemSize)).asUInt
val lineFragments = lineData.asTypeOf(Vec(fragmentsPerLine, UInt(tcDataBits.W)))
when(adapter.io.inReq.fire) {
fragmentIndex := requestFragmentIndex
}
lineFragments(fragmentIndex)
}
tcAReady(p2) := adapter.io.inReq.ready
tcDValid(p2) := adapter.io.inResp.valid
tcDData(p2) := padToCoreData(fragmentData)
tcDTag(p2) := adapter.io.inResp.bits.source
}
// port 0: global memory (cp/cb), one TL client per tensor core.
(0 until nTC).foreach { tc =>
val p0 = port(tc, 0)
val gmemClient = outer.tcGmemNodes(tc).out.head
val gmemAdapter = Module(new VortexTLAdapter(
outer.dmemSourceWidth,
new VortexBundleA(tagWidth = outer.tensorTagWidth, dataWidth = tcDataBits),
new VortexBundleD(tagWidth = outer.tensorTagWidth, dataWidth = tcDataBits),
gmemClient
))
gmemAdapter.io.inReq.bits <> DontCare
gmemAdapter.io.inReq.valid := core.io.tc_a_valid(p0)
gmemAdapter.io.inReq.bits.address := slice(core.io.tc_a_bits_address, 32, p0)
gmemAdapter.io.inReq.bits.source := slice(core.io.tc_a_bits_tag, outer.tensorTagWidth, p0)
gmemAdapter.io.inReq.bits.size := tcTlSize.U
gmemAdapter.io.inReq.bits.opcode := Mux(core.io.tc_a_bits_write(p0).asBool, TLMessages.PutFullData, TLMessages.Get)
gmemAdapter.io.inReq.bits.mask := slice(core.io.tc_a_bits_mask, 32, p0)(outer.tcSmemSize - 1, 0)
gmemAdapter.io.inReq.bits.data := slice(core.io.tc_a_bits_data, tcCoreDataBits, p0)(tcDataBits - 1, 0)
gmemAdapter.io.inResp.ready := core.io.tc_d_ready(p0)
gmemClient._1.a <> gmemAdapter.io.outReq
gmemAdapter.io.outResp <> gmemClient._1.d
tcAReady(p0) := gmemAdapter.io.inReq.ready
tcDValid(p0) := gmemAdapter.io.inResp.valid
tcDData(p0) := padToCoreData(gmemAdapter.io.inResp.bits.data)
tcDTag(p0) := gmemAdapter.io.inResp.bits.source
}
core.io.tc_a_ready := tcAReady.asUInt
core.io.tc_d_valid := tcDValid.asUInt
core.io.tc_d_bits_data := tcDData.asUInt
core.io.tc_d_bits_tag := tcDTag.asUInt
} else {
core.io.tc_a_ready := false.B
core.io.tc_d_valid := false.B
core.io.tc_d_bits_data := DontCare
core.io.tc_d_bits_tag := DontCare
core.io.tc_tmem_A_rready := DontCare
core.io.tc_tmem_A_rdata := DontCare
core.io.tc_tmem_C_rready := DontCare
core.io.tc_tmem_C_rdata := DontCare
core.io.tc_tmem_C_wready := DontCare
core.io.sc_tmem_rready := DontCare
core.io.sc_tmem_rdata := DontCare
core.io.sc_tmem_wready := DontCare
}
} }
def connectBarrier = { def connectBarrier = {
@@ -847,7 +1196,11 @@ class RadianceTileModuleImp(outer: RadianceTile)
connectImem connectImem
connectDmem connectDmem
connectSmem connectSmem
if (outer.radianceParams.core.tensorCoreBlackwell) {
connectTensorBlackwell
} else {
connectTensor connectTensor
}
connectBarrier connectBarrier
connectAccelerator connectAccelerator
} }
@@ -874,6 +1227,27 @@ class RadianceTileModuleImp(outer: RadianceTile)
tensor.io.reqA.ready := false.B tensor.io.reqA.ready := false.B
tensor.io.reqB.ready := false.B tensor.io.reqB.ready := false.B
tensor.io.writeback.ready := false.B tensor.io.writeback.ready := false.B
dontTouch(tensor.io)
} else if (outer.radianceParams.core.tensorCoreBlackwell) {
val tensorNumSourceIds = (1 << outer.tensorTagWidth)
val tensor = Module(new radiance.core.TensorCoreBlackwell(
outer.numWarps, outer.numLsuLanes, half = true, tensorNumSourceIds))
tensor.io.initiate.valid := false.B
tensor.io.initiate.bits := DontCare
tensor.io.respA.valid := false.B
tensor.io.respA.bits := DontCare
tensor.io.respB.valid := false.B
tensor.io.respB.bits := DontCare
tensor.io.respC := DontCare
tensor.io.reqA.ready := false.B
tensor.io.reqB.ready := false.B
tensor.io.writeback.ready := false.B
tensor.io.tmemC.aRready := false.B
tensor.io.tmemC.aRdata := DontCare
tensor.io.tmemC.cRready := false.B
tensor.io.tmemC.cRdata := DontCare
tensor.io.tmemC.cWready := false.B
dontTouch(tensor.io)
} else { } else {
if (outer.radianceParams.core.tensorCoreFP16) { if (outer.radianceParams.core.tensorCoreFP16) {
val dpu = Module(new radiance.core.TensorDotProductUnit(4, half = true)) val dpu = Module(new radiance.core.TensorDotProductUnit(4, half = true))
@@ -936,18 +1310,6 @@ class VortexTLAdapter(
val outResp = chiselTypeOf(outTL._1.d) val outResp = chiselTypeOf(outTL._1.d)
}) })
val (bundle, edge) = outTL val (bundle, edge) = outTL
val sourceGen = Module(
new SourceGenerator(
newSourceWidth,
Some(inReqT.source),
ignoreInUse = false
)
)
sourceGen.io.gen := io.outReq.fire // use up a source ID only when request is created
sourceGen.io.reclaim.valid := io.outResp.fire
sourceGen.io.reclaim.bits := io.outResp.bits.source
sourceGen.io.meta := io.inReq.bits.source
// io passthrough logic // io passthrough logic
// TLBundleA <> VortexBundleA // TLBundleA <> VortexBundleA
io.outReq.valid := io.inReq.valid io.outReq.valid := io.inReq.valid
@@ -956,29 +1318,70 @@ class VortexTLAdapter(
io.outReq.bits.size := io.inReq.bits.size io.outReq.bits.size := io.inReq.bits.size
io.outReq.bits.source := io.inReq.bits.source io.outReq.bits.source := io.inReq.bits.source
io.outReq.bits.address := io.inReq.bits.address io.outReq.bits.address := io.inReq.bits.address
// Get requires contiguous mask; only copy core's potentially-partial mask val outMaskWidth = io.outReq.bits.mask.getWidth
// when writing val inMaskWidth = io.inReq.bits.mask.getWidth
val outDataWidth = io.outReq.bits.data.getWidth
val inDataWidth = io.inReq.bits.data.getWidth
val byteOffset = io.inReq.bits.address(log2Ceil(outMaskWidth) - 1, 0)
val responseOffsetWidth = log2Ceil(outMaskWidth)
val responseSourceWidth = inReqT.source.getWidth
val sourceGen = Module(
new SourceGenerator(
newSourceWidth,
Some(UInt((responseSourceWidth + responseOffsetWidth).W)),
ignoreInUse = false
)
)
sourceGen.io.gen := io.outReq.fire // use up a source ID only when request is created
sourceGen.io.reclaim.valid := io.outResp.fire
sourceGen.io.reclaim.bits := io.outResp.bits.source
sourceGen.io.meta := Cat(byteOffset, io.inReq.bits.source)
val alignedMask = Wire(UInt(outMaskWidth.W))
val alignedData = Wire(UInt(outDataWidth.W))
if (outMaskWidth == inMaskWidth) {
alignedMask := io.inReq.bits.mask
} else {
val paddedMask = Wire(UInt(outMaskWidth.W))
paddedMask := io.inReq.bits.mask
alignedMask := (paddedMask << byteOffset)(outMaskWidth - 1, 0)
}
if (outDataWidth == inDataWidth) {
alignedData := io.inReq.bits.data
} else {
val paddedData = Wire(UInt(outDataWidth.W))
paddedData := io.inReq.bits.data
alignedData := (paddedData << (byteOffset << 3))(outDataWidth - 1, 0)
}
// PutFull requires the TL-canonical full mask for address+size; PutPartial
// can carry the core-provided byte mask.
io.outReq.bits.mask := Mux( io.outReq.bits.mask := Mux(
edge.hasData(io.outReq.bits), io.outReq.bits.opcode === TLMessages.PutPartialData,
io.inReq.bits.mask, alignedMask,
// generate TL-correct mask
edge.mask(io.inReq.bits.address, io.inReq.bits.size) edge.mask(io.inReq.bits.address, io.inReq.bits.size)
) )
io.outReq.bits.data := io.inReq.bits.data io.outReq.bits.data := alignedData
io.outReq.bits.corrupt := 0.U io.outReq.bits.corrupt := 0.U
io.inReq.ready := io.outReq.ready io.inReq.ready := io.outReq.ready
// VortexBundleD <> TLBundleD // VortexBundleD <> TLBundleD
io.inResp.valid := io.outResp.valid io.inResp.valid := io.outResp.valid
io.inResp.bits.opcode := io.outResp.bits.opcode io.inResp.bits.opcode := io.outResp.bits.opcode
io.inResp.bits.size := io.outResp.bits.size io.inResp.bits.size := io.outResp.bits.size
io.inResp.bits.source := io.outResp.bits.source val responseMeta = sourceGen.io.peek.asUInt
val responseSource = responseMeta(responseSourceWidth - 1, 0)
val responseByteOffset =
responseMeta(responseSourceWidth + responseOffsetWidth - 1, responseSourceWidth)
io.inResp.bits.source := responseSource
if (outDataWidth == inDataWidth) {
io.inResp.bits.data := io.outResp.bits.data io.inResp.bits.data := io.outResp.bits.data
} else {
io.inResp.bits.data := (io.outResp.bits.data >> (responseByteOffset << 3))(inDataWidth - 1, 0)
}
io.outResp.ready := io.inResp.ready io.outResp.ready := io.inResp.ready
// "man-in-the-middle" // "man-in-the-middle"
io.inReq.ready := io.outReq.ready && sourceGen.io.id.valid io.inReq.ready := io.outReq.ready && sourceGen.io.id.valid
io.outReq.valid := io.inReq.valid && sourceGen.io.id.valid io.outReq.valid := io.inReq.valid && sourceGen.io.id.valid
io.outReq.bits.source := sourceGen.io.id.bits io.outReq.bits.source := sourceGen.io.id.bits
// translate upstream response back to its old sourceId
io.inResp.bits.source := sourceGen.io.peek
} }

View File

@@ -90,14 +90,45 @@ class VortexBundle(tile: RadianceTile)(implicit p: Parameters) extends CoreBundl
val smem_d_bits_data = Input(UInt((tile.numLsuLanes * 32).W)) val smem_d_bits_data = Input(UInt((tile.numLsuLanes * 32).W))
val smem_d_ready = Output(UInt((tile.numLsuLanes * 1).W)) val smem_d_ready = Output(UInt((tile.numLsuLanes * 1).W))
val tc_a_valid = Output(UInt(2.W)) val numTensorCores = if (tile.radianceParams.core.tensorCoreBlackwell) tile.numTensorCores else 1
val tc_a_bits_address = Output(UInt((2 * 32).W)) val tcPortCount = 3
val tc_a_bits_tag = Output(UInt((2 * 4).W)) val tcFlatPortCount = tcPortCount * numTensorCores
val tc_a_ready = Input(UInt(2.W)) val tc_a_valid = Output(UInt(tcFlatPortCount.W))
val tc_d_valid = Input(UInt(2.W)) val tc_a_bits_write = Output(UInt(tcFlatPortCount.W))
val tc_d_bits_data = Input(UInt((2 * 32 * 8).W)) val tc_a_bits_address = Output(UInt((tcFlatPortCount * 32).W))
val tc_d_bits_tag = Input(UInt((2 * 4).W)) val tc_a_bits_tag = Output(UInt((tcFlatPortCount * 4).W))
val tc_d_ready = Output(UInt(2.W)) val tc_a_bits_mask = Output(UInt((tcFlatPortCount * 32).W))
val tc_a_bits_data = Output(UInt((tcFlatPortCount * 32 * 8).W))
val tc_a_ready = Input(UInt(tcFlatPortCount.W))
val tc_d_valid = Input(UInt(tcFlatPortCount.W))
val tc_d_bits_data = Input(UInt((tcFlatPortCount * 32 * 8).W))
val tc_d_bits_tag = Input(UInt((tcFlatPortCount * 4).W))
val tc_d_ready = Output(UInt(tcFlatPortCount.W))
// Direct SRAM ports for shared TMEM (bypasses TileLink)
val numLanes = tile.numLsuLanes
val tc_tmem_A_ren = Output(UInt(numTensorCores.W))
val tc_tmem_A_rready = Input(UInt(numTensorCores.W))
val tc_tmem_A_raddr = Output(UInt((numTensorCores * 9).W))
val tc_tmem_A_rdata = Input(UInt((numTensorCores * numLanes * 32).W))
val tc_tmem_C_ren = Output(UInt(numTensorCores.W))
val tc_tmem_C_rready = Input(UInt(numTensorCores.W))
val tc_tmem_C_raddr = Output(UInt((numTensorCores * 9).W))
val tc_tmem_C_rdata = Input(UInt((numTensorCores * numLanes * 32).W))
val tc_tmem_C_wen = Output(UInt(numTensorCores.W))
val tc_tmem_C_wready = Input(UInt(numTensorCores.W))
val tc_tmem_C_waddr = Output(UInt((numTensorCores * 9).W))
val tc_tmem_C_wdata = Output(UInt((numTensorCores * numLanes * 32).W))
val tc_tmem_C_mask = Output(UInt((numTensorCores * numLanes * 4).W))
val sc_tmem_ren = Output(UInt(1.W))
val sc_tmem_rready = Input(UInt(1.W))
val sc_tmem_raddr = Output(UInt(9.W))
val sc_tmem_rdata = Input(UInt((numLanes * 32).W))
val sc_tmem_wen = Output(UInt(1.W))
val sc_tmem_wready = Input(UInt(1.W))
val sc_tmem_waddr = Output(UInt(9.W))
val sc_tmem_wdata = Output(UInt((numLanes * 32).W))
val sc_tmem_mask = Output(UInt((numLanes * 4).W))
// FIXME: hardcoded // FIXME: hardcoded
val barrierIdBits = tile.barrierMasterNode.out(0)._2.barrierIdBits val barrierIdBits = tile.barrierMasterNode.out(0)._2.barrierIdBits
@@ -132,9 +163,9 @@ class Vortex(tile: RadianceTile)(implicit p: Parameters)
Map( Map(
"CORE_ID" -> tile.radianceParams.coreId, "CORE_ID" -> tile.radianceParams.coreId,
"TENSOR_FP16" -> (if (tile.radianceParams.core.tensorCoreFP16) 1 else 0), "TENSOR_FP16" -> (if (tile.radianceParams.core.tensorCoreFP16) 1 else 0),
// TODO: can we get this as a parameter? "STARTUP_ADDR" -> tile.radianceParams.core.startupAddress,
"BOOTROM_HANG100" -> 0x10100, "NUM_THREADS" -> tile.numLsuLanes,
"NUM_THREADS" -> tile.numLsuLanes "NUM_TENSOR_CORES" -> (if (tile.radianceParams.core.tensorCoreBlackwell) tile.numTensorCores else 1)
) )
) )
with HasBlackBoxResource with HasBlackBoxPath { with HasBlackBoxResource with HasBlackBoxPath {
@@ -198,6 +229,7 @@ class Vortex(tile: RadianceTile)(implicit p: Parameters)
addResource("/vsrc/vortex/hw/rtl/core/VX_scoreboard.sv") addResource("/vsrc/vortex/hw/rtl/core/VX_scoreboard.sv")
addResource("/vsrc/vortex/hw/rtl/core/VX_sfu_unit.sv") addResource("/vsrc/vortex/hw/rtl/core/VX_sfu_unit.sv")
addResource("/vsrc/vortex/hw/rtl/core/VX_smem_unit.sv") addResource("/vsrc/vortex/hw/rtl/core/VX_smem_unit.sv")
addResource("/vsrc/vortex/hw/rtl/core/VX_tensor_ctrl_unit.sv")
addResource("/vsrc/vortex/hw/rtl/core/VX_split_join.sv") addResource("/vsrc/vortex/hw/rtl/core/VX_split_join.sv")
addResource("/vsrc/vortex/hw/rtl/core/VX_trace.vh") addResource("/vsrc/vortex/hw/rtl/core/VX_trace.vh")
addResource("/vsrc/vortex/hw/rtl/core/VX_wctl_unit.sv") addResource("/vsrc/vortex/hw/rtl/core/VX_wctl_unit.sv")
@@ -328,6 +360,7 @@ class Vortex(tile: RadianceTile)(implicit p: Parameters)
addResource("/vsrc/vortex/hw/rtl/fpu/VX_fpu_div.sv") addResource("/vsrc/vortex/hw/rtl/fpu/VX_fpu_div.sv")
addResource("/vsrc/vortex/hw/rtl/fpu/VX_fpu_dpi.sv") addResource("/vsrc/vortex/hw/rtl/fpu/VX_fpu_dpi.sv")
addResource("/vsrc/vortex/hw/rtl/fpu/VX_fpu_dsp.sv") addResource("/vsrc/vortex/hw/rtl/fpu/VX_fpu_dsp.sv")
addResource("/vsrc/vortex/hw/rtl/fpu/VX_fpu_exp.sv")
addResource("/vsrc/vortex/hw/rtl/fpu/VX_fpu_fma.sv") addResource("/vsrc/vortex/hw/rtl/fpu/VX_fpu_fma.sv")
addResource("/vsrc/vortex/hw/rtl/fpu/VX_fpu_ncomp.sv") addResource("/vsrc/vortex/hw/rtl/fpu/VX_fpu_ncomp.sv")
addResource("/vsrc/vortex/hw/rtl/fpu/VX_fpu_rounding.sv") addResource("/vsrc/vortex/hw/rtl/fpu/VX_fpu_rounding.sv")
@@ -411,6 +444,8 @@ class Vortex(tile: RadianceTile)(implicit p: Parameters)
// hopper-style SMEM operand decoupling // hopper-style SMEM operand decoupling
if (tile.radianceParams.core.tensorCoreDecoupled) { if (tile.radianceParams.core.tensorCoreDecoupled) {
addResource("/vsrc/vortex/hw/rtl/core/VX_tensor_hopper_core.sv") addResource("/vsrc/vortex/hw/rtl/core/VX_tensor_hopper_core.sv")
} else if (tile.radianceParams.core.tensorCoreBlackwell) {
addResource("/vsrc/vortex/hw/rtl/core/VX_tensor_blackwell_core.sv")
// addResource("/vsrc/vortex/hw/rtl/core/VX_tensor_ucode.vh") // addResource("/vsrc/vortex/hw/rtl/core/VX_tensor_ucode.vh")
def addHopperTensorCore = { def addHopperTensorCore = {
addPath("/scratch/hansung/chipyard/sims/vcs/generated-src/chipyard.unittest.TestHarness.TensorUnitTestConfig/gen-collateral/AddRawFN.sv") addPath("/scratch/hansung/chipyard/sims/vcs/generated-src/chipyard.unittest.TestHarness.TensorUnitTestConfig/gen-collateral/AddRawFN.sv")
@@ -444,7 +479,9 @@ class Vortex(tile: RadianceTile)(implicit p: Parameters)
addResource("/vsrc/vortex/hw/rtl/core/VX_uop_sequencer.sv") addResource("/vsrc/vortex/hw/rtl/core/VX_uop_sequencer.sv")
addResource("/vsrc/vortex/hw/rtl/core/VX_reduce_unit.sv") addResource("/vsrc/vortex/hw/rtl/core/VX_reduce_unit.sv")
if (!tile.radianceParams.core.tensorCoreBlackwell) {
addResource("/vsrc/vortex/hw/rtl/fpu/VX_tensor_dpu.sv") addResource("/vsrc/vortex/hw/rtl/fpu/VX_tensor_dpu.sv")
}
if (tile.radianceParams.useVxCache) { if (tile.radianceParams.useVxCache) {
addResource("/vsrc/vortex/hw/rtl/libs/VX_pending_size.sv") addResource("/vsrc/vortex/hw/rtl/libs/VX_pending_size.sv")

View File

@@ -0,0 +1,110 @@
package radiance.core
import chisel3._
import chiseltest._
import org.scalatest.flatspec.AnyFlatSpec
class FP8E4M3DecodeHarness extends Module {
val io = IO(new Bundle {
val in = Input(UInt(8.W))
val out = Output(UInt(32.W))
})
io.out := FP8E4M3.toFloat32(io.in)
}
class FP8E4M3MulHarness extends Module {
val io = IO(new Bundle {
val a = Input(UInt(8.W))
val b = Input(UInt(8.W))
val out = Output(UInt(32.W))
})
io.out := FP8E4M3MulToFloat32(io.a, io.b)
}
class FP8E4M3Test extends AnyFlatSpec with ChiselScalatestTester {
behavior of "FP8E4M3"
it should "decode representative E4M3 values to FP32 bits" in {
test(new FP8E4M3DecodeHarness) { c =>
Seq(
0x00 -> 0x00000000L,
0x80 -> 0x80000000L,
0x38 -> 0x3f800000L,
0x40 -> 0x40000000L,
0x30 -> 0x3f000000L,
0x3c -> 0x3fc00000L
).foreach { case (fp8, fp32) =>
c.io.in.poke(fp8.U)
c.clock.step()
c.io.out.expect(fp32.U)
}
}
}
it should "multiply E4M3 operands with FP8-width significands and return FP32 bits" in {
test(new FP8E4M3MulHarness) { c =>
Seq(
(0x38, 0x40, 0x40000000L), // 1.0 * 2.0 = 2.0
(0x30, 0x3c, 0x3f400000L), // 0.5 * 1.5 = 0.75
(0xb8, 0x40, 0xc0000000L), // -1.0 * 2.0 = -2.0
(0x00, 0x40, 0x00000000L), // 0.0 * 2.0 = 0.0
(0x80, 0x40, 0x80000000L) // -0.0 * 2.0 = -0.0
).foreach { case (a, b, out) =>
c.io.a.poke(a.U)
c.io.b.poke(b.U)
c.clock.step()
c.io.out.expect(out.U)
}
}
}
it should "run an 8-wide FP8 dot product with FP32 accumulation" in {
test(new TensorDotProductUnit(8, half = false, inputType = TensorInputType.FP8E4M3)) { c =>
c.io.in.valid.poke(true.B)
c.io.stall.poke(false.B)
for (i <- 0 until 8) {
c.io.in.bits.a(i).poke(0x38.U(8.W))
c.io.in.bits.b(i).poke(0x40.U(8.W))
}
c.io.in.bits.c.poke(0x3f800000L.U(32.W))
c.io.out.valid.expect(false.B)
c.clock.step()
c.io.in.valid.poke(false.B)
c.io.out.valid.expect(false.B)
c.clock.step()
c.clock.step()
c.clock.step()
c.clock.step()
c.io.out.valid.expect(true.B)
c.io.out.bits.data.expect(0x41880000L.U)
}
}
it should "run an 8-wide fractional FP8 dot product with FP32 accumulation" in {
test(new TensorDotProductUnit(8, half = false, inputType = TensorInputType.FP8E4M3)) { c =>
c.io.in.valid.poke(true.B)
c.io.stall.poke(false.B)
for (i <- 0 until 8) {
c.io.in.bits.a(i).poke(0x30.U(8.W))
c.io.in.bits.b(i).poke(0x3c.U(8.W))
}
c.io.in.bits.c.poke(0x40000000L.U(32.W))
c.io.out.valid.expect(false.B)
c.clock.step()
c.io.in.valid.poke(false.B)
c.io.out.valid.expect(false.B)
c.clock.step()
c.clock.step()
c.clock.step()
c.clock.step()
c.io.out.valid.expect(true.B)
c.io.out.bits.data.expect(0x41000000L.U)
}
}
}

View File

@@ -0,0 +1,348 @@
package radiance.core
import chisel3._
import chiseltest._
import chiseltest.simulator.VerilatorBackendAnnotation
import org.scalatest.flatspec.AnyFlatSpec
import scala.collection.mutable
class TensorCoreBlackwellExtendedTest extends AnyFlatSpec with ChiselScalatestTester {
behavior of "TensorCoreBlackwell Extended Tests"
private val numWarps = 4
private val numLanes = 8
private val fragBytes = 32
private def idleIO(c: TensorCoreBlackwell): Unit = {
c.io.initiate.valid.poke(false.B)
c.io.respA.valid.poke(false.B)
c.io.respB.valid.poke(false.B)
c.io.respA.bits.source.poke(0.U)
c.io.respB.bits.source.poke(0.U)
c.io.respA.bits.data.poke(0.U)
c.io.respB.bits.data.poke(0.U)
c.io.reqA.ready.poke(false.B)
c.io.reqB.ready.poke(false.B)
c.io.respC.poke(0.U)
c.io.writeback.ready.poke(false.B)
c.io.tmemC.aRready.poke(true.B)
c.io.tmemC.aRdata.poke(0.U)
c.io.tmemC.cRready.poke(true.B)
c.io.tmemC.cRdata.poke(0.U)
c.io.tmemC.cWready.poke(true.B)
}
private def packWords(words: Seq[BigInt], width: Int): BigInt = {
val mask = (BigInt(1) << width) - 1
words.zipWithIndex.foldLeft(BigInt(0)) {
case (acc, (word, i)) => acc | ((word & mask) << (i * width))
}
}
private def makeTmem() = mutable.Map[BigInt, BigInt]().withDefaultValue(BigInt(0))
private def stepTmem(c: TensorCoreBlackwell, tmem: mutable.Map[BigInt, BigInt]): Unit = {
if (c.io.tmemC.aRen.peek().litToBoolean) {
val addr = c.io.tmemC.aRaddr.peek().litValue
c.io.tmemC.aRdata.poke(tmem(addr).U)
}
if (c.io.tmemC.cRen.peek().litToBoolean) {
val addr = c.io.tmemC.cRaddr.peek().litValue
c.io.tmemC.cRdata.poke(tmem(addr).U)
}
if (c.io.tmemC.cWen.peek().litToBoolean) {
val addr = c.io.tmemC.cWaddr.peek().litValue
tmem(addr) = c.io.tmemC.cWdata.peek().litValue
}
}
it should "verify bwgmma address offset with non-zero base addresses" in {
test(new TensorCoreBlackwell(numWarps, numLanes, half = true, numSourceIds = 4))
.withAnnotations(Seq(VerilatorBackendAnnotation)) { c =>
idleIO(c)
val tmem = makeTmem()
// Use non-zero base addresses to verify offset calculation
val aBase = BigInt(0x200) // row 16, A tile rows 16~47
val cBase = BigInt(0x600) // row 48, C tile rows 48~79 (no overlap with A)
val bBase = BigInt(0x800)
val fp8One = BigInt(0x38)
val fp32Zero = BigInt(0)
// 4 sets × 8 dot products × (1.0 × 2.0) = 64.0f
val fp32SixtyFour = BigInt(0x42800000L)
// Populate TMEM A at offset aBase (all 32 frags)
val aFrag = packWords(Seq.fill(32)(fp8One), 8)
val cFrag = packWords(Seq.fill(numLanes)(fp32Zero), 32)
for (i <- 0 until 32) {
tmem(aBase / fragBytes + i) = aFrag
tmem(cBase / fragBytes + i) = cFrag
}
// SMEM B with packed FP8 E4M3 2.0
val fp8Two = BigInt(0x40)
val bFrag = packWords(Seq.fill(32)(fp8Two), 8)
val bMem = mutable.Map[BigInt, BigInt]().withDefaultValue(bFrag)
for (i <- 0 until 32) bMem(bBase + i * fragBytes) = bFrag
c.io.reqB.ready.poke(true.B)
c.io.writeback.ready.poke(true.B)
c.io.initiate.valid.poke(true.B)
c.io.initiate.bits.op.poke(0.U)
c.io.initiate.bits.wid.poke(0.U)
c.io.initiate.bits.rd.poke(0.U)
c.io.initiate.bits.addressA.poke(aBase.U)
c.io.initiate.bits.addressB.poke(bBase.U)
c.io.initiate.bits.addressC.poke(cBase.U)
c.clock.step()
c.io.initiate.valid.poke(false.B)
var pendingB = Option.empty[(BigInt, BigInt)]
var sawWriteback = false
for (_ <- 0 until 50000 if !sawWriteback) {
stepTmem(c, tmem)
pendingB.foreach { case (src, data) =>
c.io.respB.valid.poke(true.B)
c.io.respB.bits.source.poke(src.U)
c.io.respB.bits.data.poke(data.U)
}
if (pendingB.isEmpty) c.io.respB.valid.poke(false.B)
if (c.io.writeback.valid.peek().litToBoolean) {
sawWriteback = true
} else {
val nextB = if (c.io.reqB.valid.peek().litToBoolean) {
val addr = c.io.reqB.bits.address.peek().litValue
val src = c.io.reqB.bits.source.peek().litValue
Some((src, bMem(addr)))
} else None
c.clock.step()
pendingB = nextB
}
}
assert(sawWriteback, "BWGMMA did not complete")
val expectedC = packWords(Seq.fill(numLanes)(fp32SixtyFour), 32)
for (i <- 0 until 8) {
val row = cBase / fragBytes + i
assert(tmem(row) == expectedC,
s"C frag $i at row $row: got 0x${tmem(row).toString(16)}, expected 0x${expectedC.toString(16)}")
}
for (i <- 0 until 8) {
assert(tmem(aBase / fragBytes + i) == aFrag, s"A frag $i should be unchanged")
}
}
}
it should "cp then ld round-trip: data written via cp is readable via ld" in {
test(new TensorCoreBlackwell(numWarps, numLanes, half = true, numSourceIds = 4)) { c =>
idleIO(c)
val tmem = makeTmem()
val tmemAddr = BigInt(0x100)
val cpData = packWords(Seq.tabulate(numLanes)(i => BigInt(0xABCD0000L + i)), 32)
// Issue cp: global mem -> tmem
c.io.initiate.valid.poke(true.B)
c.io.initiate.bits.op.poke(2.U)
c.io.initiate.bits.addressA.poke(tmemAddr.U)
c.io.initiate.bits.addressB.poke("h10000000".U)
c.io.reqA.ready.poke(true.B)
c.clock.step()
c.io.initiate.valid.poke(false.B)
// cpRead: reqA issued
c.io.reqA.valid.expect(true.B)
c.io.reqA.bits.rw.expect(false.B)
c.clock.step()
// cpWrite: respA fires, tmemC written
c.io.respA.valid.poke(true.B)
c.io.respA.bits.data.poke(cpData.U)
c.io.tmemC.cWen.expect(true.B)
c.io.tmemC.cWaddr.expect((tmemAddr / fragBytes).U)
c.io.tmemC.cWdata.expect(cpData.U)
stepTmem(c, tmem)
c.clock.step()
c.io.respA.valid.poke(false.B)
// Now issue ld from same tmem address
c.io.initiate.valid.poke(true.B)
c.io.initiate.bits.op.poke(4.U)
c.io.initiate.bits.rd.poke(2.U)
c.io.initiate.bits.addressA.poke(tmemAddr.U)
c.io.writeback.ready.poke(true.B)
c.clock.step()
c.io.initiate.valid.poke(false.B)
// ldReq: ren asserted, serve from tmem model
c.io.tmemC.cRen.expect(true.B)
c.io.tmemC.cRdata.poke(tmem(tmemAddr / fragBytes).U)
c.clock.step()
c.io.tmemC.cRdata.poke(tmem(tmemAddr / fragBytes).U)
c.clock.step()
// writeback should carry cpData
c.io.writeback.valid.expect(true.B)
for (i <- 0 until numLanes) {
c.io.writeback.bits.data(i).expect((BigInt(0xABCD0000L) + i).U)
}
}
}
it should "st then cb round-trip: data written via st is readable via cb" in {
test(new TensorCoreBlackwell(numWarps, numLanes, half = true, numSourceIds = 4)) { c =>
idleIO(c)
val tmem = makeTmem()
val tmemAddr = BigInt(0x140)
val stData = packWords(Seq.tabulate(numLanes)(i => BigInt(0xDEAD0000L + i)), 32)
// Issue st: respC -> tmem
c.io.initiate.valid.poke(true.B)
c.io.initiate.bits.op.poke(5.U)
c.io.initiate.bits.rd.poke(4.U)
c.io.initiate.bits.addressA.poke(tmemAddr.U)
c.io.respC.poke(stData.U)
c.clock.step()
c.io.initiate.valid.poke(false.B)
// stReq: reqC valid
c.io.reqC.valid.expect(true.B)
c.clock.step()
// stWrite: tmemC written
c.io.tmemC.cWen.expect(true.B)
c.io.tmemC.cWdata.expect(stData.U)
stepTmem(c, tmem)
c.clock.step()
// Issue cb: tmem -> global mem
c.io.initiate.valid.poke(true.B)
c.io.initiate.bits.op.poke(6.U)
c.io.initiate.bits.addressA.poke(tmemAddr.U)
c.io.initiate.bits.addressB.poke("h20000000".U)
c.io.reqA.ready.poke(true.B)
c.io.tmemC.cRdata.poke(tmem(tmemAddr / fragBytes).U)
c.clock.step()
c.io.initiate.valid.poke(false.B)
// cbRead: ren asserted
c.io.tmemC.cRen.expect(true.B)
c.io.tmemC.cRdata.poke(tmem(tmemAddr / fragBytes).U)
c.clock.step()
c.io.tmemC.cRdata.poke(tmem(tmemAddr / fragBytes).U)
c.clock.step()
// cbWrite: reqA write with stData
c.io.reqA.valid.expect(true.B)
c.io.reqA.bits.rw.expect(true.B)
c.io.reqA.bits.address.expect("h20000000".U)
c.io.reqA.bits.data.expect(stData.U)
}
}
it should "wait ops are no-ops and do not stall pipeline" in {
test(new TensorCoreBlackwell(numWarps, numLanes, half = true, numSourceIds = 4)) { c =>
idleIO(c)
// bwgmmaWait: should accept immediately and stay idle
c.io.initiate.valid.poke(true.B)
c.io.initiate.bits.op.poke(1.U) // bwgmmaWait
c.io.initiate.ready.expect(true.B)
c.clock.step()
c.io.initiate.valid.poke(false.B)
c.io.writeback.valid.expect(false.B)
c.io.reqA.valid.expect(false.B)
c.io.reqB.valid.expect(false.B)
// tcgen05CpWait: same
c.io.initiate.valid.poke(true.B)
c.io.initiate.bits.op.poke(3.U) // tcgen05CpWait
c.io.initiate.ready.expect(true.B)
c.clock.step()
c.io.initiate.valid.poke(false.B)
c.io.writeback.valid.expect(false.B)
c.io.reqA.valid.expect(false.B)
}
}
it should "not accept a second tensor op until the first one completes" in {
test(new TensorCoreBlackwell(numWarps, numLanes, half = true, numSourceIds = 4)) { c =>
idleIO(c)
val firstAddr = BigInt(0x180)
val secondAddr = BigInt(0x1a0)
val storeData = packWords(Seq.tabulate(numLanes)(i => BigInt(0xCAFE0000L + i)), 32)
c.io.initiate.valid.poke(true.B)
c.io.initiate.bits.op.poke(5.U)
c.io.initiate.bits.addressA.poke(firstAddr.U)
c.io.respC.poke(storeData.U)
c.io.initiate.ready.expect(true.B)
c.clock.step()
c.io.initiate.bits.op.poke(4.U)
c.io.initiate.bits.addressA.poke(secondAddr.U)
c.io.initiate.bits.rd.poke(2.U)
c.io.initiate.ready.expect(false.B)
c.clock.step()
c.io.initiate.ready.expect(false.B)
c.io.tmemC.cWen.expect(true.B)
c.clock.step()
c.io.initiate.ready.expect(true.B)
}
}
it should "multi-warp TMEM isolation: warp 0 and warp 3 do not alias" in {
test(new TensorCoreBlackwell(numWarps, numLanes, half = true, numSourceIds = 4)) { c =>
idleIO(c)
val tmem = makeTmem()
// warp 0: tmem_slot_base(0) = 0, tmem_a_base = 0
val warp0TmemA = BigInt(0x000)
val warp0Data = packWords(Seq.fill(numLanes)(BigInt(0xAAAAAAAAL)), 32)
// warp 3: tmem_slot_base(3) = 3*2048 = 6144 = 0x1800, tmem_a_base = 0x1800
val warp3TmemA = BigInt(0x1800)
val warp3Data = packWords(Seq.fill(numLanes)(BigInt(0xBBBBBBBBL)), 32)
// Write warp 0 data via st
c.io.initiate.valid.poke(true.B)
c.io.initiate.bits.op.poke(5.U)
c.io.initiate.bits.wid.poke(0.U)
c.io.initiate.bits.addressA.poke(warp0TmemA.U)
c.io.respC.poke(warp0Data.U)
c.clock.step()
c.io.initiate.valid.poke(false.B)
c.io.reqC.valid.expect(true.B)
c.clock.step()
c.io.tmemC.cWen.expect(true.B)
c.io.tmemC.cWaddr.expect((warp0TmemA / fragBytes).U)
stepTmem(c, tmem)
c.clock.step()
// Write warp 3 data via st
c.io.initiate.valid.poke(true.B)
c.io.initiate.bits.op.poke(5.U)
c.io.initiate.bits.wid.poke(3.U)
c.io.initiate.bits.addressA.poke(warp3TmemA.U)
c.io.respC.poke(warp3Data.U)
c.clock.step()
c.io.initiate.valid.poke(false.B)
c.io.reqC.valid.expect(true.B)
c.clock.step()
c.io.tmemC.cWen.expect(true.B)
c.io.tmemC.cWaddr.expect((warp3TmemA / fragBytes).U)
stepTmem(c, tmem)
c.clock.step()
// Verify no aliasing: warp 0 row != warp 3 row
assert(warp0TmemA / fragBytes != warp3TmemA / fragBytes)
assert(tmem(warp0TmemA / fragBytes) == warp0Data)
assert(tmem(warp3TmemA / fragBytes) == warp3Data)
}
}
}

View File

@@ -0,0 +1,76 @@
package radiance.core
import chisel3._
import chiseltest._
import org.scalatest.flatspec.AnyFlatSpec
class TensorCoreBlackwellFP8MicrostepHarness extends Module {
val io = IO(new Bundle {
val valid = Input(Bool())
val operandA = Input(UInt(512.W))
val operandB = Input(UInt(256.W))
val c = Input(UInt(32.W))
val elemM = Input(UInt(2.W))
val elemN = Input(UInt(1.W))
val outValid = Output(Bool())
val out = Output(UInt(32.W))
})
val dpu = Module(new TensorDotProductUnit(
dim = 8,
half = false,
inputType = TensorInputType.FP8E4M3
))
dpu.io.in.valid := io.valid
for (k <- 0 until 8) {
dpu.io.in.bits.a(k) := TensorCoreBlackwellFP8Packing.selectA(io.operandA, k, io.elemM, numLanes = 8)
dpu.io.in.bits.b(k) := TensorCoreBlackwellFP8Packing.selectB(io.operandB, k, io.elemN)
}
dpu.io.in.bits.c := io.c
dpu.io.stall := false.B
io.outValid := dpu.io.out.valid
io.out := dpu.io.out.bits.data
}
class TensorCoreBlackwellFP8Test extends AnyFlatSpec with ChiselScalatestTester {
behavior of "TensorCoreBlackwell FP8 operand microstep"
private def packWords(words: Seq[BigInt], width: Int): BigInt = {
val mask = (BigInt(1) << width) - 1
words.zipWithIndex.foldLeft(BigInt(0)) {
case (acc, (word, i)) => acc | ((word & mask) << (i * width))
}
}
it should "select packed FP8 E4M3 operands and accumulate into FP32" in {
test(new TensorCoreBlackwellFP8MicrostepHarness) { c =>
val fp8One = BigInt(0x38)
val fp8Two = BigInt(0x40)
val fp32One = BigInt(0x3f800000L)
val fp32Seventeen = BigInt(0x41880000L)
val operandA = packWords(Seq.fill(64)(fp8One), 8)
val operandB = packWords(Seq.fill(32)(fp8Two), 8)
c.io.valid.poke(true.B)
c.io.operandA.poke(operandA.U)
c.io.operandB.poke(operandB.U)
c.io.c.poke(fp32One.U)
c.io.elemM.poke(0.U)
c.io.elemN.poke(0.U)
c.io.outValid.expect(false.B)
c.clock.step()
c.io.valid.poke(false.B)
c.io.outValid.expect(false.B)
c.clock.step()
c.clock.step()
c.clock.step()
c.clock.step()
c.io.outValid.expect(true.B)
c.io.out.expect(fp32Seventeen.U)
}
}
}

View File

@@ -0,0 +1,323 @@
package radiance.core
import chisel3._
import chiseltest._
import chiseltest.simulator.VerilatorBackendAnnotation
import org.scalatest.flatspec.AnyFlatSpec
import scala.collection.mutable
class TensorCoreBlackwellTest extends AnyFlatSpec with ChiselScalatestTester {
behavior of "TensorCoreBlackwell"
private val numWarps = 4
private val numLanes = 8
private def idleIO(c: TensorCoreBlackwell): Unit = {
c.io.initiate.valid.poke(false.B)
c.io.respA.valid.poke(false.B)
c.io.respB.valid.poke(false.B)
c.io.respA.bits.source.poke(0.U)
c.io.respB.bits.source.poke(0.U)
c.io.respA.bits.data.poke(0.U)
c.io.respB.bits.data.poke(0.U)
c.io.reqA.ready.poke(false.B)
c.io.reqB.ready.poke(false.B)
c.io.respC.poke(0.U)
c.io.writeback.ready.poke(false.B)
c.io.tmemC.aRready.poke(true.B)
c.io.tmemC.aRdata.poke(0.U)
c.io.tmemC.cRready.poke(true.B)
c.io.tmemC.cRdata.poke(0.U)
c.io.tmemC.cWready.poke(true.B)
}
private def packWords(words: Seq[BigInt], width: Int): BigInt = {
val mask = (BigInt(1) << width) - 1
words.zipWithIndex.foldLeft(BigInt(0)) {
case (acc, (word, i)) => acc | ((word & mask) << (i * width))
}
}
// Simple TMEM model: address → 256-bit row
private def makeTmem() = mutable.Map[BigInt, BigInt]().withDefaultValue(BigInt(0))
// Drive TMEM read responses from model, handle C-port writes.
private def stepTmem(c: TensorCoreBlackwell, tmem: mutable.Map[BigInt, BigInt]): Unit = {
if (c.io.tmemC.aRen.peek().litToBoolean) {
val addr = c.io.tmemC.aRaddr.peek().litValue
c.io.tmemC.aRdata.poke(tmem(addr).U)
}
if (c.io.tmemC.cRen.peek().litToBoolean) {
val addr = c.io.tmemC.cRaddr.peek().litValue
c.io.tmemC.cRdata.poke(tmem(addr).U)
}
if (c.io.tmemC.cWen.peek().litToBoolean) {
val addr = c.io.tmemC.cWaddr.peek().litValue
tmem(addr) = c.io.tmemC.cWdata.peek().litValue
}
}
it should "tcgen05_ld: read from TMEM to writeback" in {
test(new TensorCoreBlackwell(numWarps, numLanes, half = true, numSourceIds = 4)) { c =>
idleIO(c)
val tmem = makeTmem()
val fragBytes = 32
val tmemAddr = BigInt(0x40) // row 2 (0x40 / 32 = 2)
val testData = packWords(Seq.tabulate(numLanes)(i => BigInt(0x1000 + i)), 32)
tmem(tmemAddr / fragBytes) = testData
c.io.initiate.valid.poke(true.B)
c.io.initiate.bits.op.poke(4.U) // tcgen05Ld
c.io.initiate.bits.wid.poke(0.U)
c.io.initiate.bits.rd.poke(3.U)
c.io.initiate.bits.addressA.poke(tmemAddr.U)
c.io.writeback.ready.poke(true.B)
c.io.tmemC.cRdata.poke(testData.U)
c.clock.step()
c.io.initiate.valid.poke(false.B)
c.io.initiate.ready.expect(false.B)
// ldReq: tmemC.ren asserted; rdata must be valid before next step
c.io.tmemC.cRen.expect(true.B)
c.io.tmemC.cRaddr.expect((tmemAddr / fragBytes).U)
c.io.tmemC.cRdata.poke(testData.U)
c.clock.step()
// waitWb: wbValid gets set this cycle, step to let it register
c.io.tmemC.cRdata.poke(testData.U)
c.clock.step()
// idle: writeback.valid now true
c.io.writeback.valid.expect(true.B)
c.io.initiate.ready.expect(false.B)
c.io.writeback.bits.rd.expect(3.U)
c.io.writeback.bits.wid.expect(0.U)
for (i <- 0 until numLanes) {
c.io.writeback.bits.data(i).expect((0x1000 + i).U)
}
}
}
it should "tcgen05_ld: support 4-lane 16-byte fragments" in {
val lanes = 4
test(new TensorCoreBlackwell(numWarps, lanes, half = true, numSourceIds = 4)) { c =>
idleIO(c)
val fragBytes = 16
val tmemAddr = BigInt(0x40)
val testData = packWords(Seq.tabulate(lanes)(i => BigInt(0x2000 + i)), 32)
c.io.initiate.valid.poke(true.B)
c.io.initiate.bits.op.poke(4.U) // tcgen05Ld
c.io.initiate.bits.wid.poke(0.U)
c.io.initiate.bits.rd.poke(3.U)
c.io.initiate.bits.addressA.poke(tmemAddr.U)
c.io.writeback.ready.poke(true.B)
c.clock.step()
c.io.initiate.valid.poke(false.B)
c.io.tmemC.cRen.expect(true.B)
c.io.tmemC.cRaddr.expect((tmemAddr / fragBytes).U)
c.io.tmemC.cRdata.poke(testData.U)
c.clock.step()
c.io.tmemC.cRdata.poke(testData.U)
c.clock.step()
c.io.writeback.valid.expect(true.B)
c.io.writeback.bits.rd.expect(3.U)
for (i <- 0 until lanes) {
c.io.writeback.bits.data(i).expect((0x2000 + i).U)
}
}
}
it should "tcgen05_st: write from respC to TMEM" in {
test(new TensorCoreBlackwell(numWarps, numLanes, half = true, numSourceIds = 4)) { c =>
idleIO(c)
val fragBytes = 32
val tmemAddr = BigInt(0x60)
val storeData = packWords(Seq.tabulate(numLanes)(i => BigInt(0xAB00 + i)), 32)
c.io.initiate.valid.poke(true.B)
c.io.initiate.bits.op.poke(5.U) // tcgen05St
c.io.initiate.bits.wid.poke(0.U)
c.io.initiate.bits.rd.poke(7.U)
c.io.initiate.bits.addressA.poke(tmemAddr.U)
c.io.respC.poke(storeData.U)
c.clock.step()
c.io.initiate.valid.poke(false.B)
c.io.initiate.ready.expect(false.B)
// stReq: reqC.valid asserted
c.io.reqC.valid.expect(true.B)
c.io.reqC.bits.expect(7.U)
c.clock.step()
// stWrite: tmemC.wen asserted with storeData
c.io.tmemC.cWen.expect(true.B)
c.io.tmemC.cWaddr.expect((tmemAddr / fragBytes).U)
c.io.tmemC.cWdata.expect(storeData.U)
c.clock.step()
c.io.initiate.ready.expect(true.B)
}
}
it should "tcgen05_cp: read from global mem (reqA) and write to TMEM" in {
test(new TensorCoreBlackwell(numWarps, numLanes, half = true, numSourceIds = 4)) { c =>
idleIO(c)
val fragBytes = 32
val tmemAddr = BigInt(0x80)
val gmemAddr = "ha0001000"
val cpData = packWords(Seq.fill(numLanes)(BigInt(0xdeadbeefL)), 32)
c.io.initiate.valid.poke(true.B)
c.io.initiate.bits.op.poke(2.U) // tcgen05Cp
c.io.initiate.bits.addressA.poke(tmemAddr.U)
c.io.initiate.bits.addressB.poke(gmemAddr.U)
c.io.reqA.ready.poke(true.B)
c.clock.step()
c.io.initiate.valid.poke(false.B)
c.io.initiate.ready.expect(false.B)
// cpRead: reqA issued to global mem
c.io.reqA.valid.expect(true.B)
c.io.reqA.bits.rw.expect(false.B)
c.io.reqA.bits.address.expect(gmemAddr.U)
c.clock.step()
c.io.initiate.ready.expect(false.B)
// cpWrite: respA fires → tmemC.wen in same cycle
c.io.respA.valid.poke(true.B)
c.io.respA.bits.data.poke(cpData.U)
// tmemC write happens combinatorially when respA fires
c.io.tmemC.cWen.expect(true.B)
c.io.tmemC.cWaddr.expect((tmemAddr / fragBytes).U)
c.io.tmemC.cWdata.expect(cpData.U)
c.clock.step()
c.io.initiate.ready.expect(true.B)
}
}
it should "tcgen05_cb: read from TMEM and write to global mem (reqA)" in {
test(new TensorCoreBlackwell(numWarps, numLanes, half = true, numSourceIds = 4)) { c =>
idleIO(c)
val fragBytes = 32
val tmemAddr = BigInt(0xa0)
val gmemAddr = "ha2000000"
val cbData = packWords(Seq.tabulate(numLanes)(i => BigInt(0xC000 + i)), 32)
c.io.initiate.valid.poke(true.B)
c.io.initiate.bits.op.poke(6.U) // tcgen05Cb
c.io.initiate.bits.addressA.poke(tmemAddr.U)
c.io.initiate.bits.addressB.poke(gmemAddr.U)
c.io.reqA.ready.poke(true.B)
c.io.tmemC.cRdata.poke(cbData.U)
c.clock.step()
c.io.initiate.valid.poke(false.B)
c.io.initiate.ready.expect(false.B)
// cbRead: tmemC.ren asserted
c.io.tmemC.cRen.expect(true.B)
c.io.tmemC.cRaddr.expect((tmemAddr / fragBytes).U)
c.clock.step()
c.io.tmemC.cRdata.poke(cbData.U)
c.clock.step()
c.io.initiate.ready.expect(false.B)
// cbWrite: reqA write to global mem
c.io.reqA.valid.expect(true.B)
c.io.reqA.bits.rw.expect(true.B)
c.io.reqA.bits.address.expect(gmemAddr.U)
c.io.reqA.bits.data.expect(cbData.U)
c.clock.step()
c.io.initiate.ready.expect(false.B)
c.io.respA.valid.poke(true.B)
c.io.respA.bits.data.poke(0.U)
c.clock.step()
c.io.initiate.ready.expect(true.B)
}
}
it should "run bwgmma: TMEM_C = TMEM_A * SMEM_B + TMEM_C" in {
test(new TensorCoreBlackwell(numWarps, numLanes, half = true, numSourceIds = 4))
.withAnnotations(Seq(VerilatorBackendAnnotation)) { c =>
idleIO(c)
val fragBytes = 32
val aBase = BigInt(0x100)
val bBase = BigInt(0x800)
val cBase = BigInt(0x1000)
// A/B: packed FP8 E4M3 bytes, 32 elements per 256-bit frag
val fp8One = BigInt(0x38)
val fp8Two = BigInt(0x40)
val fp32One = BigInt(0x3f800000)
val fp32SixtyFive = BigInt(0x42820000)
val aFrag = packWords(Seq.fill(32)(fp8One), 8)
val bFrag = packWords(Seq.fill(32)(fp8Two), 8)
val cFrag = packWords(Seq.fill(numLanes)(fp32One), 32)
val expectedCFrag = packWords(Seq.fill(numLanes)(fp32SixtyFive), 32)
// Populate TMEM with A and C tiles
val tmem = makeTmem()
for (i <- 0 until 32) {
tmem(aBase / fragBytes + i) = aFrag
tmem(cBase / fragBytes + i) = cFrag
}
val bMem = mutable.Map[BigInt, BigInt]()
for (i <- 0 until 32) bMem(bBase + i * fragBytes) = bFrag
c.io.reqB.ready.poke(true.B)
c.io.writeback.ready.poke(true.B)
c.io.initiate.valid.poke(true.B)
c.io.initiate.bits.op.poke(0.U) // bwgmma
c.io.initiate.bits.wid.poke(1.U)
c.io.initiate.bits.rd.poke(0.U)
c.io.initiate.bits.addressA.poke(aBase.U)
c.io.initiate.bits.addressB.poke(bBase.U)
c.io.initiate.bits.addressC.poke(cBase.U)
c.clock.step()
c.io.initiate.valid.poke(false.B)
var pendingB = Option.empty[(BigInt, BigInt)]
var sawWriteback = false
for (_ <- 0 until 20000 if !sawWriteback) {
// Drive TMEM reads/writes
stepTmem(c, tmem)
// Drive SMEM B responses
pendingB.foreach { case (src, data) =>
c.io.respB.valid.poke(true.B)
c.io.respB.bits.source.poke(src.U)
c.io.respB.bits.data.poke(data.U)
}
if (pendingB.isEmpty) c.io.respB.valid.poke(false.B)
if (c.io.writeback.valid.peek().litToBoolean) {
sawWriteback = true
} else {
val nextB = if (c.io.reqB.valid.peek().litToBoolean) {
val addr = c.io.reqB.bits.address.peek().litValue
val src = c.io.reqB.bits.source.peek().litValue
Some((src, bMem(addr)))
} else None
c.clock.step()
pendingB = nextB
}
}
assert(sawWriteback, "BWGMMA did not complete")
c.io.writeback.bits.wid.expect(1.U)
// Verify all 32 C frags in TMEM
for (i <- 0 until 32) {
val row = cBase / fragBytes + i
assert(tmem(row) == expectedCFrag,
s"C frag $i mismatch: got 0x${tmem(row).toString(16)}, expected 0x${expectedCFrag.toString(16)}")
}
}
}
}