Add Blackwell tensor core implementation and tests

- Implement TensorCoreBlackwell.scala with BWGMMA and TCGEN05 instructions
- Update TensorDPU, RadianceTile, and VortexCore for Blackwell integration
- Add TensorCoreBlackwellExtendedTest for comprehensive testing
- Update vortex submodule with Blackwell ISA support
This commit is contained in:
2026-05-06 14:51:11 +08:00
parent 136cf70a58
commit 5112f3665a
8 changed files with 960 additions and 262 deletions

View File

@@ -13,6 +13,9 @@ class TensorCoreBlackwell(
val numSourceIds: Int = 16,
val numFPRegs: Int = 32
) extends Module {
require(half, "Blackwell MMA currently supports FP16 inputs only")
require(numLanes == 8, "Blackwell MMA currently assumes 8 lanes")
val numWarpBits = log2Ceil(numWarps)
val sourceWidth = log2Ceil(numSourceIds)
val laneWidth = 4 * 8
@@ -20,9 +23,17 @@ class TensorCoreBlackwell(
val numFPRegBits = log2Ceil(numFPRegs)
val addressWidth = 32
val maskWidth = memWidth / 8
val fragOffsetBits = log2Ceil(memWidth / 8)
val numSets = 4
val numAFragsPerSet = 8
val numBGroups = 4
val numBFragsPerGroup = 2
val numMGroups = 4
val numCFrags = 32
object Ops {
val bwgmma :: bwgmmaWait :: tcgen05Cp :: tcgen05CpWait :: tcgen05Ld :: tcgen05St :: Nil = Enum(6)
val bwgmma :: bwgmmaWait :: tcgen05Cp :: tcgen05CpWait :: tcgen05Ld :: tcgen05St :: tcgen05Cb :: Nil = Enum(7)
}
class TensorMemReq(
@@ -44,6 +55,17 @@ class TensorCoreBlackwell(
val data = UInt(dataWidth.W)
}
// Direct SRAM port for TMEM (no TileLink overhead)
class TmemSramPort extends Bundle {
val wen = Output(Bool())
val ren = Output(Bool())
val waddr = Output(UInt(log2Ceil(numWarps * numCFrags * 2).W))
val raddr = Output(UInt(log2Ceil(numWarps * numCFrags * 2).W))
val wdata = Output(UInt(memWidth.W))
val mask = Output(UInt(maskWidth.W))
val rdata = Input(UInt(memWidth.W))
}
val io = IO(new Bundle {
val initiate = Flipped(Decoupled(new Bundle {
val op = UInt(3.W)
@@ -51,6 +73,7 @@ class TensorCoreBlackwell(
val rd = UInt(numFPRegBits.W)
val addressA = UInt(addressWidth.W)
val addressB = UInt(addressWidth.W)
val addressC = UInt(addressWidth.W)
}))
val writeback = Decoupled(new Bundle {
val last = Bool()
@@ -64,10 +87,14 @@ class TensorCoreBlackwell(
val reqA = Decoupled(new TensorMemReq(sourceWidth, memWidth))
val reqB = Decoupled(new TensorMemReq(sourceWidth, memWidth))
val reqC = Output(Valid(UInt(numFPRegBits.W)))
val tmemC = new TmemSramPort // direct SRAM for C matrix (replaces reqCmem/respCmem)
})
object State extends ChiselEnum {
val idle, bwReq, bwResp, cpRead, cpWrite, ldReq, stReq, waitWb = Value
val idle, bwLoadAReq, bwLoadAResp, bwLoadBReq, bwLoadBResp,
bwReadCReq, bwReadCResp, bwCompute, bwDpuResp, bwWriteCReq,
bwWriteCWait, bwDone, cpRead, cpWrite, ldReq, stReq, stWrite, waitWb,
cbRead, cbWrite = Value
}
val state = RegInit(State.idle)
@@ -76,16 +103,41 @@ class TensorCoreBlackwell(
val rdReg = RegInit(0.U(numFPRegBits.W))
val addrAReg = RegInit(0.U(addressWidth.W))
val addrBReg = RegInit(0.U(addressWidth.W))
val aDataReg = Reg(UInt(memWidth.W))
val bDataReg = Reg(UInt(memWidth.W))
val haveA = RegInit(false.B)
val haveB = RegInit(false.B)
val addrCReg = RegInit(0.U(addressWidth.W))
val sourceCounter = RegInit(0.U(sourceWidth.W))
val setReg = RegInit(0.U(log2Ceil(numSets).W))
val aIndexReg = RegInit(0.U(log2Ceil(numAFragsPerSet).W))
val bGroupReg = RegInit(0.U(log2Ceil(numBGroups).W))
val bIndexReg = RegInit(0.U(log2Ceil(numBFragsPerGroup).W))
val mGroupReg = RegInit(0.U(log2Ceil(numMGroups).W))
val substepReg = RegInit(0.U(1.W))
val elemReg = RegInit(0.U(log2Ceil(numLanes).W))
val waitCounter = RegInit(0.U(3.W))
val aBuf = Reg(Vec(numAFragsPerSet, UInt(memWidth.W)))
val bBuf = Reg(Vec(numBFragsPerGroup, UInt(memWidth.W)))
val cDataReg = Reg(UInt(memWidth.W))
val mmaDataReg = Reg(Vec(numLanes, UInt(laneWidth.W)))
private def bumpSource(): Unit = {
sourceCounter := sourceCounter + 1.U
}
private def byteAddress(base: UInt, fragIndex: UInt): UInt = {
base + (fragIndex << fragOffsetBits).asUInt
}
val aFragIndex = (setReg << 3) + aIndexReg
val bFragIndex = (setReg << 3) + (bGroupReg << 1) + bIndexReg
val stepIndex = Cat(bGroupReg, mGroupReg)
val cFragIndex = (stepIndex << 1) + substepReg
val aReqAddress = byteAddress(addrAReg, aFragIndex)
val bReqAddress = byteAddress(addrBReg, bFragIndex)
val cReqAddress = byteAddress(addrCReg, cFragIndex)
val tmemABase = (addrAReg >> fragOffsetBits.U).asUInt
val tmemCBase = (addrCReg >> fragOffsetBits.U).asUInt
val reqA = Wire(Decoupled(new TensorMemReq(sourceWidth, memWidth)))
val reqB = Wire(Decoupled(new TensorMemReq(sourceWidth, memWidth)))
reqA.valid := false.B
@@ -95,6 +147,13 @@ class TensorCoreBlackwell(
io.reqA <> reqA
io.reqB <> reqB
io.tmemC.wen := false.B
io.tmemC.ren := false.B
io.tmemC.waddr := 0.U
io.tmemC.raddr := 0.U
io.tmemC.wdata := 0.U
io.tmemC.mask := 0.U
val wbValid = RegInit(false.B)
val wbData = Reg(Vec(numLanes, UInt(laneWidth.W)))
io.writeback.valid := wbValid
@@ -106,10 +165,40 @@ class TensorCoreBlackwell(
io.reqC.valid := false.B
io.reqC.bits := rdReg
io.respA.ready := false.B
// drain stale write-ack from TMEM so TLRAM doesn't stall on r_full
io.respA.ready := state === State.idle
io.respB.ready := false.B
io.initiate.ready := state === State.idle && !wbValid
val operandA = Cat(aBuf((mGroupReg << 1) + 1.U), aBuf(mGroupReg << 1))
val operandB = bBuf(substepReg)
val cWords = cDataReg.asTypeOf(Vec(numLanes, UInt(laneWidth.W)))
val dpuInValid = WireDefault(false.B)
val dpu = Module(new TensorDotProductUnit(
dim = 8,
half = true
))
private def halfWord(x: UInt, idx: Int): UInt = {
x((idx + 1) * 16 - 1, idx * 16)
}
val elemM = elemReg(1, 0)
val elemN = elemReg(2)
dpu.io.in.valid := dpuInValid
for (k <- 0 until 8) {
dpu.io.in.bits.a(k) := MuxLookup(elemM, halfWord(operandA, k))(Seq(
0.U -> halfWord(operandA, k),
1.U -> halfWord(operandA, 8 + k),
2.U -> halfWord(operandA, 16 + k),
3.U -> halfWord(operandA, 24 + k)
))
dpu.io.in.bits.b(k) := Mux(elemN.asBool, halfWord(operandB, 8 + k), halfWord(operandB, k))
}
dpu.io.in.bits.c := cWords(elemReg)
dpu.io.stall := false.B
val dpuValid = dpu.io.out.valid
when(io.writeback.fire) {
wbValid := false.B
}
@@ -120,118 +209,215 @@ class TensorCoreBlackwell(
rdReg := io.initiate.bits.rd
addrAReg := io.initiate.bits.addressA
addrBReg := io.initiate.bits.addressB
haveA := false.B
haveB := false.B
addrCReg := io.initiate.bits.addressC
setReg := 0.U
aIndexReg := 0.U
bGroupReg := 0.U
bIndexReg := 0.U
mGroupReg := 0.U
substepReg := 0.U
elemReg := 0.U
switch(io.initiate.bits.op) {
is(Ops.bwgmma) { state := State.bwReq }
is(Ops.bwgmma) { state := State.bwLoadAReq }
is(Ops.tcgen05Cp) { state := State.cpRead }
is(Ops.tcgen05Ld) { state := State.ldReq }
is(Ops.tcgen05St) { state := State.stReq }
is(Ops.bwgmmaWait) { state := State.idle }
is(Ops.tcgen05CpWait) { state := State.idle }
is(Ops.tcgen05Cb) { state := State.cbRead }
}
}
when(state === State.bwReq) {
reqA.valid := true.B
reqA.bits.rw := false.B
reqA.bits.byteen := Fill(maskWidth, 1.U(1.W))
reqA.bits.address := addrAReg
reqA.bits.source := sourceCounter
when(state === State.bwLoadAReq) {
io.tmemC.ren := true.B
io.tmemC.raddr := tmemABase + aFragIndex
state := State.bwLoadAResp
}
when(state === State.bwLoadAResp) {
aBuf(aIndexReg) := io.tmemC.rdata
when(aIndexReg === (numAFragsPerSet - 1).U) {
bGroupReg := 0.U
bIndexReg := 0.U
state := State.bwLoadBReq
}.otherwise {
aIndexReg := aIndexReg + 1.U
state := State.bwLoadAReq
}
}
when(state === State.bwLoadBReq) {
reqB.valid := true.B
reqB.bits.rw := false.B
reqB.bits.byteen := Fill(maskWidth, 1.U(1.W))
reqB.bits.address := addrBReg
reqB.bits.address := bReqAddress
reqB.bits.source := sourceCounter
io.reqC.valid := true.B
when(reqA.fire && reqB.fire) {
when(reqB.fire) {
bumpSource()
state := State.bwResp
state := State.bwLoadBResp
}
}
when(state === State.bwResp) {
io.respA.ready := true.B
when(state === State.bwLoadBResp) {
io.respB.ready := true.B
when(io.respA.fire) {
aDataReg := io.respA.bits.data
haveA := true.B
}
when(io.respB.fire) {
bDataReg := io.respB.bits.data
haveB := true.B
}
when(haveA && haveB) {
val cWords = io.respC.asTypeOf(Vec(numLanes, UInt(laneWidth.W)))
val aWords = aDataReg.asTypeOf(Vec(numLanes, UInt(laneWidth.W)))
val bWords = bDataReg.asTypeOf(Vec(numLanes, UInt(laneWidth.W)))
for (i <- 0 until numLanes) {
wbData(i) := aWords(i) + bWords(i) + cWords(i)
bBuf(bIndexReg) := io.respB.bits.data
when(bIndexReg === (numBFragsPerGroup - 1).U) {
mGroupReg := 0.U
substepReg := 0.U
state := State.bwReadCReq
}.otherwise {
bIndexReg := bIndexReg + 1.U
state := State.bwLoadBReq
}
wbValid := true.B
state := State.idle
}
}
when(state === State.bwReadCReq) {
io.tmemC.ren := true.B
io.tmemC.raddr := tmemCBase + cFragIndex
state := State.bwReadCResp
}
when(state === State.bwReadCResp) {
cDataReg := io.tmemC.rdata
elemReg := 0.U
state := State.bwCompute
}
when(state === State.bwCompute) {
dpuInValid := true.B
state := State.bwDpuResp
}
when(state === State.bwDpuResp) {
when(dpuValid) {
mmaDataReg(elemReg) := dpu.io.out.bits.data
when(elemReg === (numLanes - 1).U) {
state := State.bwWriteCReq
}.otherwise {
elemReg := elemReg + 1.U
state := State.bwCompute
}
}
}
when(state === State.bwWriteCReq) {
io.tmemC.wen := true.B
io.tmemC.waddr := tmemCBase + cFragIndex
io.tmemC.wdata := mmaDataReg.asUInt
io.tmemC.mask := Fill(maskWidth, 1.U(1.W))
when(substepReg === 0.U) {
substepReg := 1.U
state := State.bwReadCReq
}.elsewhen(mGroupReg =/= (numMGroups - 1).U) {
substepReg := 0.U
mGroupReg := mGroupReg + 1.U
state := State.bwReadCReq
}.elsewhen(bGroupReg =/= (numBGroups - 1).U) {
substepReg := 0.U
mGroupReg := 0.U
bGroupReg := bGroupReg + 1.U
bIndexReg := 0.U
state := State.bwLoadBReq
}.elsewhen(setReg =/= (numSets - 1).U) {
substepReg := 0.U
mGroupReg := 0.U
bGroupReg := 0.U
bIndexReg := 0.U
setReg := setReg + 1.U
aIndexReg := 0.U
state := State.bwLoadAReq
}.otherwise {
waitCounter := 7.U
state := State.bwWriteCWait
}
}
when(state === State.bwWriteCWait) {
when(waitCounter === 0.U) {
state := State.bwDone
}.otherwise {
waitCounter := waitCounter - 1.U
}
}
when(state === State.bwDone) {
wbData := mmaDataReg
wbValid := true.B
state := State.idle
}
when(state === State.cpRead) {
reqB.valid := true.B
reqB.bits.rw := false.B
reqB.bits.byteen := Fill(maskWidth, 1.U(1.W))
reqB.bits.address := addrBReg
reqB.bits.source := sourceCounter
when(reqB.fire) {
reqA.valid := true.B
reqA.bits.rw := false.B
reqA.bits.byteen := Fill(maskWidth, 1.U(1.W))
reqA.bits.address := addrBReg
reqA.bits.source := sourceCounter
when(reqA.fire) {
bumpSource()
state := State.cpWrite
}
}
when(state === State.cpWrite) {
io.respB.ready := reqA.ready
reqA.valid := io.respB.valid
reqA.bits.rw := true.B
reqA.bits.byteen := Fill(maskWidth, 1.U(1.W))
reqA.bits.address := addrAReg
reqA.bits.source := sourceCounter
reqA.bits.data := io.respB.bits.data
when(reqA.fire) {
bumpSource()
io.respA.ready := true.B
when(io.respA.fire) {
io.tmemC.wen := true.B
io.tmemC.waddr := (addrAReg >> fragOffsetBits.U).asUInt
io.tmemC.wdata := io.respA.bits.data
io.tmemC.mask := Fill(maskWidth, 1.U(1.W))
state := State.idle
}
}
when(state === State.ldReq) {
io.tmemC.ren := true.B
io.tmemC.raddr := (addrAReg >> fragOffsetBits.U).asUInt
state := State.waitWb
}
when(state === State.waitWb && opReg === Ops.tcgen05Ld) {
wbData := io.tmemC.rdata.asTypeOf(Vec(numLanes, UInt(laneWidth.W)))
wbValid := true.B
state := State.idle
}
when(state === State.stReq) {
io.reqC.valid := true.B
state := State.stWrite
}
when(state === State.stWrite) {
io.tmemC.wen := true.B
io.tmemC.waddr := (addrAReg >> fragOffsetBits.U).asUInt
io.tmemC.wdata := io.respC
io.tmemC.mask := Fill(maskWidth, 1.U(1.W))
state := State.idle
}
when(state === State.cbRead) {
io.tmemC.ren := true.B
io.tmemC.raddr := (addrAReg >> fragOffsetBits.U).asUInt
state := State.cbWrite
}
when(state === State.cbWrite) {
reqA.valid := true.B
reqA.bits.rw := false.B
reqA.bits.rw := true.B
reqA.bits.byteen := Fill(maskWidth, 1.U(1.W))
reqA.bits.address := addrAReg
reqA.bits.address := addrBReg
reqA.bits.source := sourceCounter
reqA.bits.data := io.tmemC.rdata
when(reqA.fire) {
bumpSource()
state := State.waitWb
}
}
when(state === State.waitWb && opReg === Ops.tcgen05Ld) {
io.respA.ready := !wbValid
when(state === State.waitWb && opReg === Ops.tcgen05Cb) {
io.respA.ready := true.B
when(io.respA.fire) {
wbData := io.respA.bits.data.asTypeOf(Vec(numLanes, UInt(laneWidth.W)))
wbValid := true.B
state := State.idle
}
}
when(state === State.stReq) {
io.reqC.valid := true.B
reqA.valid := true.B
reqA.bits.rw := true.B
reqA.bits.byteen := Fill(maskWidth, 1.U(1.W))
reqA.bits.address := addrAReg
reqA.bits.source := sourceCounter
reqA.bits.data := io.respC
when(reqA.fire) {
bumpSource()
state := State.idle
}
}

View File

@@ -201,8 +201,10 @@ class DotProductPipe(dim: Int, inputType: tile.FType, outputType: tile.FType) ex
// pipeline and connect outputs to the next stage
outputs := StallingPipe(io.stall, inputs.valid, VecInit(addOuts))
outC := StallingPipe(io.stall, inputs.valid, inC.bits)
assert(inputs.valid === inC.valid,
"adder inputs valid and C pipe valid went out-of-sync")
when (inputs.valid =/= inC.valid) {
printf("WARN: DotProductPipe input/C valid mismatch: inputs=%d c=%d\n",
inputs.valid, inC.valid)
}
(outputs, outC)
}

View File

@@ -51,6 +51,7 @@ class WithRadianceCores(
tensorCoreFP16: Boolean,
tensorCoreDecoupled: Boolean,
tensorCoreBlackwell: Boolean,
startupAddress: BigInt,
useVxCache: Boolean
) extends Config((site, _, up) => {
case TilesLocated(`location`) => {
@@ -61,7 +62,8 @@ class WithRadianceCores(
core = VortexCoreParams(
tensorCoreFP16 = tensorCoreFP16,
tensorCoreDecoupled = tensorCoreDecoupled,
tensorCoreBlackwell = tensorCoreBlackwell
tensorCoreBlackwell = tensorCoreBlackwell,
startupAddress = startupAddress
),
btb = None,
useVxCache = useVxCache,
@@ -99,6 +101,7 @@ class WithRadianceCores(
def this(n: Int, location: HierarchicalLocation = InSubsystem,
tensorCoreFP16: Boolean = false, tensorCoreDecoupled: Boolean = false,
tensorCoreBlackwell: Boolean = false,
startupAddress: BigInt = BigInt("10100", 16),
useVxCache: Boolean = false)
= this(n, location, RocketCrossingParams(
master = HierarchicalElementMasterPortParams.locationDefault(location),
@@ -107,7 +110,7 @@ class WithRadianceCores(
case InSubsystem => CBUS
case InCluster(clusterId) => CCBUS(clusterId)
}
), tensorCoreFP16, tensorCoreDecoupled, tensorCoreBlackwell, useVxCache)
), tensorCoreFP16, tensorCoreDecoupled, tensorCoreBlackwell, startupAddress, useVxCache)
}
class WithBlackwellTensorCore(location: HierarchicalLocation = InSubsystem) extends Config((site, _, up) => {

View File

@@ -102,6 +102,7 @@ case class VortexCoreParams(
tensorCoreFP16: Boolean = false, // FP16 if true, FP32 if false
tensorCoreDecoupled: Boolean = false, // hopper-style SMEM operand decoupling
tensorCoreBlackwell: Boolean = false, // blackwell-style TMEM + SMEM tensor core
startupAddress: BigInt = BigInt("10100", 16), // initial warp PC programmed through startup DCRs
debugROB: Boolean = false, // if enabled, uses a C++ debug ROB to generate trace-with-wdata
haveCease: Boolean = true, // non-standard CEASE instruction
haveSimTimeout: Boolean = true // add plusarg for simulation timeout
@@ -292,50 +293,30 @@ class RadianceTile private (
masters = Seq(TLMasterParameters.v2(
name = s"rad_tc_${radianceParams.coreId}_$i",
sourceId = IdRange(0, 1 << smemSourceWidth),
supports = TLSlaveToMasterTransferSizes(
probe = TransferSizes(1, tcSmemSize),
get = TransferSizes(1, tcSmemSize),
),
requestFifo = true
))
)))
}
val tmemNodes = Seq.tabulate(if (radianceParams.core.tensorCoreBlackwell) 2 else 0) { i =>
TLClientNode(Seq(TLMasterPortParameters.v2(
masters = Seq(TLMasterParameters.v2(
name = s"rad_tmem_${radianceParams.coreId}_$i",
sourceId = IdRange(0, 1 << smemSourceWidth),
supports = TLSlaveToMasterTransferSizes(
probe = TransferSizes(1, tcSmemSize),
get = TransferSizes(1, tcSmemSize),
putFull = TransferSizes(1, tcSmemSize),
putPartial = TransferSizes(1, tcSmemSize),
),
requestFifo = true
))
)))
}
val tmemNode = if (radianceParams.core.tensorCoreBlackwell) {
Some(LazyModule(new TLRAM(
address = AddressSet(0x0, 0x3fff),
beatBytes = tcSmemSize
// For Blackwell, tcSmemNodes accesses SMEM (bwgmma B operand)
// tcGmemNode provides global memory access for cp (global→tmem) and cb (tmem→global)
val tcGmemNode = if (radianceParams.core.tensorCoreBlackwell) Some(TLClientNode(Seq(
TLMasterPortParameters.v2(masters = Seq(TLMasterParameters.v2(
name = s"rad_tc_gmem_${radianceParams.coreId}",
sourceId = IdRange(0, 1 << dmemSourceWidth),
supports = TLSlaveToMasterTransferSizes(
probe = TransferSizes(1, tcSmemSize),
get = TransferSizes(1, tcSmemSize),
putFull = TransferSizes(1, tcSmemSize),
),
requestFifo = true
)))
} else {
None
}
val tmemXbar = if (radianceParams.core.tensorCoreBlackwell) {
Some(LazyModule(new TLXbar))
} else {
None
}
(tmemNode, tmemXbar) match {
case (Some(tmem), Some(xbar)) =>
tmem.node :=* xbar.node
tmemNodes.foreach { node => xbar.node :=* node }
case _ =>
}
))) else None
// combine outgoing per-lane dmemNode into 1 idenity node
//
@@ -425,6 +406,7 @@ class RadianceTile private (
// imemNodes.foreach { tlMasterXbar.node := TLWidthWidget(4) := _ }
tlMasterXbar.node :=* AddressOrNode(base) :=* icacheNode
tlMasterXbar.node :=* AddressOrNode(base) :=* dcacheNode
tcGmemNode.foreach { n => tlMasterXbar.node := AddressOrNode(base) := n }
}
/* below are copied from rocket */
@@ -828,12 +810,12 @@ class RadianceTileModuleImp(outer: RadianceTile)
adapter.io.outResp <> client._1.d
adapter
}
core.io.tc_a_ready := Cat(adapters.last.io.inReq.ready, adapters.head.io.inReq.ready)
core.io.tc_d_valid := Cat(adapters.last.io.inResp.valid, adapters.head.io.inResp.valid)
core.io.tc_d_bits_data := Cat(adapters.last.io.inResp.bits.data, adapters.head.io.inResp.bits.data)
core.io.tc_d_bits_tag := Cat(adapters.last.io.inResp.bits.source, adapters.head.io.inResp.bits.source)
require(core.io.tc_d_bits_data.widthOption.get == adapters.head.io.inResp.bits.data.widthOption.get * 2)
require(core.io.tc_d_bits_tag.widthOption.get == adapters.head.io.inResp.bits.source.widthOption.get * 2)
core.io.tc_a_ready := Cat(0.U(1.W), adapters.last.io.inReq.ready, adapters.head.io.inReq.ready)
core.io.tc_d_valid := Cat(0.U(1.W), adapters.last.io.inResp.valid, adapters.head.io.inResp.valid)
core.io.tc_d_bits_data := Cat(0.U((32 * 8).W), adapters.last.io.inResp.bits.data, adapters.head.io.inResp.bits.data)
core.io.tc_d_bits_tag := Cat(0.U(outer.tensorTagWidth.W), adapters.last.io.inResp.bits.source, adapters.head.io.inResp.bits.source)
require(core.io.tc_d_bits_data.widthOption.get == adapters.head.io.inResp.bits.data.widthOption.get * 3)
require(core.io.tc_d_bits_tag.widthOption.get == adapters.head.io.inResp.bits.source.widthOption.get * 3)
} else {
core.io.tc_a_ready := false.B
core.io.tc_d_valid := false.B
@@ -844,66 +826,82 @@ class RadianceTileModuleImp(outer: RadianceTile)
def connectTensorBlackwell = {
if (outer.radianceParams.core.tensorCoreBlackwell) {
require(outer.tmemNodes.nonEmpty)
require(outer.tcSmemNodes.nonEmpty)
val bundles = Seq(
(outer.tmemNodes.head, new {
val addr = core.io.tc_a_bits_address(31, 0)
val tag = core.io.tc_a_bits_tag(outer.tensorTagWidth - 1, 0)
val write = core.io.tc_a_bits_write(0)
val mask = core.io.tc_a_bits_mask(31, 0)
val data = core.io.tc_a_bits_data(255, 0)
val aValid = core.io.tc_a_valid(0)
val dReady = core.io.tc_d_ready(0)
}),
(outer.tcSmemNodes.head, new {
val addr = core.io.tc_a_bits_address(63, 32)
val tag = core.io.tc_a_bits_tag(4 + outer.tensorTagWidth - 1, 4)
val write = core.io.tc_a_bits_write(1)
val mask = core.io.tc_a_bits_mask(63, 32)
val data = core.io.tc_a_bits_data(511, 256)
val aValid = core.io.tc_a_valid(1)
val dReady = core.io.tc_d_ready(1)
})
)
// TMEM C matrix: direct SRAM (no TileLink), connected via VortexCore IO
// Each warp needs 2 tiles (A + C), each tile = 32 frags × 32B = 1KB
val tmemDepth = outer.numWarps * outer.tcSmemSize * 2 // numWarps × 64 rows
val tmem = Module(new radiance.memory.TwoReadOneWriteSyncMem(
tmemDepth, UInt((outer.tcSmemSize * 8).W)))
tmem.io.ren0 := core.io.tc_tmem_C_ren
tmem.io.raddr0 := core.io.tc_tmem_C_raddr
core.io.tc_tmem_C_rdata := tmem.io.rdata0
tmem.io.ren1 := false.B
tmem.io.raddr1 := 0.U
tmem.io.wen := core.io.tc_tmem_C_wen
tmem.io.waddr := core.io.tc_tmem_C_waddr
tmem.io.wdata := core.io.tc_tmem_C_wdata
tmem.io.mask := core.io.tc_tmem_C_mask
val adapters = bundles.map { case (node, bundle) =>
val client = node.out.head
val adapter = Module(
new VortexTLAdapter(
outer.smemSourceWidth,
new VortexBundleA(tagWidth = outer.tensorTagWidth, dataWidth = 32 * 8),
new VortexBundleD(tagWidth = outer.tensorTagWidth, dataWidth = 32 * 8),
client
)
)
require(adapter.io.inReq.bits.source.widthOption.get == bundle.tag.widthOption.get)
require(adapter.io.inReq.bits.address.widthOption.get == bundle.addr.widthOption.get)
adapter.io.inReq.bits <> DontCare
adapter.io.inReq.valid := bundle.aValid
adapter.io.inReq.bits.address := bundle.addr
adapter.io.inReq.bits.source := bundle.tag
adapter.io.inReq.bits.size := 5.U
adapter.io.inReq.bits.opcode := Mux(bundle.write.asBool, TLMessages.PutFullData, TLMessages.Get)
adapter.io.inReq.bits.mask := bundle.mask
adapter.io.inReq.bits.data := bundle.data
adapter.io.inResp.ready := bundle.dReady
client._1.a <> adapter.io.outReq
adapter.io.outResp <> client._1.d
adapter
// smem_B (port 2): Global Memory via TileLink
val smemBBundle = new {
val addr = core.io.tc_a_bits_address(95, 64)
val tag = core.io.tc_a_bits_tag(8 + outer.tensorTagWidth - 1, 8)
val write = core.io.tc_a_bits_write(2)
val mask = core.io.tc_a_bits_mask(95, 64)
val data = core.io.tc_a_bits_data(767, 512)
val aValid = core.io.tc_a_valid(2)
val dReady = core.io.tc_d_ready(2)
}
val client = outer.tcSmemNodes.head.out.head
val adapter = Module(new VortexTLAdapter(
outer.smemSourceWidth,
new VortexBundleA(tagWidth = outer.tensorTagWidth, dataWidth = 32 * 8),
new VortexBundleD(tagWidth = outer.tensorTagWidth, dataWidth = 32 * 8),
client
))
adapter.io.inReq.bits <> DontCare
adapter.io.inReq.valid := smemBBundle.aValid
adapter.io.inReq.bits.address := smemBBundle.addr
adapter.io.inReq.bits.source := smemBBundle.tag
adapter.io.inReq.bits.size := 5.U
adapter.io.inReq.bits.opcode := Mux(smemBBundle.write.asBool, TLMessages.PutFullData, TLMessages.Get)
adapter.io.inReq.bits.mask := smemBBundle.mask
adapter.io.inReq.bits.data := smemBBundle.data
adapter.io.inResp.ready := smemBBundle.dReady
client._1.a <> adapter.io.outReq
adapter.io.outResp <> client._1.d
core.io.tc_a_ready := Cat(adapters.last.io.inReq.ready, adapters.head.io.inReq.ready)
core.io.tc_d_valid := Cat(adapters.last.io.inResp.valid, adapters.head.io.inResp.valid)
core.io.tc_d_bits_data := Cat(adapters.last.io.inResp.bits.data, adapters.head.io.inResp.bits.data)
core.io.tc_d_bits_tag := Cat(adapters.last.io.inResp.bits.source, adapters.head.io.inResp.bits.source)
// port 0: global memory (cp/cb)
val gmemClient = outer.tcGmemNode.get.out.head
val gmemAdapter = Module(new VortexTLAdapter(
outer.dmemSourceWidth,
new VortexBundleA(tagWidth = outer.tensorTagWidth, dataWidth = 32 * 8),
new VortexBundleD(tagWidth = outer.tensorTagWidth, dataWidth = 32 * 8),
gmemClient
))
gmemAdapter.io.inReq.bits <> DontCare
gmemAdapter.io.inReq.valid := core.io.tc_a_valid(0)
gmemAdapter.io.inReq.bits.address := core.io.tc_a_bits_address(31, 0)
gmemAdapter.io.inReq.bits.source := core.io.tc_a_bits_tag(outer.tensorTagWidth - 1, 0)
gmemAdapter.io.inReq.bits.size := 5.U
gmemAdapter.io.inReq.bits.opcode := Mux(core.io.tc_a_bits_write(0).asBool, TLMessages.PutFullData, TLMessages.Get)
gmemAdapter.io.inReq.bits.mask := core.io.tc_a_bits_mask(31, 0)
gmemAdapter.io.inReq.bits.data := core.io.tc_a_bits_data(255, 0)
gmemAdapter.io.inResp.ready := core.io.tc_d_ready(0)
gmemClient._1.a <> gmemAdapter.io.outReq
gmemAdapter.io.outResp <> gmemClient._1.d
core.io.tc_a_ready := Cat(adapter.io.inReq.ready, 0.U(1.W), gmemAdapter.io.inReq.ready)
core.io.tc_d_valid := Cat(adapter.io.inResp.valid, 0.U(1.W), gmemAdapter.io.inResp.valid)
core.io.tc_d_bits_data := Cat(adapter.io.inResp.bits.data, 0.U((outer.tcSmemSize * 8).W), gmemAdapter.io.inResp.bits.data)
core.io.tc_d_bits_tag := Cat(adapter.io.inResp.bits.source, 0.U(outer.tensorTagWidth.W), gmemAdapter.io.inResp.bits.source)
} else {
core.io.tc_a_ready := false.B
core.io.tc_d_valid := false.B
core.io.tc_a_ready := false.B
core.io.tc_d_valid := false.B
core.io.tc_d_bits_data := DontCare
core.io.tc_d_bits_tag := DontCare
core.io.tc_d_bits_tag := DontCare
core.io.tc_tmem_C_rdata := DontCare
}
}
@@ -993,6 +991,7 @@ class RadianceTileModuleImp(outer: RadianceTile)
tensor.io.reqA.ready := false.B
tensor.io.reqB.ready := false.B
tensor.io.writeback.ready := false.B
dontTouch(tensor.io)
} else if (outer.radianceParams.core.tensorCoreBlackwell) {
val tensorNumSourceIds = (1 << outer.tensorTagWidth)
val tensor = Module(new radiance.core.TensorCoreBlackwell(
@@ -1007,6 +1006,8 @@ class RadianceTileModuleImp(outer: RadianceTile)
tensor.io.reqA.ready := false.B
tensor.io.reqB.ready := false.B
tensor.io.writeback.ready := false.B
tensor.io.tmemC.rdata := DontCare
dontTouch(tensor.io)
} else {
if (outer.radianceParams.core.tensorCoreFP16) {
val dpu = Module(new radiance.core.TensorDotProductUnit(4, half = true))

View File

@@ -90,17 +90,28 @@ class VortexBundle(tile: RadianceTile)(implicit p: Parameters) extends CoreBundl
val smem_d_bits_data = Input(UInt((tile.numLsuLanes * 32).W))
val smem_d_ready = Output(UInt((tile.numLsuLanes * 1).W))
val tc_a_valid = Output(UInt(2.W))
val tc_a_bits_write = Output(UInt(2.W))
val tc_a_bits_address = Output(UInt((2 * 32).W))
val tc_a_bits_tag = Output(UInt((2 * 4).W))
val tc_a_bits_mask = Output(UInt((2 * 32).W))
val tc_a_bits_data = Output(UInt((2 * 32 * 8).W))
val tc_a_ready = Input(UInt(2.W))
val tc_d_valid = Input(UInt(2.W))
val tc_d_bits_data = Input(UInt((2 * 32 * 8).W))
val tc_d_bits_tag = Input(UInt((2 * 4).W))
val tc_d_ready = Output(UInt(2.W))
val tcPortCount = 3
val tc_a_valid = Output(UInt(tcPortCount.W))
val tc_a_bits_write = Output(UInt(tcPortCount.W))
val tc_a_bits_address = Output(UInt((tcPortCount * 32).W))
val tc_a_bits_tag = Output(UInt((tcPortCount * 4).W))
val tc_a_bits_mask = Output(UInt((tcPortCount * 32).W))
val tc_a_bits_data = Output(UInt((tcPortCount * 32 * 8).W))
val tc_a_ready = Input(UInt(tcPortCount.W))
val tc_d_valid = Input(UInt(tcPortCount.W))
val tc_d_bits_data = Input(UInt((tcPortCount * 32 * 8).W))
val tc_d_bits_tag = Input(UInt((tcPortCount * 4).W))
val tc_d_ready = Output(UInt(tcPortCount.W))
// Direct SRAM port for TMEM C (bypasses TileLink)
val numLanes = tile.numLsuLanes
val tc_tmem_C_wen = Output(Bool())
val tc_tmem_C_ren = Output(Bool())
val tc_tmem_C_waddr = Output(UInt(9.W))
val tc_tmem_C_raddr = Output(UInt(9.W))
val tc_tmem_C_wdata = Output(UInt((numLanes * 32).W))
val tc_tmem_C_mask = Output(UInt((numLanes * 4).W))
val tc_tmem_C_rdata = Input(UInt((numLanes * 32).W))
// FIXME: hardcoded
val barrierIdBits = tile.barrierMasterNode.out(0)._2.barrierIdBits
@@ -135,8 +146,7 @@ class Vortex(tile: RadianceTile)(implicit p: Parameters)
Map(
"CORE_ID" -> tile.radianceParams.coreId,
"TENSOR_FP16" -> (if (tile.radianceParams.core.tensorCoreFP16) 1 else 0),
// TODO: can we get this as a parameter?
"BOOTROM_HANG100" -> 0x10100,
"STARTUP_ADDR" -> tile.radianceParams.core.startupAddress,
"NUM_THREADS" -> tile.numLsuLanes
)
)
@@ -449,7 +459,9 @@ class Vortex(tile: RadianceTile)(implicit p: Parameters)
addResource("/vsrc/vortex/hw/rtl/core/VX_uop_sequencer.sv")
addResource("/vsrc/vortex/hw/rtl/core/VX_reduce_unit.sv")
addResource("/vsrc/vortex/hw/rtl/fpu/VX_tensor_dpu.sv")
if (!tile.radianceParams.core.tensorCoreBlackwell) {
addResource("/vsrc/vortex/hw/rtl/fpu/VX_tensor_dpu.sv")
}
if (tile.radianceParams.useVxCache) {
addResource("/vsrc/vortex/hw/rtl/libs/VX_pending_size.sv")

View File

@@ -0,0 +1,338 @@
package radiance.core
import chisel3._
import chiseltest._
import chiseltest.simulator.VerilatorBackendAnnotation
import org.scalatest.flatspec.AnyFlatSpec
import scala.collection.mutable
class TensorCoreBlackwellExtendedTest extends AnyFlatSpec with ChiselScalatestTester {
behavior of "TensorCoreBlackwell Extended Tests"
private val numWarps = 4
private val numLanes = 8
private val fragBytes = 32
private def idleIO(c: TensorCoreBlackwell): Unit = {
c.io.initiate.valid.poke(false.B)
c.io.respA.valid.poke(false.B)
c.io.respB.valid.poke(false.B)
c.io.respA.bits.source.poke(0.U)
c.io.respB.bits.source.poke(0.U)
c.io.respA.bits.data.poke(0.U)
c.io.respB.bits.data.poke(0.U)
c.io.reqA.ready.poke(false.B)
c.io.reqB.ready.poke(false.B)
c.io.respC.poke(0.U)
c.io.writeback.ready.poke(false.B)
c.io.tmemC.rdata.poke(0.U)
}
private def packWords(words: Seq[BigInt], width: Int): BigInt = {
val mask = (BigInt(1) << width) - 1
words.zipWithIndex.foldLeft(BigInt(0)) {
case (acc, (word, i)) => acc | ((word & mask) << (i * width))
}
}
private def makeTmem() = mutable.Map[BigInt, BigInt]().withDefaultValue(BigInt(0))
private def stepTmem(c: TensorCoreBlackwell, tmem: mutable.Map[BigInt, BigInt]): Unit = {
if (c.io.tmemC.ren.peek().litToBoolean) {
val addr = c.io.tmemC.raddr.peek().litValue
c.io.tmemC.rdata.poke(tmem(addr).U)
}
if (c.io.tmemC.wen.peek().litToBoolean) {
val addr = c.io.tmemC.waddr.peek().litValue
tmem(addr) = c.io.tmemC.wdata.peek().litValue
}
}
it should "verify bwgmma address offset with non-zero base addresses" in {
test(new TensorCoreBlackwell(numWarps, numLanes, half = true, numSourceIds = 4))
.withAnnotations(Seq(VerilatorBackendAnnotation)) { c =>
idleIO(c)
val tmem = makeTmem()
// Use non-zero base addresses to verify offset calculation
val aBase = BigInt(0x200) // row 16, A tile rows 16~47
val cBase = BigInt(0x600) // row 48, C tile rows 48~79 (no overlap with A)
val bBase = BigInt(0x800)
val fp16One = BigInt(0x3c00)
val fp32Zero = BigInt(0)
// 4 sets × 8 dot products × (1.0 × 2.0) = 64.0f
val fp32SixtyFour = BigInt(0x42800000L)
// Populate TMEM A at offset aBase (all 32 frags)
val aFrag = packWords(Seq.fill(16)(fp16One), 16)
val cFrag = packWords(Seq.fill(numLanes)(fp32Zero), 32)
for (i <- 0 until 32) {
tmem(aBase / fragBytes + i) = aFrag
tmem(cBase / fragBytes + i) = cFrag
}
// SMEM B with fp16 2.0
val fp16Two = BigInt(0x4000)
val bFrag = packWords(Seq.fill(16)(fp16Two), 16)
val bMem = mutable.Map[BigInt, BigInt]().withDefaultValue(bFrag)
for (i <- 0 until 32) bMem(bBase + i * fragBytes) = bFrag
c.io.reqB.ready.poke(true.B)
c.io.writeback.ready.poke(true.B)
c.io.initiate.valid.poke(true.B)
c.io.initiate.bits.op.poke(0.U)
c.io.initiate.bits.wid.poke(0.U)
c.io.initiate.bits.rd.poke(0.U)
c.io.initiate.bits.addressA.poke(aBase.U)
c.io.initiate.bits.addressB.poke(bBase.U)
c.io.initiate.bits.addressC.poke(cBase.U)
c.clock.step()
c.io.initiate.valid.poke(false.B)
var pendingB = Option.empty[(BigInt, BigInt)]
var sawWriteback = false
for (_ <- 0 until 50000 if !sawWriteback) {
stepTmem(c, tmem)
pendingB.foreach { case (src, data) =>
c.io.respB.valid.poke(true.B)
c.io.respB.bits.source.poke(src.U)
c.io.respB.bits.data.poke(data.U)
}
if (pendingB.isEmpty) c.io.respB.valid.poke(false.B)
if (c.io.writeback.valid.peek().litToBoolean) {
sawWriteback = true
} else {
val nextB = if (c.io.reqB.valid.peek().litToBoolean) {
val addr = c.io.reqB.bits.address.peek().litValue
val src = c.io.reqB.bits.source.peek().litValue
Some((src, bMem(addr)))
} else None
c.clock.step()
pendingB = nextB
}
}
assert(sawWriteback, "BWGMMA did not complete")
val expectedC = packWords(Seq.fill(numLanes)(fp32SixtyFour), 32)
for (i <- 0 until 8) {
val row = cBase / fragBytes + i
assert(tmem(row) == expectedC,
s"C frag $i at row $row: got 0x${tmem(row).toString(16)}, expected 0x${expectedC.toString(16)}")
}
for (i <- 0 until 8) {
assert(tmem(aBase / fragBytes + i) == aFrag, s"A frag $i should be unchanged")
}
}
}
it should "cp then ld round-trip: data written via cp is readable via ld" in {
test(new TensorCoreBlackwell(numWarps, numLanes, half = true, numSourceIds = 4)) { c =>
idleIO(c)
val tmem = makeTmem()
val tmemAddr = BigInt(0x100)
val cpData = packWords(Seq.tabulate(numLanes)(i => BigInt(0xABCD0000L + i)), 32)
// Issue cp: global mem -> tmem
c.io.initiate.valid.poke(true.B)
c.io.initiate.bits.op.poke(2.U)
c.io.initiate.bits.addressA.poke(tmemAddr.U)
c.io.initiate.bits.addressB.poke("h10000000".U)
c.io.reqA.ready.poke(true.B)
c.clock.step()
c.io.initiate.valid.poke(false.B)
// cpRead: reqA issued
c.io.reqA.valid.expect(true.B)
c.io.reqA.bits.rw.expect(false.B)
c.clock.step()
// cpWrite: respA fires, tmemC written
c.io.respA.valid.poke(true.B)
c.io.respA.bits.data.poke(cpData.U)
c.io.tmemC.wen.expect(true.B)
c.io.tmemC.waddr.expect((tmemAddr / fragBytes).U)
c.io.tmemC.wdata.expect(cpData.U)
stepTmem(c, tmem)
c.clock.step()
c.io.respA.valid.poke(false.B)
// Now issue ld from same tmem address
c.io.initiate.valid.poke(true.B)
c.io.initiate.bits.op.poke(4.U)
c.io.initiate.bits.rd.poke(2.U)
c.io.initiate.bits.addressA.poke(tmemAddr.U)
c.io.writeback.ready.poke(true.B)
c.clock.step()
c.io.initiate.valid.poke(false.B)
// ldReq: ren asserted, serve from tmem model
c.io.tmemC.ren.expect(true.B)
c.io.tmemC.rdata.poke(tmem(tmemAddr / fragBytes).U)
c.clock.step()
c.io.tmemC.rdata.poke(tmem(tmemAddr / fragBytes).U)
c.clock.step()
// writeback should carry cpData
c.io.writeback.valid.expect(true.B)
for (i <- 0 until numLanes) {
c.io.writeback.bits.data(i).expect((BigInt(0xABCD0000L) + i).U)
}
}
}
it should "st then cb round-trip: data written via st is readable via cb" in {
test(new TensorCoreBlackwell(numWarps, numLanes, half = true, numSourceIds = 4)) { c =>
idleIO(c)
val tmem = makeTmem()
val tmemAddr = BigInt(0x140)
val stData = packWords(Seq.tabulate(numLanes)(i => BigInt(0xDEAD0000L + i)), 32)
// Issue st: respC -> tmem
c.io.initiate.valid.poke(true.B)
c.io.initiate.bits.op.poke(5.U)
c.io.initiate.bits.rd.poke(4.U)
c.io.initiate.bits.addressA.poke(tmemAddr.U)
c.io.respC.poke(stData.U)
c.clock.step()
c.io.initiate.valid.poke(false.B)
// stReq: reqC valid
c.io.reqC.valid.expect(true.B)
c.clock.step()
// stWrite: tmemC written
c.io.tmemC.wen.expect(true.B)
c.io.tmemC.wdata.expect(stData.U)
stepTmem(c, tmem)
c.clock.step()
// Issue cb: tmem -> global mem
c.io.initiate.valid.poke(true.B)
c.io.initiate.bits.op.poke(6.U)
c.io.initiate.bits.addressA.poke(tmemAddr.U)
c.io.initiate.bits.addressB.poke("h20000000".U)
c.io.reqA.ready.poke(true.B)
c.io.tmemC.rdata.poke(tmem(tmemAddr / fragBytes).U)
c.clock.step()
c.io.initiate.valid.poke(false.B)
// cbRead: ren asserted
c.io.tmemC.ren.expect(true.B)
c.io.tmemC.rdata.poke(tmem(tmemAddr / fragBytes).U)
c.clock.step()
// cbWrite: reqA write with stData
c.io.reqA.valid.expect(true.B)
c.io.reqA.bits.rw.expect(true.B)
c.io.reqA.bits.address.expect("h20000000".U)
c.io.reqA.bits.data.expect(stData.U)
}
}
it should "wait ops are no-ops and do not stall pipeline" in {
test(new TensorCoreBlackwell(numWarps, numLanes, half = true, numSourceIds = 4)) { c =>
idleIO(c)
// bwgmmaWait: should accept immediately and stay idle
c.io.initiate.valid.poke(true.B)
c.io.initiate.bits.op.poke(1.U) // bwgmmaWait
c.io.initiate.ready.expect(true.B)
c.clock.step()
c.io.initiate.valid.poke(false.B)
c.io.writeback.valid.expect(false.B)
c.io.reqA.valid.expect(false.B)
c.io.reqB.valid.expect(false.B)
// tcgen05CpWait: same
c.io.initiate.valid.poke(true.B)
c.io.initiate.bits.op.poke(3.U) // tcgen05CpWait
c.io.initiate.ready.expect(true.B)
c.clock.step()
c.io.initiate.valid.poke(false.B)
c.io.writeback.valid.expect(false.B)
c.io.reqA.valid.expect(false.B)
}
}
it should "not accept a second tensor op until the first one completes" in {
test(new TensorCoreBlackwell(numWarps, numLanes, half = true, numSourceIds = 4)) { c =>
idleIO(c)
val firstAddr = BigInt(0x180)
val secondAddr = BigInt(0x1a0)
val storeData = packWords(Seq.tabulate(numLanes)(i => BigInt(0xCAFE0000L + i)), 32)
c.io.initiate.valid.poke(true.B)
c.io.initiate.bits.op.poke(5.U)
c.io.initiate.bits.addressA.poke(firstAddr.U)
c.io.respC.poke(storeData.U)
c.io.initiate.ready.expect(true.B)
c.clock.step()
c.io.initiate.bits.op.poke(4.U)
c.io.initiate.bits.addressA.poke(secondAddr.U)
c.io.initiate.bits.rd.poke(2.U)
c.io.initiate.ready.expect(false.B)
c.clock.step()
c.io.initiate.ready.expect(false.B)
c.io.tmemC.wen.expect(true.B)
c.clock.step()
c.io.initiate.ready.expect(true.B)
}
}
it should "multi-warp TMEM isolation: warp 0 and warp 3 do not alias" in {
test(new TensorCoreBlackwell(numWarps, numLanes, half = true, numSourceIds = 4)) { c =>
idleIO(c)
val tmem = makeTmem()
// warp 0: tmem_slot_base(0) = 0, tmem_a_base = 0
val warp0TmemA = BigInt(0x000)
val warp0Data = packWords(Seq.fill(numLanes)(BigInt(0xAAAAAAAAL)), 32)
// warp 3: tmem_slot_base(3) = 3*2048 = 6144 = 0x1800, tmem_a_base = 0x1800
val warp3TmemA = BigInt(0x1800)
val warp3Data = packWords(Seq.fill(numLanes)(BigInt(0xBBBBBBBBL)), 32)
// Write warp 0 data via st
c.io.initiate.valid.poke(true.B)
c.io.initiate.bits.op.poke(5.U)
c.io.initiate.bits.wid.poke(0.U)
c.io.initiate.bits.addressA.poke(warp0TmemA.U)
c.io.respC.poke(warp0Data.U)
c.clock.step()
c.io.initiate.valid.poke(false.B)
c.io.reqC.valid.expect(true.B)
c.clock.step()
c.io.tmemC.wen.expect(true.B)
c.io.tmemC.waddr.expect((warp0TmemA / fragBytes).U)
stepTmem(c, tmem)
c.clock.step()
// Write warp 3 data via st
c.io.initiate.valid.poke(true.B)
c.io.initiate.bits.op.poke(5.U)
c.io.initiate.bits.wid.poke(3.U)
c.io.initiate.bits.addressA.poke(warp3TmemA.U)
c.io.respC.poke(warp3Data.U)
c.clock.step()
c.io.initiate.valid.poke(false.B)
c.io.reqC.valid.expect(true.B)
c.clock.step()
c.io.tmemC.wen.expect(true.B)
c.io.tmemC.waddr.expect((warp3TmemA / fragBytes).U)
stepTmem(c, tmem)
c.clock.step()
// Verify no aliasing: warp 0 row != warp 3 row
assert(warp0TmemA / fragBytes != warp3TmemA / fragBytes)
assert(tmem(warp0TmemA / fragBytes) == warp0Data)
assert(tmem(warp3TmemA / fragBytes) == warp3Data)
}
}
}

View File

@@ -2,11 +2,17 @@ package radiance.core
import chisel3._
import chiseltest._
import chiseltest.simulator.VerilatorBackendAnnotation
import org.scalatest.flatspec.AnyFlatSpec
import scala.collection.mutable
class TensorCoreBlackwellTest extends AnyFlatSpec with ChiselScalatestTester {
behavior of "TensorCoreBlackwell"
private val numWarps = 4
private val numLanes = 8
private def idleIO(c: TensorCoreBlackwell): Unit = {
c.io.initiate.valid.poke(false.B)
c.io.respA.valid.poke(false.B)
@@ -15,111 +21,261 @@ class TensorCoreBlackwellTest extends AnyFlatSpec with ChiselScalatestTester {
c.io.respB.bits.source.poke(0.U)
c.io.respA.bits.data.poke(0.U)
c.io.respB.bits.data.poke(0.U)
c.io.reqA.ready.poke(false.B)
c.io.reqB.ready.poke(false.B)
c.io.respC.poke(0.U)
c.io.writeback.ready.poke(false.B)
c.io.tmemC.rdata.poke(0.U)
}
it should "run a minimal BWGMMA path" in {
test(new TensorCoreBlackwell(8, 8, numSourceIds = 4, half = true)) { c =>
idleIO(c)
c.io.initiate.valid.poke(true.B)
c.io.initiate.bits.op.poke(0.U)
c.io.initiate.bits.wid.poke(1.U)
c.io.initiate.bits.rd.poke(3.U)
c.io.initiate.bits.addressA.poke(0x40.U)
c.io.initiate.bits.addressB.poke(0x80.U)
c.io.reqA.ready.poke(true.B)
c.io.reqB.ready.poke(true.B)
c.io.respC.poke("h0000000800000007000000060000000500000004000000030000000200000001".U)
c.clock.step()
c.io.initiate.valid.poke(false.B)
c.io.reqA.valid.expect(true.B)
c.io.reqB.valid.expect(true.B)
c.clock.step()
c.io.respA.valid.poke(true.B)
c.io.respB.valid.poke(true.B)
c.io.respA.bits.data.poke("h0000000800000007000000060000000500000004000000030000000200000001".U)
c.io.respB.bits.data.poke("h000000100000000f0000000e0000000d0000000c0000000b0000000a00000009".U)
c.clock.step()
c.io.respA.valid.poke(false.B)
c.io.respB.valid.poke(false.B)
c.clock.step()
c.clock.step()
c.io.writeback.valid.expect(true.B)
c.io.writeback.bits.rd.expect(3.U)
c.io.writeback.bits.wid.expect(1.U)
c.io.writeback.ready.poke(true.B)
c.clock.step()
private def packWords(words: Seq[BigInt], width: Int): BigInt = {
val mask = (BigInt(1) << width) - 1
words.zipWithIndex.foldLeft(BigInt(0)) {
case (acc, (word, i)) => acc | ((word & mask) << (i * width))
}
}
it should "copy from SMEM to TMEM on TCGEN05_CP" in {
test(new TensorCoreBlackwell(8, 8, numSourceIds = 4, half = true)) { c =>
// Simple TMEM model: address → 256-bit row
private def makeTmem() = mutable.Map[BigInt, BigInt]().withDefaultValue(BigInt(0))
// Drive tmemC read response from model, handle write
private def stepTmem(c: TensorCoreBlackwell, tmem: mutable.Map[BigInt, BigInt]): Unit = {
if (c.io.tmemC.ren.peek().litToBoolean) {
val addr = c.io.tmemC.raddr.peek().litValue
c.io.tmemC.rdata.poke(tmem(addr).U)
}
if (c.io.tmemC.wen.peek().litToBoolean) {
val addr = c.io.tmemC.waddr.peek().litValue
tmem(addr) = c.io.tmemC.wdata.peek().litValue
}
}
it should "tcgen05_ld: read from TMEM to writeback" in {
test(new TensorCoreBlackwell(numWarps, numLanes, half = true, numSourceIds = 4)) { c =>
idleIO(c)
val tmem = makeTmem()
val fragBytes = 32
val tmemAddr = BigInt(0x40) // row 2 (0x40 / 32 = 2)
val testData = packWords(Seq.tabulate(numLanes)(i => BigInt(0x1000 + i)), 32)
tmem(tmemAddr / fragBytes) = testData
c.io.initiate.valid.poke(true.B)
c.io.initiate.bits.op.poke(2.U)
c.io.initiate.bits.op.poke(4.U) // tcgen05Ld
c.io.initiate.bits.wid.poke(0.U)
c.io.initiate.bits.rd.poke(0.U)
c.io.initiate.bits.addressA.poke(0x100.U)
c.io.initiate.bits.addressB.poke(0x200.U)
c.io.reqB.ready.poke(true.B)
c.io.initiate.bits.rd.poke(3.U)
c.io.initiate.bits.addressA.poke(tmemAddr.U)
c.io.writeback.ready.poke(true.B)
c.io.tmemC.rdata.poke(testData.U)
c.clock.step()
c.io.initiate.valid.poke(false.B)
c.io.initiate.ready.expect(false.B)
// ldReq: tmemC.ren asserted; rdata must be valid before next step
c.io.tmemC.ren.expect(true.B)
c.io.tmemC.raddr.expect((tmemAddr / fragBytes).U)
c.io.tmemC.rdata.poke(testData.U)
c.clock.step()
c.io.initiate.valid.poke(false.B)
c.io.reqB.valid.expect(true.B)
c.io.respB.valid.poke(true.B)
c.io.respB.bits.data.poke("hdeadbeef".U)
c.io.reqA.ready.poke(true.B)
// waitWb: wbValid gets set this cycle, step to let it register
c.io.tmemC.rdata.poke(testData.U)
c.clock.step()
c.io.reqA.valid.expect(true.B)
c.io.reqA.bits.rw.expect(true.B)
c.io.reqA.bits.address.expect(0x100.U)
// idle: writeback.valid now true
c.io.writeback.valid.expect(true.B)
c.io.initiate.ready.expect(false.B)
c.io.writeback.bits.rd.expect(3.U)
c.io.writeback.bits.wid.expect(0.U)
for (i <- 0 until numLanes) {
c.io.writeback.bits.data(i).expect((0x1000 + i).U)
}
}
}
it should "load and store fragments through TMEM" in {
test(new TensorCoreBlackwell(8, 8, numSourceIds = 4, half = true)) { c =>
it should "tcgen05_st: write from respC to TMEM" in {
test(new TensorCoreBlackwell(numWarps, numLanes, half = true, numSourceIds = 4)) { c =>
idleIO(c)
val fragBytes = 32
val tmemAddr = BigInt(0x60)
val storeData = packWords(Seq.tabulate(numLanes)(i => BigInt(0xAB00 + i)), 32)
c.io.initiate.valid.poke(true.B)
c.io.initiate.bits.op.poke(4.U)
c.io.initiate.bits.wid.poke(2.U)
c.io.initiate.bits.rd.poke(5.U)
c.io.initiate.bits.addressA.poke(0x300.U)
c.io.initiate.bits.addressB.poke(0.U)
c.io.reqA.ready.poke(true.B)
c.io.initiate.bits.op.poke(5.U) // tcgen05St
c.io.initiate.bits.wid.poke(0.U)
c.io.initiate.bits.rd.poke(7.U)
c.io.initiate.bits.addressA.poke(tmemAddr.U)
c.io.respC.poke(storeData.U)
c.clock.step()
c.io.initiate.valid.poke(false.B)
c.clock.step()
c.io.respA.valid.poke(true.B)
c.io.respA.bits.data.poke("h1234".U)
c.clock.step()
c.io.respA.valid.poke(false.B)
c.clock.step()
c.io.writeback.valid.expect(true.B)
c.io.writeback.bits.rd.expect(5.U)
c.io.writeback.ready.poke(true.B)
c.io.initiate.ready.expect(false.B)
// stReq: reqC.valid asserted
c.io.reqC.valid.expect(true.B)
c.io.reqC.bits.expect(7.U)
c.clock.step()
idleIO(c)
c.io.initiate.valid.poke(true.B)
c.io.initiate.bits.op.poke(5.U)
c.io.initiate.bits.wid.poke(2.U)
c.io.initiate.bits.rd.poke(6.U)
c.io.initiate.bits.addressA.poke(0x340.U)
c.io.initiate.bits.addressB.poke(0.U)
c.io.reqA.ready.poke(true.B)
c.io.respC.poke("habcd".U)
// stWrite: tmemC.wen asserted with storeData
c.io.tmemC.wen.expect(true.B)
c.io.tmemC.waddr.expect((tmemAddr / fragBytes).U)
c.io.tmemC.wdata.expect(storeData.U)
c.clock.step()
c.io.initiate.ready.expect(true.B)
}
}
it should "tcgen05_cp: read from global mem (reqA) and write to TMEM" in {
test(new TensorCoreBlackwell(numWarps, numLanes, half = true, numSourceIds = 4)) { c =>
idleIO(c)
val fragBytes = 32
val tmemAddr = BigInt(0x80)
val gmemAddr = "ha0001000"
val cpData = packWords(Seq.fill(numLanes)(BigInt(0xdeadbeefL)), 32)
c.io.initiate.valid.poke(true.B)
c.io.initiate.bits.op.poke(2.U) // tcgen05Cp
c.io.initiate.bits.addressA.poke(tmemAddr.U)
c.io.initiate.bits.addressB.poke(gmemAddr.U)
c.io.reqA.ready.poke(true.B)
c.clock.step()
c.io.initiate.valid.poke(false.B)
c.io.initiate.ready.expect(false.B)
// cpRead: reqA issued to global mem
c.io.reqA.valid.expect(true.B)
c.io.reqA.bits.rw.expect(false.B)
c.io.reqA.bits.address.expect(gmemAddr.U)
c.clock.step()
c.io.initiate.ready.expect(false.B)
// cpWrite: respA fires → tmemC.wen in same cycle
c.io.respA.valid.poke(true.B)
c.io.respA.bits.data.poke(cpData.U)
// tmemC write happens combinatorially when respA fires
c.io.tmemC.wen.expect(true.B)
c.io.tmemC.waddr.expect((tmemAddr / fragBytes).U)
c.io.tmemC.wdata.expect(cpData.U)
c.clock.step()
c.io.initiate.ready.expect(true.B)
}
}
it should "tcgen05_cb: read from TMEM and write to global mem (reqA)" in {
test(new TensorCoreBlackwell(numWarps, numLanes, half = true, numSourceIds = 4)) { c =>
idleIO(c)
val fragBytes = 32
val tmemAddr = BigInt(0xa0)
val gmemAddr = "ha2000000"
val cbData = packWords(Seq.tabulate(numLanes)(i => BigInt(0xC000 + i)), 32)
c.io.initiate.valid.poke(true.B)
c.io.initiate.bits.op.poke(6.U) // tcgen05Cb
c.io.initiate.bits.addressA.poke(tmemAddr.U)
c.io.initiate.bits.addressB.poke(gmemAddr.U)
c.io.reqA.ready.poke(true.B)
c.io.tmemC.rdata.poke(cbData.U)
c.clock.step()
c.io.initiate.valid.poke(false.B)
c.io.initiate.ready.expect(false.B)
// cbRead: tmemC.ren asserted
c.io.tmemC.ren.expect(true.B)
c.io.tmemC.raddr.expect((tmemAddr / fragBytes).U)
c.clock.step()
c.io.initiate.ready.expect(false.B)
// cbWrite: reqA write to global mem
c.io.reqA.valid.expect(true.B)
c.io.reqA.bits.rw.expect(true.B)
c.io.reqA.bits.address.expect(0x340.U)
c.io.reqA.bits.address.expect(gmemAddr.U)
c.io.reqA.bits.data.expect(cbData.U)
c.clock.step()
c.io.initiate.ready.expect(false.B)
c.io.respA.valid.poke(true.B)
c.io.respA.bits.data.poke(0.U)
c.clock.step()
c.io.initiate.ready.expect(true.B)
}
}
it should "run bwgmma: TMEM_C = TMEM_A * SMEM_B + TMEM_C" in {
test(new TensorCoreBlackwell(numWarps, numLanes, half = true, numSourceIds = 4))
.withAnnotations(Seq(VerilatorBackendAnnotation)) { c =>
idleIO(c)
val fragBytes = 32
val aBase = BigInt(0x100)
val bBase = BigInt(0x800)
val cBase = BigInt(0x1000)
// A: all fp16 1.0 (0x3c00), 16 halves per frag
val fp16One = BigInt(0x3c00)
val fp16Two = BigInt(0x4000)
val fp32One = BigInt(0x3f800000)
val fp32SixtyFive = BigInt(0x42820000)
val aFrag = packWords(Seq.fill(16)(fp16One), 16)
val bFrag = packWords(Seq.fill(16)(fp16Two), 16)
val cFrag = packWords(Seq.fill(numLanes)(fp32One), 32)
val expectedCFrag = packWords(Seq.fill(numLanes)(fp32SixtyFive), 32)
// Populate TMEM with A and C tiles
val tmem = makeTmem()
for (i <- 0 until 32) {
tmem(aBase / fragBytes + i) = aFrag
tmem(cBase / fragBytes + i) = cFrag
}
val bMem = mutable.Map[BigInt, BigInt]()
for (i <- 0 until 32) bMem(bBase + i * fragBytes) = bFrag
c.io.reqB.ready.poke(true.B)
c.io.writeback.ready.poke(true.B)
c.io.initiate.valid.poke(true.B)
c.io.initiate.bits.op.poke(0.U) // bwgmma
c.io.initiate.bits.wid.poke(1.U)
c.io.initiate.bits.rd.poke(0.U)
c.io.initiate.bits.addressA.poke(aBase.U)
c.io.initiate.bits.addressB.poke(bBase.U)
c.io.initiate.bits.addressC.poke(cBase.U)
c.clock.step()
c.io.initiate.valid.poke(false.B)
var pendingB = Option.empty[(BigInt, BigInt)]
var sawWriteback = false
for (_ <- 0 until 20000 if !sawWriteback) {
// Drive TMEM reads/writes
stepTmem(c, tmem)
// Drive SMEM B responses
pendingB.foreach { case (src, data) =>
c.io.respB.valid.poke(true.B)
c.io.respB.bits.source.poke(src.U)
c.io.respB.bits.data.poke(data.U)
}
if (pendingB.isEmpty) c.io.respB.valid.poke(false.B)
if (c.io.writeback.valid.peek().litToBoolean) {
sawWriteback = true
} else {
val nextB = if (c.io.reqB.valid.peek().litToBoolean) {
val addr = c.io.reqB.bits.address.peek().litValue
val src = c.io.reqB.bits.source.peek().litValue
Some((src, bMem(addr)))
} else None
c.clock.step()
pendingB = nextB
}
}
assert(sawWriteback, "BWGMMA did not complete")
c.io.writeback.bits.wid.expect(1.U)
// Verify all 32 C frags in TMEM
for (i <- 0 until 32) {
val row = cBase / fragBytes + i
assert(tmem(row) == expectedCFrag,
s"C frag $i mismatch: got 0x${tmem(row).toString(16)}, expected 0x${expectedCFrag.toString(16)}")
}
}
}
}