From 47d65858966dfd494fffca57da1f0001d788db56 Mon Sep 17 00:00:00 2001 From: Zhongdi LUO Date: Wed, 24 Jun 2026 06:25:10 +0000 Subject: [PATCH] Wire scalar TMEM through Radiance tile --- src/main/resources/vsrc/vortex | 2 +- .../scala/radiance/tile/RadianceTile.scala | 183 ++++++++++++++---- src/main/scala/radiance/tile/VortexCore.scala | 9 + 3 files changed, 158 insertions(+), 36 deletions(-) diff --git a/src/main/resources/vsrc/vortex b/src/main/resources/vsrc/vortex index abee301..97a1eff 160000 --- a/src/main/resources/vsrc/vortex +++ b/src/main/resources/vsrc/vortex @@ -1 +1 @@ -Subproject commit abee301b6e2eb15d9d41a7c241e95875fc185c18 +Subproject commit 97a1eff701ca71a1b93be2b4bd64ec697202cef7 diff --git a/src/main/scala/radiance/tile/RadianceTile.scala b/src/main/scala/radiance/tile/RadianceTile.scala index cd7a55e..0b9d21c 100644 --- a/src/main/scala/radiance/tile/RadianceTile.scala +++ b/src/main/scala/radiance/tile/RadianceTile.scala @@ -851,6 +851,9 @@ class RadianceTileModuleImp(outer: RadianceTile) core.io.tc_tmem_C_rready := DontCare core.io.tc_tmem_C_rdata := DontCare core.io.tc_tmem_C_wready := DontCare + core.io.sc_tmem_rready := DontCare + core.io.sc_tmem_rdata := DontCare + core.io.sc_tmem_wready := DontCare } def connectTensorBlackwell = { @@ -885,59 +888,166 @@ class RadianceTileModuleImp(outer: RadianceTile) tcDData.foreach(_ := 0.U) tcDTag.foreach(_ := 0.U) - // TMEM matrix: one shared 2R1W SRAM. read0 is operand A, read1 is C. + // TMEM matrix: four banked 2R1W SRAMs. Tensor A/C reads and scalar + // reads can proceed together when bank placement avoids conflicts. // Each warp owns 2KB: A tile and C tile are 1KB each. The row count // scales with the physical fragment width (16B for 4 lanes, 32B for 8). val tmemBytesPerWarp = 2048 val tmemDepth = outer.numWarps * (tmemBytesPerWarp / outer.tcSmemSize) - val tmem = Module(new radiance.memory.TwoReadOneWriteSyncMem( - tmemDepth, UInt((outer.tcSmemSize * 8).W))) + val tmemBanks = 4 + val tmemBankBits = log2Ceil(tmemBanks) + val tmemBankDepth = tmemDepth / tmemBanks + require(isPow2(tmemBanks)) + require(tmemDepth % tmemBanks == 0) + val tmem = Seq.fill(tmemBanks) { + Module(new radiance.memory.TwoReadOneWriteSyncMem( + tmemBankDepth, UInt((outer.tcSmemSize * 8).W))) + } - val aReadArb = Module(new RRArbiter(UInt(tmemAddrBits.W), nTC)) - val cReadArb = Module(new RRArbiter(UInt(tmemAddrBits.W), nTC)) + class TmemReadReq extends Bundle { + val addr = UInt(tmemAddrBits.W) + val src = UInt(2.W) + val tc = UInt(log2Ceil(nTC max 2).W) + } class TmemWriteReq extends Bundle { val addr = UInt(tmemAddrBits.W) val data = UInt(tmemDataBits.W) val mask = UInt(tmemMaskBits.W) - } - val cWriteArb = Module(new RRArbiter(new TmemWriteReq, nTC)) - - (0 until nTC).foreach { tc => - aReadArb.io.in(tc).valid := core.io.tc_tmem_A_ren(tc) - aReadArb.io.in(tc).bits := slice(core.io.tc_tmem_A_raddr, tmemAddrBits, tc) - cReadArb.io.in(tc).valid := core.io.tc_tmem_C_ren(tc) - cReadArb.io.in(tc).bits := slice(core.io.tc_tmem_C_raddr, tmemAddrBits, tc) - cWriteArb.io.in(tc).valid := core.io.tc_tmem_C_wen(tc) - cWriteArb.io.in(tc).bits.addr := slice(core.io.tc_tmem_C_waddr, tmemAddrBits, tc) - cWriteArb.io.in(tc).bits.data := slice(core.io.tc_tmem_C_wdata, tmemDataBits, tc) - cWriteArb.io.in(tc).bits.mask := slice(core.io.tc_tmem_C_mask, tmemMaskBits, tc) + val src = UInt(1.W) + val tc = UInt(log2Ceil(nTC max 2).W) } - aReadArb.io.out.ready := true.B - cReadArb.io.out.ready := true.B - cWriteArb.io.out.ready := true.B + def bank(addr: UInt): UInt = addr(tmemBankBits - 1, 0) + def row(addr: UInt): UInt = addr(tmemAddrBits - 1, tmemBankBits) - tmem.io.ren0 := aReadArb.io.out.fire - tmem.io.raddr0 := aReadArb.io.out.bits - tmem.io.ren1 := cReadArb.io.out.fire - tmem.io.raddr1 := cReadArb.io.out.bits - tmem.io.wen := cWriteArb.io.out.fire - tmem.io.waddr := cWriteArb.io.out.bits.addr - tmem.io.wdata := cWriteArb.io.out.bits.data - tmem.io.mask := cWriteArb.io.out.bits.mask + val aReady = Wire(Vec(nTC, Bool())) + val cReady = Wire(Vec(nTC, Bool())) + val wReady = Wire(Vec(nTC, Bool())) + val scReadReady = Wire(Bool()) + val scWriteReady = Wire(Bool()) + aReady.foreach(_ := false.B) + cReady.foreach(_ := false.B) + wReady.foreach(_ := false.B) + scReadReady := false.B + scWriteReady := false.B - val aReadGrant = RegNext(Mux(aReadArb.io.out.fire, UIntToOH(aReadArb.io.chosen, nTC), 0.U(nTC.W))) - val cReadGrant = RegNext(Mux(cReadArb.io.out.fire, UIntToOH(cReadArb.io.chosen, nTC), 0.U(nTC.W))) - core.io.tc_tmem_A_rready := VecInit(aReadArb.io.in.map(_.fire)).asUInt - core.io.tc_tmem_C_rready := VecInit(cReadArb.io.in.map(_.fire)).asUInt - core.io.tc_tmem_C_wready := VecInit(cWriteArb.io.in.map(_.fire)).asUInt + val read0Grant = Wire(Vec(tmemBanks, new TmemReadReq)) + val read1Grant = Wire(Vec(tmemBanks, new TmemReadReq)) + val read0Valid = Wire(Vec(tmemBanks, Bool())) + val read1Valid = Wire(Vec(tmemBanks, Bool())) + val writeGrant = Wire(Vec(tmemBanks, new TmemWriteReq)) + val writeValid = Wire(Vec(tmemBanks, Bool())) + read0Grant.foreach(_ := 0.U.asTypeOf(new TmemReadReq)) + read1Grant.foreach(_ := 0.U.asTypeOf(new TmemReadReq)) + read0Valid.foreach(_ := false.B) + read1Valid.foreach(_ := false.B) + writeGrant.foreach(_ := 0.U.asTypeOf(new TmemWriteReq)) + writeValid.foreach(_ := false.B) + + (0 until tmemBanks).foreach { b => + val requests = (0 until nTC).flatMap { tc => + val aAddr = slice(core.io.tc_tmem_A_raddr, tmemAddrBits, tc) + val cAddr = slice(core.io.tc_tmem_C_raddr, tmemAddrBits, tc) + Seq( + (core.io.tc_tmem_A_ren(tc).asBool && bank(aAddr) === b.U, aAddr, 0.U(2.W), tc.U), + (core.io.tc_tmem_C_ren(tc).asBool && bank(cAddr) === b.U, cAddr, 1.U(2.W), tc.U) + ) + } ++ Seq( + (core.io.sc_tmem_ren.asBool && bank(core.io.sc_tmem_raddr) === b.U, + core.io.sc_tmem_raddr, 2.U(2.W), 0.U) + ) + + var used0 = false.B + var used1 = false.B + requests.foreach { case (valid, addr, src, tc) => + val grant0 = valid && !used0 + val grant1 = valid && used0 && !used1 + when(grant0) { + read0Grant(b).addr := addr + read0Grant(b).src := src + read0Grant(b).tc := tc + } + when(grant1) { + read1Grant(b).addr := addr + read1Grant(b).src := src + read1Grant(b).tc := tc + } + used0 = used0 || grant0 + used1 = used1 || grant1 + when(grant0 || grant1) { + when(src === 0.U) { aReady(tc) := true.B } + when(src === 1.U) { cReady(tc) := true.B } + when(src === 2.U) { scReadReady := true.B } + } + } + read0Valid(b) := used0 + read1Valid(b) := used1 + + var writeUsed = false.B + (0 until nTC).foreach { tc => + val addr = slice(core.io.tc_tmem_C_waddr, tmemAddrBits, tc) + val valid = core.io.tc_tmem_C_wen(tc).asBool && bank(addr) === b.U + val grant = valid && !writeUsed + when(grant) { + writeValid(b) := true.B + writeGrant(b).addr := addr + writeGrant(b).data := slice(core.io.tc_tmem_C_wdata, tmemDataBits, tc) + writeGrant(b).mask := slice(core.io.tc_tmem_C_mask, tmemMaskBits, tc) + writeGrant(b).src := 0.U + writeGrant(b).tc := tc.U + wReady(tc) := true.B + } + writeUsed = writeUsed || grant + } + + val scWValid = core.io.sc_tmem_wen.asBool && bank(core.io.sc_tmem_waddr) === b.U + val scWGrant = scWValid && !writeUsed + when(scWGrant) { + writeValid(b) := true.B + writeGrant(b).addr := core.io.sc_tmem_waddr + writeGrant(b).data := core.io.sc_tmem_wdata + writeGrant(b).mask := core.io.sc_tmem_mask + writeGrant(b).src := 1.U + writeGrant(b).tc := 0.U + scWriteReady := true.B + } + + tmem(b).io.ren0 := read0Valid(b) + tmem(b).io.raddr0 := row(read0Grant(b).addr) + tmem(b).io.ren1 := read1Valid(b) + tmem(b).io.raddr1 := row(read1Grant(b).addr) + tmem(b).io.wen := writeValid(b) + tmem(b).io.waddr := row(writeGrant(b).addr) + tmem(b).io.wdata := writeGrant(b).data + tmem(b).io.mask := writeGrant(b).mask + } + + val read0GrantReg = RegNext(read0Grant) + val read1GrantReg = RegNext(read1Grant) + val read0ValidReg = RegNext(read0Valid) + val read1ValidReg = RegNext(read1Valid) + core.io.tc_tmem_A_rready := aReady.asUInt + core.io.tc_tmem_C_rready := cReady.asUInt + core.io.tc_tmem_C_wready := wReady.asUInt + core.io.sc_tmem_rready := scReadReady.asUInt + core.io.sc_tmem_wready := scWriteReady.asUInt core.io.tc_tmem_A_rdata := VecInit((0 until nTC).map { tc => - Mux(aReadGrant(tc), tmem.io.rdata0, 0.U(tmemDataBits.W)) + VecInit((0 until tmemBanks).map { b => + Mux(read0ValidReg(b) && read0GrantReg(b).src === 0.U && read0GrantReg(b).tc === tc.U, tmem(b).io.rdata0, + Mux(read1ValidReg(b) && read1GrantReg(b).src === 0.U && read1GrantReg(b).tc === tc.U, tmem(b).io.rdata1, 0.U(tmemDataBits.W))) + }).reduce(_ | _) }).asUInt core.io.tc_tmem_C_rdata := VecInit((0 until nTC).map { tc => - Mux(cReadGrant(tc), tmem.io.rdata1, 0.U(tmemDataBits.W)) + VecInit((0 until tmemBanks).map { b => + Mux(read0ValidReg(b) && read0GrantReg(b).src === 1.U && read0GrantReg(b).tc === tc.U, tmem(b).io.rdata0, + Mux(read1ValidReg(b) && read1GrantReg(b).src === 1.U && read1GrantReg(b).tc === tc.U, tmem(b).io.rdata1, 0.U(tmemDataBits.W))) + }).reduce(_ | _) }).asUInt + core.io.sc_tmem_rdata := VecInit((0 until tmemBanks).map { b => + Mux(read0ValidReg(b) && read0GrantReg(b).src === 2.U, tmem(b).io.rdata0, + Mux(read1ValidReg(b) && read1GrantReg(b).src === 2.U, tmem(b).io.rdata1, 0.U(tmemDataBits.W))) + }).reduce(_ | _) // port 2: SMEM B, one TL client per tensor core. RadianceSharedMem arbitrates them. (0 until nTC).foreach { tc => @@ -1025,6 +1135,9 @@ class RadianceTileModuleImp(outer: RadianceTile) core.io.tc_tmem_C_rready := DontCare core.io.tc_tmem_C_rdata := DontCare core.io.tc_tmem_C_wready := DontCare + core.io.sc_tmem_rready := DontCare + core.io.sc_tmem_rdata := DontCare + core.io.sc_tmem_wready := DontCare } } diff --git a/src/main/scala/radiance/tile/VortexCore.scala b/src/main/scala/radiance/tile/VortexCore.scala index 803322c..87518fc 100644 --- a/src/main/scala/radiance/tile/VortexCore.scala +++ b/src/main/scala/radiance/tile/VortexCore.scala @@ -120,6 +120,15 @@ class VortexBundle(tile: RadianceTile)(implicit p: Parameters) extends CoreBundl val tc_tmem_C_waddr = Output(UInt((numTensorCores * 9).W)) val tc_tmem_C_wdata = Output(UInt((numTensorCores * numLanes * 32).W)) val tc_tmem_C_mask = Output(UInt((numTensorCores * numLanes * 4).W)) + val sc_tmem_ren = Output(UInt(1.W)) + val sc_tmem_rready = Input(UInt(1.W)) + val sc_tmem_raddr = Output(UInt(9.W)) + val sc_tmem_rdata = Input(UInt((numLanes * 32).W)) + val sc_tmem_wen = Output(UInt(1.W)) + val sc_tmem_wready = Input(UInt(1.W)) + val sc_tmem_waddr = Output(UInt(9.W)) + val sc_tmem_wdata = Output(UInt((numLanes * 32).W)) + val sc_tmem_mask = Output(UInt((numLanes * 4).W)) // FIXME: hardcoded val barrierIdBits = tile.barrierMasterNode.out(0)._2.barrierIdBits