Files
radiance/src/test/scala/radiance/TensorCoreBlackwellExtendedTest.scala
abnerhexu 5112f3665a Add Blackwell tensor core implementation and tests
- Implement TensorCoreBlackwell.scala with BWGMMA and TCGEN05 instructions
- Update TensorDPU, RadianceTile, and VortexCore for Blackwell integration
- Add TensorCoreBlackwellExtendedTest for comprehensive testing
- Update vortex submodule with Blackwell ISA support
2026-05-06 14:51:11 +08:00

339 lines
11 KiB
Scala
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package radiance.core
import chisel3._
import chiseltest._
import chiseltest.simulator.VerilatorBackendAnnotation
import org.scalatest.flatspec.AnyFlatSpec
import scala.collection.mutable
class TensorCoreBlackwellExtendedTest extends AnyFlatSpec with ChiselScalatestTester {
behavior of "TensorCoreBlackwell Extended Tests"
private val numWarps = 4
private val numLanes = 8
private val fragBytes = 32
private def idleIO(c: TensorCoreBlackwell): Unit = {
c.io.initiate.valid.poke(false.B)
c.io.respA.valid.poke(false.B)
c.io.respB.valid.poke(false.B)
c.io.respA.bits.source.poke(0.U)
c.io.respB.bits.source.poke(0.U)
c.io.respA.bits.data.poke(0.U)
c.io.respB.bits.data.poke(0.U)
c.io.reqA.ready.poke(false.B)
c.io.reqB.ready.poke(false.B)
c.io.respC.poke(0.U)
c.io.writeback.ready.poke(false.B)
c.io.tmemC.rdata.poke(0.U)
}
private def packWords(words: Seq[BigInt], width: Int): BigInt = {
val mask = (BigInt(1) << width) - 1
words.zipWithIndex.foldLeft(BigInt(0)) {
case (acc, (word, i)) => acc | ((word & mask) << (i * width))
}
}
private def makeTmem() = mutable.Map[BigInt, BigInt]().withDefaultValue(BigInt(0))
private def stepTmem(c: TensorCoreBlackwell, tmem: mutable.Map[BigInt, BigInt]): Unit = {
if (c.io.tmemC.ren.peek().litToBoolean) {
val addr = c.io.tmemC.raddr.peek().litValue
c.io.tmemC.rdata.poke(tmem(addr).U)
}
if (c.io.tmemC.wen.peek().litToBoolean) {
val addr = c.io.tmemC.waddr.peek().litValue
tmem(addr) = c.io.tmemC.wdata.peek().litValue
}
}
it should "verify bwgmma address offset with non-zero base addresses" in {
test(new TensorCoreBlackwell(numWarps, numLanes, half = true, numSourceIds = 4))
.withAnnotations(Seq(VerilatorBackendAnnotation)) { c =>
idleIO(c)
val tmem = makeTmem()
// Use non-zero base addresses to verify offset calculation
val aBase = BigInt(0x200) // row 16, A tile rows 16~47
val cBase = BigInt(0x600) // row 48, C tile rows 48~79 (no overlap with A)
val bBase = BigInt(0x800)
val fp16One = BigInt(0x3c00)
val fp32Zero = BigInt(0)
// 4 sets × 8 dot products × (1.0 × 2.0) = 64.0f
val fp32SixtyFour = BigInt(0x42800000L)
// Populate TMEM A at offset aBase (all 32 frags)
val aFrag = packWords(Seq.fill(16)(fp16One), 16)
val cFrag = packWords(Seq.fill(numLanes)(fp32Zero), 32)
for (i <- 0 until 32) {
tmem(aBase / fragBytes + i) = aFrag
tmem(cBase / fragBytes + i) = cFrag
}
// SMEM B with fp16 2.0
val fp16Two = BigInt(0x4000)
val bFrag = packWords(Seq.fill(16)(fp16Two), 16)
val bMem = mutable.Map[BigInt, BigInt]().withDefaultValue(bFrag)
for (i <- 0 until 32) bMem(bBase + i * fragBytes) = bFrag
c.io.reqB.ready.poke(true.B)
c.io.writeback.ready.poke(true.B)
c.io.initiate.valid.poke(true.B)
c.io.initiate.bits.op.poke(0.U)
c.io.initiate.bits.wid.poke(0.U)
c.io.initiate.bits.rd.poke(0.U)
c.io.initiate.bits.addressA.poke(aBase.U)
c.io.initiate.bits.addressB.poke(bBase.U)
c.io.initiate.bits.addressC.poke(cBase.U)
c.clock.step()
c.io.initiate.valid.poke(false.B)
var pendingB = Option.empty[(BigInt, BigInt)]
var sawWriteback = false
for (_ <- 0 until 50000 if !sawWriteback) {
stepTmem(c, tmem)
pendingB.foreach { case (src, data) =>
c.io.respB.valid.poke(true.B)
c.io.respB.bits.source.poke(src.U)
c.io.respB.bits.data.poke(data.U)
}
if (pendingB.isEmpty) c.io.respB.valid.poke(false.B)
if (c.io.writeback.valid.peek().litToBoolean) {
sawWriteback = true
} else {
val nextB = if (c.io.reqB.valid.peek().litToBoolean) {
val addr = c.io.reqB.bits.address.peek().litValue
val src = c.io.reqB.bits.source.peek().litValue
Some((src, bMem(addr)))
} else None
c.clock.step()
pendingB = nextB
}
}
assert(sawWriteback, "BWGMMA did not complete")
val expectedC = packWords(Seq.fill(numLanes)(fp32SixtyFour), 32)
for (i <- 0 until 8) {
val row = cBase / fragBytes + i
assert(tmem(row) == expectedC,
s"C frag $i at row $row: got 0x${tmem(row).toString(16)}, expected 0x${expectedC.toString(16)}")
}
for (i <- 0 until 8) {
assert(tmem(aBase / fragBytes + i) == aFrag, s"A frag $i should be unchanged")
}
}
}
it should "cp then ld round-trip: data written via cp is readable via ld" in {
test(new TensorCoreBlackwell(numWarps, numLanes, half = true, numSourceIds = 4)) { c =>
idleIO(c)
val tmem = makeTmem()
val tmemAddr = BigInt(0x100)
val cpData = packWords(Seq.tabulate(numLanes)(i => BigInt(0xABCD0000L + i)), 32)
// Issue cp: global mem -> tmem
c.io.initiate.valid.poke(true.B)
c.io.initiate.bits.op.poke(2.U)
c.io.initiate.bits.addressA.poke(tmemAddr.U)
c.io.initiate.bits.addressB.poke("h10000000".U)
c.io.reqA.ready.poke(true.B)
c.clock.step()
c.io.initiate.valid.poke(false.B)
// cpRead: reqA issued
c.io.reqA.valid.expect(true.B)
c.io.reqA.bits.rw.expect(false.B)
c.clock.step()
// cpWrite: respA fires, tmemC written
c.io.respA.valid.poke(true.B)
c.io.respA.bits.data.poke(cpData.U)
c.io.tmemC.wen.expect(true.B)
c.io.tmemC.waddr.expect((tmemAddr / fragBytes).U)
c.io.tmemC.wdata.expect(cpData.U)
stepTmem(c, tmem)
c.clock.step()
c.io.respA.valid.poke(false.B)
// Now issue ld from same tmem address
c.io.initiate.valid.poke(true.B)
c.io.initiate.bits.op.poke(4.U)
c.io.initiate.bits.rd.poke(2.U)
c.io.initiate.bits.addressA.poke(tmemAddr.U)
c.io.writeback.ready.poke(true.B)
c.clock.step()
c.io.initiate.valid.poke(false.B)
// ldReq: ren asserted, serve from tmem model
c.io.tmemC.ren.expect(true.B)
c.io.tmemC.rdata.poke(tmem(tmemAddr / fragBytes).U)
c.clock.step()
c.io.tmemC.rdata.poke(tmem(tmemAddr / fragBytes).U)
c.clock.step()
// writeback should carry cpData
c.io.writeback.valid.expect(true.B)
for (i <- 0 until numLanes) {
c.io.writeback.bits.data(i).expect((BigInt(0xABCD0000L) + i).U)
}
}
}
it should "st then cb round-trip: data written via st is readable via cb" in {
test(new TensorCoreBlackwell(numWarps, numLanes, half = true, numSourceIds = 4)) { c =>
idleIO(c)
val tmem = makeTmem()
val tmemAddr = BigInt(0x140)
val stData = packWords(Seq.tabulate(numLanes)(i => BigInt(0xDEAD0000L + i)), 32)
// Issue st: respC -> tmem
c.io.initiate.valid.poke(true.B)
c.io.initiate.bits.op.poke(5.U)
c.io.initiate.bits.rd.poke(4.U)
c.io.initiate.bits.addressA.poke(tmemAddr.U)
c.io.respC.poke(stData.U)
c.clock.step()
c.io.initiate.valid.poke(false.B)
// stReq: reqC valid
c.io.reqC.valid.expect(true.B)
c.clock.step()
// stWrite: tmemC written
c.io.tmemC.wen.expect(true.B)
c.io.tmemC.wdata.expect(stData.U)
stepTmem(c, tmem)
c.clock.step()
// Issue cb: tmem -> global mem
c.io.initiate.valid.poke(true.B)
c.io.initiate.bits.op.poke(6.U)
c.io.initiate.bits.addressA.poke(tmemAddr.U)
c.io.initiate.bits.addressB.poke("h20000000".U)
c.io.reqA.ready.poke(true.B)
c.io.tmemC.rdata.poke(tmem(tmemAddr / fragBytes).U)
c.clock.step()
c.io.initiate.valid.poke(false.B)
// cbRead: ren asserted
c.io.tmemC.ren.expect(true.B)
c.io.tmemC.rdata.poke(tmem(tmemAddr / fragBytes).U)
c.clock.step()
// cbWrite: reqA write with stData
c.io.reqA.valid.expect(true.B)
c.io.reqA.bits.rw.expect(true.B)
c.io.reqA.bits.address.expect("h20000000".U)
c.io.reqA.bits.data.expect(stData.U)
}
}
it should "wait ops are no-ops and do not stall pipeline" in {
test(new TensorCoreBlackwell(numWarps, numLanes, half = true, numSourceIds = 4)) { c =>
idleIO(c)
// bwgmmaWait: should accept immediately and stay idle
c.io.initiate.valid.poke(true.B)
c.io.initiate.bits.op.poke(1.U) // bwgmmaWait
c.io.initiate.ready.expect(true.B)
c.clock.step()
c.io.initiate.valid.poke(false.B)
c.io.writeback.valid.expect(false.B)
c.io.reqA.valid.expect(false.B)
c.io.reqB.valid.expect(false.B)
// tcgen05CpWait: same
c.io.initiate.valid.poke(true.B)
c.io.initiate.bits.op.poke(3.U) // tcgen05CpWait
c.io.initiate.ready.expect(true.B)
c.clock.step()
c.io.initiate.valid.poke(false.B)
c.io.writeback.valid.expect(false.B)
c.io.reqA.valid.expect(false.B)
}
}
it should "not accept a second tensor op until the first one completes" in {
test(new TensorCoreBlackwell(numWarps, numLanes, half = true, numSourceIds = 4)) { c =>
idleIO(c)
val firstAddr = BigInt(0x180)
val secondAddr = BigInt(0x1a0)
val storeData = packWords(Seq.tabulate(numLanes)(i => BigInt(0xCAFE0000L + i)), 32)
c.io.initiate.valid.poke(true.B)
c.io.initiate.bits.op.poke(5.U)
c.io.initiate.bits.addressA.poke(firstAddr.U)
c.io.respC.poke(storeData.U)
c.io.initiate.ready.expect(true.B)
c.clock.step()
c.io.initiate.bits.op.poke(4.U)
c.io.initiate.bits.addressA.poke(secondAddr.U)
c.io.initiate.bits.rd.poke(2.U)
c.io.initiate.ready.expect(false.B)
c.clock.step()
c.io.initiate.ready.expect(false.B)
c.io.tmemC.wen.expect(true.B)
c.clock.step()
c.io.initiate.ready.expect(true.B)
}
}
it should "multi-warp TMEM isolation: warp 0 and warp 3 do not alias" in {
test(new TensorCoreBlackwell(numWarps, numLanes, half = true, numSourceIds = 4)) { c =>
idleIO(c)
val tmem = makeTmem()
// warp 0: tmem_slot_base(0) = 0, tmem_a_base = 0
val warp0TmemA = BigInt(0x000)
val warp0Data = packWords(Seq.fill(numLanes)(BigInt(0xAAAAAAAAL)), 32)
// warp 3: tmem_slot_base(3) = 3*2048 = 6144 = 0x1800, tmem_a_base = 0x1800
val warp3TmemA = BigInt(0x1800)
val warp3Data = packWords(Seq.fill(numLanes)(BigInt(0xBBBBBBBBL)), 32)
// Write warp 0 data via st
c.io.initiate.valid.poke(true.B)
c.io.initiate.bits.op.poke(5.U)
c.io.initiate.bits.wid.poke(0.U)
c.io.initiate.bits.addressA.poke(warp0TmemA.U)
c.io.respC.poke(warp0Data.U)
c.clock.step()
c.io.initiate.valid.poke(false.B)
c.io.reqC.valid.expect(true.B)
c.clock.step()
c.io.tmemC.wen.expect(true.B)
c.io.tmemC.waddr.expect((warp0TmemA / fragBytes).U)
stepTmem(c, tmem)
c.clock.step()
// Write warp 3 data via st
c.io.initiate.valid.poke(true.B)
c.io.initiate.bits.op.poke(5.U)
c.io.initiate.bits.wid.poke(3.U)
c.io.initiate.bits.addressA.poke(warp3TmemA.U)
c.io.respC.poke(warp3Data.U)
c.clock.step()
c.io.initiate.valid.poke(false.B)
c.io.reqC.valid.expect(true.B)
c.clock.step()
c.io.tmemC.wen.expect(true.B)
c.io.tmemC.waddr.expect((warp3TmemA / fragBytes).U)
stepTmem(c, tmem)
c.clock.step()
// Verify no aliasing: warp 0 row != warp 3 row
assert(warp0TmemA / fragBytes != warp3TmemA / fragBytes)
assert(tmem(warp0TmemA / fragBytes) == warp0Data)
assert(tmem(warp3TmemA / fragBytes) == warp3Data)
}
}
}