Compare commits
5 Commits
wu-archite
...
wu-tmem-ba
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
007350fd5a | ||
|
|
47d6585896 | ||
|
|
f88085331e | ||
|
|
1e78574113 | ||
|
|
c6c30ec0dc |
Submodule src/main/resources/vsrc/vortex updated: c87fea5c48...9251ba0a24
@@ -14,7 +14,8 @@ class TensorCoreBlackwell(
|
|||||||
val numFPRegs: Int = 32
|
val numFPRegs: Int = 32
|
||||||
) extends Module {
|
) extends Module {
|
||||||
require(half, "Blackwell MMA currently supports FP16 inputs only")
|
require(half, "Blackwell MMA currently supports FP16 inputs only")
|
||||||
require(numLanes == 8, "Blackwell MMA currently assumes 8 lanes")
|
require(numLanes == 4 || numLanes == 8,
|
||||||
|
s"Blackwell MMA currently supports 4 or 8 lanes, got ${numLanes}")
|
||||||
|
|
||||||
val numWarpBits = log2Ceil(numWarps)
|
val numWarpBits = log2Ceil(numWarps)
|
||||||
val sourceWidth = log2Ceil(numSourceIds)
|
val sourceWidth = log2Ceil(numSourceIds)
|
||||||
@@ -26,11 +27,16 @@ class TensorCoreBlackwell(
|
|||||||
val fragOffsetBits = log2Ceil(memWidth / 8)
|
val fragOffsetBits = log2Ceil(memWidth / 8)
|
||||||
|
|
||||||
val numSets = 4
|
val numSets = 4
|
||||||
val numAFragsPerSet = 8
|
|
||||||
val numBGroups = 4
|
val numBGroups = 4
|
||||||
val numBFragsPerGroup = 2
|
val numSubsteps = 2
|
||||||
val numMGroups = 4
|
val mElemsPerFrag = if (numLanes == 4) 2 else 4
|
||||||
val numCFrags = 32
|
val numMGroups = 16 / mElemsPerFrag
|
||||||
|
val numAFragsPerMGroup = 2
|
||||||
|
val numAFragsPerSet = numMGroups * numAFragsPerMGroup
|
||||||
|
val numBFragsPerSubstep = if (numLanes == 4) 2 else 1
|
||||||
|
val numBFragsPerGroup = numSubsteps * numBFragsPerSubstep
|
||||||
|
val numBFragsPerSet = numBGroups * numBFragsPerGroup
|
||||||
|
val numCFrags = numBGroups * numMGroups * numSubsteps
|
||||||
|
|
||||||
object Ops {
|
object Ops {
|
||||||
val bwgmma :: bwgmmaWait :: tcgen05Cp :: tcgen05CpWait :: tcgen05Ld :: tcgen05St :: tcgen05Cb :: Nil = Enum(7)
|
val bwgmma :: bwgmmaWait :: tcgen05Cp :: tcgen05CpWait :: tcgen05Ld :: tcgen05St :: tcgen05Cb :: Nil = Enum(7)
|
||||||
@@ -136,10 +142,11 @@ class TensorCoreBlackwell(
|
|||||||
base + (fragIndex << fragOffsetBits).asUInt
|
base + (fragIndex << fragOffsetBits).asUInt
|
||||||
}
|
}
|
||||||
|
|
||||||
val aFragIndex = (setReg << 3) + aIndexReg
|
val aFragIndex = (setReg * numAFragsPerSet.U) + aIndexReg
|
||||||
val bFragIndex = (setReg << 3) + (bGroupReg << 1) + bIndexReg
|
val bFragIndex =
|
||||||
val stepIndex = Cat(bGroupReg, mGroupReg)
|
(setReg * numBFragsPerSet.U) + (bGroupReg * numBFragsPerGroup.U) + bIndexReg
|
||||||
val cFragIndex = (stepIndex << 1) + substepReg
|
val cFragIndex =
|
||||||
|
(((bGroupReg * numMGroups.U) + mGroupReg) * numSubsteps.U) + substepReg
|
||||||
val aReqAddress = byteAddress(addrAReg, aFragIndex)
|
val aReqAddress = byteAddress(addrAReg, aFragIndex)
|
||||||
val bReqAddress = byteAddress(addrBReg, bFragIndex)
|
val bReqAddress = byteAddress(addrBReg, bFragIndex)
|
||||||
val cReqAddress = byteAddress(addrCReg, cFragIndex)
|
val cReqAddress = byteAddress(addrCReg, cFragIndex)
|
||||||
@@ -181,7 +188,12 @@ class TensorCoreBlackwell(
|
|||||||
io.initiate.ready := state === State.idle && !wbValid
|
io.initiate.ready := state === State.idle && !wbValid
|
||||||
|
|
||||||
val operandA = Cat(aBuf((mGroupReg << 1) + 1.U), aBuf(mGroupReg << 1))
|
val operandA = Cat(aBuf((mGroupReg << 1) + 1.U), aBuf(mGroupReg << 1))
|
||||||
val operandB = bBuf(substepReg)
|
val operandB =
|
||||||
|
if (numLanes == 4) {
|
||||||
|
Cat(bBuf((substepReg << 1) + 1.U), bBuf(substepReg << 1))
|
||||||
|
} else {
|
||||||
|
bBuf(substepReg)
|
||||||
|
}
|
||||||
val cWords = cDataReg.asTypeOf(Vec(numLanes, UInt(laneWidth.W)))
|
val cWords = cDataReg.asTypeOf(Vec(numLanes, UInt(laneWidth.W)))
|
||||||
val dpuInValid = WireDefault(false.B)
|
val dpuInValid = WireDefault(false.B)
|
||||||
val dpu = Module(new TensorDotProductUnit(
|
val dpu = Module(new TensorDotProductUnit(
|
||||||
@@ -193,16 +205,22 @@ class TensorCoreBlackwell(
|
|||||||
x((idx + 1) * 16 - 1, idx * 16)
|
x((idx + 1) * 16 - 1, idx * 16)
|
||||||
}
|
}
|
||||||
|
|
||||||
val elemM = elemReg(1, 0)
|
val elemM = if (numLanes == 4) elemReg(0, 0) else elemReg(1, 0)
|
||||||
val elemN = elemReg(2)
|
val elemN = if (numLanes == 4) elemReg(1) else elemReg(2)
|
||||||
dpu.io.in.valid := dpuInValid
|
dpu.io.in.valid := dpuInValid
|
||||||
for (k <- 0 until 8) {
|
for (k <- 0 until 8) {
|
||||||
dpu.io.in.bits.a(k) := MuxLookup(elemM, halfWord(operandA, k))(Seq(
|
dpu.io.in.bits.a(k) := (
|
||||||
0.U -> halfWord(operandA, k),
|
if (numLanes == 4) {
|
||||||
1.U -> halfWord(operandA, 8 + k),
|
Mux(elemM.asBool, halfWord(operandA, 8 + k), halfWord(operandA, k))
|
||||||
2.U -> halfWord(operandA, 16 + k),
|
} else {
|
||||||
3.U -> halfWord(operandA, 24 + k)
|
MuxLookup(elemM, halfWord(operandA, k))(Seq(
|
||||||
))
|
0.U -> halfWord(operandA, k),
|
||||||
|
1.U -> halfWord(operandA, 8 + k),
|
||||||
|
2.U -> halfWord(operandA, 16 + k),
|
||||||
|
3.U -> halfWord(operandA, 24 + k)
|
||||||
|
))
|
||||||
|
}
|
||||||
|
)
|
||||||
dpu.io.in.bits.b(k) := Mux(elemN.asBool, halfWord(operandB, 8 + k), halfWord(operandB, k))
|
dpu.io.in.bits.b(k) := Mux(elemN.asBool, halfWord(operandB, 8 + k), halfWord(operandB, k))
|
||||||
}
|
}
|
||||||
dpu.io.in.bits.c := cWords(elemReg)
|
dpu.io.in.bits.c := cWords(elemReg)
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ import midas.targetutils.SynthesizePrintf
|
|||||||
import org.chipsalliance.cde.config._
|
import org.chipsalliance.cde.config._
|
||||||
import radiance.core._
|
import radiance.core._
|
||||||
import radiance.memory._
|
import radiance.memory._
|
||||||
import radiance.subsystem.{GPUMemParams, GPUMemory, RadianceSimArgs}
|
import radiance.subsystem.{GPUMemParams, GPUMemory, RadianceSharedMemKey, RadianceSimArgs}
|
||||||
|
|
||||||
/** For determining radiance core id. This may be different from
|
/** For determining radiance core id. This may be different from
|
||||||
* RadianceTileParams.tileId, when a cluster contains non-core tiles */
|
* RadianceTileParams.tileId, when a cluster contains non-core tiles */
|
||||||
@@ -288,14 +288,25 @@ class RadianceTile private (
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
val tcSmemSize = 32
|
val tcSmemSize = numLsuLanes * 4
|
||||||
|
val tcSmemLineSize = p(RadianceSharedMemKey)
|
||||||
|
.map(k => k.numWords * k.wordSize)
|
||||||
|
.getOrElse(tcSmemSize)
|
||||||
|
val tcSmemClientMaxSize =
|
||||||
|
if (radianceParams.core.tensorCoreBlackwell) math.max(tcSmemSize, tcSmemLineSize) else tcSmemSize
|
||||||
val numTensorWarps = radianceParams.core.numTensorWarps
|
val numTensorWarps = radianceParams.core.numTensorWarps
|
||||||
val numScalarWarps = numWarps - numTensorWarps
|
val numScalarWarps = numWarps - numTensorWarps
|
||||||
require(numTensorWarps > 0 && numTensorWarps < numWarps,
|
require(numTensorWarps > 0 && numTensorWarps < numWarps,
|
||||||
s"Wu requires 0 < numTensorWarps (${numTensorWarps}) < numWarps (${numWarps})")
|
s"Wu requires 0 < numTensorWarps (${numTensorWarps}) < numWarps (${numWarps})")
|
||||||
val numTensorCores = if (radianceParams.core.tensorCoreBlackwell) numTensorWarps else 1
|
val numTensorCores = if (radianceParams.core.tensorCoreBlackwell) numTensorWarps else 1
|
||||||
if (radianceParams.core.tensorCoreBlackwell) {
|
if (radianceParams.core.tensorCoreBlackwell) {
|
||||||
|
require(numCoreLanes == numLsuLanes,
|
||||||
|
s"Wu Blackwell binding requires matching core lanes (${numCoreLanes}) and memory lanes (${numLsuLanes})")
|
||||||
|
require(numLsuLanes == 4 || numLsuLanes == 8,
|
||||||
|
s"Wu Blackwell binding supports 4 or 8 lanes, got ${numLsuLanes}")
|
||||||
require(numTensorCores == numTensorWarps, "Wu Blackwell binding requires one Tensor Core per Tensor warp")
|
require(numTensorCores == numTensorWarps, "Wu Blackwell binding requires one Tensor Core per Tensor warp")
|
||||||
|
require(isPow2(tcSmemLineSize) && tcSmemLineSize >= tcSmemSize && (tcSmemLineSize % tcSmemSize) == 0,
|
||||||
|
s"Wu Blackwell SMEM line size (${tcSmemLineSize}) must be a power-of-two multiple of TC fragment size (${tcSmemSize})")
|
||||||
}
|
}
|
||||||
val tensorUsesAsyncMem = radianceParams.core.tensorCoreDecoupled || radianceParams.core.tensorCoreBlackwell
|
val tensorUsesAsyncMem = radianceParams.core.tensorCoreDecoupled || radianceParams.core.tensorCoreBlackwell
|
||||||
val tcSmemNodeCount = if (radianceParams.core.tensorCoreDecoupled) 2 else if (radianceParams.core.tensorCoreBlackwell) numTensorCores else 0
|
val tcSmemNodeCount = if (radianceParams.core.tensorCoreDecoupled) 2 else if (radianceParams.core.tensorCoreBlackwell) numTensorCores else 0
|
||||||
@@ -305,9 +316,9 @@ class RadianceTile private (
|
|||||||
name = s"rad_tc_${radianceParams.coreId}_$i",
|
name = s"rad_tc_${radianceParams.coreId}_$i",
|
||||||
sourceId = IdRange(0, 1 << smemSourceWidth),
|
sourceId = IdRange(0, 1 << smemSourceWidth),
|
||||||
supports = TLSlaveToMasterTransferSizes(
|
supports = TLSlaveToMasterTransferSizes(
|
||||||
probe = TransferSizes(1, tcSmemSize),
|
probe = TransferSizes(1, tcSmemClientMaxSize),
|
||||||
get = TransferSizes(1, tcSmemSize),
|
get = TransferSizes(1, tcSmemClientMaxSize),
|
||||||
putFull = TransferSizes(1, tcSmemSize),
|
putFull = TransferSizes(1, tcSmemClientMaxSize),
|
||||||
),
|
),
|
||||||
requestFifo = true
|
requestFifo = true
|
||||||
))
|
))
|
||||||
@@ -840,6 +851,9 @@ class RadianceTileModuleImp(outer: RadianceTile)
|
|||||||
core.io.tc_tmem_C_rready := DontCare
|
core.io.tc_tmem_C_rready := DontCare
|
||||||
core.io.tc_tmem_C_rdata := DontCare
|
core.io.tc_tmem_C_rdata := DontCare
|
||||||
core.io.tc_tmem_C_wready := DontCare
|
core.io.tc_tmem_C_wready := DontCare
|
||||||
|
core.io.sc_tmem_rready := DontCare
|
||||||
|
core.io.sc_tmem_rdata := DontCare
|
||||||
|
core.io.sc_tmem_wready := DontCare
|
||||||
}
|
}
|
||||||
|
|
||||||
def connectTensorBlackwell = {
|
def connectTensorBlackwell = {
|
||||||
@@ -850,100 +864,232 @@ class RadianceTileModuleImp(outer: RadianceTile)
|
|||||||
|
|
||||||
val nTC = outer.numTensorCores
|
val nTC = outer.numTensorCores
|
||||||
val tcPorts = 3
|
val tcPorts = 3
|
||||||
|
val tcCoreDataBits = 32 * 8
|
||||||
val tcDataBits = outer.tcSmemSize * 8
|
val tcDataBits = outer.tcSmemSize * 8
|
||||||
|
val tcSmemLineBits = outer.tcSmemLineSize * 8
|
||||||
val tmemAddrBits = 9
|
val tmemAddrBits = 9
|
||||||
val tmemDataBits = outer.numLsuLanes * 32
|
val tmemDataBits = tcDataBits
|
||||||
val tmemMaskBits = outer.numLsuLanes * 4
|
val tmemMaskBits = outer.tcSmemSize
|
||||||
|
val tcTlSize = log2Ceil(outer.tcSmemSize)
|
||||||
|
val tcSmemLineTlSize = log2Ceil(outer.tcSmemLineSize)
|
||||||
|
|
||||||
def slice(u: UInt, width: Int, idx: Int): UInt = u(width * (idx + 1) - 1, width * idx)
|
def slice(u: UInt, width: Int, idx: Int): UInt = u(width * (idx + 1) - 1, width * idx)
|
||||||
def port(tc: Int, p: Int): Int = tc * tcPorts + p
|
def port(tc: Int, p: Int): Int = tc * tcPorts + p
|
||||||
|
def padToCoreData(u: UInt): UInt = {
|
||||||
|
if (u.getWidth == tcCoreDataBits) u else Cat(0.U((tcCoreDataBits - u.getWidth).W), u)
|
||||||
|
}
|
||||||
|
|
||||||
val tcAReady = Wire(Vec(nTC * tcPorts, Bool()))
|
val tcAReady = Wire(Vec(nTC * tcPorts, Bool()))
|
||||||
val tcDValid = Wire(Vec(nTC * tcPorts, Bool()))
|
val tcDValid = Wire(Vec(nTC * tcPorts, Bool()))
|
||||||
val tcDData = Wire(Vec(nTC * tcPorts, UInt(tcDataBits.W)))
|
val tcDData = Wire(Vec(nTC * tcPorts, UInt(tcCoreDataBits.W)))
|
||||||
val tcDTag = Wire(Vec(nTC * tcPorts, UInt(outer.tensorTagWidth.W)))
|
val tcDTag = Wire(Vec(nTC * tcPorts, UInt(outer.tensorTagWidth.W)))
|
||||||
tcAReady.foreach(_ := false.B)
|
tcAReady.foreach(_ := false.B)
|
||||||
tcDValid.foreach(_ := false.B)
|
tcDValid.foreach(_ := false.B)
|
||||||
tcDData.foreach(_ := 0.U)
|
tcDData.foreach(_ := 0.U)
|
||||||
tcDTag.foreach(_ := 0.U)
|
tcDTag.foreach(_ := 0.U)
|
||||||
|
|
||||||
// TMEM matrix: one shared 2R1W SRAM. read0 is operand A, read1 is C.
|
// TMEM matrix: four banked 2R1W SRAMs. Tensor A/C reads and scalar
|
||||||
// Each warp needs 2 tiles (A + C), each tile = 32 frags × 32B = 1KB
|
// reads can proceed together when bank placement avoids conflicts.
|
||||||
val tmemDepth = outer.numWarps * outer.tcSmemSize * 2 // numWarps × 64 rows
|
// Each warp owns 2KB: A tile and C tile are 1KB each. The row count
|
||||||
val tmem = Module(new radiance.memory.TwoReadOneWriteSyncMem(
|
// scales with the physical fragment width (16B for 4 lanes, 32B for 8).
|
||||||
tmemDepth, UInt((outer.tcSmemSize * 8).W)))
|
val tmemBytesPerWarp = 2048
|
||||||
|
val tmemDepth = outer.numWarps * (tmemBytesPerWarp / outer.tcSmemSize)
|
||||||
|
val tmemBanks = 4
|
||||||
|
val tmemBankBits = log2Ceil(tmemBanks)
|
||||||
|
val tmemBankDepth = tmemDepth / tmemBanks
|
||||||
|
require(isPow2(tmemBanks))
|
||||||
|
require(tmemDepth % tmemBanks == 0)
|
||||||
|
val tmem = Seq.fill(tmemBanks) {
|
||||||
|
Module(new radiance.memory.TwoReadOneWriteSyncMem(
|
||||||
|
tmemBankDepth, UInt((outer.tcSmemSize * 8).W)))
|
||||||
|
}
|
||||||
|
|
||||||
val aReadArb = Module(new RRArbiter(UInt(tmemAddrBits.W), nTC))
|
class TmemReadReq extends Bundle {
|
||||||
val cReadArb = Module(new RRArbiter(UInt(tmemAddrBits.W), nTC))
|
val addr = UInt(tmemAddrBits.W)
|
||||||
|
val src = UInt(2.W)
|
||||||
|
val tc = UInt(log2Ceil(nTC max 2).W)
|
||||||
|
}
|
||||||
|
|
||||||
class TmemWriteReq extends Bundle {
|
class TmemWriteReq extends Bundle {
|
||||||
val addr = UInt(tmemAddrBits.W)
|
val addr = UInt(tmemAddrBits.W)
|
||||||
val data = UInt(tmemDataBits.W)
|
val data = UInt(tmemDataBits.W)
|
||||||
val mask = UInt(tmemMaskBits.W)
|
val mask = UInt(tmemMaskBits.W)
|
||||||
}
|
val src = UInt(1.W)
|
||||||
val cWriteArb = Module(new RRArbiter(new TmemWriteReq, nTC))
|
val tc = UInt(log2Ceil(nTC max 2).W)
|
||||||
|
|
||||||
(0 until nTC).foreach { tc =>
|
|
||||||
aReadArb.io.in(tc).valid := core.io.tc_tmem_A_ren(tc)
|
|
||||||
aReadArb.io.in(tc).bits := slice(core.io.tc_tmem_A_raddr, tmemAddrBits, tc)
|
|
||||||
cReadArb.io.in(tc).valid := core.io.tc_tmem_C_ren(tc)
|
|
||||||
cReadArb.io.in(tc).bits := slice(core.io.tc_tmem_C_raddr, tmemAddrBits, tc)
|
|
||||||
cWriteArb.io.in(tc).valid := core.io.tc_tmem_C_wen(tc)
|
|
||||||
cWriteArb.io.in(tc).bits.addr := slice(core.io.tc_tmem_C_waddr, tmemAddrBits, tc)
|
|
||||||
cWriteArb.io.in(tc).bits.data := slice(core.io.tc_tmem_C_wdata, tmemDataBits, tc)
|
|
||||||
cWriteArb.io.in(tc).bits.mask := slice(core.io.tc_tmem_C_mask, tmemMaskBits, tc)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
aReadArb.io.out.ready := true.B
|
def bank(addr: UInt): UInt = addr(tmemBankBits - 1, 0)
|
||||||
cReadArb.io.out.ready := true.B
|
def row(addr: UInt): UInt = addr(tmemAddrBits - 1, tmemBankBits)
|
||||||
cWriteArb.io.out.ready := true.B
|
|
||||||
|
|
||||||
tmem.io.ren0 := aReadArb.io.out.fire
|
val aReady = Wire(Vec(nTC, Bool()))
|
||||||
tmem.io.raddr0 := aReadArb.io.out.bits
|
val cReady = Wire(Vec(nTC, Bool()))
|
||||||
tmem.io.ren1 := cReadArb.io.out.fire
|
val wReady = Wire(Vec(nTC, Bool()))
|
||||||
tmem.io.raddr1 := cReadArb.io.out.bits
|
val scReadReady = Wire(Bool())
|
||||||
tmem.io.wen := cWriteArb.io.out.fire
|
val scWriteReady = Wire(Bool())
|
||||||
tmem.io.waddr := cWriteArb.io.out.bits.addr
|
aReady.foreach(_ := false.B)
|
||||||
tmem.io.wdata := cWriteArb.io.out.bits.data
|
cReady.foreach(_ := false.B)
|
||||||
tmem.io.mask := cWriteArb.io.out.bits.mask
|
wReady.foreach(_ := false.B)
|
||||||
|
scReadReady := false.B
|
||||||
|
scWriteReady := false.B
|
||||||
|
|
||||||
val aReadGrant = RegNext(Mux(aReadArb.io.out.fire, UIntToOH(aReadArb.io.chosen, nTC), 0.U(nTC.W)))
|
val read0Grant = Wire(Vec(tmemBanks, new TmemReadReq))
|
||||||
val cReadGrant = RegNext(Mux(cReadArb.io.out.fire, UIntToOH(cReadArb.io.chosen, nTC), 0.U(nTC.W)))
|
val read1Grant = Wire(Vec(tmemBanks, new TmemReadReq))
|
||||||
core.io.tc_tmem_A_rready := VecInit(aReadArb.io.in.map(_.fire)).asUInt
|
val read0Valid = Wire(Vec(tmemBanks, Bool()))
|
||||||
core.io.tc_tmem_C_rready := VecInit(cReadArb.io.in.map(_.fire)).asUInt
|
val read1Valid = Wire(Vec(tmemBanks, Bool()))
|
||||||
core.io.tc_tmem_C_wready := VecInit(cWriteArb.io.in.map(_.fire)).asUInt
|
val writeGrant = Wire(Vec(tmemBanks, new TmemWriteReq))
|
||||||
|
val writeValid = Wire(Vec(tmemBanks, Bool()))
|
||||||
|
read0Grant.foreach(_ := 0.U.asTypeOf(new TmemReadReq))
|
||||||
|
read1Grant.foreach(_ := 0.U.asTypeOf(new TmemReadReq))
|
||||||
|
read0Valid.foreach(_ := false.B)
|
||||||
|
read1Valid.foreach(_ := false.B)
|
||||||
|
writeGrant.foreach(_ := 0.U.asTypeOf(new TmemWriteReq))
|
||||||
|
writeValid.foreach(_ := false.B)
|
||||||
|
|
||||||
|
(0 until tmemBanks).foreach { b =>
|
||||||
|
val requests = (0 until nTC).flatMap { tc =>
|
||||||
|
val aAddr = slice(core.io.tc_tmem_A_raddr, tmemAddrBits, tc)
|
||||||
|
val cAddr = slice(core.io.tc_tmem_C_raddr, tmemAddrBits, tc)
|
||||||
|
Seq(
|
||||||
|
(core.io.tc_tmem_A_ren(tc).asBool && bank(aAddr) === b.U, aAddr, 0.U(2.W), tc.U),
|
||||||
|
(core.io.tc_tmem_C_ren(tc).asBool && bank(cAddr) === b.U, cAddr, 1.U(2.W), tc.U)
|
||||||
|
)
|
||||||
|
} ++ Seq(
|
||||||
|
(core.io.sc_tmem_ren.asBool && bank(core.io.sc_tmem_raddr) === b.U,
|
||||||
|
core.io.sc_tmem_raddr, 2.U(2.W), 0.U)
|
||||||
|
)
|
||||||
|
|
||||||
|
var used0 = false.B
|
||||||
|
var used1 = false.B
|
||||||
|
requests.foreach { case (valid, addr, src, tc) =>
|
||||||
|
val grant0 = valid && !used0
|
||||||
|
val grant1 = valid && used0 && !used1
|
||||||
|
when(grant0) {
|
||||||
|
read0Grant(b).addr := addr
|
||||||
|
read0Grant(b).src := src
|
||||||
|
read0Grant(b).tc := tc
|
||||||
|
}
|
||||||
|
when(grant1) {
|
||||||
|
read1Grant(b).addr := addr
|
||||||
|
read1Grant(b).src := src
|
||||||
|
read1Grant(b).tc := tc
|
||||||
|
}
|
||||||
|
used0 = used0 || grant0
|
||||||
|
used1 = used1 || grant1
|
||||||
|
when(grant0 || grant1) {
|
||||||
|
when(src === 0.U) { aReady(tc) := true.B }
|
||||||
|
when(src === 1.U) { cReady(tc) := true.B }
|
||||||
|
when(src === 2.U) { scReadReady := true.B }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
read0Valid(b) := used0
|
||||||
|
read1Valid(b) := used1
|
||||||
|
|
||||||
|
var writeUsed = false.B
|
||||||
|
(0 until nTC).foreach { tc =>
|
||||||
|
val addr = slice(core.io.tc_tmem_C_waddr, tmemAddrBits, tc)
|
||||||
|
val valid = core.io.tc_tmem_C_wen(tc).asBool && bank(addr) === b.U
|
||||||
|
val grant = valid && !writeUsed
|
||||||
|
when(grant) {
|
||||||
|
writeValid(b) := true.B
|
||||||
|
writeGrant(b).addr := addr
|
||||||
|
writeGrant(b).data := slice(core.io.tc_tmem_C_wdata, tmemDataBits, tc)
|
||||||
|
writeGrant(b).mask := slice(core.io.tc_tmem_C_mask, tmemMaskBits, tc)
|
||||||
|
writeGrant(b).src := 0.U
|
||||||
|
writeGrant(b).tc := tc.U
|
||||||
|
wReady(tc) := true.B
|
||||||
|
}
|
||||||
|
writeUsed = writeUsed || grant
|
||||||
|
}
|
||||||
|
|
||||||
|
val scWValid = core.io.sc_tmem_wen.asBool && bank(core.io.sc_tmem_waddr) === b.U
|
||||||
|
val scWGrant = scWValid && !writeUsed
|
||||||
|
when(scWGrant) {
|
||||||
|
writeValid(b) := true.B
|
||||||
|
writeGrant(b).addr := core.io.sc_tmem_waddr
|
||||||
|
writeGrant(b).data := core.io.sc_tmem_wdata
|
||||||
|
writeGrant(b).mask := core.io.sc_tmem_mask
|
||||||
|
writeGrant(b).src := 1.U
|
||||||
|
writeGrant(b).tc := 0.U
|
||||||
|
scWriteReady := true.B
|
||||||
|
}
|
||||||
|
|
||||||
|
tmem(b).io.ren0 := read0Valid(b)
|
||||||
|
tmem(b).io.raddr0 := row(read0Grant(b).addr)
|
||||||
|
tmem(b).io.ren1 := read1Valid(b)
|
||||||
|
tmem(b).io.raddr1 := row(read1Grant(b).addr)
|
||||||
|
tmem(b).io.wen := writeValid(b)
|
||||||
|
tmem(b).io.waddr := row(writeGrant(b).addr)
|
||||||
|
tmem(b).io.wdata := writeGrant(b).data
|
||||||
|
tmem(b).io.mask := writeGrant(b).mask
|
||||||
|
}
|
||||||
|
|
||||||
|
val read0GrantReg = RegNext(read0Grant)
|
||||||
|
val read1GrantReg = RegNext(read1Grant)
|
||||||
|
val read0ValidReg = RegNext(read0Valid)
|
||||||
|
val read1ValidReg = RegNext(read1Valid)
|
||||||
|
core.io.tc_tmem_A_rready := aReady.asUInt
|
||||||
|
core.io.tc_tmem_C_rready := cReady.asUInt
|
||||||
|
core.io.tc_tmem_C_wready := wReady.asUInt
|
||||||
|
core.io.sc_tmem_rready := scReadReady.asUInt
|
||||||
|
core.io.sc_tmem_wready := scWriteReady.asUInt
|
||||||
core.io.tc_tmem_A_rdata := VecInit((0 until nTC).map { tc =>
|
core.io.tc_tmem_A_rdata := VecInit((0 until nTC).map { tc =>
|
||||||
Mux(aReadGrant(tc), tmem.io.rdata0, 0.U(tmemDataBits.W))
|
VecInit((0 until tmemBanks).map { b =>
|
||||||
|
Mux(read0ValidReg(b) && read0GrantReg(b).src === 0.U && read0GrantReg(b).tc === tc.U, tmem(b).io.rdata0,
|
||||||
|
Mux(read1ValidReg(b) && read1GrantReg(b).src === 0.U && read1GrantReg(b).tc === tc.U, tmem(b).io.rdata1, 0.U(tmemDataBits.W)))
|
||||||
|
}).reduce(_ | _)
|
||||||
}).asUInt
|
}).asUInt
|
||||||
core.io.tc_tmem_C_rdata := VecInit((0 until nTC).map { tc =>
|
core.io.tc_tmem_C_rdata := VecInit((0 until nTC).map { tc =>
|
||||||
Mux(cReadGrant(tc), tmem.io.rdata1, 0.U(tmemDataBits.W))
|
VecInit((0 until tmemBanks).map { b =>
|
||||||
|
Mux(read0ValidReg(b) && read0GrantReg(b).src === 1.U && read0GrantReg(b).tc === tc.U, tmem(b).io.rdata0,
|
||||||
|
Mux(read1ValidReg(b) && read1GrantReg(b).src === 1.U && read1GrantReg(b).tc === tc.U, tmem(b).io.rdata1, 0.U(tmemDataBits.W)))
|
||||||
|
}).reduce(_ | _)
|
||||||
}).asUInt
|
}).asUInt
|
||||||
|
core.io.sc_tmem_rdata := VecInit((0 until tmemBanks).map { b =>
|
||||||
|
Mux(read0ValidReg(b) && read0GrantReg(b).src === 2.U, tmem(b).io.rdata0,
|
||||||
|
Mux(read1ValidReg(b) && read1GrantReg(b).src === 2.U, tmem(b).io.rdata1, 0.U(tmemDataBits.W)))
|
||||||
|
}).reduce(_ | _)
|
||||||
|
|
||||||
// port 2: SMEM B, one TL client per tensor core. RadianceSharedMem arbitrates them.
|
// port 2: SMEM B, one TL client per tensor core. RadianceSharedMem arbitrates them.
|
||||||
(0 until nTC).foreach { tc =>
|
(0 until nTC).foreach { tc =>
|
||||||
val p2 = port(tc, 2)
|
val p2 = port(tc, 2)
|
||||||
val client = outer.tcSmemNodes(tc).out.head
|
val client = outer.tcSmemNodes(tc).out.head
|
||||||
|
val rawAddress = slice(core.io.tc_a_bits_address, 32, p2)
|
||||||
|
val lineAddress = rawAddress & (~((outer.tcSmemLineSize - 1).U(32.W))).asUInt
|
||||||
val adapter = Module(new VortexTLAdapter(
|
val adapter = Module(new VortexTLAdapter(
|
||||||
outer.smemSourceWidth,
|
outer.smemSourceWidth,
|
||||||
new VortexBundleA(tagWidth = outer.tensorTagWidth, dataWidth = tcDataBits),
|
new VortexBundleA(tagWidth = outer.tensorTagWidth, dataWidth = tcSmemLineBits),
|
||||||
new VortexBundleD(tagWidth = outer.tensorTagWidth, dataWidth = tcDataBits),
|
new VortexBundleD(tagWidth = outer.tensorTagWidth, dataWidth = tcSmemLineBits),
|
||||||
client
|
client
|
||||||
))
|
))
|
||||||
adapter.io.inReq.bits <> DontCare
|
adapter.io.inReq.bits <> DontCare
|
||||||
adapter.io.inReq.valid := core.io.tc_a_valid(p2)
|
adapter.io.inReq.valid := core.io.tc_a_valid(p2)
|
||||||
adapter.io.inReq.bits.address := slice(core.io.tc_a_bits_address, 32, p2)
|
adapter.io.inReq.bits.address := lineAddress
|
||||||
adapter.io.inReq.bits.source := slice(core.io.tc_a_bits_tag, outer.tensorTagWidth, p2)
|
adapter.io.inReq.bits.source := slice(core.io.tc_a_bits_tag, outer.tensorTagWidth, p2)
|
||||||
adapter.io.inReq.bits.size := 5.U
|
adapter.io.inReq.bits.size := tcSmemLineTlSize.U
|
||||||
adapter.io.inReq.bits.opcode := Mux(core.io.tc_a_bits_write(p2).asBool, TLMessages.PutFullData, TLMessages.Get)
|
adapter.io.inReq.bits.opcode := Mux(core.io.tc_a_bits_write(p2).asBool, TLMessages.PutFullData, TLMessages.Get)
|
||||||
adapter.io.inReq.bits.mask := slice(core.io.tc_a_bits_mask, 32, p2)
|
adapter.io.inReq.bits.mask := Fill(outer.tcSmemLineSize, 1.U(1.W))
|
||||||
adapter.io.inReq.bits.data := slice(core.io.tc_a_bits_data, tcDataBits, p2)
|
adapter.io.inReq.bits.data := slice(core.io.tc_a_bits_data, tcCoreDataBits, p2)(tcSmemLineBits - 1, 0)
|
||||||
adapter.io.inResp.ready := core.io.tc_d_ready(p2)
|
adapter.io.inResp.ready := core.io.tc_d_ready(p2)
|
||||||
client._1.a <> adapter.io.outReq
|
client._1.a <> adapter.io.outReq
|
||||||
adapter.io.outResp <> client._1.d
|
adapter.io.outResp <> client._1.d
|
||||||
|
val lineData = adapter.io.inResp.bits.data
|
||||||
|
val fragmentData = if (outer.tcSmemLineSize == outer.tcSmemSize) {
|
||||||
|
lineData
|
||||||
|
} else {
|
||||||
|
val fragmentsPerLine = outer.tcSmemLineSize / outer.tcSmemSize
|
||||||
|
val fragmentIndex = RegInit(0.U(log2Ceil(fragmentsPerLine).W))
|
||||||
|
val requestFragmentIndex = ((rawAddress & (outer.tcSmemLineSize - 1).U) >>
|
||||||
|
log2Ceil(outer.tcSmemSize)).asUInt
|
||||||
|
val lineFragments = lineData.asTypeOf(Vec(fragmentsPerLine, UInt(tcDataBits.W)))
|
||||||
|
when(adapter.io.inReq.fire) {
|
||||||
|
fragmentIndex := requestFragmentIndex
|
||||||
|
}
|
||||||
|
lineFragments(fragmentIndex)
|
||||||
|
}
|
||||||
|
|
||||||
tcAReady(p2) := adapter.io.inReq.ready
|
tcAReady(p2) := adapter.io.inReq.ready
|
||||||
tcDValid(p2) := adapter.io.inResp.valid
|
tcDValid(p2) := adapter.io.inResp.valid
|
||||||
tcDData(p2) := adapter.io.inResp.bits.data
|
tcDData(p2) := padToCoreData(fragmentData)
|
||||||
tcDTag(p2) := adapter.io.inResp.bits.source
|
tcDTag(p2) := adapter.io.inResp.bits.source
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -961,17 +1107,17 @@ class RadianceTileModuleImp(outer: RadianceTile)
|
|||||||
gmemAdapter.io.inReq.valid := core.io.tc_a_valid(p0)
|
gmemAdapter.io.inReq.valid := core.io.tc_a_valid(p0)
|
||||||
gmemAdapter.io.inReq.bits.address := slice(core.io.tc_a_bits_address, 32, p0)
|
gmemAdapter.io.inReq.bits.address := slice(core.io.tc_a_bits_address, 32, p0)
|
||||||
gmemAdapter.io.inReq.bits.source := slice(core.io.tc_a_bits_tag, outer.tensorTagWidth, p0)
|
gmemAdapter.io.inReq.bits.source := slice(core.io.tc_a_bits_tag, outer.tensorTagWidth, p0)
|
||||||
gmemAdapter.io.inReq.bits.size := 5.U
|
gmemAdapter.io.inReq.bits.size := tcTlSize.U
|
||||||
gmemAdapter.io.inReq.bits.opcode := Mux(core.io.tc_a_bits_write(p0).asBool, TLMessages.PutFullData, TLMessages.Get)
|
gmemAdapter.io.inReq.bits.opcode := Mux(core.io.tc_a_bits_write(p0).asBool, TLMessages.PutFullData, TLMessages.Get)
|
||||||
gmemAdapter.io.inReq.bits.mask := slice(core.io.tc_a_bits_mask, 32, p0)
|
gmemAdapter.io.inReq.bits.mask := slice(core.io.tc_a_bits_mask, 32, p0)(outer.tcSmemSize - 1, 0)
|
||||||
gmemAdapter.io.inReq.bits.data := slice(core.io.tc_a_bits_data, tcDataBits, p0)
|
gmemAdapter.io.inReq.bits.data := slice(core.io.tc_a_bits_data, tcCoreDataBits, p0)(tcDataBits - 1, 0)
|
||||||
gmemAdapter.io.inResp.ready := core.io.tc_d_ready(p0)
|
gmemAdapter.io.inResp.ready := core.io.tc_d_ready(p0)
|
||||||
gmemClient._1.a <> gmemAdapter.io.outReq
|
gmemClient._1.a <> gmemAdapter.io.outReq
|
||||||
gmemAdapter.io.outResp <> gmemClient._1.d
|
gmemAdapter.io.outResp <> gmemClient._1.d
|
||||||
|
|
||||||
tcAReady(p0) := gmemAdapter.io.inReq.ready
|
tcAReady(p0) := gmemAdapter.io.inReq.ready
|
||||||
tcDValid(p0) := gmemAdapter.io.inResp.valid
|
tcDValid(p0) := gmemAdapter.io.inResp.valid
|
||||||
tcDData(p0) := gmemAdapter.io.inResp.bits.data
|
tcDData(p0) := padToCoreData(gmemAdapter.io.inResp.bits.data)
|
||||||
tcDTag(p0) := gmemAdapter.io.inResp.bits.source
|
tcDTag(p0) := gmemAdapter.io.inResp.bits.source
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -989,6 +1135,9 @@ class RadianceTileModuleImp(outer: RadianceTile)
|
|||||||
core.io.tc_tmem_C_rready := DontCare
|
core.io.tc_tmem_C_rready := DontCare
|
||||||
core.io.tc_tmem_C_rdata := DontCare
|
core.io.tc_tmem_C_rdata := DontCare
|
||||||
core.io.tc_tmem_C_wready := DontCare
|
core.io.tc_tmem_C_wready := DontCare
|
||||||
|
core.io.sc_tmem_rready := DontCare
|
||||||
|
core.io.sc_tmem_rdata := DontCare
|
||||||
|
core.io.sc_tmem_wready := DontCare
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1082,7 +1231,7 @@ class RadianceTileModuleImp(outer: RadianceTile)
|
|||||||
} else if (outer.radianceParams.core.tensorCoreBlackwell) {
|
} else if (outer.radianceParams.core.tensorCoreBlackwell) {
|
||||||
val tensorNumSourceIds = (1 << outer.tensorTagWidth)
|
val tensorNumSourceIds = (1 << outer.tensorTagWidth)
|
||||||
val tensor = Module(new radiance.core.TensorCoreBlackwell(
|
val tensor = Module(new radiance.core.TensorCoreBlackwell(
|
||||||
8, 8, half = true, tensorNumSourceIds))
|
outer.numWarps, outer.numLsuLanes, half = true, tensorNumSourceIds))
|
||||||
tensor.io.initiate.valid := false.B
|
tensor.io.initiate.valid := false.B
|
||||||
tensor.io.initiate.bits := DontCare
|
tensor.io.initiate.bits := DontCare
|
||||||
tensor.io.respA.valid := false.B
|
tensor.io.respA.valid := false.B
|
||||||
@@ -1161,18 +1310,6 @@ class VortexTLAdapter(
|
|||||||
val outResp = chiselTypeOf(outTL._1.d)
|
val outResp = chiselTypeOf(outTL._1.d)
|
||||||
})
|
})
|
||||||
val (bundle, edge) = outTL
|
val (bundle, edge) = outTL
|
||||||
val sourceGen = Module(
|
|
||||||
new SourceGenerator(
|
|
||||||
newSourceWidth,
|
|
||||||
Some(inReqT.source),
|
|
||||||
ignoreInUse = false
|
|
||||||
)
|
|
||||||
)
|
|
||||||
sourceGen.io.gen := io.outReq.fire // use up a source ID only when request is created
|
|
||||||
sourceGen.io.reclaim.valid := io.outResp.fire
|
|
||||||
sourceGen.io.reclaim.bits := io.outResp.bits.source
|
|
||||||
sourceGen.io.meta := io.inReq.bits.source
|
|
||||||
|
|
||||||
// io passthrough logic
|
// io passthrough logic
|
||||||
// TLBundleA <> VortexBundleA
|
// TLBundleA <> VortexBundleA
|
||||||
io.outReq.valid := io.inReq.valid
|
io.outReq.valid := io.inReq.valid
|
||||||
@@ -1181,29 +1318,70 @@ class VortexTLAdapter(
|
|||||||
io.outReq.bits.size := io.inReq.bits.size
|
io.outReq.bits.size := io.inReq.bits.size
|
||||||
io.outReq.bits.source := io.inReq.bits.source
|
io.outReq.bits.source := io.inReq.bits.source
|
||||||
io.outReq.bits.address := io.inReq.bits.address
|
io.outReq.bits.address := io.inReq.bits.address
|
||||||
// Get requires contiguous mask; only copy core's potentially-partial mask
|
val outMaskWidth = io.outReq.bits.mask.getWidth
|
||||||
// when writing
|
val inMaskWidth = io.inReq.bits.mask.getWidth
|
||||||
|
val outDataWidth = io.outReq.bits.data.getWidth
|
||||||
|
val inDataWidth = io.inReq.bits.data.getWidth
|
||||||
|
val byteOffset = io.inReq.bits.address(log2Ceil(outMaskWidth) - 1, 0)
|
||||||
|
val responseOffsetWidth = log2Ceil(outMaskWidth)
|
||||||
|
val responseSourceWidth = inReqT.source.getWidth
|
||||||
|
val sourceGen = Module(
|
||||||
|
new SourceGenerator(
|
||||||
|
newSourceWidth,
|
||||||
|
Some(UInt((responseSourceWidth + responseOffsetWidth).W)),
|
||||||
|
ignoreInUse = false
|
||||||
|
)
|
||||||
|
)
|
||||||
|
sourceGen.io.gen := io.outReq.fire // use up a source ID only when request is created
|
||||||
|
sourceGen.io.reclaim.valid := io.outResp.fire
|
||||||
|
sourceGen.io.reclaim.bits := io.outResp.bits.source
|
||||||
|
sourceGen.io.meta := Cat(byteOffset, io.inReq.bits.source)
|
||||||
|
|
||||||
|
val alignedMask = Wire(UInt(outMaskWidth.W))
|
||||||
|
val alignedData = Wire(UInt(outDataWidth.W))
|
||||||
|
if (outMaskWidth == inMaskWidth) {
|
||||||
|
alignedMask := io.inReq.bits.mask
|
||||||
|
} else {
|
||||||
|
val paddedMask = Wire(UInt(outMaskWidth.W))
|
||||||
|
paddedMask := io.inReq.bits.mask
|
||||||
|
alignedMask := (paddedMask << byteOffset)(outMaskWidth - 1, 0)
|
||||||
|
}
|
||||||
|
if (outDataWidth == inDataWidth) {
|
||||||
|
alignedData := io.inReq.bits.data
|
||||||
|
} else {
|
||||||
|
val paddedData = Wire(UInt(outDataWidth.W))
|
||||||
|
paddedData := io.inReq.bits.data
|
||||||
|
alignedData := (paddedData << (byteOffset << 3))(outDataWidth - 1, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PutFull requires the TL-canonical full mask for address+size; PutPartial
|
||||||
|
// can carry the core-provided byte mask.
|
||||||
io.outReq.bits.mask := Mux(
|
io.outReq.bits.mask := Mux(
|
||||||
edge.hasData(io.outReq.bits),
|
io.outReq.bits.opcode === TLMessages.PutPartialData,
|
||||||
io.inReq.bits.mask,
|
alignedMask,
|
||||||
// generate TL-correct mask
|
|
||||||
edge.mask(io.inReq.bits.address, io.inReq.bits.size)
|
edge.mask(io.inReq.bits.address, io.inReq.bits.size)
|
||||||
)
|
)
|
||||||
io.outReq.bits.data := io.inReq.bits.data
|
io.outReq.bits.data := alignedData
|
||||||
io.outReq.bits.corrupt := 0.U
|
io.outReq.bits.corrupt := 0.U
|
||||||
io.inReq.ready := io.outReq.ready
|
io.inReq.ready := io.outReq.ready
|
||||||
// VortexBundleD <> TLBundleD
|
// VortexBundleD <> TLBundleD
|
||||||
io.inResp.valid := io.outResp.valid
|
io.inResp.valid := io.outResp.valid
|
||||||
io.inResp.bits.opcode := io.outResp.bits.opcode
|
io.inResp.bits.opcode := io.outResp.bits.opcode
|
||||||
io.inResp.bits.size := io.outResp.bits.size
|
io.inResp.bits.size := io.outResp.bits.size
|
||||||
io.inResp.bits.source := io.outResp.bits.source
|
val responseMeta = sourceGen.io.peek.asUInt
|
||||||
io.inResp.bits.data := io.outResp.bits.data
|
val responseSource = responseMeta(responseSourceWidth - 1, 0)
|
||||||
|
val responseByteOffset =
|
||||||
|
responseMeta(responseSourceWidth + responseOffsetWidth - 1, responseSourceWidth)
|
||||||
|
io.inResp.bits.source := responseSource
|
||||||
|
if (outDataWidth == inDataWidth) {
|
||||||
|
io.inResp.bits.data := io.outResp.bits.data
|
||||||
|
} else {
|
||||||
|
io.inResp.bits.data := (io.outResp.bits.data >> (responseByteOffset << 3))(inDataWidth - 1, 0)
|
||||||
|
}
|
||||||
io.outResp.ready := io.inResp.ready
|
io.outResp.ready := io.inResp.ready
|
||||||
|
|
||||||
// "man-in-the-middle"
|
// "man-in-the-middle"
|
||||||
io.inReq.ready := io.outReq.ready && sourceGen.io.id.valid
|
io.inReq.ready := io.outReq.ready && sourceGen.io.id.valid
|
||||||
io.outReq.valid := io.inReq.valid && sourceGen.io.id.valid
|
io.outReq.valid := io.inReq.valid && sourceGen.io.id.valid
|
||||||
io.outReq.bits.source := sourceGen.io.id.bits
|
io.outReq.bits.source := sourceGen.io.id.bits
|
||||||
// translate upstream response back to its old sourceId
|
|
||||||
io.inResp.bits.source := sourceGen.io.peek
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -120,6 +120,15 @@ class VortexBundle(tile: RadianceTile)(implicit p: Parameters) extends CoreBundl
|
|||||||
val tc_tmem_C_waddr = Output(UInt((numTensorCores * 9).W))
|
val tc_tmem_C_waddr = Output(UInt((numTensorCores * 9).W))
|
||||||
val tc_tmem_C_wdata = Output(UInt((numTensorCores * numLanes * 32).W))
|
val tc_tmem_C_wdata = Output(UInt((numTensorCores * numLanes * 32).W))
|
||||||
val tc_tmem_C_mask = Output(UInt((numTensorCores * numLanes * 4).W))
|
val tc_tmem_C_mask = Output(UInt((numTensorCores * numLanes * 4).W))
|
||||||
|
val sc_tmem_ren = Output(UInt(1.W))
|
||||||
|
val sc_tmem_rready = Input(UInt(1.W))
|
||||||
|
val sc_tmem_raddr = Output(UInt(9.W))
|
||||||
|
val sc_tmem_rdata = Input(UInt((numLanes * 32).W))
|
||||||
|
val sc_tmem_wen = Output(UInt(1.W))
|
||||||
|
val sc_tmem_wready = Input(UInt(1.W))
|
||||||
|
val sc_tmem_waddr = Output(UInt(9.W))
|
||||||
|
val sc_tmem_wdata = Output(UInt((numLanes * 32).W))
|
||||||
|
val sc_tmem_mask = Output(UInt((numLanes * 4).W))
|
||||||
|
|
||||||
// FIXME: hardcoded
|
// FIXME: hardcoded
|
||||||
val barrierIdBits = tile.barrierMasterNode.out(0)._2.barrierIdBits
|
val barrierIdBits = tile.barrierMasterNode.out(0)._2.barrierIdBits
|
||||||
@@ -351,6 +360,7 @@ class Vortex(tile: RadianceTile)(implicit p: Parameters)
|
|||||||
addResource("/vsrc/vortex/hw/rtl/fpu/VX_fpu_div.sv")
|
addResource("/vsrc/vortex/hw/rtl/fpu/VX_fpu_div.sv")
|
||||||
addResource("/vsrc/vortex/hw/rtl/fpu/VX_fpu_dpi.sv")
|
addResource("/vsrc/vortex/hw/rtl/fpu/VX_fpu_dpi.sv")
|
||||||
addResource("/vsrc/vortex/hw/rtl/fpu/VX_fpu_dsp.sv")
|
addResource("/vsrc/vortex/hw/rtl/fpu/VX_fpu_dsp.sv")
|
||||||
|
addResource("/vsrc/vortex/hw/rtl/fpu/VX_fpu_exp.sv")
|
||||||
addResource("/vsrc/vortex/hw/rtl/fpu/VX_fpu_fma.sv")
|
addResource("/vsrc/vortex/hw/rtl/fpu/VX_fpu_fma.sv")
|
||||||
addResource("/vsrc/vortex/hw/rtl/fpu/VX_fpu_ncomp.sv")
|
addResource("/vsrc/vortex/hw/rtl/fpu/VX_fpu_ncomp.sv")
|
||||||
addResource("/vsrc/vortex/hw/rtl/fpu/VX_fpu_rounding.sv")
|
addResource("/vsrc/vortex/hw/rtl/fpu/VX_fpu_rounding.sv")
|
||||||
|
|||||||
@@ -26,7 +26,11 @@ class TensorCoreBlackwellExtendedTest extends AnyFlatSpec with ChiselScalatestTe
|
|||||||
c.io.reqB.ready.poke(false.B)
|
c.io.reqB.ready.poke(false.B)
|
||||||
c.io.respC.poke(0.U)
|
c.io.respC.poke(0.U)
|
||||||
c.io.writeback.ready.poke(false.B)
|
c.io.writeback.ready.poke(false.B)
|
||||||
c.io.tmemC.rdata.poke(0.U)
|
c.io.tmemC.aRready.poke(true.B)
|
||||||
|
c.io.tmemC.aRdata.poke(0.U)
|
||||||
|
c.io.tmemC.cRready.poke(true.B)
|
||||||
|
c.io.tmemC.cRdata.poke(0.U)
|
||||||
|
c.io.tmemC.cWready.poke(true.B)
|
||||||
}
|
}
|
||||||
|
|
||||||
private def packWords(words: Seq[BigInt], width: Int): BigInt = {
|
private def packWords(words: Seq[BigInt], width: Int): BigInt = {
|
||||||
@@ -39,13 +43,17 @@ class TensorCoreBlackwellExtendedTest extends AnyFlatSpec with ChiselScalatestTe
|
|||||||
private def makeTmem() = mutable.Map[BigInt, BigInt]().withDefaultValue(BigInt(0))
|
private def makeTmem() = mutable.Map[BigInt, BigInt]().withDefaultValue(BigInt(0))
|
||||||
|
|
||||||
private def stepTmem(c: TensorCoreBlackwell, tmem: mutable.Map[BigInt, BigInt]): Unit = {
|
private def stepTmem(c: TensorCoreBlackwell, tmem: mutable.Map[BigInt, BigInt]): Unit = {
|
||||||
if (c.io.tmemC.ren.peek().litToBoolean) {
|
if (c.io.tmemC.aRen.peek().litToBoolean) {
|
||||||
val addr = c.io.tmemC.raddr.peek().litValue
|
val addr = c.io.tmemC.aRaddr.peek().litValue
|
||||||
c.io.tmemC.rdata.poke(tmem(addr).U)
|
c.io.tmemC.aRdata.poke(tmem(addr).U)
|
||||||
}
|
}
|
||||||
if (c.io.tmemC.wen.peek().litToBoolean) {
|
if (c.io.tmemC.cRen.peek().litToBoolean) {
|
||||||
val addr = c.io.tmemC.waddr.peek().litValue
|
val addr = c.io.tmemC.cRaddr.peek().litValue
|
||||||
tmem(addr) = c.io.tmemC.wdata.peek().litValue
|
c.io.tmemC.cRdata.poke(tmem(addr).U)
|
||||||
|
}
|
||||||
|
if (c.io.tmemC.cWen.peek().litToBoolean) {
|
||||||
|
val addr = c.io.tmemC.cWaddr.peek().litValue
|
||||||
|
tmem(addr) = c.io.tmemC.cWdata.peek().litValue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -154,9 +162,9 @@ class TensorCoreBlackwellExtendedTest extends AnyFlatSpec with ChiselScalatestTe
|
|||||||
// cpWrite: respA fires, tmemC written
|
// cpWrite: respA fires, tmemC written
|
||||||
c.io.respA.valid.poke(true.B)
|
c.io.respA.valid.poke(true.B)
|
||||||
c.io.respA.bits.data.poke(cpData.U)
|
c.io.respA.bits.data.poke(cpData.U)
|
||||||
c.io.tmemC.wen.expect(true.B)
|
c.io.tmemC.cWen.expect(true.B)
|
||||||
c.io.tmemC.waddr.expect((tmemAddr / fragBytes).U)
|
c.io.tmemC.cWaddr.expect((tmemAddr / fragBytes).U)
|
||||||
c.io.tmemC.wdata.expect(cpData.U)
|
c.io.tmemC.cWdata.expect(cpData.U)
|
||||||
stepTmem(c, tmem)
|
stepTmem(c, tmem)
|
||||||
c.clock.step()
|
c.clock.step()
|
||||||
c.io.respA.valid.poke(false.B)
|
c.io.respA.valid.poke(false.B)
|
||||||
@@ -171,10 +179,10 @@ class TensorCoreBlackwellExtendedTest extends AnyFlatSpec with ChiselScalatestTe
|
|||||||
c.io.initiate.valid.poke(false.B)
|
c.io.initiate.valid.poke(false.B)
|
||||||
|
|
||||||
// ldReq: ren asserted, serve from tmem model
|
// ldReq: ren asserted, serve from tmem model
|
||||||
c.io.tmemC.ren.expect(true.B)
|
c.io.tmemC.cRen.expect(true.B)
|
||||||
c.io.tmemC.rdata.poke(tmem(tmemAddr / fragBytes).U)
|
c.io.tmemC.cRdata.poke(tmem(tmemAddr / fragBytes).U)
|
||||||
c.clock.step()
|
c.clock.step()
|
||||||
c.io.tmemC.rdata.poke(tmem(tmemAddr / fragBytes).U)
|
c.io.tmemC.cRdata.poke(tmem(tmemAddr / fragBytes).U)
|
||||||
c.clock.step()
|
c.clock.step()
|
||||||
|
|
||||||
// writeback should carry cpData
|
// writeback should carry cpData
|
||||||
@@ -206,8 +214,8 @@ class TensorCoreBlackwellExtendedTest extends AnyFlatSpec with ChiselScalatestTe
|
|||||||
c.clock.step()
|
c.clock.step()
|
||||||
|
|
||||||
// stWrite: tmemC written
|
// stWrite: tmemC written
|
||||||
c.io.tmemC.wen.expect(true.B)
|
c.io.tmemC.cWen.expect(true.B)
|
||||||
c.io.tmemC.wdata.expect(stData.U)
|
c.io.tmemC.cWdata.expect(stData.U)
|
||||||
stepTmem(c, tmem)
|
stepTmem(c, tmem)
|
||||||
c.clock.step()
|
c.clock.step()
|
||||||
|
|
||||||
@@ -217,13 +225,15 @@ class TensorCoreBlackwellExtendedTest extends AnyFlatSpec with ChiselScalatestTe
|
|||||||
c.io.initiate.bits.addressA.poke(tmemAddr.U)
|
c.io.initiate.bits.addressA.poke(tmemAddr.U)
|
||||||
c.io.initiate.bits.addressB.poke("h20000000".U)
|
c.io.initiate.bits.addressB.poke("h20000000".U)
|
||||||
c.io.reqA.ready.poke(true.B)
|
c.io.reqA.ready.poke(true.B)
|
||||||
c.io.tmemC.rdata.poke(tmem(tmemAddr / fragBytes).U)
|
c.io.tmemC.cRdata.poke(tmem(tmemAddr / fragBytes).U)
|
||||||
c.clock.step()
|
c.clock.step()
|
||||||
c.io.initiate.valid.poke(false.B)
|
c.io.initiate.valid.poke(false.B)
|
||||||
|
|
||||||
// cbRead: ren asserted
|
// cbRead: ren asserted
|
||||||
c.io.tmemC.ren.expect(true.B)
|
c.io.tmemC.cRen.expect(true.B)
|
||||||
c.io.tmemC.rdata.poke(tmem(tmemAddr / fragBytes).U)
|
c.io.tmemC.cRdata.poke(tmem(tmemAddr / fragBytes).U)
|
||||||
|
c.clock.step()
|
||||||
|
c.io.tmemC.cRdata.poke(tmem(tmemAddr / fragBytes).U)
|
||||||
c.clock.step()
|
c.clock.step()
|
||||||
|
|
||||||
// cbWrite: reqA write with stData
|
// cbWrite: reqA write with stData
|
||||||
@@ -280,7 +290,7 @@ class TensorCoreBlackwellExtendedTest extends AnyFlatSpec with ChiselScalatestTe
|
|||||||
c.clock.step()
|
c.clock.step()
|
||||||
c.io.initiate.ready.expect(false.B)
|
c.io.initiate.ready.expect(false.B)
|
||||||
|
|
||||||
c.io.tmemC.wen.expect(true.B)
|
c.io.tmemC.cWen.expect(true.B)
|
||||||
c.clock.step()
|
c.clock.step()
|
||||||
c.io.initiate.ready.expect(true.B)
|
c.io.initiate.ready.expect(true.B)
|
||||||
}
|
}
|
||||||
@@ -309,8 +319,8 @@ class TensorCoreBlackwellExtendedTest extends AnyFlatSpec with ChiselScalatestTe
|
|||||||
c.io.initiate.valid.poke(false.B)
|
c.io.initiate.valid.poke(false.B)
|
||||||
c.io.reqC.valid.expect(true.B)
|
c.io.reqC.valid.expect(true.B)
|
||||||
c.clock.step()
|
c.clock.step()
|
||||||
c.io.tmemC.wen.expect(true.B)
|
c.io.tmemC.cWen.expect(true.B)
|
||||||
c.io.tmemC.waddr.expect((warp0TmemA / fragBytes).U)
|
c.io.tmemC.cWaddr.expect((warp0TmemA / fragBytes).U)
|
||||||
stepTmem(c, tmem)
|
stepTmem(c, tmem)
|
||||||
c.clock.step()
|
c.clock.step()
|
||||||
|
|
||||||
@@ -324,8 +334,8 @@ class TensorCoreBlackwellExtendedTest extends AnyFlatSpec with ChiselScalatestTe
|
|||||||
c.io.initiate.valid.poke(false.B)
|
c.io.initiate.valid.poke(false.B)
|
||||||
c.io.reqC.valid.expect(true.B)
|
c.io.reqC.valid.expect(true.B)
|
||||||
c.clock.step()
|
c.clock.step()
|
||||||
c.io.tmemC.wen.expect(true.B)
|
c.io.tmemC.cWen.expect(true.B)
|
||||||
c.io.tmemC.waddr.expect((warp3TmemA / fragBytes).U)
|
c.io.tmemC.cWaddr.expect((warp3TmemA / fragBytes).U)
|
||||||
stepTmem(c, tmem)
|
stepTmem(c, tmem)
|
||||||
c.clock.step()
|
c.clock.step()
|
||||||
|
|
||||||
|
|||||||
@@ -25,7 +25,11 @@ class TensorCoreBlackwellTest extends AnyFlatSpec with ChiselScalatestTester {
|
|||||||
c.io.reqB.ready.poke(false.B)
|
c.io.reqB.ready.poke(false.B)
|
||||||
c.io.respC.poke(0.U)
|
c.io.respC.poke(0.U)
|
||||||
c.io.writeback.ready.poke(false.B)
|
c.io.writeback.ready.poke(false.B)
|
||||||
c.io.tmemC.rdata.poke(0.U)
|
c.io.tmemC.aRready.poke(true.B)
|
||||||
|
c.io.tmemC.aRdata.poke(0.U)
|
||||||
|
c.io.tmemC.cRready.poke(true.B)
|
||||||
|
c.io.tmemC.cRdata.poke(0.U)
|
||||||
|
c.io.tmemC.cWready.poke(true.B)
|
||||||
}
|
}
|
||||||
|
|
||||||
private def packWords(words: Seq[BigInt], width: Int): BigInt = {
|
private def packWords(words: Seq[BigInt], width: Int): BigInt = {
|
||||||
@@ -38,15 +42,19 @@ class TensorCoreBlackwellTest extends AnyFlatSpec with ChiselScalatestTester {
|
|||||||
// Simple TMEM model: address → 256-bit row
|
// Simple TMEM model: address → 256-bit row
|
||||||
private def makeTmem() = mutable.Map[BigInt, BigInt]().withDefaultValue(BigInt(0))
|
private def makeTmem() = mutable.Map[BigInt, BigInt]().withDefaultValue(BigInt(0))
|
||||||
|
|
||||||
// Drive tmemC read response from model, handle write
|
// Drive TMEM read responses from model, handle C-port writes.
|
||||||
private def stepTmem(c: TensorCoreBlackwell, tmem: mutable.Map[BigInt, BigInt]): Unit = {
|
private def stepTmem(c: TensorCoreBlackwell, tmem: mutable.Map[BigInt, BigInt]): Unit = {
|
||||||
if (c.io.tmemC.ren.peek().litToBoolean) {
|
if (c.io.tmemC.aRen.peek().litToBoolean) {
|
||||||
val addr = c.io.tmemC.raddr.peek().litValue
|
val addr = c.io.tmemC.aRaddr.peek().litValue
|
||||||
c.io.tmemC.rdata.poke(tmem(addr).U)
|
c.io.tmemC.aRdata.poke(tmem(addr).U)
|
||||||
}
|
}
|
||||||
if (c.io.tmemC.wen.peek().litToBoolean) {
|
if (c.io.tmemC.cRen.peek().litToBoolean) {
|
||||||
val addr = c.io.tmemC.waddr.peek().litValue
|
val addr = c.io.tmemC.cRaddr.peek().litValue
|
||||||
tmem(addr) = c.io.tmemC.wdata.peek().litValue
|
c.io.tmemC.cRdata.poke(tmem(addr).U)
|
||||||
|
}
|
||||||
|
if (c.io.tmemC.cWen.peek().litToBoolean) {
|
||||||
|
val addr = c.io.tmemC.cWaddr.peek().litValue
|
||||||
|
tmem(addr) = c.io.tmemC.cWdata.peek().litValue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -65,19 +73,19 @@ class TensorCoreBlackwellTest extends AnyFlatSpec with ChiselScalatestTester {
|
|||||||
c.io.initiate.bits.rd.poke(3.U)
|
c.io.initiate.bits.rd.poke(3.U)
|
||||||
c.io.initiate.bits.addressA.poke(tmemAddr.U)
|
c.io.initiate.bits.addressA.poke(tmemAddr.U)
|
||||||
c.io.writeback.ready.poke(true.B)
|
c.io.writeback.ready.poke(true.B)
|
||||||
c.io.tmemC.rdata.poke(testData.U)
|
c.io.tmemC.cRdata.poke(testData.U)
|
||||||
c.clock.step()
|
c.clock.step()
|
||||||
c.io.initiate.valid.poke(false.B)
|
c.io.initiate.valid.poke(false.B)
|
||||||
c.io.initiate.ready.expect(false.B)
|
c.io.initiate.ready.expect(false.B)
|
||||||
|
|
||||||
// ldReq: tmemC.ren asserted; rdata must be valid before next step
|
// ldReq: tmemC.ren asserted; rdata must be valid before next step
|
||||||
c.io.tmemC.ren.expect(true.B)
|
c.io.tmemC.cRen.expect(true.B)
|
||||||
c.io.tmemC.raddr.expect((tmemAddr / fragBytes).U)
|
c.io.tmemC.cRaddr.expect((tmemAddr / fragBytes).U)
|
||||||
c.io.tmemC.rdata.poke(testData.U)
|
c.io.tmemC.cRdata.poke(testData.U)
|
||||||
c.clock.step()
|
c.clock.step()
|
||||||
|
|
||||||
// waitWb: wbValid gets set this cycle, step to let it register
|
// waitWb: wbValid gets set this cycle, step to let it register
|
||||||
c.io.tmemC.rdata.poke(testData.U)
|
c.io.tmemC.cRdata.poke(testData.U)
|
||||||
c.clock.step()
|
c.clock.step()
|
||||||
|
|
||||||
// idle: writeback.valid now true
|
// idle: writeback.valid now true
|
||||||
@@ -91,6 +99,38 @@ class TensorCoreBlackwellTest extends AnyFlatSpec with ChiselScalatestTester {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
it should "tcgen05_ld: support 4-lane 16-byte fragments" in {
|
||||||
|
val lanes = 4
|
||||||
|
test(new TensorCoreBlackwell(numWarps, lanes, half = true, numSourceIds = 4)) { c =>
|
||||||
|
idleIO(c)
|
||||||
|
val fragBytes = 16
|
||||||
|
val tmemAddr = BigInt(0x40)
|
||||||
|
val testData = packWords(Seq.tabulate(lanes)(i => BigInt(0x2000 + i)), 32)
|
||||||
|
|
||||||
|
c.io.initiate.valid.poke(true.B)
|
||||||
|
c.io.initiate.bits.op.poke(4.U) // tcgen05Ld
|
||||||
|
c.io.initiate.bits.wid.poke(0.U)
|
||||||
|
c.io.initiate.bits.rd.poke(3.U)
|
||||||
|
c.io.initiate.bits.addressA.poke(tmemAddr.U)
|
||||||
|
c.io.writeback.ready.poke(true.B)
|
||||||
|
c.clock.step()
|
||||||
|
c.io.initiate.valid.poke(false.B)
|
||||||
|
|
||||||
|
c.io.tmemC.cRen.expect(true.B)
|
||||||
|
c.io.tmemC.cRaddr.expect((tmemAddr / fragBytes).U)
|
||||||
|
c.io.tmemC.cRdata.poke(testData.U)
|
||||||
|
c.clock.step()
|
||||||
|
c.io.tmemC.cRdata.poke(testData.U)
|
||||||
|
c.clock.step()
|
||||||
|
|
||||||
|
c.io.writeback.valid.expect(true.B)
|
||||||
|
c.io.writeback.bits.rd.expect(3.U)
|
||||||
|
for (i <- 0 until lanes) {
|
||||||
|
c.io.writeback.bits.data(i).expect((0x2000 + i).U)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
it should "tcgen05_st: write from respC to TMEM" in {
|
it should "tcgen05_st: write from respC to TMEM" in {
|
||||||
test(new TensorCoreBlackwell(numWarps, numLanes, half = true, numSourceIds = 4)) { c =>
|
test(new TensorCoreBlackwell(numWarps, numLanes, half = true, numSourceIds = 4)) { c =>
|
||||||
idleIO(c)
|
idleIO(c)
|
||||||
@@ -114,9 +154,9 @@ class TensorCoreBlackwellTest extends AnyFlatSpec with ChiselScalatestTester {
|
|||||||
c.clock.step()
|
c.clock.step()
|
||||||
|
|
||||||
// stWrite: tmemC.wen asserted with storeData
|
// stWrite: tmemC.wen asserted with storeData
|
||||||
c.io.tmemC.wen.expect(true.B)
|
c.io.tmemC.cWen.expect(true.B)
|
||||||
c.io.tmemC.waddr.expect((tmemAddr / fragBytes).U)
|
c.io.tmemC.cWaddr.expect((tmemAddr / fragBytes).U)
|
||||||
c.io.tmemC.wdata.expect(storeData.U)
|
c.io.tmemC.cWdata.expect(storeData.U)
|
||||||
c.clock.step()
|
c.clock.step()
|
||||||
c.io.initiate.ready.expect(true.B)
|
c.io.initiate.ready.expect(true.B)
|
||||||
}
|
}
|
||||||
@@ -151,9 +191,9 @@ class TensorCoreBlackwellTest extends AnyFlatSpec with ChiselScalatestTester {
|
|||||||
c.io.respA.bits.data.poke(cpData.U)
|
c.io.respA.bits.data.poke(cpData.U)
|
||||||
|
|
||||||
// tmemC write happens combinatorially when respA fires
|
// tmemC write happens combinatorially when respA fires
|
||||||
c.io.tmemC.wen.expect(true.B)
|
c.io.tmemC.cWen.expect(true.B)
|
||||||
c.io.tmemC.waddr.expect((tmemAddr / fragBytes).U)
|
c.io.tmemC.cWaddr.expect((tmemAddr / fragBytes).U)
|
||||||
c.io.tmemC.wdata.expect(cpData.U)
|
c.io.tmemC.cWdata.expect(cpData.U)
|
||||||
c.clock.step()
|
c.clock.step()
|
||||||
c.io.initiate.ready.expect(true.B)
|
c.io.initiate.ready.expect(true.B)
|
||||||
}
|
}
|
||||||
@@ -172,14 +212,16 @@ class TensorCoreBlackwellTest extends AnyFlatSpec with ChiselScalatestTester {
|
|||||||
c.io.initiate.bits.addressA.poke(tmemAddr.U)
|
c.io.initiate.bits.addressA.poke(tmemAddr.U)
|
||||||
c.io.initiate.bits.addressB.poke(gmemAddr.U)
|
c.io.initiate.bits.addressB.poke(gmemAddr.U)
|
||||||
c.io.reqA.ready.poke(true.B)
|
c.io.reqA.ready.poke(true.B)
|
||||||
c.io.tmemC.rdata.poke(cbData.U)
|
c.io.tmemC.cRdata.poke(cbData.U)
|
||||||
c.clock.step()
|
c.clock.step()
|
||||||
c.io.initiate.valid.poke(false.B)
|
c.io.initiate.valid.poke(false.B)
|
||||||
c.io.initiate.ready.expect(false.B)
|
c.io.initiate.ready.expect(false.B)
|
||||||
|
|
||||||
// cbRead: tmemC.ren asserted
|
// cbRead: tmemC.ren asserted
|
||||||
c.io.tmemC.ren.expect(true.B)
|
c.io.tmemC.cRen.expect(true.B)
|
||||||
c.io.tmemC.raddr.expect((tmemAddr / fragBytes).U)
|
c.io.tmemC.cRaddr.expect((tmemAddr / fragBytes).U)
|
||||||
|
c.clock.step()
|
||||||
|
c.io.tmemC.cRdata.poke(cbData.U)
|
||||||
c.clock.step()
|
c.clock.step()
|
||||||
c.io.initiate.ready.expect(false.B)
|
c.io.initiate.ready.expect(false.B)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user