- Implement TensorCoreBlackwell.scala with BWGMMA and TCGEN05 instructions - Update TensorDPU, RadianceTile, and VortexCore for Blackwell integration - Add TensorCoreBlackwellExtendedTest for comprehensive testing - Update vortex submodule with Blackwell ISA support
339 lines
11 KiB
Scala
339 lines
11 KiB
Scala
package radiance.core
|
||
|
||
import chisel3._
|
||
import chiseltest._
|
||
import chiseltest.simulator.VerilatorBackendAnnotation
|
||
import org.scalatest.flatspec.AnyFlatSpec
|
||
|
||
import scala.collection.mutable
|
||
|
||
class TensorCoreBlackwellExtendedTest extends AnyFlatSpec with ChiselScalatestTester {
|
||
behavior of "TensorCoreBlackwell Extended Tests"
|
||
|
||
private val numWarps = 4
|
||
private val numLanes = 8
|
||
private val fragBytes = 32
|
||
|
||
private def idleIO(c: TensorCoreBlackwell): Unit = {
|
||
c.io.initiate.valid.poke(false.B)
|
||
c.io.respA.valid.poke(false.B)
|
||
c.io.respB.valid.poke(false.B)
|
||
c.io.respA.bits.source.poke(0.U)
|
||
c.io.respB.bits.source.poke(0.U)
|
||
c.io.respA.bits.data.poke(0.U)
|
||
c.io.respB.bits.data.poke(0.U)
|
||
c.io.reqA.ready.poke(false.B)
|
||
c.io.reqB.ready.poke(false.B)
|
||
c.io.respC.poke(0.U)
|
||
c.io.writeback.ready.poke(false.B)
|
||
c.io.tmemC.rdata.poke(0.U)
|
||
}
|
||
|
||
private def packWords(words: Seq[BigInt], width: Int): BigInt = {
|
||
val mask = (BigInt(1) << width) - 1
|
||
words.zipWithIndex.foldLeft(BigInt(0)) {
|
||
case (acc, (word, i)) => acc | ((word & mask) << (i * width))
|
||
}
|
||
}
|
||
|
||
private def makeTmem() = mutable.Map[BigInt, BigInt]().withDefaultValue(BigInt(0))
|
||
|
||
private def stepTmem(c: TensorCoreBlackwell, tmem: mutable.Map[BigInt, BigInt]): Unit = {
|
||
if (c.io.tmemC.ren.peek().litToBoolean) {
|
||
val addr = c.io.tmemC.raddr.peek().litValue
|
||
c.io.tmemC.rdata.poke(tmem(addr).U)
|
||
}
|
||
if (c.io.tmemC.wen.peek().litToBoolean) {
|
||
val addr = c.io.tmemC.waddr.peek().litValue
|
||
tmem(addr) = c.io.tmemC.wdata.peek().litValue
|
||
}
|
||
}
|
||
|
||
it should "verify bwgmma address offset with non-zero base addresses" in {
|
||
test(new TensorCoreBlackwell(numWarps, numLanes, half = true, numSourceIds = 4))
|
||
.withAnnotations(Seq(VerilatorBackendAnnotation)) { c =>
|
||
idleIO(c)
|
||
val tmem = makeTmem()
|
||
|
||
// Use non-zero base addresses to verify offset calculation
|
||
val aBase = BigInt(0x200) // row 16, A tile rows 16~47
|
||
val cBase = BigInt(0x600) // row 48, C tile rows 48~79 (no overlap with A)
|
||
val bBase = BigInt(0x800)
|
||
|
||
val fp16One = BigInt(0x3c00)
|
||
val fp32Zero = BigInt(0)
|
||
// 4 sets × 8 dot products × (1.0 × 2.0) = 64.0f
|
||
val fp32SixtyFour = BigInt(0x42800000L)
|
||
|
||
// Populate TMEM A at offset aBase (all 32 frags)
|
||
val aFrag = packWords(Seq.fill(16)(fp16One), 16)
|
||
val cFrag = packWords(Seq.fill(numLanes)(fp32Zero), 32)
|
||
for (i <- 0 until 32) {
|
||
tmem(aBase / fragBytes + i) = aFrag
|
||
tmem(cBase / fragBytes + i) = cFrag
|
||
}
|
||
|
||
// SMEM B with fp16 2.0
|
||
val fp16Two = BigInt(0x4000)
|
||
val bFrag = packWords(Seq.fill(16)(fp16Two), 16)
|
||
val bMem = mutable.Map[BigInt, BigInt]().withDefaultValue(bFrag)
|
||
for (i <- 0 until 32) bMem(bBase + i * fragBytes) = bFrag
|
||
|
||
c.io.reqB.ready.poke(true.B)
|
||
c.io.writeback.ready.poke(true.B)
|
||
|
||
c.io.initiate.valid.poke(true.B)
|
||
c.io.initiate.bits.op.poke(0.U)
|
||
c.io.initiate.bits.wid.poke(0.U)
|
||
c.io.initiate.bits.rd.poke(0.U)
|
||
c.io.initiate.bits.addressA.poke(aBase.U)
|
||
c.io.initiate.bits.addressB.poke(bBase.U)
|
||
c.io.initiate.bits.addressC.poke(cBase.U)
|
||
c.clock.step()
|
||
c.io.initiate.valid.poke(false.B)
|
||
|
||
var pendingB = Option.empty[(BigInt, BigInt)]
|
||
var sawWriteback = false
|
||
|
||
for (_ <- 0 until 50000 if !sawWriteback) {
|
||
stepTmem(c, tmem)
|
||
pendingB.foreach { case (src, data) =>
|
||
c.io.respB.valid.poke(true.B)
|
||
c.io.respB.bits.source.poke(src.U)
|
||
c.io.respB.bits.data.poke(data.U)
|
||
}
|
||
if (pendingB.isEmpty) c.io.respB.valid.poke(false.B)
|
||
|
||
if (c.io.writeback.valid.peek().litToBoolean) {
|
||
sawWriteback = true
|
||
} else {
|
||
val nextB = if (c.io.reqB.valid.peek().litToBoolean) {
|
||
val addr = c.io.reqB.bits.address.peek().litValue
|
||
val src = c.io.reqB.bits.source.peek().litValue
|
||
Some((src, bMem(addr)))
|
||
} else None
|
||
c.clock.step()
|
||
pendingB = nextB
|
||
}
|
||
}
|
||
|
||
assert(sawWriteback, "BWGMMA did not complete")
|
||
val expectedC = packWords(Seq.fill(numLanes)(fp32SixtyFour), 32)
|
||
for (i <- 0 until 8) {
|
||
val row = cBase / fragBytes + i
|
||
assert(tmem(row) == expectedC,
|
||
s"C frag $i at row $row: got 0x${tmem(row).toString(16)}, expected 0x${expectedC.toString(16)}")
|
||
}
|
||
for (i <- 0 until 8) {
|
||
assert(tmem(aBase / fragBytes + i) == aFrag, s"A frag $i should be unchanged")
|
||
}
|
||
}
|
||
}
|
||
|
||
it should "cp then ld round-trip: data written via cp is readable via ld" in {
|
||
test(new TensorCoreBlackwell(numWarps, numLanes, half = true, numSourceIds = 4)) { c =>
|
||
idleIO(c)
|
||
val tmem = makeTmem()
|
||
val tmemAddr = BigInt(0x100)
|
||
val cpData = packWords(Seq.tabulate(numLanes)(i => BigInt(0xABCD0000L + i)), 32)
|
||
|
||
// Issue cp: global mem -> tmem
|
||
c.io.initiate.valid.poke(true.B)
|
||
c.io.initiate.bits.op.poke(2.U)
|
||
c.io.initiate.bits.addressA.poke(tmemAddr.U)
|
||
c.io.initiate.bits.addressB.poke("h10000000".U)
|
||
c.io.reqA.ready.poke(true.B)
|
||
c.clock.step()
|
||
c.io.initiate.valid.poke(false.B)
|
||
|
||
// cpRead: reqA issued
|
||
c.io.reqA.valid.expect(true.B)
|
||
c.io.reqA.bits.rw.expect(false.B)
|
||
c.clock.step()
|
||
|
||
// cpWrite: respA fires, tmemC written
|
||
c.io.respA.valid.poke(true.B)
|
||
c.io.respA.bits.data.poke(cpData.U)
|
||
c.io.tmemC.wen.expect(true.B)
|
||
c.io.tmemC.waddr.expect((tmemAddr / fragBytes).U)
|
||
c.io.tmemC.wdata.expect(cpData.U)
|
||
stepTmem(c, tmem)
|
||
c.clock.step()
|
||
c.io.respA.valid.poke(false.B)
|
||
|
||
// Now issue ld from same tmem address
|
||
c.io.initiate.valid.poke(true.B)
|
||
c.io.initiate.bits.op.poke(4.U)
|
||
c.io.initiate.bits.rd.poke(2.U)
|
||
c.io.initiate.bits.addressA.poke(tmemAddr.U)
|
||
c.io.writeback.ready.poke(true.B)
|
||
c.clock.step()
|
||
c.io.initiate.valid.poke(false.B)
|
||
|
||
// ldReq: ren asserted, serve from tmem model
|
||
c.io.tmemC.ren.expect(true.B)
|
||
c.io.tmemC.rdata.poke(tmem(tmemAddr / fragBytes).U)
|
||
c.clock.step()
|
||
c.io.tmemC.rdata.poke(tmem(tmemAddr / fragBytes).U)
|
||
c.clock.step()
|
||
|
||
// writeback should carry cpData
|
||
c.io.writeback.valid.expect(true.B)
|
||
for (i <- 0 until numLanes) {
|
||
c.io.writeback.bits.data(i).expect((BigInt(0xABCD0000L) + i).U)
|
||
}
|
||
}
|
||
}
|
||
|
||
it should "st then cb round-trip: data written via st is readable via cb" in {
|
||
test(new TensorCoreBlackwell(numWarps, numLanes, half = true, numSourceIds = 4)) { c =>
|
||
idleIO(c)
|
||
val tmem = makeTmem()
|
||
val tmemAddr = BigInt(0x140)
|
||
val stData = packWords(Seq.tabulate(numLanes)(i => BigInt(0xDEAD0000L + i)), 32)
|
||
|
||
// Issue st: respC -> tmem
|
||
c.io.initiate.valid.poke(true.B)
|
||
c.io.initiate.bits.op.poke(5.U)
|
||
c.io.initiate.bits.rd.poke(4.U)
|
||
c.io.initiate.bits.addressA.poke(tmemAddr.U)
|
||
c.io.respC.poke(stData.U)
|
||
c.clock.step()
|
||
c.io.initiate.valid.poke(false.B)
|
||
|
||
// stReq: reqC valid
|
||
c.io.reqC.valid.expect(true.B)
|
||
c.clock.step()
|
||
|
||
// stWrite: tmemC written
|
||
c.io.tmemC.wen.expect(true.B)
|
||
c.io.tmemC.wdata.expect(stData.U)
|
||
stepTmem(c, tmem)
|
||
c.clock.step()
|
||
|
||
// Issue cb: tmem -> global mem
|
||
c.io.initiate.valid.poke(true.B)
|
||
c.io.initiate.bits.op.poke(6.U)
|
||
c.io.initiate.bits.addressA.poke(tmemAddr.U)
|
||
c.io.initiate.bits.addressB.poke("h20000000".U)
|
||
c.io.reqA.ready.poke(true.B)
|
||
c.io.tmemC.rdata.poke(tmem(tmemAddr / fragBytes).U)
|
||
c.clock.step()
|
||
c.io.initiate.valid.poke(false.B)
|
||
|
||
// cbRead: ren asserted
|
||
c.io.tmemC.ren.expect(true.B)
|
||
c.io.tmemC.rdata.poke(tmem(tmemAddr / fragBytes).U)
|
||
c.clock.step()
|
||
|
||
// cbWrite: reqA write with stData
|
||
c.io.reqA.valid.expect(true.B)
|
||
c.io.reqA.bits.rw.expect(true.B)
|
||
c.io.reqA.bits.address.expect("h20000000".U)
|
||
c.io.reqA.bits.data.expect(stData.U)
|
||
}
|
||
}
|
||
|
||
it should "wait ops are no-ops and do not stall pipeline" in {
|
||
test(new TensorCoreBlackwell(numWarps, numLanes, half = true, numSourceIds = 4)) { c =>
|
||
idleIO(c)
|
||
|
||
// bwgmmaWait: should accept immediately and stay idle
|
||
c.io.initiate.valid.poke(true.B)
|
||
c.io.initiate.bits.op.poke(1.U) // bwgmmaWait
|
||
c.io.initiate.ready.expect(true.B)
|
||
c.clock.step()
|
||
c.io.initiate.valid.poke(false.B)
|
||
c.io.writeback.valid.expect(false.B)
|
||
c.io.reqA.valid.expect(false.B)
|
||
c.io.reqB.valid.expect(false.B)
|
||
|
||
// tcgen05CpWait: same
|
||
c.io.initiate.valid.poke(true.B)
|
||
c.io.initiate.bits.op.poke(3.U) // tcgen05CpWait
|
||
c.io.initiate.ready.expect(true.B)
|
||
c.clock.step()
|
||
c.io.initiate.valid.poke(false.B)
|
||
c.io.writeback.valid.expect(false.B)
|
||
c.io.reqA.valid.expect(false.B)
|
||
}
|
||
}
|
||
|
||
it should "not accept a second tensor op until the first one completes" in {
|
||
test(new TensorCoreBlackwell(numWarps, numLanes, half = true, numSourceIds = 4)) { c =>
|
||
idleIO(c)
|
||
val firstAddr = BigInt(0x180)
|
||
val secondAddr = BigInt(0x1a0)
|
||
val storeData = packWords(Seq.tabulate(numLanes)(i => BigInt(0xCAFE0000L + i)), 32)
|
||
|
||
c.io.initiate.valid.poke(true.B)
|
||
c.io.initiate.bits.op.poke(5.U)
|
||
c.io.initiate.bits.addressA.poke(firstAddr.U)
|
||
c.io.respC.poke(storeData.U)
|
||
c.io.initiate.ready.expect(true.B)
|
||
c.clock.step()
|
||
|
||
c.io.initiate.bits.op.poke(4.U)
|
||
c.io.initiate.bits.addressA.poke(secondAddr.U)
|
||
c.io.initiate.bits.rd.poke(2.U)
|
||
c.io.initiate.ready.expect(false.B)
|
||
c.clock.step()
|
||
c.io.initiate.ready.expect(false.B)
|
||
|
||
c.io.tmemC.wen.expect(true.B)
|
||
c.clock.step()
|
||
c.io.initiate.ready.expect(true.B)
|
||
}
|
||
}
|
||
|
||
it should "multi-warp TMEM isolation: warp 0 and warp 3 do not alias" in {
|
||
test(new TensorCoreBlackwell(numWarps, numLanes, half = true, numSourceIds = 4)) { c =>
|
||
idleIO(c)
|
||
val tmem = makeTmem()
|
||
|
||
// warp 0: tmem_slot_base(0) = 0, tmem_a_base = 0
|
||
val warp0TmemA = BigInt(0x000)
|
||
val warp0Data = packWords(Seq.fill(numLanes)(BigInt(0xAAAAAAAAL)), 32)
|
||
|
||
// warp 3: tmem_slot_base(3) = 3*2048 = 6144 = 0x1800, tmem_a_base = 0x1800
|
||
val warp3TmemA = BigInt(0x1800)
|
||
val warp3Data = packWords(Seq.fill(numLanes)(BigInt(0xBBBBBBBBL)), 32)
|
||
|
||
// Write warp 0 data via st
|
||
c.io.initiate.valid.poke(true.B)
|
||
c.io.initiate.bits.op.poke(5.U)
|
||
c.io.initiate.bits.wid.poke(0.U)
|
||
c.io.initiate.bits.addressA.poke(warp0TmemA.U)
|
||
c.io.respC.poke(warp0Data.U)
|
||
c.clock.step()
|
||
c.io.initiate.valid.poke(false.B)
|
||
c.io.reqC.valid.expect(true.B)
|
||
c.clock.step()
|
||
c.io.tmemC.wen.expect(true.B)
|
||
c.io.tmemC.waddr.expect((warp0TmemA / fragBytes).U)
|
||
stepTmem(c, tmem)
|
||
c.clock.step()
|
||
|
||
// Write warp 3 data via st
|
||
c.io.initiate.valid.poke(true.B)
|
||
c.io.initiate.bits.op.poke(5.U)
|
||
c.io.initiate.bits.wid.poke(3.U)
|
||
c.io.initiate.bits.addressA.poke(warp3TmemA.U)
|
||
c.io.respC.poke(warp3Data.U)
|
||
c.clock.step()
|
||
c.io.initiate.valid.poke(false.B)
|
||
c.io.reqC.valid.expect(true.B)
|
||
c.clock.step()
|
||
c.io.tmemC.wen.expect(true.B)
|
||
c.io.tmemC.waddr.expect((warp3TmemA / fragBytes).U)
|
||
stepTmem(c, tmem)
|
||
c.clock.step()
|
||
|
||
// Verify no aliasing: warp 0 row != warp 3 row
|
||
assert(warp0TmemA / fragBytes != warp3TmemA / fragBytes)
|
||
assert(tmem(warp0TmemA / fragBytes) == warp0Data)
|
||
assert(tmem(warp3TmemA / fragBytes) == warp3Data)
|
||
}
|
||
}
|
||
}
|