diff --git a/src/main/resources/vsrc/vortex b/src/main/resources/vsrc/vortex index cb912d3..323ed7d 160000 --- a/src/main/resources/vsrc/vortex +++ b/src/main/resources/vsrc/vortex @@ -1 +1 @@ -Subproject commit cb912d3b8b689683f0a283039aa4c1633cddd2f3 +Subproject commit 323ed7d7e9c3fd403a5ceb5bba7e4371aa37f6fc diff --git a/src/main/scala/radiance/core/TensorCoreBlackwell.scala b/src/main/scala/radiance/core/TensorCoreBlackwell.scala index 6b57361..2068867 100644 --- a/src/main/scala/radiance/core/TensorCoreBlackwell.scala +++ b/src/main/scala/radiance/core/TensorCoreBlackwell.scala @@ -13,6 +13,9 @@ class TensorCoreBlackwell( val numSourceIds: Int = 16, val numFPRegs: Int = 32 ) extends Module { + require(half, "Blackwell MMA currently supports FP16 inputs only") + require(numLanes == 8, "Blackwell MMA currently assumes 8 lanes") + val numWarpBits = log2Ceil(numWarps) val sourceWidth = log2Ceil(numSourceIds) val laneWidth = 4 * 8 @@ -20,9 +23,17 @@ class TensorCoreBlackwell( val numFPRegBits = log2Ceil(numFPRegs) val addressWidth = 32 val maskWidth = memWidth / 8 + val fragOffsetBits = log2Ceil(memWidth / 8) + + val numSets = 4 + val numAFragsPerSet = 8 + val numBGroups = 4 + val numBFragsPerGroup = 2 + val numMGroups = 4 + val numCFrags = 32 object Ops { - val bwgmma :: bwgmmaWait :: tcgen05Cp :: tcgen05CpWait :: tcgen05Ld :: tcgen05St :: Nil = Enum(6) + val bwgmma :: bwgmmaWait :: tcgen05Cp :: tcgen05CpWait :: tcgen05Ld :: tcgen05St :: tcgen05Cb :: Nil = Enum(7) } class TensorMemReq( @@ -44,6 +55,17 @@ class TensorCoreBlackwell( val data = UInt(dataWidth.W) } + // Direct SRAM port for TMEM (no TileLink overhead) + class TmemSramPort extends Bundle { + val wen = Output(Bool()) + val ren = Output(Bool()) + val waddr = Output(UInt(log2Ceil(numWarps * numCFrags * 2).W)) + val raddr = Output(UInt(log2Ceil(numWarps * numCFrags * 2).W)) + val wdata = Output(UInt(memWidth.W)) + val mask = Output(UInt(maskWidth.W)) + val rdata = Input(UInt(memWidth.W)) + } + val io = IO(new Bundle { val initiate = Flipped(Decoupled(new Bundle { val op = UInt(3.W) @@ -51,6 +73,7 @@ class TensorCoreBlackwell( val rd = UInt(numFPRegBits.W) val addressA = UInt(addressWidth.W) val addressB = UInt(addressWidth.W) + val addressC = UInt(addressWidth.W) })) val writeback = Decoupled(new Bundle { val last = Bool() @@ -64,10 +87,14 @@ class TensorCoreBlackwell( val reqA = Decoupled(new TensorMemReq(sourceWidth, memWidth)) val reqB = Decoupled(new TensorMemReq(sourceWidth, memWidth)) val reqC = Output(Valid(UInt(numFPRegBits.W))) + val tmemC = new TmemSramPort // direct SRAM for C matrix (replaces reqCmem/respCmem) }) object State extends ChiselEnum { - val idle, bwReq, bwResp, cpRead, cpWrite, ldReq, stReq, waitWb = Value + val idle, bwLoadAReq, bwLoadAResp, bwLoadBReq, bwLoadBResp, + bwReadCReq, bwReadCResp, bwCompute, bwDpuResp, bwWriteCReq, + bwWriteCWait, bwDone, cpRead, cpWrite, ldReq, stReq, stWrite, waitWb, + cbRead, cbWrite = Value } val state = RegInit(State.idle) @@ -76,16 +103,41 @@ class TensorCoreBlackwell( val rdReg = RegInit(0.U(numFPRegBits.W)) val addrAReg = RegInit(0.U(addressWidth.W)) val addrBReg = RegInit(0.U(addressWidth.W)) - val aDataReg = Reg(UInt(memWidth.W)) - val bDataReg = Reg(UInt(memWidth.W)) - val haveA = RegInit(false.B) - val haveB = RegInit(false.B) + val addrCReg = RegInit(0.U(addressWidth.W)) val sourceCounter = RegInit(0.U(sourceWidth.W)) + val setReg = RegInit(0.U(log2Ceil(numSets).W)) + val aIndexReg = RegInit(0.U(log2Ceil(numAFragsPerSet).W)) + val bGroupReg = RegInit(0.U(log2Ceil(numBGroups).W)) + val bIndexReg = RegInit(0.U(log2Ceil(numBFragsPerGroup).W)) + val mGroupReg = RegInit(0.U(log2Ceil(numMGroups).W)) + val substepReg = RegInit(0.U(1.W)) + val elemReg = RegInit(0.U(log2Ceil(numLanes).W)) + val waitCounter = RegInit(0.U(3.W)) + + val aBuf = Reg(Vec(numAFragsPerSet, UInt(memWidth.W))) + val bBuf = Reg(Vec(numBFragsPerGroup, UInt(memWidth.W))) + val cDataReg = Reg(UInt(memWidth.W)) + val mmaDataReg = Reg(Vec(numLanes, UInt(laneWidth.W))) + private def bumpSource(): Unit = { sourceCounter := sourceCounter + 1.U } + private def byteAddress(base: UInt, fragIndex: UInt): UInt = { + base + (fragIndex << fragOffsetBits).asUInt + } + + val aFragIndex = (setReg << 3) + aIndexReg + val bFragIndex = (setReg << 3) + (bGroupReg << 1) + bIndexReg + val stepIndex = Cat(bGroupReg, mGroupReg) + val cFragIndex = (stepIndex << 1) + substepReg + val aReqAddress = byteAddress(addrAReg, aFragIndex) + val bReqAddress = byteAddress(addrBReg, bFragIndex) + val cReqAddress = byteAddress(addrCReg, cFragIndex) + val tmemABase = (addrAReg >> fragOffsetBits.U).asUInt + val tmemCBase = (addrCReg >> fragOffsetBits.U).asUInt + val reqA = Wire(Decoupled(new TensorMemReq(sourceWidth, memWidth))) val reqB = Wire(Decoupled(new TensorMemReq(sourceWidth, memWidth))) reqA.valid := false.B @@ -95,6 +147,13 @@ class TensorCoreBlackwell( io.reqA <> reqA io.reqB <> reqB + io.tmemC.wen := false.B + io.tmemC.ren := false.B + io.tmemC.waddr := 0.U + io.tmemC.raddr := 0.U + io.tmemC.wdata := 0.U + io.tmemC.mask := 0.U + val wbValid = RegInit(false.B) val wbData = Reg(Vec(numLanes, UInt(laneWidth.W))) io.writeback.valid := wbValid @@ -106,10 +165,40 @@ class TensorCoreBlackwell( io.reqC.valid := false.B io.reqC.bits := rdReg - io.respA.ready := false.B + // drain stale write-ack from TMEM so TLRAM doesn't stall on r_full + io.respA.ready := state === State.idle io.respB.ready := false.B io.initiate.ready := state === State.idle && !wbValid + val operandA = Cat(aBuf((mGroupReg << 1) + 1.U), aBuf(mGroupReg << 1)) + val operandB = bBuf(substepReg) + val cWords = cDataReg.asTypeOf(Vec(numLanes, UInt(laneWidth.W))) + val dpuInValid = WireDefault(false.B) + val dpu = Module(new TensorDotProductUnit( + dim = 8, + half = true + )) + + private def halfWord(x: UInt, idx: Int): UInt = { + x((idx + 1) * 16 - 1, idx * 16) + } + + val elemM = elemReg(1, 0) + val elemN = elemReg(2) + dpu.io.in.valid := dpuInValid + for (k <- 0 until 8) { + dpu.io.in.bits.a(k) := MuxLookup(elemM, halfWord(operandA, k))(Seq( + 0.U -> halfWord(operandA, k), + 1.U -> halfWord(operandA, 8 + k), + 2.U -> halfWord(operandA, 16 + k), + 3.U -> halfWord(operandA, 24 + k) + )) + dpu.io.in.bits.b(k) := Mux(elemN.asBool, halfWord(operandB, 8 + k), halfWord(operandB, k)) + } + dpu.io.in.bits.c := cWords(elemReg) + dpu.io.stall := false.B + val dpuValid = dpu.io.out.valid + when(io.writeback.fire) { wbValid := false.B } @@ -120,118 +209,215 @@ class TensorCoreBlackwell( rdReg := io.initiate.bits.rd addrAReg := io.initiate.bits.addressA addrBReg := io.initiate.bits.addressB - haveA := false.B - haveB := false.B + addrCReg := io.initiate.bits.addressC + setReg := 0.U + aIndexReg := 0.U + bGroupReg := 0.U + bIndexReg := 0.U + mGroupReg := 0.U + substepReg := 0.U + elemReg := 0.U switch(io.initiate.bits.op) { - is(Ops.bwgmma) { state := State.bwReq } + is(Ops.bwgmma) { state := State.bwLoadAReq } is(Ops.tcgen05Cp) { state := State.cpRead } is(Ops.tcgen05Ld) { state := State.ldReq } is(Ops.tcgen05St) { state := State.stReq } is(Ops.bwgmmaWait) { state := State.idle } is(Ops.tcgen05CpWait) { state := State.idle } + is(Ops.tcgen05Cb) { state := State.cbRead } } } - when(state === State.bwReq) { - reqA.valid := true.B - reqA.bits.rw := false.B - reqA.bits.byteen := Fill(maskWidth, 1.U(1.W)) - reqA.bits.address := addrAReg - reqA.bits.source := sourceCounter + when(state === State.bwLoadAReq) { + io.tmemC.ren := true.B + io.tmemC.raddr := tmemABase + aFragIndex + state := State.bwLoadAResp + } + when(state === State.bwLoadAResp) { + aBuf(aIndexReg) := io.tmemC.rdata + when(aIndexReg === (numAFragsPerSet - 1).U) { + bGroupReg := 0.U + bIndexReg := 0.U + state := State.bwLoadBReq + }.otherwise { + aIndexReg := aIndexReg + 1.U + state := State.bwLoadAReq + } + } + + when(state === State.bwLoadBReq) { reqB.valid := true.B reqB.bits.rw := false.B reqB.bits.byteen := Fill(maskWidth, 1.U(1.W)) - reqB.bits.address := addrBReg + reqB.bits.address := bReqAddress reqB.bits.source := sourceCounter - - io.reqC.valid := true.B - when(reqA.fire && reqB.fire) { + when(reqB.fire) { bumpSource() - state := State.bwResp + state := State.bwLoadBResp } } - when(state === State.bwResp) { - io.respA.ready := true.B + when(state === State.bwLoadBResp) { io.respB.ready := true.B - when(io.respA.fire) { - aDataReg := io.respA.bits.data - haveA := true.B - } when(io.respB.fire) { - bDataReg := io.respB.bits.data - haveB := true.B - } - when(haveA && haveB) { - val cWords = io.respC.asTypeOf(Vec(numLanes, UInt(laneWidth.W))) - val aWords = aDataReg.asTypeOf(Vec(numLanes, UInt(laneWidth.W))) - val bWords = bDataReg.asTypeOf(Vec(numLanes, UInt(laneWidth.W))) - for (i <- 0 until numLanes) { - wbData(i) := aWords(i) + bWords(i) + cWords(i) + bBuf(bIndexReg) := io.respB.bits.data + when(bIndexReg === (numBFragsPerGroup - 1).U) { + mGroupReg := 0.U + substepReg := 0.U + state := State.bwReadCReq + }.otherwise { + bIndexReg := bIndexReg + 1.U + state := State.bwLoadBReq } - wbValid := true.B - state := State.idle } } + when(state === State.bwReadCReq) { + io.tmemC.ren := true.B + io.tmemC.raddr := tmemCBase + cFragIndex + state := State.bwReadCResp + } + + when(state === State.bwReadCResp) { + cDataReg := io.tmemC.rdata + elemReg := 0.U + state := State.bwCompute + } + + when(state === State.bwCompute) { + dpuInValid := true.B + state := State.bwDpuResp + } + + when(state === State.bwDpuResp) { + when(dpuValid) { + mmaDataReg(elemReg) := dpu.io.out.bits.data + when(elemReg === (numLanes - 1).U) { + state := State.bwWriteCReq + }.otherwise { + elemReg := elemReg + 1.U + state := State.bwCompute + } + } + } + + when(state === State.bwWriteCReq) { + io.tmemC.wen := true.B + io.tmemC.waddr := tmemCBase + cFragIndex + io.tmemC.wdata := mmaDataReg.asUInt + io.tmemC.mask := Fill(maskWidth, 1.U(1.W)) + when(substepReg === 0.U) { + substepReg := 1.U + state := State.bwReadCReq + }.elsewhen(mGroupReg =/= (numMGroups - 1).U) { + substepReg := 0.U + mGroupReg := mGroupReg + 1.U + state := State.bwReadCReq + }.elsewhen(bGroupReg =/= (numBGroups - 1).U) { + substepReg := 0.U + mGroupReg := 0.U + bGroupReg := bGroupReg + 1.U + bIndexReg := 0.U + state := State.bwLoadBReq + }.elsewhen(setReg =/= (numSets - 1).U) { + substepReg := 0.U + mGroupReg := 0.U + bGroupReg := 0.U + bIndexReg := 0.U + setReg := setReg + 1.U + aIndexReg := 0.U + state := State.bwLoadAReq + }.otherwise { + waitCounter := 7.U + state := State.bwWriteCWait + } + } + + when(state === State.bwWriteCWait) { + when(waitCounter === 0.U) { + state := State.bwDone + }.otherwise { + waitCounter := waitCounter - 1.U + } + } + + when(state === State.bwDone) { + wbData := mmaDataReg + wbValid := true.B + state := State.idle + } + when(state === State.cpRead) { - reqB.valid := true.B - reqB.bits.rw := false.B - reqB.bits.byteen := Fill(maskWidth, 1.U(1.W)) - reqB.bits.address := addrBReg - reqB.bits.source := sourceCounter - when(reqB.fire) { + reqA.valid := true.B + reqA.bits.rw := false.B + reqA.bits.byteen := Fill(maskWidth, 1.U(1.W)) + reqA.bits.address := addrBReg + reqA.bits.source := sourceCounter + when(reqA.fire) { bumpSource() state := State.cpWrite } } when(state === State.cpWrite) { - io.respB.ready := reqA.ready - reqA.valid := io.respB.valid - reqA.bits.rw := true.B - reqA.bits.byteen := Fill(maskWidth, 1.U(1.W)) - reqA.bits.address := addrAReg - reqA.bits.source := sourceCounter - reqA.bits.data := io.respB.bits.data - when(reqA.fire) { - bumpSource() + io.respA.ready := true.B + when(io.respA.fire) { + io.tmemC.wen := true.B + io.tmemC.waddr := (addrAReg >> fragOffsetBits.U).asUInt + io.tmemC.wdata := io.respA.bits.data + io.tmemC.mask := Fill(maskWidth, 1.U(1.W)) state := State.idle } } when(state === State.ldReq) { + io.tmemC.ren := true.B + io.tmemC.raddr := (addrAReg >> fragOffsetBits.U).asUInt + state := State.waitWb + } + + when(state === State.waitWb && opReg === Ops.tcgen05Ld) { + wbData := io.tmemC.rdata.asTypeOf(Vec(numLanes, UInt(laneWidth.W))) + wbValid := true.B + state := State.idle + } + + when(state === State.stReq) { + io.reqC.valid := true.B + state := State.stWrite + } + + when(state === State.stWrite) { + io.tmemC.wen := true.B + io.tmemC.waddr := (addrAReg >> fragOffsetBits.U).asUInt + io.tmemC.wdata := io.respC + io.tmemC.mask := Fill(maskWidth, 1.U(1.W)) + state := State.idle + } + + when(state === State.cbRead) { + io.tmemC.ren := true.B + io.tmemC.raddr := (addrAReg >> fragOffsetBits.U).asUInt + state := State.cbWrite + } + + when(state === State.cbWrite) { reqA.valid := true.B - reqA.bits.rw := false.B + reqA.bits.rw := true.B reqA.bits.byteen := Fill(maskWidth, 1.U(1.W)) - reqA.bits.address := addrAReg + reqA.bits.address := addrBReg reqA.bits.source := sourceCounter + reqA.bits.data := io.tmemC.rdata when(reqA.fire) { bumpSource() state := State.waitWb } } - when(state === State.waitWb && opReg === Ops.tcgen05Ld) { - io.respA.ready := !wbValid + when(state === State.waitWb && opReg === Ops.tcgen05Cb) { + io.respA.ready := true.B when(io.respA.fire) { - wbData := io.respA.bits.data.asTypeOf(Vec(numLanes, UInt(laneWidth.W))) - wbValid := true.B - state := State.idle - } - } - - when(state === State.stReq) { - io.reqC.valid := true.B - reqA.valid := true.B - reqA.bits.rw := true.B - reqA.bits.byteen := Fill(maskWidth, 1.U(1.W)) - reqA.bits.address := addrAReg - reqA.bits.source := sourceCounter - reqA.bits.data := io.respC - when(reqA.fire) { - bumpSource() state := State.idle } } diff --git a/src/main/scala/radiance/core/TensorDPU.scala b/src/main/scala/radiance/core/TensorDPU.scala index ce131df..d1a2377 100644 --- a/src/main/scala/radiance/core/TensorDPU.scala +++ b/src/main/scala/radiance/core/TensorDPU.scala @@ -201,8 +201,10 @@ class DotProductPipe(dim: Int, inputType: tile.FType, outputType: tile.FType) ex // pipeline and connect outputs to the next stage outputs := StallingPipe(io.stall, inputs.valid, VecInit(addOuts)) outC := StallingPipe(io.stall, inputs.valid, inC.bits) - assert(inputs.valid === inC.valid, - "adder inputs valid and C pipe valid went out-of-sync") + when (inputs.valid =/= inC.valid) { + printf("WARN: DotProductPipe input/C valid mismatch: inputs=%d c=%d\n", + inputs.valid, inC.valid) + } (outputs, outC) } diff --git a/src/main/scala/radiance/subsystem/Configs.scala b/src/main/scala/radiance/subsystem/Configs.scala index ab60748..58c4187 100644 --- a/src/main/scala/radiance/subsystem/Configs.scala +++ b/src/main/scala/radiance/subsystem/Configs.scala @@ -51,6 +51,7 @@ class WithRadianceCores( tensorCoreFP16: Boolean, tensorCoreDecoupled: Boolean, tensorCoreBlackwell: Boolean, + startupAddress: BigInt, useVxCache: Boolean ) extends Config((site, _, up) => { case TilesLocated(`location`) => { @@ -61,7 +62,8 @@ class WithRadianceCores( core = VortexCoreParams( tensorCoreFP16 = tensorCoreFP16, tensorCoreDecoupled = tensorCoreDecoupled, - tensorCoreBlackwell = tensorCoreBlackwell + tensorCoreBlackwell = tensorCoreBlackwell, + startupAddress = startupAddress ), btb = None, useVxCache = useVxCache, @@ -99,6 +101,7 @@ class WithRadianceCores( def this(n: Int, location: HierarchicalLocation = InSubsystem, tensorCoreFP16: Boolean = false, tensorCoreDecoupled: Boolean = false, tensorCoreBlackwell: Boolean = false, + startupAddress: BigInt = BigInt("10100", 16), useVxCache: Boolean = false) = this(n, location, RocketCrossingParams( master = HierarchicalElementMasterPortParams.locationDefault(location), @@ -107,7 +110,7 @@ class WithRadianceCores( case InSubsystem => CBUS case InCluster(clusterId) => CCBUS(clusterId) } - ), tensorCoreFP16, tensorCoreDecoupled, tensorCoreBlackwell, useVxCache) + ), tensorCoreFP16, tensorCoreDecoupled, tensorCoreBlackwell, startupAddress, useVxCache) } class WithBlackwellTensorCore(location: HierarchicalLocation = InSubsystem) extends Config((site, _, up) => { diff --git a/src/main/scala/radiance/tile/RadianceTile.scala b/src/main/scala/radiance/tile/RadianceTile.scala index d16084d..d74e91b 100644 --- a/src/main/scala/radiance/tile/RadianceTile.scala +++ b/src/main/scala/radiance/tile/RadianceTile.scala @@ -102,6 +102,7 @@ case class VortexCoreParams( tensorCoreFP16: Boolean = false, // FP16 if true, FP32 if false tensorCoreDecoupled: Boolean = false, // hopper-style SMEM operand decoupling tensorCoreBlackwell: Boolean = false, // blackwell-style TMEM + SMEM tensor core + startupAddress: BigInt = BigInt("10100", 16), // initial warp PC programmed through startup DCRs debugROB: Boolean = false, // if enabled, uses a C++ debug ROB to generate trace-with-wdata haveCease: Boolean = true, // non-standard CEASE instruction haveSimTimeout: Boolean = true // add plusarg for simulation timeout @@ -292,50 +293,30 @@ class RadianceTile private ( masters = Seq(TLMasterParameters.v2( name = s"rad_tc_${radianceParams.coreId}_$i", sourceId = IdRange(0, 1 << smemSourceWidth), - supports = TLSlaveToMasterTransferSizes( - probe = TransferSizes(1, tcSmemSize), - get = TransferSizes(1, tcSmemSize), - ), - requestFifo = true - )) - ))) - } - - val tmemNodes = Seq.tabulate(if (radianceParams.core.tensorCoreBlackwell) 2 else 0) { i => - TLClientNode(Seq(TLMasterPortParameters.v2( - masters = Seq(TLMasterParameters.v2( - name = s"rad_tmem_${radianceParams.coreId}_$i", - sourceId = IdRange(0, 1 << smemSourceWidth), supports = TLSlaveToMasterTransferSizes( probe = TransferSizes(1, tcSmemSize), get = TransferSizes(1, tcSmemSize), putFull = TransferSizes(1, tcSmemSize), - putPartial = TransferSizes(1, tcSmemSize), ), requestFifo = true )) ))) } - val tmemNode = if (radianceParams.core.tensorCoreBlackwell) { - Some(LazyModule(new TLRAM( - address = AddressSet(0x0, 0x3fff), - beatBytes = tcSmemSize + // For Blackwell, tcSmemNodes accesses SMEM (bwgmma B operand) + // tcGmemNode provides global memory access for cp (global→tmem) and cb (tmem→global) + val tcGmemNode = if (radianceParams.core.tensorCoreBlackwell) Some(TLClientNode(Seq( + TLMasterPortParameters.v2(masters = Seq(TLMasterParameters.v2( + name = s"rad_tc_gmem_${radianceParams.coreId}", + sourceId = IdRange(0, 1 << dmemSourceWidth), + supports = TLSlaveToMasterTransferSizes( + probe = TransferSizes(1, tcSmemSize), + get = TransferSizes(1, tcSmemSize), + putFull = TransferSizes(1, tcSmemSize), + ), + requestFifo = true ))) - } else { - None - } - val tmemXbar = if (radianceParams.core.tensorCoreBlackwell) { - Some(LazyModule(new TLXbar)) - } else { - None - } - (tmemNode, tmemXbar) match { - case (Some(tmem), Some(xbar)) => - tmem.node :=* xbar.node - tmemNodes.foreach { node => xbar.node :=* node } - case _ => - } + ))) else None // combine outgoing per-lane dmemNode into 1 idenity node // @@ -425,6 +406,7 @@ class RadianceTile private ( // imemNodes.foreach { tlMasterXbar.node := TLWidthWidget(4) := _ } tlMasterXbar.node :=* AddressOrNode(base) :=* icacheNode tlMasterXbar.node :=* AddressOrNode(base) :=* dcacheNode + tcGmemNode.foreach { n => tlMasterXbar.node := AddressOrNode(base) := n } } /* below are copied from rocket */ @@ -828,12 +810,12 @@ class RadianceTileModuleImp(outer: RadianceTile) adapter.io.outResp <> client._1.d adapter } - core.io.tc_a_ready := Cat(adapters.last.io.inReq.ready, adapters.head.io.inReq.ready) - core.io.tc_d_valid := Cat(adapters.last.io.inResp.valid, adapters.head.io.inResp.valid) - core.io.tc_d_bits_data := Cat(adapters.last.io.inResp.bits.data, adapters.head.io.inResp.bits.data) - core.io.tc_d_bits_tag := Cat(adapters.last.io.inResp.bits.source, adapters.head.io.inResp.bits.source) - require(core.io.tc_d_bits_data.widthOption.get == adapters.head.io.inResp.bits.data.widthOption.get * 2) - require(core.io.tc_d_bits_tag.widthOption.get == adapters.head.io.inResp.bits.source.widthOption.get * 2) + core.io.tc_a_ready := Cat(0.U(1.W), adapters.last.io.inReq.ready, adapters.head.io.inReq.ready) + core.io.tc_d_valid := Cat(0.U(1.W), adapters.last.io.inResp.valid, adapters.head.io.inResp.valid) + core.io.tc_d_bits_data := Cat(0.U((32 * 8).W), adapters.last.io.inResp.bits.data, adapters.head.io.inResp.bits.data) + core.io.tc_d_bits_tag := Cat(0.U(outer.tensorTagWidth.W), adapters.last.io.inResp.bits.source, adapters.head.io.inResp.bits.source) + require(core.io.tc_d_bits_data.widthOption.get == adapters.head.io.inResp.bits.data.widthOption.get * 3) + require(core.io.tc_d_bits_tag.widthOption.get == adapters.head.io.inResp.bits.source.widthOption.get * 3) } else { core.io.tc_a_ready := false.B core.io.tc_d_valid := false.B @@ -844,66 +826,82 @@ class RadianceTileModuleImp(outer: RadianceTile) def connectTensorBlackwell = { if (outer.radianceParams.core.tensorCoreBlackwell) { - require(outer.tmemNodes.nonEmpty) require(outer.tcSmemNodes.nonEmpty) - val bundles = Seq( - (outer.tmemNodes.head, new { - val addr = core.io.tc_a_bits_address(31, 0) - val tag = core.io.tc_a_bits_tag(outer.tensorTagWidth - 1, 0) - val write = core.io.tc_a_bits_write(0) - val mask = core.io.tc_a_bits_mask(31, 0) - val data = core.io.tc_a_bits_data(255, 0) - val aValid = core.io.tc_a_valid(0) - val dReady = core.io.tc_d_ready(0) - }), - (outer.tcSmemNodes.head, new { - val addr = core.io.tc_a_bits_address(63, 32) - val tag = core.io.tc_a_bits_tag(4 + outer.tensorTagWidth - 1, 4) - val write = core.io.tc_a_bits_write(1) - val mask = core.io.tc_a_bits_mask(63, 32) - val data = core.io.tc_a_bits_data(511, 256) - val aValid = core.io.tc_a_valid(1) - val dReady = core.io.tc_d_ready(1) - }) - ) + // TMEM C matrix: direct SRAM (no TileLink), connected via VortexCore IO + // Each warp needs 2 tiles (A + C), each tile = 32 frags × 32B = 1KB + val tmemDepth = outer.numWarps * outer.tcSmemSize * 2 // numWarps × 64 rows + val tmem = Module(new radiance.memory.TwoReadOneWriteSyncMem( + tmemDepth, UInt((outer.tcSmemSize * 8).W))) + tmem.io.ren0 := core.io.tc_tmem_C_ren + tmem.io.raddr0 := core.io.tc_tmem_C_raddr + core.io.tc_tmem_C_rdata := tmem.io.rdata0 + tmem.io.ren1 := false.B + tmem.io.raddr1 := 0.U + tmem.io.wen := core.io.tc_tmem_C_wen + tmem.io.waddr := core.io.tc_tmem_C_waddr + tmem.io.wdata := core.io.tc_tmem_C_wdata + tmem.io.mask := core.io.tc_tmem_C_mask - val adapters = bundles.map { case (node, bundle) => - val client = node.out.head - val adapter = Module( - new VortexTLAdapter( - outer.smemSourceWidth, - new VortexBundleA(tagWidth = outer.tensorTagWidth, dataWidth = 32 * 8), - new VortexBundleD(tagWidth = outer.tensorTagWidth, dataWidth = 32 * 8), - client - ) - ) - require(adapter.io.inReq.bits.source.widthOption.get == bundle.tag.widthOption.get) - require(adapter.io.inReq.bits.address.widthOption.get == bundle.addr.widthOption.get) - adapter.io.inReq.bits <> DontCare - adapter.io.inReq.valid := bundle.aValid - adapter.io.inReq.bits.address := bundle.addr - adapter.io.inReq.bits.source := bundle.tag - adapter.io.inReq.bits.size := 5.U - adapter.io.inReq.bits.opcode := Mux(bundle.write.asBool, TLMessages.PutFullData, TLMessages.Get) - adapter.io.inReq.bits.mask := bundle.mask - adapter.io.inReq.bits.data := bundle.data - adapter.io.inResp.ready := bundle.dReady - - client._1.a <> adapter.io.outReq - adapter.io.outResp <> client._1.d - adapter + // smem_B (port 2): Global Memory via TileLink + val smemBBundle = new { + val addr = core.io.tc_a_bits_address(95, 64) + val tag = core.io.tc_a_bits_tag(8 + outer.tensorTagWidth - 1, 8) + val write = core.io.tc_a_bits_write(2) + val mask = core.io.tc_a_bits_mask(95, 64) + val data = core.io.tc_a_bits_data(767, 512) + val aValid = core.io.tc_a_valid(2) + val dReady = core.io.tc_d_ready(2) } + val client = outer.tcSmemNodes.head.out.head + val adapter = Module(new VortexTLAdapter( + outer.smemSourceWidth, + new VortexBundleA(tagWidth = outer.tensorTagWidth, dataWidth = 32 * 8), + new VortexBundleD(tagWidth = outer.tensorTagWidth, dataWidth = 32 * 8), + client + )) + adapter.io.inReq.bits <> DontCare + adapter.io.inReq.valid := smemBBundle.aValid + adapter.io.inReq.bits.address := smemBBundle.addr + adapter.io.inReq.bits.source := smemBBundle.tag + adapter.io.inReq.bits.size := 5.U + adapter.io.inReq.bits.opcode := Mux(smemBBundle.write.asBool, TLMessages.PutFullData, TLMessages.Get) + adapter.io.inReq.bits.mask := smemBBundle.mask + adapter.io.inReq.bits.data := smemBBundle.data + adapter.io.inResp.ready := smemBBundle.dReady + client._1.a <> adapter.io.outReq + adapter.io.outResp <> client._1.d - core.io.tc_a_ready := Cat(adapters.last.io.inReq.ready, adapters.head.io.inReq.ready) - core.io.tc_d_valid := Cat(adapters.last.io.inResp.valid, adapters.head.io.inResp.valid) - core.io.tc_d_bits_data := Cat(adapters.last.io.inResp.bits.data, adapters.head.io.inResp.bits.data) - core.io.tc_d_bits_tag := Cat(adapters.last.io.inResp.bits.source, adapters.head.io.inResp.bits.source) + // port 0: global memory (cp/cb) + val gmemClient = outer.tcGmemNode.get.out.head + val gmemAdapter = Module(new VortexTLAdapter( + outer.dmemSourceWidth, + new VortexBundleA(tagWidth = outer.tensorTagWidth, dataWidth = 32 * 8), + new VortexBundleD(tagWidth = outer.tensorTagWidth, dataWidth = 32 * 8), + gmemClient + )) + gmemAdapter.io.inReq.bits <> DontCare + gmemAdapter.io.inReq.valid := core.io.tc_a_valid(0) + gmemAdapter.io.inReq.bits.address := core.io.tc_a_bits_address(31, 0) + gmemAdapter.io.inReq.bits.source := core.io.tc_a_bits_tag(outer.tensorTagWidth - 1, 0) + gmemAdapter.io.inReq.bits.size := 5.U + gmemAdapter.io.inReq.bits.opcode := Mux(core.io.tc_a_bits_write(0).asBool, TLMessages.PutFullData, TLMessages.Get) + gmemAdapter.io.inReq.bits.mask := core.io.tc_a_bits_mask(31, 0) + gmemAdapter.io.inReq.bits.data := core.io.tc_a_bits_data(255, 0) + gmemAdapter.io.inResp.ready := core.io.tc_d_ready(0) + gmemClient._1.a <> gmemAdapter.io.outReq + gmemAdapter.io.outResp <> gmemClient._1.d + + core.io.tc_a_ready := Cat(adapter.io.inReq.ready, 0.U(1.W), gmemAdapter.io.inReq.ready) + core.io.tc_d_valid := Cat(adapter.io.inResp.valid, 0.U(1.W), gmemAdapter.io.inResp.valid) + core.io.tc_d_bits_data := Cat(adapter.io.inResp.bits.data, 0.U((outer.tcSmemSize * 8).W), gmemAdapter.io.inResp.bits.data) + core.io.tc_d_bits_tag := Cat(adapter.io.inResp.bits.source, 0.U(outer.tensorTagWidth.W), gmemAdapter.io.inResp.bits.source) } else { - core.io.tc_a_ready := false.B - core.io.tc_d_valid := false.B + core.io.tc_a_ready := false.B + core.io.tc_d_valid := false.B core.io.tc_d_bits_data := DontCare - core.io.tc_d_bits_tag := DontCare + core.io.tc_d_bits_tag := DontCare + core.io.tc_tmem_C_rdata := DontCare } } @@ -993,6 +991,7 @@ class RadianceTileModuleImp(outer: RadianceTile) tensor.io.reqA.ready := false.B tensor.io.reqB.ready := false.B tensor.io.writeback.ready := false.B + dontTouch(tensor.io) } else if (outer.radianceParams.core.tensorCoreBlackwell) { val tensorNumSourceIds = (1 << outer.tensorTagWidth) val tensor = Module(new radiance.core.TensorCoreBlackwell( @@ -1007,6 +1006,8 @@ class RadianceTileModuleImp(outer: RadianceTile) tensor.io.reqA.ready := false.B tensor.io.reqB.ready := false.B tensor.io.writeback.ready := false.B + tensor.io.tmemC.rdata := DontCare + dontTouch(tensor.io) } else { if (outer.radianceParams.core.tensorCoreFP16) { val dpu = Module(new radiance.core.TensorDotProductUnit(4, half = true)) diff --git a/src/main/scala/radiance/tile/VortexCore.scala b/src/main/scala/radiance/tile/VortexCore.scala index 9f89527..91b65b3 100644 --- a/src/main/scala/radiance/tile/VortexCore.scala +++ b/src/main/scala/radiance/tile/VortexCore.scala @@ -90,17 +90,28 @@ class VortexBundle(tile: RadianceTile)(implicit p: Parameters) extends CoreBundl val smem_d_bits_data = Input(UInt((tile.numLsuLanes * 32).W)) val smem_d_ready = Output(UInt((tile.numLsuLanes * 1).W)) - val tc_a_valid = Output(UInt(2.W)) - val tc_a_bits_write = Output(UInt(2.W)) - val tc_a_bits_address = Output(UInt((2 * 32).W)) - val tc_a_bits_tag = Output(UInt((2 * 4).W)) - val tc_a_bits_mask = Output(UInt((2 * 32).W)) - val tc_a_bits_data = Output(UInt((2 * 32 * 8).W)) - val tc_a_ready = Input(UInt(2.W)) - val tc_d_valid = Input(UInt(2.W)) - val tc_d_bits_data = Input(UInt((2 * 32 * 8).W)) - val tc_d_bits_tag = Input(UInt((2 * 4).W)) - val tc_d_ready = Output(UInt(2.W)) + val tcPortCount = 3 + val tc_a_valid = Output(UInt(tcPortCount.W)) + val tc_a_bits_write = Output(UInt(tcPortCount.W)) + val tc_a_bits_address = Output(UInt((tcPortCount * 32).W)) + val tc_a_bits_tag = Output(UInt((tcPortCount * 4).W)) + val tc_a_bits_mask = Output(UInt((tcPortCount * 32).W)) + val tc_a_bits_data = Output(UInt((tcPortCount * 32 * 8).W)) + val tc_a_ready = Input(UInt(tcPortCount.W)) + val tc_d_valid = Input(UInt(tcPortCount.W)) + val tc_d_bits_data = Input(UInt((tcPortCount * 32 * 8).W)) + val tc_d_bits_tag = Input(UInt((tcPortCount * 4).W)) + val tc_d_ready = Output(UInt(tcPortCount.W)) + + // Direct SRAM port for TMEM C (bypasses TileLink) + val numLanes = tile.numLsuLanes + val tc_tmem_C_wen = Output(Bool()) + val tc_tmem_C_ren = Output(Bool()) + val tc_tmem_C_waddr = Output(UInt(9.W)) + val tc_tmem_C_raddr = Output(UInt(9.W)) + val tc_tmem_C_wdata = Output(UInt((numLanes * 32).W)) + val tc_tmem_C_mask = Output(UInt((numLanes * 4).W)) + val tc_tmem_C_rdata = Input(UInt((numLanes * 32).W)) // FIXME: hardcoded val barrierIdBits = tile.barrierMasterNode.out(0)._2.barrierIdBits @@ -135,8 +146,7 @@ class Vortex(tile: RadianceTile)(implicit p: Parameters) Map( "CORE_ID" -> tile.radianceParams.coreId, "TENSOR_FP16" -> (if (tile.radianceParams.core.tensorCoreFP16) 1 else 0), - // TODO: can we get this as a parameter? - "BOOTROM_HANG100" -> 0x10100, + "STARTUP_ADDR" -> tile.radianceParams.core.startupAddress, "NUM_THREADS" -> tile.numLsuLanes ) ) @@ -449,7 +459,9 @@ class Vortex(tile: RadianceTile)(implicit p: Parameters) addResource("/vsrc/vortex/hw/rtl/core/VX_uop_sequencer.sv") addResource("/vsrc/vortex/hw/rtl/core/VX_reduce_unit.sv") - addResource("/vsrc/vortex/hw/rtl/fpu/VX_tensor_dpu.sv") + if (!tile.radianceParams.core.tensorCoreBlackwell) { + addResource("/vsrc/vortex/hw/rtl/fpu/VX_tensor_dpu.sv") + } if (tile.radianceParams.useVxCache) { addResource("/vsrc/vortex/hw/rtl/libs/VX_pending_size.sv") diff --git a/src/test/scala/radiance/TensorCoreBlackwellExtendedTest.scala b/src/test/scala/radiance/TensorCoreBlackwellExtendedTest.scala new file mode 100644 index 0000000..035ca63 --- /dev/null +++ b/src/test/scala/radiance/TensorCoreBlackwellExtendedTest.scala @@ -0,0 +1,338 @@ +package radiance.core + +import chisel3._ +import chiseltest._ +import chiseltest.simulator.VerilatorBackendAnnotation +import org.scalatest.flatspec.AnyFlatSpec + +import scala.collection.mutable + +class TensorCoreBlackwellExtendedTest extends AnyFlatSpec with ChiselScalatestTester { + behavior of "TensorCoreBlackwell Extended Tests" + + private val numWarps = 4 + private val numLanes = 8 + private val fragBytes = 32 + + private def idleIO(c: TensorCoreBlackwell): Unit = { + c.io.initiate.valid.poke(false.B) + c.io.respA.valid.poke(false.B) + c.io.respB.valid.poke(false.B) + c.io.respA.bits.source.poke(0.U) + c.io.respB.bits.source.poke(0.U) + c.io.respA.bits.data.poke(0.U) + c.io.respB.bits.data.poke(0.U) + c.io.reqA.ready.poke(false.B) + c.io.reqB.ready.poke(false.B) + c.io.respC.poke(0.U) + c.io.writeback.ready.poke(false.B) + c.io.tmemC.rdata.poke(0.U) + } + + private def packWords(words: Seq[BigInt], width: Int): BigInt = { + val mask = (BigInt(1) << width) - 1 + words.zipWithIndex.foldLeft(BigInt(0)) { + case (acc, (word, i)) => acc | ((word & mask) << (i * width)) + } + } + + private def makeTmem() = mutable.Map[BigInt, BigInt]().withDefaultValue(BigInt(0)) + + private def stepTmem(c: TensorCoreBlackwell, tmem: mutable.Map[BigInt, BigInt]): Unit = { + if (c.io.tmemC.ren.peek().litToBoolean) { + val addr = c.io.tmemC.raddr.peek().litValue + c.io.tmemC.rdata.poke(tmem(addr).U) + } + if (c.io.tmemC.wen.peek().litToBoolean) { + val addr = c.io.tmemC.waddr.peek().litValue + tmem(addr) = c.io.tmemC.wdata.peek().litValue + } + } + + it should "verify bwgmma address offset with non-zero base addresses" in { + test(new TensorCoreBlackwell(numWarps, numLanes, half = true, numSourceIds = 4)) + .withAnnotations(Seq(VerilatorBackendAnnotation)) { c => + idleIO(c) + val tmem = makeTmem() + + // Use non-zero base addresses to verify offset calculation + val aBase = BigInt(0x200) // row 16, A tile rows 16~47 + val cBase = BigInt(0x600) // row 48, C tile rows 48~79 (no overlap with A) + val bBase = BigInt(0x800) + + val fp16One = BigInt(0x3c00) + val fp32Zero = BigInt(0) + // 4 sets × 8 dot products × (1.0 × 2.0) = 64.0f + val fp32SixtyFour = BigInt(0x42800000L) + + // Populate TMEM A at offset aBase (all 32 frags) + val aFrag = packWords(Seq.fill(16)(fp16One), 16) + val cFrag = packWords(Seq.fill(numLanes)(fp32Zero), 32) + for (i <- 0 until 32) { + tmem(aBase / fragBytes + i) = aFrag + tmem(cBase / fragBytes + i) = cFrag + } + + // SMEM B with fp16 2.0 + val fp16Two = BigInt(0x4000) + val bFrag = packWords(Seq.fill(16)(fp16Two), 16) + val bMem = mutable.Map[BigInt, BigInt]().withDefaultValue(bFrag) + for (i <- 0 until 32) bMem(bBase + i * fragBytes) = bFrag + + c.io.reqB.ready.poke(true.B) + c.io.writeback.ready.poke(true.B) + + c.io.initiate.valid.poke(true.B) + c.io.initiate.bits.op.poke(0.U) + c.io.initiate.bits.wid.poke(0.U) + c.io.initiate.bits.rd.poke(0.U) + c.io.initiate.bits.addressA.poke(aBase.U) + c.io.initiate.bits.addressB.poke(bBase.U) + c.io.initiate.bits.addressC.poke(cBase.U) + c.clock.step() + c.io.initiate.valid.poke(false.B) + + var pendingB = Option.empty[(BigInt, BigInt)] + var sawWriteback = false + + for (_ <- 0 until 50000 if !sawWriteback) { + stepTmem(c, tmem) + pendingB.foreach { case (src, data) => + c.io.respB.valid.poke(true.B) + c.io.respB.bits.source.poke(src.U) + c.io.respB.bits.data.poke(data.U) + } + if (pendingB.isEmpty) c.io.respB.valid.poke(false.B) + + if (c.io.writeback.valid.peek().litToBoolean) { + sawWriteback = true + } else { + val nextB = if (c.io.reqB.valid.peek().litToBoolean) { + val addr = c.io.reqB.bits.address.peek().litValue + val src = c.io.reqB.bits.source.peek().litValue + Some((src, bMem(addr))) + } else None + c.clock.step() + pendingB = nextB + } + } + + assert(sawWriteback, "BWGMMA did not complete") + val expectedC = packWords(Seq.fill(numLanes)(fp32SixtyFour), 32) + for (i <- 0 until 8) { + val row = cBase / fragBytes + i + assert(tmem(row) == expectedC, + s"C frag $i at row $row: got 0x${tmem(row).toString(16)}, expected 0x${expectedC.toString(16)}") + } + for (i <- 0 until 8) { + assert(tmem(aBase / fragBytes + i) == aFrag, s"A frag $i should be unchanged") + } + } + } + + it should "cp then ld round-trip: data written via cp is readable via ld" in { + test(new TensorCoreBlackwell(numWarps, numLanes, half = true, numSourceIds = 4)) { c => + idleIO(c) + val tmem = makeTmem() + val tmemAddr = BigInt(0x100) + val cpData = packWords(Seq.tabulate(numLanes)(i => BigInt(0xABCD0000L + i)), 32) + + // Issue cp: global mem -> tmem + c.io.initiate.valid.poke(true.B) + c.io.initiate.bits.op.poke(2.U) + c.io.initiate.bits.addressA.poke(tmemAddr.U) + c.io.initiate.bits.addressB.poke("h10000000".U) + c.io.reqA.ready.poke(true.B) + c.clock.step() + c.io.initiate.valid.poke(false.B) + + // cpRead: reqA issued + c.io.reqA.valid.expect(true.B) + c.io.reqA.bits.rw.expect(false.B) + c.clock.step() + + // cpWrite: respA fires, tmemC written + c.io.respA.valid.poke(true.B) + c.io.respA.bits.data.poke(cpData.U) + c.io.tmemC.wen.expect(true.B) + c.io.tmemC.waddr.expect((tmemAddr / fragBytes).U) + c.io.tmemC.wdata.expect(cpData.U) + stepTmem(c, tmem) + c.clock.step() + c.io.respA.valid.poke(false.B) + + // Now issue ld from same tmem address + c.io.initiate.valid.poke(true.B) + c.io.initiate.bits.op.poke(4.U) + c.io.initiate.bits.rd.poke(2.U) + c.io.initiate.bits.addressA.poke(tmemAddr.U) + c.io.writeback.ready.poke(true.B) + c.clock.step() + c.io.initiate.valid.poke(false.B) + + // ldReq: ren asserted, serve from tmem model + c.io.tmemC.ren.expect(true.B) + c.io.tmemC.rdata.poke(tmem(tmemAddr / fragBytes).U) + c.clock.step() + c.io.tmemC.rdata.poke(tmem(tmemAddr / fragBytes).U) + c.clock.step() + + // writeback should carry cpData + c.io.writeback.valid.expect(true.B) + for (i <- 0 until numLanes) { + c.io.writeback.bits.data(i).expect((BigInt(0xABCD0000L) + i).U) + } + } + } + + it should "st then cb round-trip: data written via st is readable via cb" in { + test(new TensorCoreBlackwell(numWarps, numLanes, half = true, numSourceIds = 4)) { c => + idleIO(c) + val tmem = makeTmem() + val tmemAddr = BigInt(0x140) + val stData = packWords(Seq.tabulate(numLanes)(i => BigInt(0xDEAD0000L + i)), 32) + + // Issue st: respC -> tmem + c.io.initiate.valid.poke(true.B) + c.io.initiate.bits.op.poke(5.U) + c.io.initiate.bits.rd.poke(4.U) + c.io.initiate.bits.addressA.poke(tmemAddr.U) + c.io.respC.poke(stData.U) + c.clock.step() + c.io.initiate.valid.poke(false.B) + + // stReq: reqC valid + c.io.reqC.valid.expect(true.B) + c.clock.step() + + // stWrite: tmemC written + c.io.tmemC.wen.expect(true.B) + c.io.tmemC.wdata.expect(stData.U) + stepTmem(c, tmem) + c.clock.step() + + // Issue cb: tmem -> global mem + c.io.initiate.valid.poke(true.B) + c.io.initiate.bits.op.poke(6.U) + c.io.initiate.bits.addressA.poke(tmemAddr.U) + c.io.initiate.bits.addressB.poke("h20000000".U) + c.io.reqA.ready.poke(true.B) + c.io.tmemC.rdata.poke(tmem(tmemAddr / fragBytes).U) + c.clock.step() + c.io.initiate.valid.poke(false.B) + + // cbRead: ren asserted + c.io.tmemC.ren.expect(true.B) + c.io.tmemC.rdata.poke(tmem(tmemAddr / fragBytes).U) + c.clock.step() + + // cbWrite: reqA write with stData + c.io.reqA.valid.expect(true.B) + c.io.reqA.bits.rw.expect(true.B) + c.io.reqA.bits.address.expect("h20000000".U) + c.io.reqA.bits.data.expect(stData.U) + } + } + + it should "wait ops are no-ops and do not stall pipeline" in { + test(new TensorCoreBlackwell(numWarps, numLanes, half = true, numSourceIds = 4)) { c => + idleIO(c) + + // bwgmmaWait: should accept immediately and stay idle + c.io.initiate.valid.poke(true.B) + c.io.initiate.bits.op.poke(1.U) // bwgmmaWait + c.io.initiate.ready.expect(true.B) + c.clock.step() + c.io.initiate.valid.poke(false.B) + c.io.writeback.valid.expect(false.B) + c.io.reqA.valid.expect(false.B) + c.io.reqB.valid.expect(false.B) + + // tcgen05CpWait: same + c.io.initiate.valid.poke(true.B) + c.io.initiate.bits.op.poke(3.U) // tcgen05CpWait + c.io.initiate.ready.expect(true.B) + c.clock.step() + c.io.initiate.valid.poke(false.B) + c.io.writeback.valid.expect(false.B) + c.io.reqA.valid.expect(false.B) + } + } + + it should "not accept a second tensor op until the first one completes" in { + test(new TensorCoreBlackwell(numWarps, numLanes, half = true, numSourceIds = 4)) { c => + idleIO(c) + val firstAddr = BigInt(0x180) + val secondAddr = BigInt(0x1a0) + val storeData = packWords(Seq.tabulate(numLanes)(i => BigInt(0xCAFE0000L + i)), 32) + + c.io.initiate.valid.poke(true.B) + c.io.initiate.bits.op.poke(5.U) + c.io.initiate.bits.addressA.poke(firstAddr.U) + c.io.respC.poke(storeData.U) + c.io.initiate.ready.expect(true.B) + c.clock.step() + + c.io.initiate.bits.op.poke(4.U) + c.io.initiate.bits.addressA.poke(secondAddr.U) + c.io.initiate.bits.rd.poke(2.U) + c.io.initiate.ready.expect(false.B) + c.clock.step() + c.io.initiate.ready.expect(false.B) + + c.io.tmemC.wen.expect(true.B) + c.clock.step() + c.io.initiate.ready.expect(true.B) + } + } + + it should "multi-warp TMEM isolation: warp 0 and warp 3 do not alias" in { + test(new TensorCoreBlackwell(numWarps, numLanes, half = true, numSourceIds = 4)) { c => + idleIO(c) + val tmem = makeTmem() + + // warp 0: tmem_slot_base(0) = 0, tmem_a_base = 0 + val warp0TmemA = BigInt(0x000) + val warp0Data = packWords(Seq.fill(numLanes)(BigInt(0xAAAAAAAAL)), 32) + + // warp 3: tmem_slot_base(3) = 3*2048 = 6144 = 0x1800, tmem_a_base = 0x1800 + val warp3TmemA = BigInt(0x1800) + val warp3Data = packWords(Seq.fill(numLanes)(BigInt(0xBBBBBBBBL)), 32) + + // Write warp 0 data via st + c.io.initiate.valid.poke(true.B) + c.io.initiate.bits.op.poke(5.U) + c.io.initiate.bits.wid.poke(0.U) + c.io.initiate.bits.addressA.poke(warp0TmemA.U) + c.io.respC.poke(warp0Data.U) + c.clock.step() + c.io.initiate.valid.poke(false.B) + c.io.reqC.valid.expect(true.B) + c.clock.step() + c.io.tmemC.wen.expect(true.B) + c.io.tmemC.waddr.expect((warp0TmemA / fragBytes).U) + stepTmem(c, tmem) + c.clock.step() + + // Write warp 3 data via st + c.io.initiate.valid.poke(true.B) + c.io.initiate.bits.op.poke(5.U) + c.io.initiate.bits.wid.poke(3.U) + c.io.initiate.bits.addressA.poke(warp3TmemA.U) + c.io.respC.poke(warp3Data.U) + c.clock.step() + c.io.initiate.valid.poke(false.B) + c.io.reqC.valid.expect(true.B) + c.clock.step() + c.io.tmemC.wen.expect(true.B) + c.io.tmemC.waddr.expect((warp3TmemA / fragBytes).U) + stepTmem(c, tmem) + c.clock.step() + + // Verify no aliasing: warp 0 row != warp 3 row + assert(warp0TmemA / fragBytes != warp3TmemA / fragBytes) + assert(tmem(warp0TmemA / fragBytes) == warp0Data) + assert(tmem(warp3TmemA / fragBytes) == warp3Data) + } + } +} diff --git a/src/test/scala/radiance/TensorCoreBlackwellTest.scala b/src/test/scala/radiance/TensorCoreBlackwellTest.scala index 3c8546b..18dd2d3 100644 --- a/src/test/scala/radiance/TensorCoreBlackwellTest.scala +++ b/src/test/scala/radiance/TensorCoreBlackwellTest.scala @@ -2,11 +2,17 @@ package radiance.core import chisel3._ import chiseltest._ +import chiseltest.simulator.VerilatorBackendAnnotation import org.scalatest.flatspec.AnyFlatSpec +import scala.collection.mutable + class TensorCoreBlackwellTest extends AnyFlatSpec with ChiselScalatestTester { behavior of "TensorCoreBlackwell" + private val numWarps = 4 + private val numLanes = 8 + private def idleIO(c: TensorCoreBlackwell): Unit = { c.io.initiate.valid.poke(false.B) c.io.respA.valid.poke(false.B) @@ -15,111 +21,261 @@ class TensorCoreBlackwellTest extends AnyFlatSpec with ChiselScalatestTester { c.io.respB.bits.source.poke(0.U) c.io.respA.bits.data.poke(0.U) c.io.respB.bits.data.poke(0.U) + c.io.reqA.ready.poke(false.B) + c.io.reqB.ready.poke(false.B) c.io.respC.poke(0.U) c.io.writeback.ready.poke(false.B) + c.io.tmemC.rdata.poke(0.U) } - it should "run a minimal BWGMMA path" in { - test(new TensorCoreBlackwell(8, 8, numSourceIds = 4, half = true)) { c => - idleIO(c) - - c.io.initiate.valid.poke(true.B) - c.io.initiate.bits.op.poke(0.U) - c.io.initiate.bits.wid.poke(1.U) - c.io.initiate.bits.rd.poke(3.U) - c.io.initiate.bits.addressA.poke(0x40.U) - c.io.initiate.bits.addressB.poke(0x80.U) - c.io.reqA.ready.poke(true.B) - c.io.reqB.ready.poke(true.B) - c.io.respC.poke("h0000000800000007000000060000000500000004000000030000000200000001".U) - c.clock.step() - - c.io.initiate.valid.poke(false.B) - c.io.reqA.valid.expect(true.B) - c.io.reqB.valid.expect(true.B) - c.clock.step() - - c.io.respA.valid.poke(true.B) - c.io.respB.valid.poke(true.B) - c.io.respA.bits.data.poke("h0000000800000007000000060000000500000004000000030000000200000001".U) - c.io.respB.bits.data.poke("h000000100000000f0000000e0000000d0000000c0000000b0000000a00000009".U) - c.clock.step() - - c.io.respA.valid.poke(false.B) - c.io.respB.valid.poke(false.B) - c.clock.step() - c.clock.step() - c.io.writeback.valid.expect(true.B) - c.io.writeback.bits.rd.expect(3.U) - c.io.writeback.bits.wid.expect(1.U) - c.io.writeback.ready.poke(true.B) - c.clock.step() + private def packWords(words: Seq[BigInt], width: Int): BigInt = { + val mask = (BigInt(1) << width) - 1 + words.zipWithIndex.foldLeft(BigInt(0)) { + case (acc, (word, i)) => acc | ((word & mask) << (i * width)) } } - it should "copy from SMEM to TMEM on TCGEN05_CP" in { - test(new TensorCoreBlackwell(8, 8, numSourceIds = 4, half = true)) { c => + // Simple TMEM model: address → 256-bit row + private def makeTmem() = mutable.Map[BigInt, BigInt]().withDefaultValue(BigInt(0)) + + // Drive tmemC read response from model, handle write + private def stepTmem(c: TensorCoreBlackwell, tmem: mutable.Map[BigInt, BigInt]): Unit = { + if (c.io.tmemC.ren.peek().litToBoolean) { + val addr = c.io.tmemC.raddr.peek().litValue + c.io.tmemC.rdata.poke(tmem(addr).U) + } + if (c.io.tmemC.wen.peek().litToBoolean) { + val addr = c.io.tmemC.waddr.peek().litValue + tmem(addr) = c.io.tmemC.wdata.peek().litValue + } + } + + it should "tcgen05_ld: read from TMEM to writeback" in { + test(new TensorCoreBlackwell(numWarps, numLanes, half = true, numSourceIds = 4)) { c => idleIO(c) + val tmem = makeTmem() + val fragBytes = 32 + val tmemAddr = BigInt(0x40) // row 2 (0x40 / 32 = 2) + val testData = packWords(Seq.tabulate(numLanes)(i => BigInt(0x1000 + i)), 32) + tmem(tmemAddr / fragBytes) = testData c.io.initiate.valid.poke(true.B) - c.io.initiate.bits.op.poke(2.U) + c.io.initiate.bits.op.poke(4.U) // tcgen05Ld c.io.initiate.bits.wid.poke(0.U) - c.io.initiate.bits.rd.poke(0.U) - c.io.initiate.bits.addressA.poke(0x100.U) - c.io.initiate.bits.addressB.poke(0x200.U) - c.io.reqB.ready.poke(true.B) + c.io.initiate.bits.rd.poke(3.U) + c.io.initiate.bits.addressA.poke(tmemAddr.U) + c.io.writeback.ready.poke(true.B) + c.io.tmemC.rdata.poke(testData.U) + c.clock.step() + c.io.initiate.valid.poke(false.B) + c.io.initiate.ready.expect(false.B) + + // ldReq: tmemC.ren asserted; rdata must be valid before next step + c.io.tmemC.ren.expect(true.B) + c.io.tmemC.raddr.expect((tmemAddr / fragBytes).U) + c.io.tmemC.rdata.poke(testData.U) c.clock.step() - c.io.initiate.valid.poke(false.B) - c.io.reqB.valid.expect(true.B) - c.io.respB.valid.poke(true.B) - c.io.respB.bits.data.poke("hdeadbeef".U) - c.io.reqA.ready.poke(true.B) + // waitWb: wbValid gets set this cycle, step to let it register + c.io.tmemC.rdata.poke(testData.U) c.clock.step() - c.io.reqA.valid.expect(true.B) - c.io.reqA.bits.rw.expect(true.B) - c.io.reqA.bits.address.expect(0x100.U) + + // idle: writeback.valid now true + c.io.writeback.valid.expect(true.B) + c.io.initiate.ready.expect(false.B) + c.io.writeback.bits.rd.expect(3.U) + c.io.writeback.bits.wid.expect(0.U) + for (i <- 0 until numLanes) { + c.io.writeback.bits.data(i).expect((0x1000 + i).U) + } } } - it should "load and store fragments through TMEM" in { - test(new TensorCoreBlackwell(8, 8, numSourceIds = 4, half = true)) { c => + it should "tcgen05_st: write from respC to TMEM" in { + test(new TensorCoreBlackwell(numWarps, numLanes, half = true, numSourceIds = 4)) { c => idleIO(c) + val fragBytes = 32 + val tmemAddr = BigInt(0x60) + val storeData = packWords(Seq.tabulate(numLanes)(i => BigInt(0xAB00 + i)), 32) c.io.initiate.valid.poke(true.B) - c.io.initiate.bits.op.poke(4.U) - c.io.initiate.bits.wid.poke(2.U) - c.io.initiate.bits.rd.poke(5.U) - c.io.initiate.bits.addressA.poke(0x300.U) - c.io.initiate.bits.addressB.poke(0.U) - c.io.reqA.ready.poke(true.B) + c.io.initiate.bits.op.poke(5.U) // tcgen05St + c.io.initiate.bits.wid.poke(0.U) + c.io.initiate.bits.rd.poke(7.U) + c.io.initiate.bits.addressA.poke(tmemAddr.U) + c.io.respC.poke(storeData.U) c.clock.step() - c.io.initiate.valid.poke(false.B) - c.clock.step() - c.io.respA.valid.poke(true.B) - c.io.respA.bits.data.poke("h1234".U) - c.clock.step() - c.io.respA.valid.poke(false.B) - c.clock.step() - c.io.writeback.valid.expect(true.B) - c.io.writeback.bits.rd.expect(5.U) - c.io.writeback.ready.poke(true.B) + c.io.initiate.ready.expect(false.B) + + // stReq: reqC.valid asserted + c.io.reqC.valid.expect(true.B) + c.io.reqC.bits.expect(7.U) c.clock.step() - idleIO(c) - c.io.initiate.valid.poke(true.B) - c.io.initiate.bits.op.poke(5.U) - c.io.initiate.bits.wid.poke(2.U) - c.io.initiate.bits.rd.poke(6.U) - c.io.initiate.bits.addressA.poke(0x340.U) - c.io.initiate.bits.addressB.poke(0.U) - c.io.reqA.ready.poke(true.B) - c.io.respC.poke("habcd".U) + // stWrite: tmemC.wen asserted with storeData + c.io.tmemC.wen.expect(true.B) + c.io.tmemC.waddr.expect((tmemAddr / fragBytes).U) + c.io.tmemC.wdata.expect(storeData.U) c.clock.step() + c.io.initiate.ready.expect(true.B) + } + } + + it should "tcgen05_cp: read from global mem (reqA) and write to TMEM" in { + test(new TensorCoreBlackwell(numWarps, numLanes, half = true, numSourceIds = 4)) { c => + idleIO(c) + val fragBytes = 32 + val tmemAddr = BigInt(0x80) + val gmemAddr = "ha0001000" + val cpData = packWords(Seq.fill(numLanes)(BigInt(0xdeadbeefL)), 32) + + c.io.initiate.valid.poke(true.B) + c.io.initiate.bits.op.poke(2.U) // tcgen05Cp + c.io.initiate.bits.addressA.poke(tmemAddr.U) + c.io.initiate.bits.addressB.poke(gmemAddr.U) + c.io.reqA.ready.poke(true.B) + c.clock.step() + c.io.initiate.valid.poke(false.B) + c.io.initiate.ready.expect(false.B) + + // cpRead: reqA issued to global mem + c.io.reqA.valid.expect(true.B) + c.io.reqA.bits.rw.expect(false.B) + c.io.reqA.bits.address.expect(gmemAddr.U) + c.clock.step() + c.io.initiate.ready.expect(false.B) + + // cpWrite: respA fires → tmemC.wen in same cycle + c.io.respA.valid.poke(true.B) + c.io.respA.bits.data.poke(cpData.U) + + // tmemC write happens combinatorially when respA fires + c.io.tmemC.wen.expect(true.B) + c.io.tmemC.waddr.expect((tmemAddr / fragBytes).U) + c.io.tmemC.wdata.expect(cpData.U) + c.clock.step() + c.io.initiate.ready.expect(true.B) + } + } + + it should "tcgen05_cb: read from TMEM and write to global mem (reqA)" in { + test(new TensorCoreBlackwell(numWarps, numLanes, half = true, numSourceIds = 4)) { c => + idleIO(c) + val fragBytes = 32 + val tmemAddr = BigInt(0xa0) + val gmemAddr = "ha2000000" + val cbData = packWords(Seq.tabulate(numLanes)(i => BigInt(0xC000 + i)), 32) + + c.io.initiate.valid.poke(true.B) + c.io.initiate.bits.op.poke(6.U) // tcgen05Cb + c.io.initiate.bits.addressA.poke(tmemAddr.U) + c.io.initiate.bits.addressB.poke(gmemAddr.U) + c.io.reqA.ready.poke(true.B) + c.io.tmemC.rdata.poke(cbData.U) + c.clock.step() + c.io.initiate.valid.poke(false.B) + c.io.initiate.ready.expect(false.B) + + // cbRead: tmemC.ren asserted + c.io.tmemC.ren.expect(true.B) + c.io.tmemC.raddr.expect((tmemAddr / fragBytes).U) + c.clock.step() + c.io.initiate.ready.expect(false.B) + + // cbWrite: reqA write to global mem c.io.reqA.valid.expect(true.B) c.io.reqA.bits.rw.expect(true.B) - c.io.reqA.bits.address.expect(0x340.U) + c.io.reqA.bits.address.expect(gmemAddr.U) + c.io.reqA.bits.data.expect(cbData.U) + c.clock.step() + c.io.initiate.ready.expect(false.B) + c.io.respA.valid.poke(true.B) + c.io.respA.bits.data.poke(0.U) + c.clock.step() + c.io.initiate.ready.expect(true.B) + } + } + + it should "run bwgmma: TMEM_C = TMEM_A * SMEM_B + TMEM_C" in { + test(new TensorCoreBlackwell(numWarps, numLanes, half = true, numSourceIds = 4)) + .withAnnotations(Seq(VerilatorBackendAnnotation)) { c => + idleIO(c) + + val fragBytes = 32 + val aBase = BigInt(0x100) + val bBase = BigInt(0x800) + val cBase = BigInt(0x1000) + + // A: all fp16 1.0 (0x3c00), 16 halves per frag + val fp16One = BigInt(0x3c00) + val fp16Two = BigInt(0x4000) + val fp32One = BigInt(0x3f800000) + val fp32SixtyFive = BigInt(0x42820000) + val aFrag = packWords(Seq.fill(16)(fp16One), 16) + val bFrag = packWords(Seq.fill(16)(fp16Two), 16) + val cFrag = packWords(Seq.fill(numLanes)(fp32One), 32) + val expectedCFrag = packWords(Seq.fill(numLanes)(fp32SixtyFive), 32) + + // Populate TMEM with A and C tiles + val tmem = makeTmem() + for (i <- 0 until 32) { + tmem(aBase / fragBytes + i) = aFrag + tmem(cBase / fragBytes + i) = cFrag + } + val bMem = mutable.Map[BigInt, BigInt]() + for (i <- 0 until 32) bMem(bBase + i * fragBytes) = bFrag + + c.io.reqB.ready.poke(true.B) + c.io.writeback.ready.poke(true.B) + + c.io.initiate.valid.poke(true.B) + c.io.initiate.bits.op.poke(0.U) // bwgmma + c.io.initiate.bits.wid.poke(1.U) + c.io.initiate.bits.rd.poke(0.U) + c.io.initiate.bits.addressA.poke(aBase.U) + c.io.initiate.bits.addressB.poke(bBase.U) + c.io.initiate.bits.addressC.poke(cBase.U) + c.clock.step() + c.io.initiate.valid.poke(false.B) + + var pendingB = Option.empty[(BigInt, BigInt)] + var sawWriteback = false + + for (_ <- 0 until 20000 if !sawWriteback) { + // Drive TMEM reads/writes + stepTmem(c, tmem) + + // Drive SMEM B responses + pendingB.foreach { case (src, data) => + c.io.respB.valid.poke(true.B) + c.io.respB.bits.source.poke(src.U) + c.io.respB.bits.data.poke(data.U) + } + if (pendingB.isEmpty) c.io.respB.valid.poke(false.B) + + if (c.io.writeback.valid.peek().litToBoolean) { + sawWriteback = true + } else { + val nextB = if (c.io.reqB.valid.peek().litToBoolean) { + val addr = c.io.reqB.bits.address.peek().litValue + val src = c.io.reqB.bits.source.peek().litValue + Some((src, bMem(addr))) + } else None + + c.clock.step() + pendingB = nextB + } + } + + assert(sawWriteback, "BWGMMA did not complete") + c.io.writeback.bits.wid.expect(1.U) + // Verify all 32 C frags in TMEM + for (i <- 0 until 32) { + val row = cBase / fragBytes + i + assert(tmem(row) == expectedCFrag, + s"C frag $i mismatch: got 0x${tmem(row).toString(16)}, expected 0x${expectedCFrag.toString(16)}") + } } } }