Wire scalar TMEM through Radiance tile
This commit is contained in:
Submodule src/main/resources/vsrc/vortex updated: abee301b6e...97a1eff701
@@ -851,6 +851,9 @@ class RadianceTileModuleImp(outer: RadianceTile)
|
||||
core.io.tc_tmem_C_rready := DontCare
|
||||
core.io.tc_tmem_C_rdata := DontCare
|
||||
core.io.tc_tmem_C_wready := DontCare
|
||||
core.io.sc_tmem_rready := DontCare
|
||||
core.io.sc_tmem_rdata := DontCare
|
||||
core.io.sc_tmem_wready := DontCare
|
||||
}
|
||||
|
||||
def connectTensorBlackwell = {
|
||||
@@ -885,59 +888,166 @@ class RadianceTileModuleImp(outer: RadianceTile)
|
||||
tcDData.foreach(_ := 0.U)
|
||||
tcDTag.foreach(_ := 0.U)
|
||||
|
||||
// TMEM matrix: one shared 2R1W SRAM. read0 is operand A, read1 is C.
|
||||
// TMEM matrix: four banked 2R1W SRAMs. Tensor A/C reads and scalar
|
||||
// reads can proceed together when bank placement avoids conflicts.
|
||||
// Each warp owns 2KB: A tile and C tile are 1KB each. The row count
|
||||
// scales with the physical fragment width (16B for 4 lanes, 32B for 8).
|
||||
val tmemBytesPerWarp = 2048
|
||||
val tmemDepth = outer.numWarps * (tmemBytesPerWarp / outer.tcSmemSize)
|
||||
val tmem = Module(new radiance.memory.TwoReadOneWriteSyncMem(
|
||||
tmemDepth, UInt((outer.tcSmemSize * 8).W)))
|
||||
val tmemBanks = 4
|
||||
val tmemBankBits = log2Ceil(tmemBanks)
|
||||
val tmemBankDepth = tmemDepth / tmemBanks
|
||||
require(isPow2(tmemBanks))
|
||||
require(tmemDepth % tmemBanks == 0)
|
||||
val tmem = Seq.fill(tmemBanks) {
|
||||
Module(new radiance.memory.TwoReadOneWriteSyncMem(
|
||||
tmemBankDepth, UInt((outer.tcSmemSize * 8).W)))
|
||||
}
|
||||
|
||||
val aReadArb = Module(new RRArbiter(UInt(tmemAddrBits.W), nTC))
|
||||
val cReadArb = Module(new RRArbiter(UInt(tmemAddrBits.W), nTC))
|
||||
class TmemReadReq extends Bundle {
|
||||
val addr = UInt(tmemAddrBits.W)
|
||||
val src = UInt(2.W)
|
||||
val tc = UInt(log2Ceil(nTC max 2).W)
|
||||
}
|
||||
|
||||
class TmemWriteReq extends Bundle {
|
||||
val addr = UInt(tmemAddrBits.W)
|
||||
val data = UInt(tmemDataBits.W)
|
||||
val mask = UInt(tmemMaskBits.W)
|
||||
}
|
||||
val cWriteArb = Module(new RRArbiter(new TmemWriteReq, nTC))
|
||||
|
||||
(0 until nTC).foreach { tc =>
|
||||
aReadArb.io.in(tc).valid := core.io.tc_tmem_A_ren(tc)
|
||||
aReadArb.io.in(tc).bits := slice(core.io.tc_tmem_A_raddr, tmemAddrBits, tc)
|
||||
cReadArb.io.in(tc).valid := core.io.tc_tmem_C_ren(tc)
|
||||
cReadArb.io.in(tc).bits := slice(core.io.tc_tmem_C_raddr, tmemAddrBits, tc)
|
||||
cWriteArb.io.in(tc).valid := core.io.tc_tmem_C_wen(tc)
|
||||
cWriteArb.io.in(tc).bits.addr := slice(core.io.tc_tmem_C_waddr, tmemAddrBits, tc)
|
||||
cWriteArb.io.in(tc).bits.data := slice(core.io.tc_tmem_C_wdata, tmemDataBits, tc)
|
||||
cWriteArb.io.in(tc).bits.mask := slice(core.io.tc_tmem_C_mask, tmemMaskBits, tc)
|
||||
val src = UInt(1.W)
|
||||
val tc = UInt(log2Ceil(nTC max 2).W)
|
||||
}
|
||||
|
||||
aReadArb.io.out.ready := true.B
|
||||
cReadArb.io.out.ready := true.B
|
||||
cWriteArb.io.out.ready := true.B
|
||||
def bank(addr: UInt): UInt = addr(tmemBankBits - 1, 0)
|
||||
def row(addr: UInt): UInt = addr(tmemAddrBits - 1, tmemBankBits)
|
||||
|
||||
tmem.io.ren0 := aReadArb.io.out.fire
|
||||
tmem.io.raddr0 := aReadArb.io.out.bits
|
||||
tmem.io.ren1 := cReadArb.io.out.fire
|
||||
tmem.io.raddr1 := cReadArb.io.out.bits
|
||||
tmem.io.wen := cWriteArb.io.out.fire
|
||||
tmem.io.waddr := cWriteArb.io.out.bits.addr
|
||||
tmem.io.wdata := cWriteArb.io.out.bits.data
|
||||
tmem.io.mask := cWriteArb.io.out.bits.mask
|
||||
val aReady = Wire(Vec(nTC, Bool()))
|
||||
val cReady = Wire(Vec(nTC, Bool()))
|
||||
val wReady = Wire(Vec(nTC, Bool()))
|
||||
val scReadReady = Wire(Bool())
|
||||
val scWriteReady = Wire(Bool())
|
||||
aReady.foreach(_ := false.B)
|
||||
cReady.foreach(_ := false.B)
|
||||
wReady.foreach(_ := false.B)
|
||||
scReadReady := false.B
|
||||
scWriteReady := false.B
|
||||
|
||||
val aReadGrant = RegNext(Mux(aReadArb.io.out.fire, UIntToOH(aReadArb.io.chosen, nTC), 0.U(nTC.W)))
|
||||
val cReadGrant = RegNext(Mux(cReadArb.io.out.fire, UIntToOH(cReadArb.io.chosen, nTC), 0.U(nTC.W)))
|
||||
core.io.tc_tmem_A_rready := VecInit(aReadArb.io.in.map(_.fire)).asUInt
|
||||
core.io.tc_tmem_C_rready := VecInit(cReadArb.io.in.map(_.fire)).asUInt
|
||||
core.io.tc_tmem_C_wready := VecInit(cWriteArb.io.in.map(_.fire)).asUInt
|
||||
val read0Grant = Wire(Vec(tmemBanks, new TmemReadReq))
|
||||
val read1Grant = Wire(Vec(tmemBanks, new TmemReadReq))
|
||||
val read0Valid = Wire(Vec(tmemBanks, Bool()))
|
||||
val read1Valid = Wire(Vec(tmemBanks, Bool()))
|
||||
val writeGrant = Wire(Vec(tmemBanks, new TmemWriteReq))
|
||||
val writeValid = Wire(Vec(tmemBanks, Bool()))
|
||||
read0Grant.foreach(_ := 0.U.asTypeOf(new TmemReadReq))
|
||||
read1Grant.foreach(_ := 0.U.asTypeOf(new TmemReadReq))
|
||||
read0Valid.foreach(_ := false.B)
|
||||
read1Valid.foreach(_ := false.B)
|
||||
writeGrant.foreach(_ := 0.U.asTypeOf(new TmemWriteReq))
|
||||
writeValid.foreach(_ := false.B)
|
||||
|
||||
(0 until tmemBanks).foreach { b =>
|
||||
val requests = (0 until nTC).flatMap { tc =>
|
||||
val aAddr = slice(core.io.tc_tmem_A_raddr, tmemAddrBits, tc)
|
||||
val cAddr = slice(core.io.tc_tmem_C_raddr, tmemAddrBits, tc)
|
||||
Seq(
|
||||
(core.io.tc_tmem_A_ren(tc).asBool && bank(aAddr) === b.U, aAddr, 0.U(2.W), tc.U),
|
||||
(core.io.tc_tmem_C_ren(tc).asBool && bank(cAddr) === b.U, cAddr, 1.U(2.W), tc.U)
|
||||
)
|
||||
} ++ Seq(
|
||||
(core.io.sc_tmem_ren.asBool && bank(core.io.sc_tmem_raddr) === b.U,
|
||||
core.io.sc_tmem_raddr, 2.U(2.W), 0.U)
|
||||
)
|
||||
|
||||
var used0 = false.B
|
||||
var used1 = false.B
|
||||
requests.foreach { case (valid, addr, src, tc) =>
|
||||
val grant0 = valid && !used0
|
||||
val grant1 = valid && used0 && !used1
|
||||
when(grant0) {
|
||||
read0Grant(b).addr := addr
|
||||
read0Grant(b).src := src
|
||||
read0Grant(b).tc := tc
|
||||
}
|
||||
when(grant1) {
|
||||
read1Grant(b).addr := addr
|
||||
read1Grant(b).src := src
|
||||
read1Grant(b).tc := tc
|
||||
}
|
||||
used0 = used0 || grant0
|
||||
used1 = used1 || grant1
|
||||
when(grant0 || grant1) {
|
||||
when(src === 0.U) { aReady(tc) := true.B }
|
||||
when(src === 1.U) { cReady(tc) := true.B }
|
||||
when(src === 2.U) { scReadReady := true.B }
|
||||
}
|
||||
}
|
||||
read0Valid(b) := used0
|
||||
read1Valid(b) := used1
|
||||
|
||||
var writeUsed = false.B
|
||||
(0 until nTC).foreach { tc =>
|
||||
val addr = slice(core.io.tc_tmem_C_waddr, tmemAddrBits, tc)
|
||||
val valid = core.io.tc_tmem_C_wen(tc).asBool && bank(addr) === b.U
|
||||
val grant = valid && !writeUsed
|
||||
when(grant) {
|
||||
writeValid(b) := true.B
|
||||
writeGrant(b).addr := addr
|
||||
writeGrant(b).data := slice(core.io.tc_tmem_C_wdata, tmemDataBits, tc)
|
||||
writeGrant(b).mask := slice(core.io.tc_tmem_C_mask, tmemMaskBits, tc)
|
||||
writeGrant(b).src := 0.U
|
||||
writeGrant(b).tc := tc.U
|
||||
wReady(tc) := true.B
|
||||
}
|
||||
writeUsed = writeUsed || grant
|
||||
}
|
||||
|
||||
val scWValid = core.io.sc_tmem_wen.asBool && bank(core.io.sc_tmem_waddr) === b.U
|
||||
val scWGrant = scWValid && !writeUsed
|
||||
when(scWGrant) {
|
||||
writeValid(b) := true.B
|
||||
writeGrant(b).addr := core.io.sc_tmem_waddr
|
||||
writeGrant(b).data := core.io.sc_tmem_wdata
|
||||
writeGrant(b).mask := core.io.sc_tmem_mask
|
||||
writeGrant(b).src := 1.U
|
||||
writeGrant(b).tc := 0.U
|
||||
scWriteReady := true.B
|
||||
}
|
||||
|
||||
tmem(b).io.ren0 := read0Valid(b)
|
||||
tmem(b).io.raddr0 := row(read0Grant(b).addr)
|
||||
tmem(b).io.ren1 := read1Valid(b)
|
||||
tmem(b).io.raddr1 := row(read1Grant(b).addr)
|
||||
tmem(b).io.wen := writeValid(b)
|
||||
tmem(b).io.waddr := row(writeGrant(b).addr)
|
||||
tmem(b).io.wdata := writeGrant(b).data
|
||||
tmem(b).io.mask := writeGrant(b).mask
|
||||
}
|
||||
|
||||
val read0GrantReg = RegNext(read0Grant)
|
||||
val read1GrantReg = RegNext(read1Grant)
|
||||
val read0ValidReg = RegNext(read0Valid)
|
||||
val read1ValidReg = RegNext(read1Valid)
|
||||
core.io.tc_tmem_A_rready := aReady.asUInt
|
||||
core.io.tc_tmem_C_rready := cReady.asUInt
|
||||
core.io.tc_tmem_C_wready := wReady.asUInt
|
||||
core.io.sc_tmem_rready := scReadReady.asUInt
|
||||
core.io.sc_tmem_wready := scWriteReady.asUInt
|
||||
core.io.tc_tmem_A_rdata := VecInit((0 until nTC).map { tc =>
|
||||
Mux(aReadGrant(tc), tmem.io.rdata0, 0.U(tmemDataBits.W))
|
||||
VecInit((0 until tmemBanks).map { b =>
|
||||
Mux(read0ValidReg(b) && read0GrantReg(b).src === 0.U && read0GrantReg(b).tc === tc.U, tmem(b).io.rdata0,
|
||||
Mux(read1ValidReg(b) && read1GrantReg(b).src === 0.U && read1GrantReg(b).tc === tc.U, tmem(b).io.rdata1, 0.U(tmemDataBits.W)))
|
||||
}).reduce(_ | _)
|
||||
}).asUInt
|
||||
core.io.tc_tmem_C_rdata := VecInit((0 until nTC).map { tc =>
|
||||
Mux(cReadGrant(tc), tmem.io.rdata1, 0.U(tmemDataBits.W))
|
||||
VecInit((0 until tmemBanks).map { b =>
|
||||
Mux(read0ValidReg(b) && read0GrantReg(b).src === 1.U && read0GrantReg(b).tc === tc.U, tmem(b).io.rdata0,
|
||||
Mux(read1ValidReg(b) && read1GrantReg(b).src === 1.U && read1GrantReg(b).tc === tc.U, tmem(b).io.rdata1, 0.U(tmemDataBits.W)))
|
||||
}).reduce(_ | _)
|
||||
}).asUInt
|
||||
core.io.sc_tmem_rdata := VecInit((0 until tmemBanks).map { b =>
|
||||
Mux(read0ValidReg(b) && read0GrantReg(b).src === 2.U, tmem(b).io.rdata0,
|
||||
Mux(read1ValidReg(b) && read1GrantReg(b).src === 2.U, tmem(b).io.rdata1, 0.U(tmemDataBits.W)))
|
||||
}).reduce(_ | _)
|
||||
|
||||
// port 2: SMEM B, one TL client per tensor core. RadianceSharedMem arbitrates them.
|
||||
(0 until nTC).foreach { tc =>
|
||||
@@ -1025,6 +1135,9 @@ class RadianceTileModuleImp(outer: RadianceTile)
|
||||
core.io.tc_tmem_C_rready := DontCare
|
||||
core.io.tc_tmem_C_rdata := DontCare
|
||||
core.io.tc_tmem_C_wready := DontCare
|
||||
core.io.sc_tmem_rready := DontCare
|
||||
core.io.sc_tmem_rdata := DontCare
|
||||
core.io.sc_tmem_wready := DontCare
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -120,6 +120,15 @@ class VortexBundle(tile: RadianceTile)(implicit p: Parameters) extends CoreBundl
|
||||
val tc_tmem_C_waddr = Output(UInt((numTensorCores * 9).W))
|
||||
val tc_tmem_C_wdata = Output(UInt((numTensorCores * numLanes * 32).W))
|
||||
val tc_tmem_C_mask = Output(UInt((numTensorCores * numLanes * 4).W))
|
||||
val sc_tmem_ren = Output(UInt(1.W))
|
||||
val sc_tmem_rready = Input(UInt(1.W))
|
||||
val sc_tmem_raddr = Output(UInt(9.W))
|
||||
val sc_tmem_rdata = Input(UInt((numLanes * 32).W))
|
||||
val sc_tmem_wen = Output(UInt(1.W))
|
||||
val sc_tmem_wready = Input(UInt(1.W))
|
||||
val sc_tmem_waddr = Output(UInt(9.W))
|
||||
val sc_tmem_wdata = Output(UInt((numLanes * 32).W))
|
||||
val sc_tmem_mask = Output(UInt((numLanes * 4).W))
|
||||
|
||||
// FIXME: hardcoded
|
||||
val barrierIdBits = tile.barrierMasterNode.out(0)._2.barrierIdBits
|
||||
|
||||
Reference in New Issue
Block a user