Wire scalar TMEM through Radiance tile

This commit is contained in:
Zhongdi LUO
2026-06-24 06:25:10 +00:00
parent f88085331e
commit 47d6585896
3 changed files with 158 additions and 36 deletions

View File

@@ -851,6 +851,9 @@ class RadianceTileModuleImp(outer: RadianceTile)
core.io.tc_tmem_C_rready := DontCare core.io.tc_tmem_C_rready := DontCare
core.io.tc_tmem_C_rdata := DontCare core.io.tc_tmem_C_rdata := DontCare
core.io.tc_tmem_C_wready := DontCare core.io.tc_tmem_C_wready := DontCare
core.io.sc_tmem_rready := DontCare
core.io.sc_tmem_rdata := DontCare
core.io.sc_tmem_wready := DontCare
} }
def connectTensorBlackwell = { def connectTensorBlackwell = {
@@ -885,59 +888,166 @@ class RadianceTileModuleImp(outer: RadianceTile)
tcDData.foreach(_ := 0.U) tcDData.foreach(_ := 0.U)
tcDTag.foreach(_ := 0.U) tcDTag.foreach(_ := 0.U)
// TMEM matrix: one shared 2R1W SRAM. read0 is operand A, read1 is C. // TMEM matrix: four banked 2R1W SRAMs. Tensor A/C reads and scalar
// reads can proceed together when bank placement avoids conflicts.
// Each warp owns 2KB: A tile and C tile are 1KB each. The row count // Each warp owns 2KB: A tile and C tile are 1KB each. The row count
// scales with the physical fragment width (16B for 4 lanes, 32B for 8). // scales with the physical fragment width (16B for 4 lanes, 32B for 8).
val tmemBytesPerWarp = 2048 val tmemBytesPerWarp = 2048
val tmemDepth = outer.numWarps * (tmemBytesPerWarp / outer.tcSmemSize) val tmemDepth = outer.numWarps * (tmemBytesPerWarp / outer.tcSmemSize)
val tmem = Module(new radiance.memory.TwoReadOneWriteSyncMem( val tmemBanks = 4
tmemDepth, UInt((outer.tcSmemSize * 8).W))) val tmemBankBits = log2Ceil(tmemBanks)
val tmemBankDepth = tmemDepth / tmemBanks
require(isPow2(tmemBanks))
require(tmemDepth % tmemBanks == 0)
val tmem = Seq.fill(tmemBanks) {
Module(new radiance.memory.TwoReadOneWriteSyncMem(
tmemBankDepth, UInt((outer.tcSmemSize * 8).W)))
}
val aReadArb = Module(new RRArbiter(UInt(tmemAddrBits.W), nTC)) class TmemReadReq extends Bundle {
val cReadArb = Module(new RRArbiter(UInt(tmemAddrBits.W), nTC)) val addr = UInt(tmemAddrBits.W)
val src = UInt(2.W)
val tc = UInt(log2Ceil(nTC max 2).W)
}
class TmemWriteReq extends Bundle { class TmemWriteReq extends Bundle {
val addr = UInt(tmemAddrBits.W) val addr = UInt(tmemAddrBits.W)
val data = UInt(tmemDataBits.W) val data = UInt(tmemDataBits.W)
val mask = UInt(tmemMaskBits.W) val mask = UInt(tmemMaskBits.W)
val src = UInt(1.W)
val tc = UInt(log2Ceil(nTC max 2).W)
} }
val cWriteArb = Module(new RRArbiter(new TmemWriteReq, nTC))
def bank(addr: UInt): UInt = addr(tmemBankBits - 1, 0)
def row(addr: UInt): UInt = addr(tmemAddrBits - 1, tmemBankBits)
val aReady = Wire(Vec(nTC, Bool()))
val cReady = Wire(Vec(nTC, Bool()))
val wReady = Wire(Vec(nTC, Bool()))
val scReadReady = Wire(Bool())
val scWriteReady = Wire(Bool())
aReady.foreach(_ := false.B)
cReady.foreach(_ := false.B)
wReady.foreach(_ := false.B)
scReadReady := false.B
scWriteReady := false.B
val read0Grant = Wire(Vec(tmemBanks, new TmemReadReq))
val read1Grant = Wire(Vec(tmemBanks, new TmemReadReq))
val read0Valid = Wire(Vec(tmemBanks, Bool()))
val read1Valid = Wire(Vec(tmemBanks, Bool()))
val writeGrant = Wire(Vec(tmemBanks, new TmemWriteReq))
val writeValid = Wire(Vec(tmemBanks, Bool()))
read0Grant.foreach(_ := 0.U.asTypeOf(new TmemReadReq))
read1Grant.foreach(_ := 0.U.asTypeOf(new TmemReadReq))
read0Valid.foreach(_ := false.B)
read1Valid.foreach(_ := false.B)
writeGrant.foreach(_ := 0.U.asTypeOf(new TmemWriteReq))
writeValid.foreach(_ := false.B)
(0 until tmemBanks).foreach { b =>
val requests = (0 until nTC).flatMap { tc =>
val aAddr = slice(core.io.tc_tmem_A_raddr, tmemAddrBits, tc)
val cAddr = slice(core.io.tc_tmem_C_raddr, tmemAddrBits, tc)
Seq(
(core.io.tc_tmem_A_ren(tc).asBool && bank(aAddr) === b.U, aAddr, 0.U(2.W), tc.U),
(core.io.tc_tmem_C_ren(tc).asBool && bank(cAddr) === b.U, cAddr, 1.U(2.W), tc.U)
)
} ++ Seq(
(core.io.sc_tmem_ren.asBool && bank(core.io.sc_tmem_raddr) === b.U,
core.io.sc_tmem_raddr, 2.U(2.W), 0.U)
)
var used0 = false.B
var used1 = false.B
requests.foreach { case (valid, addr, src, tc) =>
val grant0 = valid && !used0
val grant1 = valid && used0 && !used1
when(grant0) {
read0Grant(b).addr := addr
read0Grant(b).src := src
read0Grant(b).tc := tc
}
when(grant1) {
read1Grant(b).addr := addr
read1Grant(b).src := src
read1Grant(b).tc := tc
}
used0 = used0 || grant0
used1 = used1 || grant1
when(grant0 || grant1) {
when(src === 0.U) { aReady(tc) := true.B }
when(src === 1.U) { cReady(tc) := true.B }
when(src === 2.U) { scReadReady := true.B }
}
}
read0Valid(b) := used0
read1Valid(b) := used1
var writeUsed = false.B
(0 until nTC).foreach { tc => (0 until nTC).foreach { tc =>
aReadArb.io.in(tc).valid := core.io.tc_tmem_A_ren(tc) val addr = slice(core.io.tc_tmem_C_waddr, tmemAddrBits, tc)
aReadArb.io.in(tc).bits := slice(core.io.tc_tmem_A_raddr, tmemAddrBits, tc) val valid = core.io.tc_tmem_C_wen(tc).asBool && bank(addr) === b.U
cReadArb.io.in(tc).valid := core.io.tc_tmem_C_ren(tc) val grant = valid && !writeUsed
cReadArb.io.in(tc).bits := slice(core.io.tc_tmem_C_raddr, tmemAddrBits, tc) when(grant) {
cWriteArb.io.in(tc).valid := core.io.tc_tmem_C_wen(tc) writeValid(b) := true.B
cWriteArb.io.in(tc).bits.addr := slice(core.io.tc_tmem_C_waddr, tmemAddrBits, tc) writeGrant(b).addr := addr
cWriteArb.io.in(tc).bits.data := slice(core.io.tc_tmem_C_wdata, tmemDataBits, tc) writeGrant(b).data := slice(core.io.tc_tmem_C_wdata, tmemDataBits, tc)
cWriteArb.io.in(tc).bits.mask := slice(core.io.tc_tmem_C_mask, tmemMaskBits, tc) writeGrant(b).mask := slice(core.io.tc_tmem_C_mask, tmemMaskBits, tc)
writeGrant(b).src := 0.U
writeGrant(b).tc := tc.U
wReady(tc) := true.B
}
writeUsed = writeUsed || grant
} }
aReadArb.io.out.ready := true.B val scWValid = core.io.sc_tmem_wen.asBool && bank(core.io.sc_tmem_waddr) === b.U
cReadArb.io.out.ready := true.B val scWGrant = scWValid && !writeUsed
cWriteArb.io.out.ready := true.B when(scWGrant) {
writeValid(b) := true.B
writeGrant(b).addr := core.io.sc_tmem_waddr
writeGrant(b).data := core.io.sc_tmem_wdata
writeGrant(b).mask := core.io.sc_tmem_mask
writeGrant(b).src := 1.U
writeGrant(b).tc := 0.U
scWriteReady := true.B
}
tmem.io.ren0 := aReadArb.io.out.fire tmem(b).io.ren0 := read0Valid(b)
tmem.io.raddr0 := aReadArb.io.out.bits tmem(b).io.raddr0 := row(read0Grant(b).addr)
tmem.io.ren1 := cReadArb.io.out.fire tmem(b).io.ren1 := read1Valid(b)
tmem.io.raddr1 := cReadArb.io.out.bits tmem(b).io.raddr1 := row(read1Grant(b).addr)
tmem.io.wen := cWriteArb.io.out.fire tmem(b).io.wen := writeValid(b)
tmem.io.waddr := cWriteArb.io.out.bits.addr tmem(b).io.waddr := row(writeGrant(b).addr)
tmem.io.wdata := cWriteArb.io.out.bits.data tmem(b).io.wdata := writeGrant(b).data
tmem.io.mask := cWriteArb.io.out.bits.mask tmem(b).io.mask := writeGrant(b).mask
}
val aReadGrant = RegNext(Mux(aReadArb.io.out.fire, UIntToOH(aReadArb.io.chosen, nTC), 0.U(nTC.W))) val read0GrantReg = RegNext(read0Grant)
val cReadGrant = RegNext(Mux(cReadArb.io.out.fire, UIntToOH(cReadArb.io.chosen, nTC), 0.U(nTC.W))) val read1GrantReg = RegNext(read1Grant)
core.io.tc_tmem_A_rready := VecInit(aReadArb.io.in.map(_.fire)).asUInt val read0ValidReg = RegNext(read0Valid)
core.io.tc_tmem_C_rready := VecInit(cReadArb.io.in.map(_.fire)).asUInt val read1ValidReg = RegNext(read1Valid)
core.io.tc_tmem_C_wready := VecInit(cWriteArb.io.in.map(_.fire)).asUInt core.io.tc_tmem_A_rready := aReady.asUInt
core.io.tc_tmem_C_rready := cReady.asUInt
core.io.tc_tmem_C_wready := wReady.asUInt
core.io.sc_tmem_rready := scReadReady.asUInt
core.io.sc_tmem_wready := scWriteReady.asUInt
core.io.tc_tmem_A_rdata := VecInit((0 until nTC).map { tc => core.io.tc_tmem_A_rdata := VecInit((0 until nTC).map { tc =>
Mux(aReadGrant(tc), tmem.io.rdata0, 0.U(tmemDataBits.W)) VecInit((0 until tmemBanks).map { b =>
Mux(read0ValidReg(b) && read0GrantReg(b).src === 0.U && read0GrantReg(b).tc === tc.U, tmem(b).io.rdata0,
Mux(read1ValidReg(b) && read1GrantReg(b).src === 0.U && read1GrantReg(b).tc === tc.U, tmem(b).io.rdata1, 0.U(tmemDataBits.W)))
}).reduce(_ | _)
}).asUInt }).asUInt
core.io.tc_tmem_C_rdata := VecInit((0 until nTC).map { tc => core.io.tc_tmem_C_rdata := VecInit((0 until nTC).map { tc =>
Mux(cReadGrant(tc), tmem.io.rdata1, 0.U(tmemDataBits.W)) VecInit((0 until tmemBanks).map { b =>
Mux(read0ValidReg(b) && read0GrantReg(b).src === 1.U && read0GrantReg(b).tc === tc.U, tmem(b).io.rdata0,
Mux(read1ValidReg(b) && read1GrantReg(b).src === 1.U && read1GrantReg(b).tc === tc.U, tmem(b).io.rdata1, 0.U(tmemDataBits.W)))
}).reduce(_ | _)
}).asUInt }).asUInt
core.io.sc_tmem_rdata := VecInit((0 until tmemBanks).map { b =>
Mux(read0ValidReg(b) && read0GrantReg(b).src === 2.U, tmem(b).io.rdata0,
Mux(read1ValidReg(b) && read1GrantReg(b).src === 2.U, tmem(b).io.rdata1, 0.U(tmemDataBits.W)))
}).reduce(_ | _)
// port 2: SMEM B, one TL client per tensor core. RadianceSharedMem arbitrates them. // port 2: SMEM B, one TL client per tensor core. RadianceSharedMem arbitrates them.
(0 until nTC).foreach { tc => (0 until nTC).foreach { tc =>
@@ -1025,6 +1135,9 @@ class RadianceTileModuleImp(outer: RadianceTile)
core.io.tc_tmem_C_rready := DontCare core.io.tc_tmem_C_rready := DontCare
core.io.tc_tmem_C_rdata := DontCare core.io.tc_tmem_C_rdata := DontCare
core.io.tc_tmem_C_wready := DontCare core.io.tc_tmem_C_wready := DontCare
core.io.sc_tmem_rready := DontCare
core.io.sc_tmem_rdata := DontCare
core.io.sc_tmem_wready := DontCare
} }
} }

View File

@@ -120,6 +120,15 @@ class VortexBundle(tile: RadianceTile)(implicit p: Parameters) extends CoreBundl
val tc_tmem_C_waddr = Output(UInt((numTensorCores * 9).W)) val tc_tmem_C_waddr = Output(UInt((numTensorCores * 9).W))
val tc_tmem_C_wdata = Output(UInt((numTensorCores * numLanes * 32).W)) val tc_tmem_C_wdata = Output(UInt((numTensorCores * numLanes * 32).W))
val tc_tmem_C_mask = Output(UInt((numTensorCores * numLanes * 4).W)) val tc_tmem_C_mask = Output(UInt((numTensorCores * numLanes * 4).W))
val sc_tmem_ren = Output(UInt(1.W))
val sc_tmem_rready = Input(UInt(1.W))
val sc_tmem_raddr = Output(UInt(9.W))
val sc_tmem_rdata = Input(UInt((numLanes * 32).W))
val sc_tmem_wen = Output(UInt(1.W))
val sc_tmem_wready = Input(UInt(1.W))
val sc_tmem_waddr = Output(UInt(9.W))
val sc_tmem_wdata = Output(UInt((numLanes * 32).W))
val sc_tmem_mask = Output(UInt((numLanes * 4).W))
// FIXME: hardcoded // FIXME: hardcoded
val barrierIdBits = tile.barrierMasterNode.out(0)._2.barrierIdBits val barrierIdBits = tile.barrierMasterNode.out(0)._2.barrierIdBits