Compare commits

...

4 Commits

Author SHA1 Message Date
Zhongdi LUO
1e78574113 Fix Blackwell SMEM fragment alignment 2026-05-27 08:43:36 +00:00
Zhongdi LUO
c6c30ec0dc Add 4-lane Blackwell tensor core support 2026-05-27 05:54:39 +00:00
126523c5d2 Track vortex masked barrier fix 2026-05-27 09:08:03 +08:00
87a4bbc757 Integrate WU architecture in Radiance 2026-05-25 19:25:59 +08:00
9 changed files with 476 additions and 239 deletions

2
.gitmodules vendored
View File

@@ -1,3 +1,3 @@
[submodule "src/main/resources/vsrc/vortex"]
path = src/main/resources/vsrc/vortex
url = https://github.com/hansungk/vortex.git
url = https://git.nudt.space/wu-arch/vortex.git

View File

@@ -24,7 +24,7 @@ ifeq ($(shell echo $(CONFIG) | grep -E "HopperConfig$$"),$(CONFIG))
EXTRA_SIM_PREPROC_DEFINES += +define+NUM_CORES=4 +define+EXT_T_HOPPER
endif
ifeq ($(shell echo $(CONFIG) | grep -E "BlackwellConfig$$"),$(CONFIG))
EXTRA_SIM_PREPROC_DEFINES += +define+NUM_CORES=4 +define+EXT_T_BLACKWELL
EXTRA_SIM_PREPROC_DEFINES += +define+NUM_CORES=1 +define+NUM_WARPS=4 +define+NUM_THREADS=4 +define+NUM_TENSOR_WARPS=2 +define+EXT_T_BLACKWELL
endif
ifeq ($(shell echo $(CONFIG) | grep -E "FlashConfig$$"),$(CONFIG))
EXTRA_SIM_PREPROC_DEFINES += +define+NUM_CORES=4

View File

@@ -14,7 +14,8 @@ class TensorCoreBlackwell(
val numFPRegs: Int = 32
) extends Module {
require(half, "Blackwell MMA currently supports FP16 inputs only")
require(numLanes == 8, "Blackwell MMA currently assumes 8 lanes")
require(numLanes == 4 || numLanes == 8,
s"Blackwell MMA currently supports 4 or 8 lanes, got ${numLanes}")
val numWarpBits = log2Ceil(numWarps)
val sourceWidth = log2Ceil(numSourceIds)
@@ -26,11 +27,16 @@ class TensorCoreBlackwell(
val fragOffsetBits = log2Ceil(memWidth / 8)
val numSets = 4
val numAFragsPerSet = 8
val numBGroups = 4
val numBFragsPerGroup = 2
val numMGroups = 4
val numCFrags = 32
val numSubsteps = 2
val mElemsPerFrag = if (numLanes == 4) 2 else 4
val numMGroups = 16 / mElemsPerFrag
val numAFragsPerMGroup = 2
val numAFragsPerSet = numMGroups * numAFragsPerMGroup
val numBFragsPerSubstep = if (numLanes == 4) 2 else 1
val numBFragsPerGroup = numSubsteps * numBFragsPerSubstep
val numBFragsPerSet = numBGroups * numBFragsPerGroup
val numCFrags = numBGroups * numMGroups * numSubsteps
object Ops {
val bwgmma :: bwgmmaWait :: tcgen05Cp :: tcgen05CpWait :: tcgen05Ld :: tcgen05St :: tcgen05Cb :: Nil = Enum(7)
@@ -57,13 +63,21 @@ class TensorCoreBlackwell(
// Direct SRAM port for TMEM (no TileLink overhead)
class TmemSramPort extends Bundle {
val wen = Output(Bool())
val ren = Output(Bool())
val waddr = Output(UInt(log2Ceil(numWarps * numCFrags * 2).W))
val raddr = Output(UInt(log2Ceil(numWarps * numCFrags * 2).W))
val wdata = Output(UInt(memWidth.W))
val mask = Output(UInt(maskWidth.W))
val rdata = Input(UInt(memWidth.W))
val aRen = Output(Bool())
val aRready = Input(Bool())
val aRaddr = Output(UInt(log2Ceil(numWarps * numCFrags * 2).W))
val aRdata = Input(UInt(memWidth.W))
val cRen = Output(Bool())
val cRready = Input(Bool())
val cRaddr = Output(UInt(log2Ceil(numWarps * numCFrags * 2).W))
val cRdata = Input(UInt(memWidth.W))
val cWen = Output(Bool())
val cWready = Input(Bool())
val cWaddr = Output(UInt(log2Ceil(numWarps * numCFrags * 2).W))
val cWdata = Output(UInt(memWidth.W))
val cMask = Output(UInt(maskWidth.W))
}
val io = IO(new Bundle {
@@ -94,7 +108,7 @@ class TensorCoreBlackwell(
val idle, bwLoadAReq, bwLoadAResp, bwLoadBReq, bwLoadBResp,
bwReadCReq, bwReadCResp, bwCompute, bwDpuResp, bwWriteCReq,
bwWriteCWait, bwDone, cpRead, cpWrite, ldReq, stReq, stWrite, waitWb,
cbRead, cbWrite = Value
cbRead, cbCapture, cbWrite = Value
}
val state = RegInit(State.idle)
@@ -128,10 +142,11 @@ class TensorCoreBlackwell(
base + (fragIndex << fragOffsetBits).asUInt
}
val aFragIndex = (setReg << 3) + aIndexReg
val bFragIndex = (setReg << 3) + (bGroupReg << 1) + bIndexReg
val stepIndex = Cat(bGroupReg, mGroupReg)
val cFragIndex = (stepIndex << 1) + substepReg
val aFragIndex = (setReg * numAFragsPerSet.U) + aIndexReg
val bFragIndex =
(setReg * numBFragsPerSet.U) + (bGroupReg * numBFragsPerGroup.U) + bIndexReg
val cFragIndex =
(((bGroupReg * numMGroups.U) + mGroupReg) * numSubsteps.U) + substepReg
val aReqAddress = byteAddress(addrAReg, aFragIndex)
val bReqAddress = byteAddress(addrBReg, bFragIndex)
val cReqAddress = byteAddress(addrCReg, cFragIndex)
@@ -147,12 +162,14 @@ class TensorCoreBlackwell(
io.reqA <> reqA
io.reqB <> reqB
io.tmemC.wen := false.B
io.tmemC.ren := false.B
io.tmemC.waddr := 0.U
io.tmemC.raddr := 0.U
io.tmemC.wdata := 0.U
io.tmemC.mask := 0.U
io.tmemC.aRen := false.B
io.tmemC.aRaddr := 0.U
io.tmemC.cRen := false.B
io.tmemC.cRaddr := 0.U
io.tmemC.cWen := false.B
io.tmemC.cWaddr := 0.U
io.tmemC.cWdata := 0.U
io.tmemC.cMask := 0.U
val wbValid = RegInit(false.B)
val wbData = Reg(Vec(numLanes, UInt(laneWidth.W)))
@@ -171,7 +188,12 @@ class TensorCoreBlackwell(
io.initiate.ready := state === State.idle && !wbValid
val operandA = Cat(aBuf((mGroupReg << 1) + 1.U), aBuf(mGroupReg << 1))
val operandB = bBuf(substepReg)
val operandB =
if (numLanes == 4) {
Cat(bBuf((substepReg << 1) + 1.U), bBuf(substepReg << 1))
} else {
bBuf(substepReg)
}
val cWords = cDataReg.asTypeOf(Vec(numLanes, UInt(laneWidth.W)))
val dpuInValid = WireDefault(false.B)
val dpu = Module(new TensorDotProductUnit(
@@ -183,16 +205,22 @@ class TensorCoreBlackwell(
x((idx + 1) * 16 - 1, idx * 16)
}
val elemM = elemReg(1, 0)
val elemN = elemReg(2)
val elemM = if (numLanes == 4) elemReg(0, 0) else elemReg(1, 0)
val elemN = if (numLanes == 4) elemReg(1) else elemReg(2)
dpu.io.in.valid := dpuInValid
for (k <- 0 until 8) {
dpu.io.in.bits.a(k) := MuxLookup(elemM, halfWord(operandA, k))(Seq(
0.U -> halfWord(operandA, k),
1.U -> halfWord(operandA, 8 + k),
2.U -> halfWord(operandA, 16 + k),
3.U -> halfWord(operandA, 24 + k)
))
dpu.io.in.bits.a(k) := (
if (numLanes == 4) {
Mux(elemM.asBool, halfWord(operandA, 8 + k), halfWord(operandA, k))
} else {
MuxLookup(elemM, halfWord(operandA, k))(Seq(
0.U -> halfWord(operandA, k),
1.U -> halfWord(operandA, 8 + k),
2.U -> halfWord(operandA, 16 + k),
3.U -> halfWord(operandA, 24 + k)
))
}
)
dpu.io.in.bits.b(k) := Mux(elemN.asBool, halfWord(operandB, 8 + k), halfWord(operandB, k))
}
dpu.io.in.bits.c := cWords(elemReg)
@@ -229,13 +257,15 @@ class TensorCoreBlackwell(
}
when(state === State.bwLoadAReq) {
io.tmemC.ren := true.B
io.tmemC.raddr := tmemABase + aFragIndex
state := State.bwLoadAResp
io.tmemC.aRen := true.B
io.tmemC.aRaddr := tmemABase + aFragIndex
when(io.tmemC.aRready) {
state := State.bwLoadAResp
}
}
when(state === State.bwLoadAResp) {
aBuf(aIndexReg) := io.tmemC.rdata
aBuf(aIndexReg) := io.tmemC.aRdata
when(aIndexReg === (numAFragsPerSet - 1).U) {
bGroupReg := 0.U
bIndexReg := 0.U
@@ -274,13 +304,15 @@ class TensorCoreBlackwell(
}
when(state === State.bwReadCReq) {
io.tmemC.ren := true.B
io.tmemC.raddr := tmemCBase + cFragIndex
state := State.bwReadCResp
io.tmemC.cRen := true.B
io.tmemC.cRaddr := tmemCBase + cFragIndex
when(io.tmemC.cRready) {
state := State.bwReadCResp
}
}
when(state === State.bwReadCResp) {
cDataReg := io.tmemC.rdata
cDataReg := io.tmemC.cRdata
elemReg := 0.U
state := State.bwCompute
}
@@ -303,34 +335,36 @@ class TensorCoreBlackwell(
}
when(state === State.bwWriteCReq) {
io.tmemC.wen := true.B
io.tmemC.waddr := tmemCBase + cFragIndex
io.tmemC.wdata := mmaDataReg.asUInt
io.tmemC.mask := Fill(maskWidth, 1.U(1.W))
when(substepReg === 0.U) {
substepReg := 1.U
state := State.bwReadCReq
}.elsewhen(mGroupReg =/= (numMGroups - 1).U) {
substepReg := 0.U
mGroupReg := mGroupReg + 1.U
state := State.bwReadCReq
}.elsewhen(bGroupReg =/= (numBGroups - 1).U) {
substepReg := 0.U
mGroupReg := 0.U
bGroupReg := bGroupReg + 1.U
bIndexReg := 0.U
state := State.bwLoadBReq
}.elsewhen(setReg =/= (numSets - 1).U) {
substepReg := 0.U
mGroupReg := 0.U
bGroupReg := 0.U
bIndexReg := 0.U
setReg := setReg + 1.U
aIndexReg := 0.U
state := State.bwLoadAReq
}.otherwise {
waitCounter := 7.U
state := State.bwWriteCWait
io.tmemC.cWen := true.B
io.tmemC.cWaddr := tmemCBase + cFragIndex
io.tmemC.cWdata := mmaDataReg.asUInt
io.tmemC.cMask := Fill(maskWidth, 1.U(1.W))
when(io.tmemC.cWready) {
when(substepReg === 0.U) {
substepReg := 1.U
state := State.bwReadCReq
}.elsewhen(mGroupReg =/= (numMGroups - 1).U) {
substepReg := 0.U
mGroupReg := mGroupReg + 1.U
state := State.bwReadCReq
}.elsewhen(bGroupReg =/= (numBGroups - 1).U) {
substepReg := 0.U
mGroupReg := 0.U
bGroupReg := bGroupReg + 1.U
bIndexReg := 0.U
state := State.bwLoadBReq
}.elsewhen(setReg =/= (numSets - 1).U) {
substepReg := 0.U
mGroupReg := 0.U
bGroupReg := 0.U
bIndexReg := 0.U
setReg := setReg + 1.U
aIndexReg := 0.U
state := State.bwLoadAReq
}.otherwise {
waitCounter := 7.U
state := State.bwWriteCWait
}
}
}
@@ -361,24 +395,26 @@ class TensorCoreBlackwell(
}
when(state === State.cpWrite) {
io.respA.ready := true.B
io.respA.ready := io.tmemC.cWready
io.tmemC.cWen := io.respA.valid
io.tmemC.cWaddr := (addrAReg >> fragOffsetBits.U).asUInt
io.tmemC.cWdata := io.respA.bits.data
io.tmemC.cMask := Fill(maskWidth, 1.U(1.W))
when(io.respA.fire) {
io.tmemC.wen := true.B
io.tmemC.waddr := (addrAReg >> fragOffsetBits.U).asUInt
io.tmemC.wdata := io.respA.bits.data
io.tmemC.mask := Fill(maskWidth, 1.U(1.W))
state := State.idle
}
}
when(state === State.ldReq) {
io.tmemC.ren := true.B
io.tmemC.raddr := (addrAReg >> fragOffsetBits.U).asUInt
state := State.waitWb
io.tmemC.cRen := true.B
io.tmemC.cRaddr := (addrAReg >> fragOffsetBits.U).asUInt
when(io.tmemC.cRready) {
state := State.waitWb
}
}
when(state === State.waitWb && opReg === Ops.tcgen05Ld) {
wbData := io.tmemC.rdata.asTypeOf(Vec(numLanes, UInt(laneWidth.W)))
wbData := io.tmemC.cRdata.asTypeOf(Vec(numLanes, UInt(laneWidth.W)))
wbValid := true.B
state := State.idle
}
@@ -389,16 +425,25 @@ class TensorCoreBlackwell(
}
when(state === State.stWrite) {
io.tmemC.wen := true.B
io.tmemC.waddr := (addrAReg >> fragOffsetBits.U).asUInt
io.tmemC.wdata := io.respC
io.tmemC.mask := Fill(maskWidth, 1.U(1.W))
state := State.idle
io.tmemC.cWen := true.B
io.tmemC.cWaddr := (addrAReg >> fragOffsetBits.U).asUInt
io.tmemC.cWdata := io.respC
io.tmemC.cMask := Fill(maskWidth, 1.U(1.W))
when(io.tmemC.cWready) {
state := State.idle
}
}
when(state === State.cbRead) {
io.tmemC.ren := true.B
io.tmemC.raddr := (addrAReg >> fragOffsetBits.U).asUInt
io.tmemC.cRen := true.B
io.tmemC.cRaddr := (addrAReg >> fragOffsetBits.U).asUInt
when(io.tmemC.cRready) {
state := State.cbCapture
}
}
when(state === State.cbCapture) {
cDataReg := io.tmemC.cRdata
state := State.cbWrite
}
@@ -408,7 +453,7 @@ class TensorCoreBlackwell(
reqA.bits.byteen := Fill(maskWidth, 1.U(1.W))
reqA.bits.address := addrBReg
reqA.bits.source := sourceCounter
reqA.bits.data := io.tmemC.rdata
reqA.bits.data := cDataReg
when(reqA.fire) {
bumpSource()
state := State.waitWb

View File

@@ -51,6 +51,7 @@ class WithRadianceCores(
tensorCoreFP16: Boolean,
tensorCoreDecoupled: Boolean,
tensorCoreBlackwell: Boolean,
numTensorWarps: Int,
startupAddress: BigInt,
useVxCache: Boolean
) extends Config((site, _, up) => {
@@ -63,6 +64,7 @@ class WithRadianceCores(
tensorCoreFP16 = tensorCoreFP16,
tensorCoreDecoupled = tensorCoreDecoupled,
tensorCoreBlackwell = tensorCoreBlackwell,
numTensorWarps = numTensorWarps,
startupAddress = startupAddress
),
btb = None,
@@ -101,6 +103,7 @@ class WithRadianceCores(
def this(n: Int, location: HierarchicalLocation = InSubsystem,
tensorCoreFP16: Boolean = false, tensorCoreDecoupled: Boolean = false,
tensorCoreBlackwell: Boolean = false,
numTensorWarps: Int = 4,
startupAddress: BigInt = BigInt("10100", 16),
useVxCache: Boolean = false)
= this(n, location, RocketCrossingParams(
@@ -110,7 +113,7 @@ class WithRadianceCores(
case InSubsystem => CBUS
case InCluster(clusterId) => CCBUS(clusterId)
}
), tensorCoreFP16, tensorCoreDecoupled, tensorCoreBlackwell, startupAddress, useVxCache)
), tensorCoreFP16, tensorCoreDecoupled, tensorCoreBlackwell, numTensorWarps, startupAddress, useVxCache)
}
class WithBlackwellTensorCore(location: HierarchicalLocation = InSubsystem) extends Config((site, _, up) => {

View File

@@ -21,7 +21,7 @@ import midas.targetutils.SynthesizePrintf
import org.chipsalliance.cde.config._
import radiance.core._
import radiance.memory._
import radiance.subsystem.{GPUMemParams, GPUMemory, RadianceSimArgs}
import radiance.subsystem.{GPUMemParams, GPUMemory, RadianceSharedMemKey, RadianceSimArgs}
/** For determining radiance core id. This may be different from
* RadianceTileParams.tileId, when a cluster contains non-core tiles */
@@ -102,6 +102,7 @@ case class VortexCoreParams(
tensorCoreFP16: Boolean = false, // FP16 if true, FP32 if false
tensorCoreDecoupled: Boolean = false, // hopper-style SMEM operand decoupling
tensorCoreBlackwell: Boolean = false, // blackwell-style TMEM + SMEM tensor core
numTensorWarps: Int = 4,
startupAddress: BigInt = BigInt("10100", 16), // initial warp PC programmed through startup DCRs
debugROB: Boolean = false, // if enabled, uses a C++ debug ROB to generate trace-with-wdata
haveCease: Boolean = true, // non-standard CEASE instruction
@@ -210,7 +211,9 @@ class RadianceTile private (
case Some(false) => 1
case None => 1
}
val imemTagWidth = UUID_WIDTH + NW_WIDTH
// Must match VX_gpu_pkg.sv: ICACHE_TAG_WIDTH = domain + UUID + wid.
val imemDomainWidth = 1
val imemTagWidth = imemDomainWidth + UUID_WIDTH + NW_WIDTH
require(numWarps >= numLsuLanes,
s"Vortex core requires numWarps (${numWarps}) >= numLsuLanes (${numLsuLanes})")
@@ -285,18 +288,37 @@ class RadianceTile private (
)
}
val tcSmemSize = 32
val tcSmemSize = numLsuLanes * 4
val tcSmemLineSize = p(RadianceSharedMemKey)
.map(k => k.numWords * k.wordSize)
.getOrElse(tcSmemSize)
val tcSmemClientMaxSize =
if (radianceParams.core.tensorCoreBlackwell) math.max(tcSmemSize, tcSmemLineSize) else tcSmemSize
val numTensorWarps = radianceParams.core.numTensorWarps
val numScalarWarps = numWarps - numTensorWarps
require(numTensorWarps > 0 && numTensorWarps < numWarps,
s"Wu requires 0 < numTensorWarps (${numTensorWarps}) < numWarps (${numWarps})")
val numTensorCores = if (radianceParams.core.tensorCoreBlackwell) numTensorWarps else 1
if (radianceParams.core.tensorCoreBlackwell) {
require(numCoreLanes == numLsuLanes,
s"Wu Blackwell binding requires matching core lanes (${numCoreLanes}) and memory lanes (${numLsuLanes})")
require(numLsuLanes == 4 || numLsuLanes == 8,
s"Wu Blackwell binding supports 4 or 8 lanes, got ${numLsuLanes}")
require(numTensorCores == numTensorWarps, "Wu Blackwell binding requires one Tensor Core per Tensor warp")
require(isPow2(tcSmemLineSize) && tcSmemLineSize >= tcSmemSize && (tcSmemLineSize % tcSmemSize) == 0,
s"Wu Blackwell SMEM line size (${tcSmemLineSize}) must be a power-of-two multiple of TC fragment size (${tcSmemSize})")
}
val tensorUsesAsyncMem = radianceParams.core.tensorCoreDecoupled || radianceParams.core.tensorCoreBlackwell
val tcSmemNodeCount = if (radianceParams.core.tensorCoreDecoupled) 2 else if (radianceParams.core.tensorCoreBlackwell) 1 else 0
val tcSmemNodeCount = if (radianceParams.core.tensorCoreDecoupled) 2 else if (radianceParams.core.tensorCoreBlackwell) numTensorCores else 0
val tcSmemNodes = Seq.tabulate(tcSmemNodeCount) { i =>
TLClientNode(Seq(TLMasterPortParameters.v2(
masters = Seq(TLMasterParameters.v2(
name = s"rad_tc_${radianceParams.coreId}_$i",
sourceId = IdRange(0, 1 << smemSourceWidth),
supports = TLSlaveToMasterTransferSizes(
probe = TransferSizes(1, tcSmemSize),
get = TransferSizes(1, tcSmemSize),
putFull = TransferSizes(1, tcSmemSize),
probe = TransferSizes(1, tcSmemClientMaxSize),
get = TransferSizes(1, tcSmemClientMaxSize),
putFull = TransferSizes(1, tcSmemClientMaxSize),
),
requestFifo = true
))
@@ -304,19 +326,21 @@ class RadianceTile private (
}
// For Blackwell, tcSmemNodes accesses SMEM (bwgmma B operand)
// tcGmemNode provides global memory access for cp (global→tmem) and cb (tmem→global)
val tcGmemNode = if (radianceParams.core.tensorCoreBlackwell) Some(TLClientNode(Seq(
TLMasterPortParameters.v2(masters = Seq(TLMasterParameters.v2(
name = s"rad_tc_gmem_${radianceParams.coreId}",
sourceId = IdRange(0, 1 << dmemSourceWidth),
supports = TLSlaveToMasterTransferSizes(
probe = TransferSizes(1, tcSmemSize),
get = TransferSizes(1, tcSmemSize),
putFull = TransferSizes(1, tcSmemSize),
),
requestFifo = true
)))
))) else None
// tcGmemNodes provide global memory access for cp (global→tmem) and cb (tmem→global)
val tcGmemNodes = if (radianceParams.core.tensorCoreBlackwell) {
Seq.tabulate(numTensorCores) { i =>
TLClientNode(Seq(TLMasterPortParameters.v2(masters = Seq(TLMasterParameters.v2(
name = s"rad_tc_gmem_${radianceParams.coreId}_$i",
sourceId = IdRange(0, 1 << dmemSourceWidth),
supports = TLSlaveToMasterTransferSizes(
probe = TransferSizes(1, tcSmemSize),
get = TransferSizes(1, tcSmemSize),
putFull = TransferSizes(1, tcSmemSize),
),
requestFifo = true
)))))
}
} else Seq.empty
// combine outgoing per-lane dmemNode into 1 idenity node
//
@@ -406,7 +430,7 @@ class RadianceTile private (
// imemNodes.foreach { tlMasterXbar.node := TLWidthWidget(4) := _ }
tlMasterXbar.node :=* AddressOrNode(base) :=* icacheNode
tlMasterXbar.node :=* AddressOrNode(base) :=* dcacheNode
tcGmemNode.foreach { n => tlMasterXbar.node := AddressOrNode(base) := n }
tcGmemNodes.foreach { n => tlMasterXbar.node := AddressOrNode(base) := n }
}
/* below are copied from rocket */
@@ -822,86 +846,185 @@ class RadianceTileModuleImp(outer: RadianceTile)
core.io.tc_d_bits_data := DontCare
core.io.tc_d_bits_tag := DontCare
}
core.io.tc_tmem_A_rready := DontCare
core.io.tc_tmem_A_rdata := DontCare
core.io.tc_tmem_C_rready := DontCare
core.io.tc_tmem_C_rdata := DontCare
core.io.tc_tmem_C_wready := DontCare
}
def connectTensorBlackwell = {
if (outer.radianceParams.core.tensorCoreBlackwell) {
require(outer.tcSmemNodes.nonEmpty)
require(outer.tcSmemNodes.length == outer.numTensorCores)
require(outer.tcGmemNodes.length == outer.numTensorCores)
// TMEM C matrix: direct SRAM (no TileLink), connected via VortexCore IO
// Each warp needs 2 tiles (A + C), each tile = 32 frags × 32B = 1KB
val tmemDepth = outer.numWarps * outer.tcSmemSize * 2 // numWarps × 64 rows
val nTC = outer.numTensorCores
val tcPorts = 3
val tcCoreDataBits = 32 * 8
val tcDataBits = outer.tcSmemSize * 8
val tcSmemLineBits = outer.tcSmemLineSize * 8
val tmemAddrBits = 9
val tmemDataBits = tcDataBits
val tmemMaskBits = outer.tcSmemSize
val tcTlSize = log2Ceil(outer.tcSmemSize)
val tcSmemLineTlSize = log2Ceil(outer.tcSmemLineSize)
def slice(u: UInt, width: Int, idx: Int): UInt = u(width * (idx + 1) - 1, width * idx)
def port(tc: Int, p: Int): Int = tc * tcPorts + p
def padToCoreData(u: UInt): UInt = {
if (u.getWidth == tcCoreDataBits) u else Cat(0.U((tcCoreDataBits - u.getWidth).W), u)
}
val tcAReady = Wire(Vec(nTC * tcPorts, Bool()))
val tcDValid = Wire(Vec(nTC * tcPorts, Bool()))
val tcDData = Wire(Vec(nTC * tcPorts, UInt(tcCoreDataBits.W)))
val tcDTag = Wire(Vec(nTC * tcPorts, UInt(outer.tensorTagWidth.W)))
tcAReady.foreach(_ := false.B)
tcDValid.foreach(_ := false.B)
tcDData.foreach(_ := 0.U)
tcDTag.foreach(_ := 0.U)
// TMEM matrix: one shared 2R1W SRAM. read0 is operand A, read1 is C.
// Each warp owns 2KB: A tile and C tile are 1KB each. The row count
// scales with the physical fragment width (16B for 4 lanes, 32B for 8).
val tmemBytesPerWarp = 2048
val tmemDepth = outer.numWarps * (tmemBytesPerWarp / outer.tcSmemSize)
val tmem = Module(new radiance.memory.TwoReadOneWriteSyncMem(
tmemDepth, UInt((outer.tcSmemSize * 8).W)))
tmem.io.ren0 := core.io.tc_tmem_C_ren
tmem.io.raddr0 := core.io.tc_tmem_C_raddr
core.io.tc_tmem_C_rdata := tmem.io.rdata0
tmem.io.ren1 := false.B
tmem.io.raddr1 := 0.U
tmem.io.wen := core.io.tc_tmem_C_wen
tmem.io.waddr := core.io.tc_tmem_C_waddr
tmem.io.wdata := core.io.tc_tmem_C_wdata
tmem.io.mask := core.io.tc_tmem_C_mask
// smem_B (port 2): Global Memory via TileLink
val smemBBundle = new {
val addr = core.io.tc_a_bits_address(95, 64)
val tag = core.io.tc_a_bits_tag(8 + outer.tensorTagWidth - 1, 8)
val write = core.io.tc_a_bits_write(2)
val mask = core.io.tc_a_bits_mask(95, 64)
val data = core.io.tc_a_bits_data(767, 512)
val aValid = core.io.tc_a_valid(2)
val dReady = core.io.tc_d_ready(2)
val aReadArb = Module(new RRArbiter(UInt(tmemAddrBits.W), nTC))
val cReadArb = Module(new RRArbiter(UInt(tmemAddrBits.W), nTC))
class TmemWriteReq extends Bundle {
val addr = UInt(tmemAddrBits.W)
val data = UInt(tmemDataBits.W)
val mask = UInt(tmemMaskBits.W)
}
val client = outer.tcSmemNodes.head.out.head
val adapter = Module(new VortexTLAdapter(
outer.smemSourceWidth,
new VortexBundleA(tagWidth = outer.tensorTagWidth, dataWidth = 32 * 8),
new VortexBundleD(tagWidth = outer.tensorTagWidth, dataWidth = 32 * 8),
client
))
adapter.io.inReq.bits <> DontCare
adapter.io.inReq.valid := smemBBundle.aValid
adapter.io.inReq.bits.address := smemBBundle.addr
adapter.io.inReq.bits.source := smemBBundle.tag
adapter.io.inReq.bits.size := 5.U
adapter.io.inReq.bits.opcode := Mux(smemBBundle.write.asBool, TLMessages.PutFullData, TLMessages.Get)
adapter.io.inReq.bits.mask := smemBBundle.mask
adapter.io.inReq.bits.data := smemBBundle.data
adapter.io.inResp.ready := smemBBundle.dReady
client._1.a <> adapter.io.outReq
adapter.io.outResp <> client._1.d
val cWriteArb = Module(new RRArbiter(new TmemWriteReq, nTC))
// port 0: global memory (cp/cb)
val gmemClient = outer.tcGmemNode.get.out.head
val gmemAdapter = Module(new VortexTLAdapter(
outer.dmemSourceWidth,
new VortexBundleA(tagWidth = outer.tensorTagWidth, dataWidth = 32 * 8),
new VortexBundleD(tagWidth = outer.tensorTagWidth, dataWidth = 32 * 8),
gmemClient
))
gmemAdapter.io.inReq.bits <> DontCare
gmemAdapter.io.inReq.valid := core.io.tc_a_valid(0)
gmemAdapter.io.inReq.bits.address := core.io.tc_a_bits_address(31, 0)
gmemAdapter.io.inReq.bits.source := core.io.tc_a_bits_tag(outer.tensorTagWidth - 1, 0)
gmemAdapter.io.inReq.bits.size := 5.U
gmemAdapter.io.inReq.bits.opcode := Mux(core.io.tc_a_bits_write(0).asBool, TLMessages.PutFullData, TLMessages.Get)
gmemAdapter.io.inReq.bits.mask := core.io.tc_a_bits_mask(31, 0)
gmemAdapter.io.inReq.bits.data := core.io.tc_a_bits_data(255, 0)
gmemAdapter.io.inResp.ready := core.io.tc_d_ready(0)
gmemClient._1.a <> gmemAdapter.io.outReq
gmemAdapter.io.outResp <> gmemClient._1.d
(0 until nTC).foreach { tc =>
aReadArb.io.in(tc).valid := core.io.tc_tmem_A_ren(tc)
aReadArb.io.in(tc).bits := slice(core.io.tc_tmem_A_raddr, tmemAddrBits, tc)
cReadArb.io.in(tc).valid := core.io.tc_tmem_C_ren(tc)
cReadArb.io.in(tc).bits := slice(core.io.tc_tmem_C_raddr, tmemAddrBits, tc)
cWriteArb.io.in(tc).valid := core.io.tc_tmem_C_wen(tc)
cWriteArb.io.in(tc).bits.addr := slice(core.io.tc_tmem_C_waddr, tmemAddrBits, tc)
cWriteArb.io.in(tc).bits.data := slice(core.io.tc_tmem_C_wdata, tmemDataBits, tc)
cWriteArb.io.in(tc).bits.mask := slice(core.io.tc_tmem_C_mask, tmemMaskBits, tc)
}
core.io.tc_a_ready := Cat(adapter.io.inReq.ready, 0.U(1.W), gmemAdapter.io.inReq.ready)
core.io.tc_d_valid := Cat(adapter.io.inResp.valid, 0.U(1.W), gmemAdapter.io.inResp.valid)
core.io.tc_d_bits_data := Cat(adapter.io.inResp.bits.data, 0.U((outer.tcSmemSize * 8).W), gmemAdapter.io.inResp.bits.data)
core.io.tc_d_bits_tag := Cat(adapter.io.inResp.bits.source, 0.U(outer.tensorTagWidth.W), gmemAdapter.io.inResp.bits.source)
aReadArb.io.out.ready := true.B
cReadArb.io.out.ready := true.B
cWriteArb.io.out.ready := true.B
tmem.io.ren0 := aReadArb.io.out.fire
tmem.io.raddr0 := aReadArb.io.out.bits
tmem.io.ren1 := cReadArb.io.out.fire
tmem.io.raddr1 := cReadArb.io.out.bits
tmem.io.wen := cWriteArb.io.out.fire
tmem.io.waddr := cWriteArb.io.out.bits.addr
tmem.io.wdata := cWriteArb.io.out.bits.data
tmem.io.mask := cWriteArb.io.out.bits.mask
val aReadGrant = RegNext(Mux(aReadArb.io.out.fire, UIntToOH(aReadArb.io.chosen, nTC), 0.U(nTC.W)))
val cReadGrant = RegNext(Mux(cReadArb.io.out.fire, UIntToOH(cReadArb.io.chosen, nTC), 0.U(nTC.W)))
core.io.tc_tmem_A_rready := VecInit(aReadArb.io.in.map(_.fire)).asUInt
core.io.tc_tmem_C_rready := VecInit(cReadArb.io.in.map(_.fire)).asUInt
core.io.tc_tmem_C_wready := VecInit(cWriteArb.io.in.map(_.fire)).asUInt
core.io.tc_tmem_A_rdata := VecInit((0 until nTC).map { tc =>
Mux(aReadGrant(tc), tmem.io.rdata0, 0.U(tmemDataBits.W))
}).asUInt
core.io.tc_tmem_C_rdata := VecInit((0 until nTC).map { tc =>
Mux(cReadGrant(tc), tmem.io.rdata1, 0.U(tmemDataBits.W))
}).asUInt
// port 2: SMEM B, one TL client per tensor core. RadianceSharedMem arbitrates them.
(0 until nTC).foreach { tc =>
val p2 = port(tc, 2)
val client = outer.tcSmemNodes(tc).out.head
val rawAddress = slice(core.io.tc_a_bits_address, 32, p2)
val lineAddress = rawAddress & (~((outer.tcSmemLineSize - 1).U(32.W))).asUInt
val adapter = Module(new VortexTLAdapter(
outer.smemSourceWidth,
new VortexBundleA(tagWidth = outer.tensorTagWidth, dataWidth = tcSmemLineBits),
new VortexBundleD(tagWidth = outer.tensorTagWidth, dataWidth = tcSmemLineBits),
client
))
adapter.io.inReq.bits <> DontCare
adapter.io.inReq.valid := core.io.tc_a_valid(p2)
adapter.io.inReq.bits.address := lineAddress
adapter.io.inReq.bits.source := slice(core.io.tc_a_bits_tag, outer.tensorTagWidth, p2)
adapter.io.inReq.bits.size := tcSmemLineTlSize.U
adapter.io.inReq.bits.opcode := Mux(core.io.tc_a_bits_write(p2).asBool, TLMessages.PutFullData, TLMessages.Get)
adapter.io.inReq.bits.mask := Fill(outer.tcSmemLineSize, 1.U(1.W))
adapter.io.inReq.bits.data := slice(core.io.tc_a_bits_data, tcCoreDataBits, p2)(tcSmemLineBits - 1, 0)
adapter.io.inResp.ready := core.io.tc_d_ready(p2)
client._1.a <> adapter.io.outReq
adapter.io.outResp <> client._1.d
val lineData = adapter.io.inResp.bits.data
val fragmentData = if (outer.tcSmemLineSize == outer.tcSmemSize) {
lineData
} else {
val fragmentsPerLine = outer.tcSmemLineSize / outer.tcSmemSize
val fragmentIndex = RegInit(0.U(log2Ceil(fragmentsPerLine).W))
val requestFragmentIndex = ((rawAddress & (outer.tcSmemLineSize - 1).U) >>
log2Ceil(outer.tcSmemSize)).asUInt
val lineFragments = lineData.asTypeOf(Vec(fragmentsPerLine, UInt(tcDataBits.W)))
when(adapter.io.inReq.fire) {
fragmentIndex := requestFragmentIndex
}
lineFragments(fragmentIndex)
}
tcAReady(p2) := adapter.io.inReq.ready
tcDValid(p2) := adapter.io.inResp.valid
tcDData(p2) := padToCoreData(fragmentData)
tcDTag(p2) := adapter.io.inResp.bits.source
}
// port 0: global memory (cp/cb), one TL client per tensor core.
(0 until nTC).foreach { tc =>
val p0 = port(tc, 0)
val gmemClient = outer.tcGmemNodes(tc).out.head
val gmemAdapter = Module(new VortexTLAdapter(
outer.dmemSourceWidth,
new VortexBundleA(tagWidth = outer.tensorTagWidth, dataWidth = tcDataBits),
new VortexBundleD(tagWidth = outer.tensorTagWidth, dataWidth = tcDataBits),
gmemClient
))
gmemAdapter.io.inReq.bits <> DontCare
gmemAdapter.io.inReq.valid := core.io.tc_a_valid(p0)
gmemAdapter.io.inReq.bits.address := slice(core.io.tc_a_bits_address, 32, p0)
gmemAdapter.io.inReq.bits.source := slice(core.io.tc_a_bits_tag, outer.tensorTagWidth, p0)
gmemAdapter.io.inReq.bits.size := tcTlSize.U
gmemAdapter.io.inReq.bits.opcode := Mux(core.io.tc_a_bits_write(p0).asBool, TLMessages.PutFullData, TLMessages.Get)
gmemAdapter.io.inReq.bits.mask := slice(core.io.tc_a_bits_mask, 32, p0)(outer.tcSmemSize - 1, 0)
gmemAdapter.io.inReq.bits.data := slice(core.io.tc_a_bits_data, tcCoreDataBits, p0)(tcDataBits - 1, 0)
gmemAdapter.io.inResp.ready := core.io.tc_d_ready(p0)
gmemClient._1.a <> gmemAdapter.io.outReq
gmemAdapter.io.outResp <> gmemClient._1.d
tcAReady(p0) := gmemAdapter.io.inReq.ready
tcDValid(p0) := gmemAdapter.io.inResp.valid
tcDData(p0) := padToCoreData(gmemAdapter.io.inResp.bits.data)
tcDTag(p0) := gmemAdapter.io.inResp.bits.source
}
core.io.tc_a_ready := tcAReady.asUInt
core.io.tc_d_valid := tcDValid.asUInt
core.io.tc_d_bits_data := tcDData.asUInt
core.io.tc_d_bits_tag := tcDTag.asUInt
} else {
core.io.tc_a_ready := false.B
core.io.tc_d_valid := false.B
core.io.tc_d_bits_data := DontCare
core.io.tc_d_bits_tag := DontCare
core.io.tc_tmem_A_rready := DontCare
core.io.tc_tmem_A_rdata := DontCare
core.io.tc_tmem_C_rready := DontCare
core.io.tc_tmem_C_rdata := DontCare
core.io.tc_tmem_C_wready := DontCare
}
}
@@ -995,7 +1118,7 @@ class RadianceTileModuleImp(outer: RadianceTile)
} else if (outer.radianceParams.core.tensorCoreBlackwell) {
val tensorNumSourceIds = (1 << outer.tensorTagWidth)
val tensor = Module(new radiance.core.TensorCoreBlackwell(
8, 8, half = true, tensorNumSourceIds))
outer.numWarps, outer.numLsuLanes, half = true, tensorNumSourceIds))
tensor.io.initiate.valid := false.B
tensor.io.initiate.bits := DontCare
tensor.io.respA.valid := false.B
@@ -1006,7 +1129,11 @@ class RadianceTileModuleImp(outer: RadianceTile)
tensor.io.reqA.ready := false.B
tensor.io.reqB.ready := false.B
tensor.io.writeback.ready := false.B
tensor.io.tmemC.rdata := DontCare
tensor.io.tmemC.aRready := false.B
tensor.io.tmemC.aRdata := DontCare
tensor.io.tmemC.cRready := false.B
tensor.io.tmemC.cRdata := DontCare
tensor.io.tmemC.cWready := false.B
dontTouch(tensor.io)
} else {
if (outer.radianceParams.core.tensorCoreFP16) {

View File

@@ -90,28 +90,36 @@ class VortexBundle(tile: RadianceTile)(implicit p: Parameters) extends CoreBundl
val smem_d_bits_data = Input(UInt((tile.numLsuLanes * 32).W))
val smem_d_ready = Output(UInt((tile.numLsuLanes * 1).W))
val numTensorCores = if (tile.radianceParams.core.tensorCoreBlackwell) tile.numTensorCores else 1
val tcPortCount = 3
val tc_a_valid = Output(UInt(tcPortCount.W))
val tc_a_bits_write = Output(UInt(tcPortCount.W))
val tc_a_bits_address = Output(UInt((tcPortCount * 32).W))
val tc_a_bits_tag = Output(UInt((tcPortCount * 4).W))
val tc_a_bits_mask = Output(UInt((tcPortCount * 32).W))
val tc_a_bits_data = Output(UInt((tcPortCount * 32 * 8).W))
val tc_a_ready = Input(UInt(tcPortCount.W))
val tc_d_valid = Input(UInt(tcPortCount.W))
val tc_d_bits_data = Input(UInt((tcPortCount * 32 * 8).W))
val tc_d_bits_tag = Input(UInt((tcPortCount * 4).W))
val tc_d_ready = Output(UInt(tcPortCount.W))
val tcFlatPortCount = tcPortCount * numTensorCores
val tc_a_valid = Output(UInt(tcFlatPortCount.W))
val tc_a_bits_write = Output(UInt(tcFlatPortCount.W))
val tc_a_bits_address = Output(UInt((tcFlatPortCount * 32).W))
val tc_a_bits_tag = Output(UInt((tcFlatPortCount * 4).W))
val tc_a_bits_mask = Output(UInt((tcFlatPortCount * 32).W))
val tc_a_bits_data = Output(UInt((tcFlatPortCount * 32 * 8).W))
val tc_a_ready = Input(UInt(tcFlatPortCount.W))
val tc_d_valid = Input(UInt(tcFlatPortCount.W))
val tc_d_bits_data = Input(UInt((tcFlatPortCount * 32 * 8).W))
val tc_d_bits_tag = Input(UInt((tcFlatPortCount * 4).W))
val tc_d_ready = Output(UInt(tcFlatPortCount.W))
// Direct SRAM port for TMEM C (bypasses TileLink)
// Direct SRAM ports for shared TMEM (bypasses TileLink)
val numLanes = tile.numLsuLanes
val tc_tmem_C_wen = Output(Bool())
val tc_tmem_C_ren = Output(Bool())
val tc_tmem_C_waddr = Output(UInt(9.W))
val tc_tmem_C_raddr = Output(UInt(9.W))
val tc_tmem_C_wdata = Output(UInt((numLanes * 32).W))
val tc_tmem_C_mask = Output(UInt((numLanes * 4).W))
val tc_tmem_C_rdata = Input(UInt((numLanes * 32).W))
val tc_tmem_A_ren = Output(UInt(numTensorCores.W))
val tc_tmem_A_rready = Input(UInt(numTensorCores.W))
val tc_tmem_A_raddr = Output(UInt((numTensorCores * 9).W))
val tc_tmem_A_rdata = Input(UInt((numTensorCores * numLanes * 32).W))
val tc_tmem_C_ren = Output(UInt(numTensorCores.W))
val tc_tmem_C_rready = Input(UInt(numTensorCores.W))
val tc_tmem_C_raddr = Output(UInt((numTensorCores * 9).W))
val tc_tmem_C_rdata = Input(UInt((numTensorCores * numLanes * 32).W))
val tc_tmem_C_wen = Output(UInt(numTensorCores.W))
val tc_tmem_C_wready = Input(UInt(numTensorCores.W))
val tc_tmem_C_waddr = Output(UInt((numTensorCores * 9).W))
val tc_tmem_C_wdata = Output(UInt((numTensorCores * numLanes * 32).W))
val tc_tmem_C_mask = Output(UInt((numTensorCores * numLanes * 4).W))
// FIXME: hardcoded
val barrierIdBits = tile.barrierMasterNode.out(0)._2.barrierIdBits
@@ -147,7 +155,8 @@ class Vortex(tile: RadianceTile)(implicit p: Parameters)
"CORE_ID" -> tile.radianceParams.coreId,
"TENSOR_FP16" -> (if (tile.radianceParams.core.tensorCoreFP16) 1 else 0),
"STARTUP_ADDR" -> tile.radianceParams.core.startupAddress,
"NUM_THREADS" -> tile.numLsuLanes
"NUM_THREADS" -> tile.numLsuLanes,
"NUM_TENSOR_CORES" -> (if (tile.radianceParams.core.tensorCoreBlackwell) tile.numTensorCores else 1)
)
)
with HasBlackBoxResource with HasBlackBoxPath {
@@ -211,6 +220,7 @@ class Vortex(tile: RadianceTile)(implicit p: Parameters)
addResource("/vsrc/vortex/hw/rtl/core/VX_scoreboard.sv")
addResource("/vsrc/vortex/hw/rtl/core/VX_sfu_unit.sv")
addResource("/vsrc/vortex/hw/rtl/core/VX_smem_unit.sv")
addResource("/vsrc/vortex/hw/rtl/core/VX_tensor_ctrl_unit.sv")
addResource("/vsrc/vortex/hw/rtl/core/VX_split_join.sv")
addResource("/vsrc/vortex/hw/rtl/core/VX_trace.vh")
addResource("/vsrc/vortex/hw/rtl/core/VX_wctl_unit.sv")

View File

@@ -26,7 +26,11 @@ class TensorCoreBlackwellExtendedTest extends AnyFlatSpec with ChiselScalatestTe
c.io.reqB.ready.poke(false.B)
c.io.respC.poke(0.U)
c.io.writeback.ready.poke(false.B)
c.io.tmemC.rdata.poke(0.U)
c.io.tmemC.aRready.poke(true.B)
c.io.tmemC.aRdata.poke(0.U)
c.io.tmemC.cRready.poke(true.B)
c.io.tmemC.cRdata.poke(0.U)
c.io.tmemC.cWready.poke(true.B)
}
private def packWords(words: Seq[BigInt], width: Int): BigInt = {
@@ -39,13 +43,17 @@ class TensorCoreBlackwellExtendedTest extends AnyFlatSpec with ChiselScalatestTe
private def makeTmem() = mutable.Map[BigInt, BigInt]().withDefaultValue(BigInt(0))
private def stepTmem(c: TensorCoreBlackwell, tmem: mutable.Map[BigInt, BigInt]): Unit = {
if (c.io.tmemC.ren.peek().litToBoolean) {
val addr = c.io.tmemC.raddr.peek().litValue
c.io.tmemC.rdata.poke(tmem(addr).U)
if (c.io.tmemC.aRen.peek().litToBoolean) {
val addr = c.io.tmemC.aRaddr.peek().litValue
c.io.tmemC.aRdata.poke(tmem(addr).U)
}
if (c.io.tmemC.wen.peek().litToBoolean) {
val addr = c.io.tmemC.waddr.peek().litValue
tmem(addr) = c.io.tmemC.wdata.peek().litValue
if (c.io.tmemC.cRen.peek().litToBoolean) {
val addr = c.io.tmemC.cRaddr.peek().litValue
c.io.tmemC.cRdata.poke(tmem(addr).U)
}
if (c.io.tmemC.cWen.peek().litToBoolean) {
val addr = c.io.tmemC.cWaddr.peek().litValue
tmem(addr) = c.io.tmemC.cWdata.peek().litValue
}
}
@@ -154,9 +162,9 @@ class TensorCoreBlackwellExtendedTest extends AnyFlatSpec with ChiselScalatestTe
// cpWrite: respA fires, tmemC written
c.io.respA.valid.poke(true.B)
c.io.respA.bits.data.poke(cpData.U)
c.io.tmemC.wen.expect(true.B)
c.io.tmemC.waddr.expect((tmemAddr / fragBytes).U)
c.io.tmemC.wdata.expect(cpData.U)
c.io.tmemC.cWen.expect(true.B)
c.io.tmemC.cWaddr.expect((tmemAddr / fragBytes).U)
c.io.tmemC.cWdata.expect(cpData.U)
stepTmem(c, tmem)
c.clock.step()
c.io.respA.valid.poke(false.B)
@@ -171,10 +179,10 @@ class TensorCoreBlackwellExtendedTest extends AnyFlatSpec with ChiselScalatestTe
c.io.initiate.valid.poke(false.B)
// ldReq: ren asserted, serve from tmem model
c.io.tmemC.ren.expect(true.B)
c.io.tmemC.rdata.poke(tmem(tmemAddr / fragBytes).U)
c.io.tmemC.cRen.expect(true.B)
c.io.tmemC.cRdata.poke(tmem(tmemAddr / fragBytes).U)
c.clock.step()
c.io.tmemC.rdata.poke(tmem(tmemAddr / fragBytes).U)
c.io.tmemC.cRdata.poke(tmem(tmemAddr / fragBytes).U)
c.clock.step()
// writeback should carry cpData
@@ -206,8 +214,8 @@ class TensorCoreBlackwellExtendedTest extends AnyFlatSpec with ChiselScalatestTe
c.clock.step()
// stWrite: tmemC written
c.io.tmemC.wen.expect(true.B)
c.io.tmemC.wdata.expect(stData.U)
c.io.tmemC.cWen.expect(true.B)
c.io.tmemC.cWdata.expect(stData.U)
stepTmem(c, tmem)
c.clock.step()
@@ -217,13 +225,15 @@ class TensorCoreBlackwellExtendedTest extends AnyFlatSpec with ChiselScalatestTe
c.io.initiate.bits.addressA.poke(tmemAddr.U)
c.io.initiate.bits.addressB.poke("h20000000".U)
c.io.reqA.ready.poke(true.B)
c.io.tmemC.rdata.poke(tmem(tmemAddr / fragBytes).U)
c.io.tmemC.cRdata.poke(tmem(tmemAddr / fragBytes).U)
c.clock.step()
c.io.initiate.valid.poke(false.B)
// cbRead: ren asserted
c.io.tmemC.ren.expect(true.B)
c.io.tmemC.rdata.poke(tmem(tmemAddr / fragBytes).U)
c.io.tmemC.cRen.expect(true.B)
c.io.tmemC.cRdata.poke(tmem(tmemAddr / fragBytes).U)
c.clock.step()
c.io.tmemC.cRdata.poke(tmem(tmemAddr / fragBytes).U)
c.clock.step()
// cbWrite: reqA write with stData
@@ -280,7 +290,7 @@ class TensorCoreBlackwellExtendedTest extends AnyFlatSpec with ChiselScalatestTe
c.clock.step()
c.io.initiate.ready.expect(false.B)
c.io.tmemC.wen.expect(true.B)
c.io.tmemC.cWen.expect(true.B)
c.clock.step()
c.io.initiate.ready.expect(true.B)
}
@@ -309,8 +319,8 @@ class TensorCoreBlackwellExtendedTest extends AnyFlatSpec with ChiselScalatestTe
c.io.initiate.valid.poke(false.B)
c.io.reqC.valid.expect(true.B)
c.clock.step()
c.io.tmemC.wen.expect(true.B)
c.io.tmemC.waddr.expect((warp0TmemA / fragBytes).U)
c.io.tmemC.cWen.expect(true.B)
c.io.tmemC.cWaddr.expect((warp0TmemA / fragBytes).U)
stepTmem(c, tmem)
c.clock.step()
@@ -324,8 +334,8 @@ class TensorCoreBlackwellExtendedTest extends AnyFlatSpec with ChiselScalatestTe
c.io.initiate.valid.poke(false.B)
c.io.reqC.valid.expect(true.B)
c.clock.step()
c.io.tmemC.wen.expect(true.B)
c.io.tmemC.waddr.expect((warp3TmemA / fragBytes).U)
c.io.tmemC.cWen.expect(true.B)
c.io.tmemC.cWaddr.expect((warp3TmemA / fragBytes).U)
stepTmem(c, tmem)
c.clock.step()

View File

@@ -25,7 +25,11 @@ class TensorCoreBlackwellTest extends AnyFlatSpec with ChiselScalatestTester {
c.io.reqB.ready.poke(false.B)
c.io.respC.poke(0.U)
c.io.writeback.ready.poke(false.B)
c.io.tmemC.rdata.poke(0.U)
c.io.tmemC.aRready.poke(true.B)
c.io.tmemC.aRdata.poke(0.U)
c.io.tmemC.cRready.poke(true.B)
c.io.tmemC.cRdata.poke(0.U)
c.io.tmemC.cWready.poke(true.B)
}
private def packWords(words: Seq[BigInt], width: Int): BigInt = {
@@ -38,15 +42,19 @@ class TensorCoreBlackwellTest extends AnyFlatSpec with ChiselScalatestTester {
// Simple TMEM model: address → 256-bit row
private def makeTmem() = mutable.Map[BigInt, BigInt]().withDefaultValue(BigInt(0))
// Drive tmemC read response from model, handle write
// Drive TMEM read responses from model, handle C-port writes.
private def stepTmem(c: TensorCoreBlackwell, tmem: mutable.Map[BigInt, BigInt]): Unit = {
if (c.io.tmemC.ren.peek().litToBoolean) {
val addr = c.io.tmemC.raddr.peek().litValue
c.io.tmemC.rdata.poke(tmem(addr).U)
if (c.io.tmemC.aRen.peek().litToBoolean) {
val addr = c.io.tmemC.aRaddr.peek().litValue
c.io.tmemC.aRdata.poke(tmem(addr).U)
}
if (c.io.tmemC.wen.peek().litToBoolean) {
val addr = c.io.tmemC.waddr.peek().litValue
tmem(addr) = c.io.tmemC.wdata.peek().litValue
if (c.io.tmemC.cRen.peek().litToBoolean) {
val addr = c.io.tmemC.cRaddr.peek().litValue
c.io.tmemC.cRdata.poke(tmem(addr).U)
}
if (c.io.tmemC.cWen.peek().litToBoolean) {
val addr = c.io.tmemC.cWaddr.peek().litValue
tmem(addr) = c.io.tmemC.cWdata.peek().litValue
}
}
@@ -65,19 +73,19 @@ class TensorCoreBlackwellTest extends AnyFlatSpec with ChiselScalatestTester {
c.io.initiate.bits.rd.poke(3.U)
c.io.initiate.bits.addressA.poke(tmemAddr.U)
c.io.writeback.ready.poke(true.B)
c.io.tmemC.rdata.poke(testData.U)
c.io.tmemC.cRdata.poke(testData.U)
c.clock.step()
c.io.initiate.valid.poke(false.B)
c.io.initiate.ready.expect(false.B)
// ldReq: tmemC.ren asserted; rdata must be valid before next step
c.io.tmemC.ren.expect(true.B)
c.io.tmemC.raddr.expect((tmemAddr / fragBytes).U)
c.io.tmemC.rdata.poke(testData.U)
c.io.tmemC.cRen.expect(true.B)
c.io.tmemC.cRaddr.expect((tmemAddr / fragBytes).U)
c.io.tmemC.cRdata.poke(testData.U)
c.clock.step()
// waitWb: wbValid gets set this cycle, step to let it register
c.io.tmemC.rdata.poke(testData.U)
c.io.tmemC.cRdata.poke(testData.U)
c.clock.step()
// idle: writeback.valid now true
@@ -91,6 +99,38 @@ class TensorCoreBlackwellTest extends AnyFlatSpec with ChiselScalatestTester {
}
}
it should "tcgen05_ld: support 4-lane 16-byte fragments" in {
val lanes = 4
test(new TensorCoreBlackwell(numWarps, lanes, half = true, numSourceIds = 4)) { c =>
idleIO(c)
val fragBytes = 16
val tmemAddr = BigInt(0x40)
val testData = packWords(Seq.tabulate(lanes)(i => BigInt(0x2000 + i)), 32)
c.io.initiate.valid.poke(true.B)
c.io.initiate.bits.op.poke(4.U) // tcgen05Ld
c.io.initiate.bits.wid.poke(0.U)
c.io.initiate.bits.rd.poke(3.U)
c.io.initiate.bits.addressA.poke(tmemAddr.U)
c.io.writeback.ready.poke(true.B)
c.clock.step()
c.io.initiate.valid.poke(false.B)
c.io.tmemC.cRen.expect(true.B)
c.io.tmemC.cRaddr.expect((tmemAddr / fragBytes).U)
c.io.tmemC.cRdata.poke(testData.U)
c.clock.step()
c.io.tmemC.cRdata.poke(testData.U)
c.clock.step()
c.io.writeback.valid.expect(true.B)
c.io.writeback.bits.rd.expect(3.U)
for (i <- 0 until lanes) {
c.io.writeback.bits.data(i).expect((0x2000 + i).U)
}
}
}
it should "tcgen05_st: write from respC to TMEM" in {
test(new TensorCoreBlackwell(numWarps, numLanes, half = true, numSourceIds = 4)) { c =>
idleIO(c)
@@ -114,9 +154,9 @@ class TensorCoreBlackwellTest extends AnyFlatSpec with ChiselScalatestTester {
c.clock.step()
// stWrite: tmemC.wen asserted with storeData
c.io.tmemC.wen.expect(true.B)
c.io.tmemC.waddr.expect((tmemAddr / fragBytes).U)
c.io.tmemC.wdata.expect(storeData.U)
c.io.tmemC.cWen.expect(true.B)
c.io.tmemC.cWaddr.expect((tmemAddr / fragBytes).U)
c.io.tmemC.cWdata.expect(storeData.U)
c.clock.step()
c.io.initiate.ready.expect(true.B)
}
@@ -151,9 +191,9 @@ class TensorCoreBlackwellTest extends AnyFlatSpec with ChiselScalatestTester {
c.io.respA.bits.data.poke(cpData.U)
// tmemC write happens combinatorially when respA fires
c.io.tmemC.wen.expect(true.B)
c.io.tmemC.waddr.expect((tmemAddr / fragBytes).U)
c.io.tmemC.wdata.expect(cpData.U)
c.io.tmemC.cWen.expect(true.B)
c.io.tmemC.cWaddr.expect((tmemAddr / fragBytes).U)
c.io.tmemC.cWdata.expect(cpData.U)
c.clock.step()
c.io.initiate.ready.expect(true.B)
}
@@ -172,14 +212,16 @@ class TensorCoreBlackwellTest extends AnyFlatSpec with ChiselScalatestTester {
c.io.initiate.bits.addressA.poke(tmemAddr.U)
c.io.initiate.bits.addressB.poke(gmemAddr.U)
c.io.reqA.ready.poke(true.B)
c.io.tmemC.rdata.poke(cbData.U)
c.io.tmemC.cRdata.poke(cbData.U)
c.clock.step()
c.io.initiate.valid.poke(false.B)
c.io.initiate.ready.expect(false.B)
// cbRead: tmemC.ren asserted
c.io.tmemC.ren.expect(true.B)
c.io.tmemC.raddr.expect((tmemAddr / fragBytes).U)
c.io.tmemC.cRen.expect(true.B)
c.io.tmemC.cRaddr.expect((tmemAddr / fragBytes).U)
c.clock.step()
c.io.tmemC.cRdata.poke(cbData.U)
c.clock.step()
c.io.initiate.ready.expect(false.B)