Add 4-lane Blackwell tensor core support
This commit is contained in:
Submodule src/main/resources/vsrc/vortex updated: c87fea5c48...abee301b6e
@@ -14,7 +14,8 @@ class TensorCoreBlackwell(
|
||||
val numFPRegs: Int = 32
|
||||
) extends Module {
|
||||
require(half, "Blackwell MMA currently supports FP16 inputs only")
|
||||
require(numLanes == 8, "Blackwell MMA currently assumes 8 lanes")
|
||||
require(numLanes == 4 || numLanes == 8,
|
||||
s"Blackwell MMA currently supports 4 or 8 lanes, got ${numLanes}")
|
||||
|
||||
val numWarpBits = log2Ceil(numWarps)
|
||||
val sourceWidth = log2Ceil(numSourceIds)
|
||||
@@ -26,11 +27,16 @@ class TensorCoreBlackwell(
|
||||
val fragOffsetBits = log2Ceil(memWidth / 8)
|
||||
|
||||
val numSets = 4
|
||||
val numAFragsPerSet = 8
|
||||
val numBGroups = 4
|
||||
val numBFragsPerGroup = 2
|
||||
val numMGroups = 4
|
||||
val numCFrags = 32
|
||||
val numSubsteps = 2
|
||||
val mElemsPerFrag = if (numLanes == 4) 2 else 4
|
||||
val numMGroups = 16 / mElemsPerFrag
|
||||
val numAFragsPerMGroup = 2
|
||||
val numAFragsPerSet = numMGroups * numAFragsPerMGroup
|
||||
val numBFragsPerSubstep = if (numLanes == 4) 2 else 1
|
||||
val numBFragsPerGroup = numSubsteps * numBFragsPerSubstep
|
||||
val numBFragsPerSet = numBGroups * numBFragsPerGroup
|
||||
val numCFrags = numBGroups * numMGroups * numSubsteps
|
||||
|
||||
object Ops {
|
||||
val bwgmma :: bwgmmaWait :: tcgen05Cp :: tcgen05CpWait :: tcgen05Ld :: tcgen05St :: tcgen05Cb :: Nil = Enum(7)
|
||||
@@ -136,10 +142,11 @@ class TensorCoreBlackwell(
|
||||
base + (fragIndex << fragOffsetBits).asUInt
|
||||
}
|
||||
|
||||
val aFragIndex = (setReg << 3) + aIndexReg
|
||||
val bFragIndex = (setReg << 3) + (bGroupReg << 1) + bIndexReg
|
||||
val stepIndex = Cat(bGroupReg, mGroupReg)
|
||||
val cFragIndex = (stepIndex << 1) + substepReg
|
||||
val aFragIndex = (setReg * numAFragsPerSet.U) + aIndexReg
|
||||
val bFragIndex =
|
||||
(setReg * numBFragsPerSet.U) + (bGroupReg * numBFragsPerGroup.U) + bIndexReg
|
||||
val cFragIndex =
|
||||
(((bGroupReg * numMGroups.U) + mGroupReg) * numSubsteps.U) + substepReg
|
||||
val aReqAddress = byteAddress(addrAReg, aFragIndex)
|
||||
val bReqAddress = byteAddress(addrBReg, bFragIndex)
|
||||
val cReqAddress = byteAddress(addrCReg, cFragIndex)
|
||||
@@ -181,7 +188,12 @@ class TensorCoreBlackwell(
|
||||
io.initiate.ready := state === State.idle && !wbValid
|
||||
|
||||
val operandA = Cat(aBuf((mGroupReg << 1) + 1.U), aBuf(mGroupReg << 1))
|
||||
val operandB = bBuf(substepReg)
|
||||
val operandB =
|
||||
if (numLanes == 4) {
|
||||
Cat(bBuf((substepReg << 1) + 1.U), bBuf(substepReg << 1))
|
||||
} else {
|
||||
bBuf(substepReg)
|
||||
}
|
||||
val cWords = cDataReg.asTypeOf(Vec(numLanes, UInt(laneWidth.W)))
|
||||
val dpuInValid = WireDefault(false.B)
|
||||
val dpu = Module(new TensorDotProductUnit(
|
||||
@@ -193,16 +205,22 @@ class TensorCoreBlackwell(
|
||||
x((idx + 1) * 16 - 1, idx * 16)
|
||||
}
|
||||
|
||||
val elemM = elemReg(1, 0)
|
||||
val elemN = elemReg(2)
|
||||
val elemM = if (numLanes == 4) elemReg(0, 0) else elemReg(1, 0)
|
||||
val elemN = if (numLanes == 4) elemReg(1) else elemReg(2)
|
||||
dpu.io.in.valid := dpuInValid
|
||||
for (k <- 0 until 8) {
|
||||
dpu.io.in.bits.a(k) := MuxLookup(elemM, halfWord(operandA, k))(Seq(
|
||||
0.U -> halfWord(operandA, k),
|
||||
1.U -> halfWord(operandA, 8 + k),
|
||||
2.U -> halfWord(operandA, 16 + k),
|
||||
3.U -> halfWord(operandA, 24 + k)
|
||||
))
|
||||
dpu.io.in.bits.a(k) := (
|
||||
if (numLanes == 4) {
|
||||
Mux(elemM.asBool, halfWord(operandA, 8 + k), halfWord(operandA, k))
|
||||
} else {
|
||||
MuxLookup(elemM, halfWord(operandA, k))(Seq(
|
||||
0.U -> halfWord(operandA, k),
|
||||
1.U -> halfWord(operandA, 8 + k),
|
||||
2.U -> halfWord(operandA, 16 + k),
|
||||
3.U -> halfWord(operandA, 24 + k)
|
||||
))
|
||||
}
|
||||
)
|
||||
dpu.io.in.bits.b(k) := Mux(elemN.asBool, halfWord(operandB, 8 + k), halfWord(operandB, k))
|
||||
}
|
||||
dpu.io.in.bits.c := cWords(elemReg)
|
||||
|
||||
@@ -288,13 +288,17 @@ class RadianceTile private (
|
||||
)
|
||||
}
|
||||
|
||||
val tcSmemSize = 32
|
||||
val tcSmemSize = numLsuLanes * 4
|
||||
val numTensorWarps = radianceParams.core.numTensorWarps
|
||||
val numScalarWarps = numWarps - numTensorWarps
|
||||
require(numTensorWarps > 0 && numTensorWarps < numWarps,
|
||||
s"Wu requires 0 < numTensorWarps (${numTensorWarps}) < numWarps (${numWarps})")
|
||||
val numTensorCores = if (radianceParams.core.tensorCoreBlackwell) numTensorWarps else 1
|
||||
if (radianceParams.core.tensorCoreBlackwell) {
|
||||
require(numCoreLanes == numLsuLanes,
|
||||
s"Wu Blackwell binding requires matching core lanes (${numCoreLanes}) and memory lanes (${numLsuLanes})")
|
||||
require(numLsuLanes == 4 || numLsuLanes == 8,
|
||||
s"Wu Blackwell binding supports 4 or 8 lanes, got ${numLsuLanes}")
|
||||
require(numTensorCores == numTensorWarps, "Wu Blackwell binding requires one Tensor Core per Tensor warp")
|
||||
}
|
||||
val tensorUsesAsyncMem = radianceParams.core.tensorCoreDecoupled || radianceParams.core.tensorCoreBlackwell
|
||||
@@ -852,8 +856,9 @@ class RadianceTileModuleImp(outer: RadianceTile)
|
||||
val tcPorts = 3
|
||||
val tcDataBits = outer.tcSmemSize * 8
|
||||
val tmemAddrBits = 9
|
||||
val tmemDataBits = outer.numLsuLanes * 32
|
||||
val tmemMaskBits = outer.numLsuLanes * 4
|
||||
val tmemDataBits = tcDataBits
|
||||
val tmemMaskBits = outer.tcSmemSize
|
||||
val tcTlSize = log2Ceil(outer.tcSmemSize)
|
||||
|
||||
def slice(u: UInt, width: Int, idx: Int): UInt = u(width * (idx + 1) - 1, width * idx)
|
||||
def port(tc: Int, p: Int): Int = tc * tcPorts + p
|
||||
@@ -868,8 +873,10 @@ class RadianceTileModuleImp(outer: RadianceTile)
|
||||
tcDTag.foreach(_ := 0.U)
|
||||
|
||||
// TMEM matrix: one shared 2R1W SRAM. read0 is operand A, read1 is C.
|
||||
// Each warp needs 2 tiles (A + C), each tile = 32 frags × 32B = 1KB
|
||||
val tmemDepth = outer.numWarps * outer.tcSmemSize * 2 // numWarps × 64 rows
|
||||
// Each warp owns 2KB: A tile and C tile are 1KB each. The row count
|
||||
// scales with the physical fragment width (16B for 4 lanes, 32B for 8).
|
||||
val tmemBytesPerWarp = 2048
|
||||
val tmemDepth = outer.numWarps * (tmemBytesPerWarp / outer.tcSmemSize)
|
||||
val tmem = Module(new radiance.memory.TwoReadOneWriteSyncMem(
|
||||
tmemDepth, UInt((outer.tcSmemSize * 8).W)))
|
||||
|
||||
@@ -933,7 +940,7 @@ class RadianceTileModuleImp(outer: RadianceTile)
|
||||
adapter.io.inReq.valid := core.io.tc_a_valid(p2)
|
||||
adapter.io.inReq.bits.address := slice(core.io.tc_a_bits_address, 32, p2)
|
||||
adapter.io.inReq.bits.source := slice(core.io.tc_a_bits_tag, outer.tensorTagWidth, p2)
|
||||
adapter.io.inReq.bits.size := 5.U
|
||||
adapter.io.inReq.bits.size := tcTlSize.U
|
||||
adapter.io.inReq.bits.opcode := Mux(core.io.tc_a_bits_write(p2).asBool, TLMessages.PutFullData, TLMessages.Get)
|
||||
adapter.io.inReq.bits.mask := slice(core.io.tc_a_bits_mask, 32, p2)
|
||||
adapter.io.inReq.bits.data := slice(core.io.tc_a_bits_data, tcDataBits, p2)
|
||||
@@ -961,7 +968,7 @@ class RadianceTileModuleImp(outer: RadianceTile)
|
||||
gmemAdapter.io.inReq.valid := core.io.tc_a_valid(p0)
|
||||
gmemAdapter.io.inReq.bits.address := slice(core.io.tc_a_bits_address, 32, p0)
|
||||
gmemAdapter.io.inReq.bits.source := slice(core.io.tc_a_bits_tag, outer.tensorTagWidth, p0)
|
||||
gmemAdapter.io.inReq.bits.size := 5.U
|
||||
gmemAdapter.io.inReq.bits.size := tcTlSize.U
|
||||
gmemAdapter.io.inReq.bits.opcode := Mux(core.io.tc_a_bits_write(p0).asBool, TLMessages.PutFullData, TLMessages.Get)
|
||||
gmemAdapter.io.inReq.bits.mask := slice(core.io.tc_a_bits_mask, 32, p0)
|
||||
gmemAdapter.io.inReq.bits.data := slice(core.io.tc_a_bits_data, tcDataBits, p0)
|
||||
@@ -1082,7 +1089,7 @@ class RadianceTileModuleImp(outer: RadianceTile)
|
||||
} else if (outer.radianceParams.core.tensorCoreBlackwell) {
|
||||
val tensorNumSourceIds = (1 << outer.tensorTagWidth)
|
||||
val tensor = Module(new radiance.core.TensorCoreBlackwell(
|
||||
8, 8, half = true, tensorNumSourceIds))
|
||||
outer.numWarps, outer.numLsuLanes, half = true, tensorNumSourceIds))
|
||||
tensor.io.initiate.valid := false.B
|
||||
tensor.io.initiate.bits := DontCare
|
||||
tensor.io.respA.valid := false.B
|
||||
|
||||
@@ -26,7 +26,11 @@ class TensorCoreBlackwellExtendedTest extends AnyFlatSpec with ChiselScalatestTe
|
||||
c.io.reqB.ready.poke(false.B)
|
||||
c.io.respC.poke(0.U)
|
||||
c.io.writeback.ready.poke(false.B)
|
||||
c.io.tmemC.rdata.poke(0.U)
|
||||
c.io.tmemC.aRready.poke(true.B)
|
||||
c.io.tmemC.aRdata.poke(0.U)
|
||||
c.io.tmemC.cRready.poke(true.B)
|
||||
c.io.tmemC.cRdata.poke(0.U)
|
||||
c.io.tmemC.cWready.poke(true.B)
|
||||
}
|
||||
|
||||
private def packWords(words: Seq[BigInt], width: Int): BigInt = {
|
||||
@@ -39,13 +43,17 @@ class TensorCoreBlackwellExtendedTest extends AnyFlatSpec with ChiselScalatestTe
|
||||
private def makeTmem() = mutable.Map[BigInt, BigInt]().withDefaultValue(BigInt(0))
|
||||
|
||||
private def stepTmem(c: TensorCoreBlackwell, tmem: mutable.Map[BigInt, BigInt]): Unit = {
|
||||
if (c.io.tmemC.ren.peek().litToBoolean) {
|
||||
val addr = c.io.tmemC.raddr.peek().litValue
|
||||
c.io.tmemC.rdata.poke(tmem(addr).U)
|
||||
if (c.io.tmemC.aRen.peek().litToBoolean) {
|
||||
val addr = c.io.tmemC.aRaddr.peek().litValue
|
||||
c.io.tmemC.aRdata.poke(tmem(addr).U)
|
||||
}
|
||||
if (c.io.tmemC.wen.peek().litToBoolean) {
|
||||
val addr = c.io.tmemC.waddr.peek().litValue
|
||||
tmem(addr) = c.io.tmemC.wdata.peek().litValue
|
||||
if (c.io.tmemC.cRen.peek().litToBoolean) {
|
||||
val addr = c.io.tmemC.cRaddr.peek().litValue
|
||||
c.io.tmemC.cRdata.poke(tmem(addr).U)
|
||||
}
|
||||
if (c.io.tmemC.cWen.peek().litToBoolean) {
|
||||
val addr = c.io.tmemC.cWaddr.peek().litValue
|
||||
tmem(addr) = c.io.tmemC.cWdata.peek().litValue
|
||||
}
|
||||
}
|
||||
|
||||
@@ -154,9 +162,9 @@ class TensorCoreBlackwellExtendedTest extends AnyFlatSpec with ChiselScalatestTe
|
||||
// cpWrite: respA fires, tmemC written
|
||||
c.io.respA.valid.poke(true.B)
|
||||
c.io.respA.bits.data.poke(cpData.U)
|
||||
c.io.tmemC.wen.expect(true.B)
|
||||
c.io.tmemC.waddr.expect((tmemAddr / fragBytes).U)
|
||||
c.io.tmemC.wdata.expect(cpData.U)
|
||||
c.io.tmemC.cWen.expect(true.B)
|
||||
c.io.tmemC.cWaddr.expect((tmemAddr / fragBytes).U)
|
||||
c.io.tmemC.cWdata.expect(cpData.U)
|
||||
stepTmem(c, tmem)
|
||||
c.clock.step()
|
||||
c.io.respA.valid.poke(false.B)
|
||||
@@ -171,10 +179,10 @@ class TensorCoreBlackwellExtendedTest extends AnyFlatSpec with ChiselScalatestTe
|
||||
c.io.initiate.valid.poke(false.B)
|
||||
|
||||
// ldReq: ren asserted, serve from tmem model
|
||||
c.io.tmemC.ren.expect(true.B)
|
||||
c.io.tmemC.rdata.poke(tmem(tmemAddr / fragBytes).U)
|
||||
c.io.tmemC.cRen.expect(true.B)
|
||||
c.io.tmemC.cRdata.poke(tmem(tmemAddr / fragBytes).U)
|
||||
c.clock.step()
|
||||
c.io.tmemC.rdata.poke(tmem(tmemAddr / fragBytes).U)
|
||||
c.io.tmemC.cRdata.poke(tmem(tmemAddr / fragBytes).U)
|
||||
c.clock.step()
|
||||
|
||||
// writeback should carry cpData
|
||||
@@ -206,8 +214,8 @@ class TensorCoreBlackwellExtendedTest extends AnyFlatSpec with ChiselScalatestTe
|
||||
c.clock.step()
|
||||
|
||||
// stWrite: tmemC written
|
||||
c.io.tmemC.wen.expect(true.B)
|
||||
c.io.tmemC.wdata.expect(stData.U)
|
||||
c.io.tmemC.cWen.expect(true.B)
|
||||
c.io.tmemC.cWdata.expect(stData.U)
|
||||
stepTmem(c, tmem)
|
||||
c.clock.step()
|
||||
|
||||
@@ -217,13 +225,15 @@ class TensorCoreBlackwellExtendedTest extends AnyFlatSpec with ChiselScalatestTe
|
||||
c.io.initiate.bits.addressA.poke(tmemAddr.U)
|
||||
c.io.initiate.bits.addressB.poke("h20000000".U)
|
||||
c.io.reqA.ready.poke(true.B)
|
||||
c.io.tmemC.rdata.poke(tmem(tmemAddr / fragBytes).U)
|
||||
c.io.tmemC.cRdata.poke(tmem(tmemAddr / fragBytes).U)
|
||||
c.clock.step()
|
||||
c.io.initiate.valid.poke(false.B)
|
||||
|
||||
// cbRead: ren asserted
|
||||
c.io.tmemC.ren.expect(true.B)
|
||||
c.io.tmemC.rdata.poke(tmem(tmemAddr / fragBytes).U)
|
||||
c.io.tmemC.cRen.expect(true.B)
|
||||
c.io.tmemC.cRdata.poke(tmem(tmemAddr / fragBytes).U)
|
||||
c.clock.step()
|
||||
c.io.tmemC.cRdata.poke(tmem(tmemAddr / fragBytes).U)
|
||||
c.clock.step()
|
||||
|
||||
// cbWrite: reqA write with stData
|
||||
@@ -280,7 +290,7 @@ class TensorCoreBlackwellExtendedTest extends AnyFlatSpec with ChiselScalatestTe
|
||||
c.clock.step()
|
||||
c.io.initiate.ready.expect(false.B)
|
||||
|
||||
c.io.tmemC.wen.expect(true.B)
|
||||
c.io.tmemC.cWen.expect(true.B)
|
||||
c.clock.step()
|
||||
c.io.initiate.ready.expect(true.B)
|
||||
}
|
||||
@@ -309,8 +319,8 @@ class TensorCoreBlackwellExtendedTest extends AnyFlatSpec with ChiselScalatestTe
|
||||
c.io.initiate.valid.poke(false.B)
|
||||
c.io.reqC.valid.expect(true.B)
|
||||
c.clock.step()
|
||||
c.io.tmemC.wen.expect(true.B)
|
||||
c.io.tmemC.waddr.expect((warp0TmemA / fragBytes).U)
|
||||
c.io.tmemC.cWen.expect(true.B)
|
||||
c.io.tmemC.cWaddr.expect((warp0TmemA / fragBytes).U)
|
||||
stepTmem(c, tmem)
|
||||
c.clock.step()
|
||||
|
||||
@@ -324,8 +334,8 @@ class TensorCoreBlackwellExtendedTest extends AnyFlatSpec with ChiselScalatestTe
|
||||
c.io.initiate.valid.poke(false.B)
|
||||
c.io.reqC.valid.expect(true.B)
|
||||
c.clock.step()
|
||||
c.io.tmemC.wen.expect(true.B)
|
||||
c.io.tmemC.waddr.expect((warp3TmemA / fragBytes).U)
|
||||
c.io.tmemC.cWen.expect(true.B)
|
||||
c.io.tmemC.cWaddr.expect((warp3TmemA / fragBytes).U)
|
||||
stepTmem(c, tmem)
|
||||
c.clock.step()
|
||||
|
||||
|
||||
@@ -25,7 +25,11 @@ class TensorCoreBlackwellTest extends AnyFlatSpec with ChiselScalatestTester {
|
||||
c.io.reqB.ready.poke(false.B)
|
||||
c.io.respC.poke(0.U)
|
||||
c.io.writeback.ready.poke(false.B)
|
||||
c.io.tmemC.rdata.poke(0.U)
|
||||
c.io.tmemC.aRready.poke(true.B)
|
||||
c.io.tmemC.aRdata.poke(0.U)
|
||||
c.io.tmemC.cRready.poke(true.B)
|
||||
c.io.tmemC.cRdata.poke(0.U)
|
||||
c.io.tmemC.cWready.poke(true.B)
|
||||
}
|
||||
|
||||
private def packWords(words: Seq[BigInt], width: Int): BigInt = {
|
||||
@@ -38,15 +42,19 @@ class TensorCoreBlackwellTest extends AnyFlatSpec with ChiselScalatestTester {
|
||||
// Simple TMEM model: address → 256-bit row
|
||||
private def makeTmem() = mutable.Map[BigInt, BigInt]().withDefaultValue(BigInt(0))
|
||||
|
||||
// Drive tmemC read response from model, handle write
|
||||
// Drive TMEM read responses from model, handle C-port writes.
|
||||
private def stepTmem(c: TensorCoreBlackwell, tmem: mutable.Map[BigInt, BigInt]): Unit = {
|
||||
if (c.io.tmemC.ren.peek().litToBoolean) {
|
||||
val addr = c.io.tmemC.raddr.peek().litValue
|
||||
c.io.tmemC.rdata.poke(tmem(addr).U)
|
||||
if (c.io.tmemC.aRen.peek().litToBoolean) {
|
||||
val addr = c.io.tmemC.aRaddr.peek().litValue
|
||||
c.io.tmemC.aRdata.poke(tmem(addr).U)
|
||||
}
|
||||
if (c.io.tmemC.wen.peek().litToBoolean) {
|
||||
val addr = c.io.tmemC.waddr.peek().litValue
|
||||
tmem(addr) = c.io.tmemC.wdata.peek().litValue
|
||||
if (c.io.tmemC.cRen.peek().litToBoolean) {
|
||||
val addr = c.io.tmemC.cRaddr.peek().litValue
|
||||
c.io.tmemC.cRdata.poke(tmem(addr).U)
|
||||
}
|
||||
if (c.io.tmemC.cWen.peek().litToBoolean) {
|
||||
val addr = c.io.tmemC.cWaddr.peek().litValue
|
||||
tmem(addr) = c.io.tmemC.cWdata.peek().litValue
|
||||
}
|
||||
}
|
||||
|
||||
@@ -65,19 +73,19 @@ class TensorCoreBlackwellTest extends AnyFlatSpec with ChiselScalatestTester {
|
||||
c.io.initiate.bits.rd.poke(3.U)
|
||||
c.io.initiate.bits.addressA.poke(tmemAddr.U)
|
||||
c.io.writeback.ready.poke(true.B)
|
||||
c.io.tmemC.rdata.poke(testData.U)
|
||||
c.io.tmemC.cRdata.poke(testData.U)
|
||||
c.clock.step()
|
||||
c.io.initiate.valid.poke(false.B)
|
||||
c.io.initiate.ready.expect(false.B)
|
||||
|
||||
// ldReq: tmemC.ren asserted; rdata must be valid before next step
|
||||
c.io.tmemC.ren.expect(true.B)
|
||||
c.io.tmemC.raddr.expect((tmemAddr / fragBytes).U)
|
||||
c.io.tmemC.rdata.poke(testData.U)
|
||||
c.io.tmemC.cRen.expect(true.B)
|
||||
c.io.tmemC.cRaddr.expect((tmemAddr / fragBytes).U)
|
||||
c.io.tmemC.cRdata.poke(testData.U)
|
||||
c.clock.step()
|
||||
|
||||
// waitWb: wbValid gets set this cycle, step to let it register
|
||||
c.io.tmemC.rdata.poke(testData.U)
|
||||
c.io.tmemC.cRdata.poke(testData.U)
|
||||
c.clock.step()
|
||||
|
||||
// idle: writeback.valid now true
|
||||
@@ -91,6 +99,38 @@ class TensorCoreBlackwellTest extends AnyFlatSpec with ChiselScalatestTester {
|
||||
}
|
||||
}
|
||||
|
||||
it should "tcgen05_ld: support 4-lane 16-byte fragments" in {
|
||||
val lanes = 4
|
||||
test(new TensorCoreBlackwell(numWarps, lanes, half = true, numSourceIds = 4)) { c =>
|
||||
idleIO(c)
|
||||
val fragBytes = 16
|
||||
val tmemAddr = BigInt(0x40)
|
||||
val testData = packWords(Seq.tabulate(lanes)(i => BigInt(0x2000 + i)), 32)
|
||||
|
||||
c.io.initiate.valid.poke(true.B)
|
||||
c.io.initiate.bits.op.poke(4.U) // tcgen05Ld
|
||||
c.io.initiate.bits.wid.poke(0.U)
|
||||
c.io.initiate.bits.rd.poke(3.U)
|
||||
c.io.initiate.bits.addressA.poke(tmemAddr.U)
|
||||
c.io.writeback.ready.poke(true.B)
|
||||
c.clock.step()
|
||||
c.io.initiate.valid.poke(false.B)
|
||||
|
||||
c.io.tmemC.cRen.expect(true.B)
|
||||
c.io.tmemC.cRaddr.expect((tmemAddr / fragBytes).U)
|
||||
c.io.tmemC.cRdata.poke(testData.U)
|
||||
c.clock.step()
|
||||
c.io.tmemC.cRdata.poke(testData.U)
|
||||
c.clock.step()
|
||||
|
||||
c.io.writeback.valid.expect(true.B)
|
||||
c.io.writeback.bits.rd.expect(3.U)
|
||||
for (i <- 0 until lanes) {
|
||||
c.io.writeback.bits.data(i).expect((0x2000 + i).U)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
it should "tcgen05_st: write from respC to TMEM" in {
|
||||
test(new TensorCoreBlackwell(numWarps, numLanes, half = true, numSourceIds = 4)) { c =>
|
||||
idleIO(c)
|
||||
@@ -114,9 +154,9 @@ class TensorCoreBlackwellTest extends AnyFlatSpec with ChiselScalatestTester {
|
||||
c.clock.step()
|
||||
|
||||
// stWrite: tmemC.wen asserted with storeData
|
||||
c.io.tmemC.wen.expect(true.B)
|
||||
c.io.tmemC.waddr.expect((tmemAddr / fragBytes).U)
|
||||
c.io.tmemC.wdata.expect(storeData.U)
|
||||
c.io.tmemC.cWen.expect(true.B)
|
||||
c.io.tmemC.cWaddr.expect((tmemAddr / fragBytes).U)
|
||||
c.io.tmemC.cWdata.expect(storeData.U)
|
||||
c.clock.step()
|
||||
c.io.initiate.ready.expect(true.B)
|
||||
}
|
||||
@@ -151,9 +191,9 @@ class TensorCoreBlackwellTest extends AnyFlatSpec with ChiselScalatestTester {
|
||||
c.io.respA.bits.data.poke(cpData.U)
|
||||
|
||||
// tmemC write happens combinatorially when respA fires
|
||||
c.io.tmemC.wen.expect(true.B)
|
||||
c.io.tmemC.waddr.expect((tmemAddr / fragBytes).U)
|
||||
c.io.tmemC.wdata.expect(cpData.U)
|
||||
c.io.tmemC.cWen.expect(true.B)
|
||||
c.io.tmemC.cWaddr.expect((tmemAddr / fragBytes).U)
|
||||
c.io.tmemC.cWdata.expect(cpData.U)
|
||||
c.clock.step()
|
||||
c.io.initiate.ready.expect(true.B)
|
||||
}
|
||||
@@ -172,14 +212,16 @@ class TensorCoreBlackwellTest extends AnyFlatSpec with ChiselScalatestTester {
|
||||
c.io.initiate.bits.addressA.poke(tmemAddr.U)
|
||||
c.io.initiate.bits.addressB.poke(gmemAddr.U)
|
||||
c.io.reqA.ready.poke(true.B)
|
||||
c.io.tmemC.rdata.poke(cbData.U)
|
||||
c.io.tmemC.cRdata.poke(cbData.U)
|
||||
c.clock.step()
|
||||
c.io.initiate.valid.poke(false.B)
|
||||
c.io.initiate.ready.expect(false.B)
|
||||
|
||||
// cbRead: tmemC.ren asserted
|
||||
c.io.tmemC.ren.expect(true.B)
|
||||
c.io.tmemC.raddr.expect((tmemAddr / fragBytes).U)
|
||||
c.io.tmemC.cRen.expect(true.B)
|
||||
c.io.tmemC.cRaddr.expect((tmemAddr / fragBytes).U)
|
||||
c.clock.step()
|
||||
c.io.tmemC.cRdata.poke(cbData.U)
|
||||
c.clock.step()
|
||||
c.io.initiate.ready.expect(false.B)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user