2 Commits

Author SHA1 Message Date
Zhongdi LUO
007350fd5a feat: include vortex fexp RTL 2026-07-02 07:25:32 +00:00
Zhongdi LUO
47d6585896 Wire scalar TMEM through Radiance tile 2026-06-24 06:25:10 +00:00
3 changed files with 159 additions and 36 deletions

View File

@@ -851,6 +851,9 @@ class RadianceTileModuleImp(outer: RadianceTile)
core.io.tc_tmem_C_rready := DontCare
core.io.tc_tmem_C_rdata := DontCare
core.io.tc_tmem_C_wready := DontCare
core.io.sc_tmem_rready := DontCare
core.io.sc_tmem_rdata := DontCare
core.io.sc_tmem_wready := DontCare
}
def connectTensorBlackwell = {
@@ -885,59 +888,166 @@ class RadianceTileModuleImp(outer: RadianceTile)
tcDData.foreach(_ := 0.U)
tcDTag.foreach(_ := 0.U)
// TMEM matrix: one shared 2R1W SRAM. read0 is operand A, read1 is C.
// TMEM matrix: four banked 2R1W SRAMs. Tensor A/C reads and scalar
// reads can proceed together when bank placement avoids conflicts.
// Each warp owns 2KB: A tile and C tile are 1KB each. The row count
// scales with the physical fragment width (16B for 4 lanes, 32B for 8).
val tmemBytesPerWarp = 2048
val tmemDepth = outer.numWarps * (tmemBytesPerWarp / outer.tcSmemSize)
val tmem = Module(new radiance.memory.TwoReadOneWriteSyncMem(
tmemDepth, UInt((outer.tcSmemSize * 8).W)))
val tmemBanks = 4
val tmemBankBits = log2Ceil(tmemBanks)
val tmemBankDepth = tmemDepth / tmemBanks
require(isPow2(tmemBanks))
require(tmemDepth % tmemBanks == 0)
val tmem = Seq.fill(tmemBanks) {
Module(new radiance.memory.TwoReadOneWriteSyncMem(
tmemBankDepth, UInt((outer.tcSmemSize * 8).W)))
}
val aReadArb = Module(new RRArbiter(UInt(tmemAddrBits.W), nTC))
val cReadArb = Module(new RRArbiter(UInt(tmemAddrBits.W), nTC))
class TmemReadReq extends Bundle {
val addr = UInt(tmemAddrBits.W)
val src = UInt(2.W)
val tc = UInt(log2Ceil(nTC max 2).W)
}
class TmemWriteReq extends Bundle {
val addr = UInt(tmemAddrBits.W)
val data = UInt(tmemDataBits.W)
val mask = UInt(tmemMaskBits.W)
}
val cWriteArb = Module(new RRArbiter(new TmemWriteReq, nTC))
(0 until nTC).foreach { tc =>
aReadArb.io.in(tc).valid := core.io.tc_tmem_A_ren(tc)
aReadArb.io.in(tc).bits := slice(core.io.tc_tmem_A_raddr, tmemAddrBits, tc)
cReadArb.io.in(tc).valid := core.io.tc_tmem_C_ren(tc)
cReadArb.io.in(tc).bits := slice(core.io.tc_tmem_C_raddr, tmemAddrBits, tc)
cWriteArb.io.in(tc).valid := core.io.tc_tmem_C_wen(tc)
cWriteArb.io.in(tc).bits.addr := slice(core.io.tc_tmem_C_waddr, tmemAddrBits, tc)
cWriteArb.io.in(tc).bits.data := slice(core.io.tc_tmem_C_wdata, tmemDataBits, tc)
cWriteArb.io.in(tc).bits.mask := slice(core.io.tc_tmem_C_mask, tmemMaskBits, tc)
val src = UInt(1.W)
val tc = UInt(log2Ceil(nTC max 2).W)
}
aReadArb.io.out.ready := true.B
cReadArb.io.out.ready := true.B
cWriteArb.io.out.ready := true.B
def bank(addr: UInt): UInt = addr(tmemBankBits - 1, 0)
def row(addr: UInt): UInt = addr(tmemAddrBits - 1, tmemBankBits)
tmem.io.ren0 := aReadArb.io.out.fire
tmem.io.raddr0 := aReadArb.io.out.bits
tmem.io.ren1 := cReadArb.io.out.fire
tmem.io.raddr1 := cReadArb.io.out.bits
tmem.io.wen := cWriteArb.io.out.fire
tmem.io.waddr := cWriteArb.io.out.bits.addr
tmem.io.wdata := cWriteArb.io.out.bits.data
tmem.io.mask := cWriteArb.io.out.bits.mask
val aReady = Wire(Vec(nTC, Bool()))
val cReady = Wire(Vec(nTC, Bool()))
val wReady = Wire(Vec(nTC, Bool()))
val scReadReady = Wire(Bool())
val scWriteReady = Wire(Bool())
aReady.foreach(_ := false.B)
cReady.foreach(_ := false.B)
wReady.foreach(_ := false.B)
scReadReady := false.B
scWriteReady := false.B
val aReadGrant = RegNext(Mux(aReadArb.io.out.fire, UIntToOH(aReadArb.io.chosen, nTC), 0.U(nTC.W)))
val cReadGrant = RegNext(Mux(cReadArb.io.out.fire, UIntToOH(cReadArb.io.chosen, nTC), 0.U(nTC.W)))
core.io.tc_tmem_A_rready := VecInit(aReadArb.io.in.map(_.fire)).asUInt
core.io.tc_tmem_C_rready := VecInit(cReadArb.io.in.map(_.fire)).asUInt
core.io.tc_tmem_C_wready := VecInit(cWriteArb.io.in.map(_.fire)).asUInt
val read0Grant = Wire(Vec(tmemBanks, new TmemReadReq))
val read1Grant = Wire(Vec(tmemBanks, new TmemReadReq))
val read0Valid = Wire(Vec(tmemBanks, Bool()))
val read1Valid = Wire(Vec(tmemBanks, Bool()))
val writeGrant = Wire(Vec(tmemBanks, new TmemWriteReq))
val writeValid = Wire(Vec(tmemBanks, Bool()))
read0Grant.foreach(_ := 0.U.asTypeOf(new TmemReadReq))
read1Grant.foreach(_ := 0.U.asTypeOf(new TmemReadReq))
read0Valid.foreach(_ := false.B)
read1Valid.foreach(_ := false.B)
writeGrant.foreach(_ := 0.U.asTypeOf(new TmemWriteReq))
writeValid.foreach(_ := false.B)
(0 until tmemBanks).foreach { b =>
val requests = (0 until nTC).flatMap { tc =>
val aAddr = slice(core.io.tc_tmem_A_raddr, tmemAddrBits, tc)
val cAddr = slice(core.io.tc_tmem_C_raddr, tmemAddrBits, tc)
Seq(
(core.io.tc_tmem_A_ren(tc).asBool && bank(aAddr) === b.U, aAddr, 0.U(2.W), tc.U),
(core.io.tc_tmem_C_ren(tc).asBool && bank(cAddr) === b.U, cAddr, 1.U(2.W), tc.U)
)
} ++ Seq(
(core.io.sc_tmem_ren.asBool && bank(core.io.sc_tmem_raddr) === b.U,
core.io.sc_tmem_raddr, 2.U(2.W), 0.U)
)
var used0 = false.B
var used1 = false.B
requests.foreach { case (valid, addr, src, tc) =>
val grant0 = valid && !used0
val grant1 = valid && used0 && !used1
when(grant0) {
read0Grant(b).addr := addr
read0Grant(b).src := src
read0Grant(b).tc := tc
}
when(grant1) {
read1Grant(b).addr := addr
read1Grant(b).src := src
read1Grant(b).tc := tc
}
used0 = used0 || grant0
used1 = used1 || grant1
when(grant0 || grant1) {
when(src === 0.U) { aReady(tc) := true.B }
when(src === 1.U) { cReady(tc) := true.B }
when(src === 2.U) { scReadReady := true.B }
}
}
read0Valid(b) := used0
read1Valid(b) := used1
var writeUsed = false.B
(0 until nTC).foreach { tc =>
val addr = slice(core.io.tc_tmem_C_waddr, tmemAddrBits, tc)
val valid = core.io.tc_tmem_C_wen(tc).asBool && bank(addr) === b.U
val grant = valid && !writeUsed
when(grant) {
writeValid(b) := true.B
writeGrant(b).addr := addr
writeGrant(b).data := slice(core.io.tc_tmem_C_wdata, tmemDataBits, tc)
writeGrant(b).mask := slice(core.io.tc_tmem_C_mask, tmemMaskBits, tc)
writeGrant(b).src := 0.U
writeGrant(b).tc := tc.U
wReady(tc) := true.B
}
writeUsed = writeUsed || grant
}
val scWValid = core.io.sc_tmem_wen.asBool && bank(core.io.sc_tmem_waddr) === b.U
val scWGrant = scWValid && !writeUsed
when(scWGrant) {
writeValid(b) := true.B
writeGrant(b).addr := core.io.sc_tmem_waddr
writeGrant(b).data := core.io.sc_tmem_wdata
writeGrant(b).mask := core.io.sc_tmem_mask
writeGrant(b).src := 1.U
writeGrant(b).tc := 0.U
scWriteReady := true.B
}
tmem(b).io.ren0 := read0Valid(b)
tmem(b).io.raddr0 := row(read0Grant(b).addr)
tmem(b).io.ren1 := read1Valid(b)
tmem(b).io.raddr1 := row(read1Grant(b).addr)
tmem(b).io.wen := writeValid(b)
tmem(b).io.waddr := row(writeGrant(b).addr)
tmem(b).io.wdata := writeGrant(b).data
tmem(b).io.mask := writeGrant(b).mask
}
val read0GrantReg = RegNext(read0Grant)
val read1GrantReg = RegNext(read1Grant)
val read0ValidReg = RegNext(read0Valid)
val read1ValidReg = RegNext(read1Valid)
core.io.tc_tmem_A_rready := aReady.asUInt
core.io.tc_tmem_C_rready := cReady.asUInt
core.io.tc_tmem_C_wready := wReady.asUInt
core.io.sc_tmem_rready := scReadReady.asUInt
core.io.sc_tmem_wready := scWriteReady.asUInt
core.io.tc_tmem_A_rdata := VecInit((0 until nTC).map { tc =>
Mux(aReadGrant(tc), tmem.io.rdata0, 0.U(tmemDataBits.W))
VecInit((0 until tmemBanks).map { b =>
Mux(read0ValidReg(b) && read0GrantReg(b).src === 0.U && read0GrantReg(b).tc === tc.U, tmem(b).io.rdata0,
Mux(read1ValidReg(b) && read1GrantReg(b).src === 0.U && read1GrantReg(b).tc === tc.U, tmem(b).io.rdata1, 0.U(tmemDataBits.W)))
}).reduce(_ | _)
}).asUInt
core.io.tc_tmem_C_rdata := VecInit((0 until nTC).map { tc =>
Mux(cReadGrant(tc), tmem.io.rdata1, 0.U(tmemDataBits.W))
VecInit((0 until tmemBanks).map { b =>
Mux(read0ValidReg(b) && read0GrantReg(b).src === 1.U && read0GrantReg(b).tc === tc.U, tmem(b).io.rdata0,
Mux(read1ValidReg(b) && read1GrantReg(b).src === 1.U && read1GrantReg(b).tc === tc.U, tmem(b).io.rdata1, 0.U(tmemDataBits.W)))
}).reduce(_ | _)
}).asUInt
core.io.sc_tmem_rdata := VecInit((0 until tmemBanks).map { b =>
Mux(read0ValidReg(b) && read0GrantReg(b).src === 2.U, tmem(b).io.rdata0,
Mux(read1ValidReg(b) && read1GrantReg(b).src === 2.U, tmem(b).io.rdata1, 0.U(tmemDataBits.W)))
}).reduce(_ | _)
// port 2: SMEM B, one TL client per tensor core. RadianceSharedMem arbitrates them.
(0 until nTC).foreach { tc =>
@@ -1025,6 +1135,9 @@ class RadianceTileModuleImp(outer: RadianceTile)
core.io.tc_tmem_C_rready := DontCare
core.io.tc_tmem_C_rdata := DontCare
core.io.tc_tmem_C_wready := DontCare
core.io.sc_tmem_rready := DontCare
core.io.sc_tmem_rdata := DontCare
core.io.sc_tmem_wready := DontCare
}
}

View File

@@ -120,6 +120,15 @@ class VortexBundle(tile: RadianceTile)(implicit p: Parameters) extends CoreBundl
val tc_tmem_C_waddr = Output(UInt((numTensorCores * 9).W))
val tc_tmem_C_wdata = Output(UInt((numTensorCores * numLanes * 32).W))
val tc_tmem_C_mask = Output(UInt((numTensorCores * numLanes * 4).W))
val sc_tmem_ren = Output(UInt(1.W))
val sc_tmem_rready = Input(UInt(1.W))
val sc_tmem_raddr = Output(UInt(9.W))
val sc_tmem_rdata = Input(UInt((numLanes * 32).W))
val sc_tmem_wen = Output(UInt(1.W))
val sc_tmem_wready = Input(UInt(1.W))
val sc_tmem_waddr = Output(UInt(9.W))
val sc_tmem_wdata = Output(UInt((numLanes * 32).W))
val sc_tmem_mask = Output(UInt((numLanes * 4).W))
// FIXME: hardcoded
val barrierIdBits = tile.barrierMasterNode.out(0)._2.barrierIdBits
@@ -351,6 +360,7 @@ class Vortex(tile: RadianceTile)(implicit p: Parameters)
addResource("/vsrc/vortex/hw/rtl/fpu/VX_fpu_div.sv")
addResource("/vsrc/vortex/hw/rtl/fpu/VX_fpu_dpi.sv")
addResource("/vsrc/vortex/hw/rtl/fpu/VX_fpu_dsp.sv")
addResource("/vsrc/vortex/hw/rtl/fpu/VX_fpu_exp.sv")
addResource("/vsrc/vortex/hw/rtl/fpu/VX_fpu_fma.sv")
addResource("/vsrc/vortex/hw/rtl/fpu/VX_fpu_ncomp.sv")
addResource("/vsrc/vortex/hw/rtl/fpu/VX_fpu_rounding.sv")