Compare commits
2 Commits
c6c30ec0dc
...
wu-blackwe
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f88085331e | ||
|
|
1e78574113 |
@@ -21,7 +21,7 @@ import midas.targetutils.SynthesizePrintf
|
||||
import org.chipsalliance.cde.config._
|
||||
import radiance.core._
|
||||
import radiance.memory._
|
||||
import radiance.subsystem.{GPUMemParams, GPUMemory, RadianceSimArgs}
|
||||
import radiance.subsystem.{GPUMemParams, GPUMemory, RadianceSharedMemKey, RadianceSimArgs}
|
||||
|
||||
/** For determining radiance core id. This may be different from
|
||||
* RadianceTileParams.tileId, when a cluster contains non-core tiles */
|
||||
@@ -289,6 +289,11 @@ class RadianceTile private (
|
||||
}
|
||||
|
||||
val tcSmemSize = numLsuLanes * 4
|
||||
val tcSmemLineSize = p(RadianceSharedMemKey)
|
||||
.map(k => k.numWords * k.wordSize)
|
||||
.getOrElse(tcSmemSize)
|
||||
val tcSmemClientMaxSize =
|
||||
if (radianceParams.core.tensorCoreBlackwell) math.max(tcSmemSize, tcSmemLineSize) else tcSmemSize
|
||||
val numTensorWarps = radianceParams.core.numTensorWarps
|
||||
val numScalarWarps = numWarps - numTensorWarps
|
||||
require(numTensorWarps > 0 && numTensorWarps < numWarps,
|
||||
@@ -300,6 +305,8 @@ class RadianceTile private (
|
||||
require(numLsuLanes == 4 || numLsuLanes == 8,
|
||||
s"Wu Blackwell binding supports 4 or 8 lanes, got ${numLsuLanes}")
|
||||
require(numTensorCores == numTensorWarps, "Wu Blackwell binding requires one Tensor Core per Tensor warp")
|
||||
require(isPow2(tcSmemLineSize) && tcSmemLineSize >= tcSmemSize && (tcSmemLineSize % tcSmemSize) == 0,
|
||||
s"Wu Blackwell SMEM line size (${tcSmemLineSize}) must be a power-of-two multiple of TC fragment size (${tcSmemSize})")
|
||||
}
|
||||
val tensorUsesAsyncMem = radianceParams.core.tensorCoreDecoupled || radianceParams.core.tensorCoreBlackwell
|
||||
val tcSmemNodeCount = if (radianceParams.core.tensorCoreDecoupled) 2 else if (radianceParams.core.tensorCoreBlackwell) numTensorCores else 0
|
||||
@@ -309,9 +316,9 @@ class RadianceTile private (
|
||||
name = s"rad_tc_${radianceParams.coreId}_$i",
|
||||
sourceId = IdRange(0, 1 << smemSourceWidth),
|
||||
supports = TLSlaveToMasterTransferSizes(
|
||||
probe = TransferSizes(1, tcSmemSize),
|
||||
get = TransferSizes(1, tcSmemSize),
|
||||
putFull = TransferSizes(1, tcSmemSize),
|
||||
probe = TransferSizes(1, tcSmemClientMaxSize),
|
||||
get = TransferSizes(1, tcSmemClientMaxSize),
|
||||
putFull = TransferSizes(1, tcSmemClientMaxSize),
|
||||
),
|
||||
requestFifo = true
|
||||
))
|
||||
@@ -854,18 +861,24 @@ class RadianceTileModuleImp(outer: RadianceTile)
|
||||
|
||||
val nTC = outer.numTensorCores
|
||||
val tcPorts = 3
|
||||
val tcCoreDataBits = 32 * 8
|
||||
val tcDataBits = outer.tcSmemSize * 8
|
||||
val tcSmemLineBits = outer.tcSmemLineSize * 8
|
||||
val tmemAddrBits = 9
|
||||
val tmemDataBits = tcDataBits
|
||||
val tmemMaskBits = outer.tcSmemSize
|
||||
val tcTlSize = log2Ceil(outer.tcSmemSize)
|
||||
val tcSmemLineTlSize = log2Ceil(outer.tcSmemLineSize)
|
||||
|
||||
def slice(u: UInt, width: Int, idx: Int): UInt = u(width * (idx + 1) - 1, width * idx)
|
||||
def port(tc: Int, p: Int): Int = tc * tcPorts + p
|
||||
def padToCoreData(u: UInt): UInt = {
|
||||
if (u.getWidth == tcCoreDataBits) u else Cat(0.U((tcCoreDataBits - u.getWidth).W), u)
|
||||
}
|
||||
|
||||
val tcAReady = Wire(Vec(nTC * tcPorts, Bool()))
|
||||
val tcDValid = Wire(Vec(nTC * tcPorts, Bool()))
|
||||
val tcDData = Wire(Vec(nTC * tcPorts, UInt(tcDataBits.W)))
|
||||
val tcDData = Wire(Vec(nTC * tcPorts, UInt(tcCoreDataBits.W)))
|
||||
val tcDTag = Wire(Vec(nTC * tcPorts, UInt(outer.tensorTagWidth.W)))
|
||||
tcAReady.foreach(_ := false.B)
|
||||
tcDValid.foreach(_ := false.B)
|
||||
@@ -930,27 +943,43 @@ class RadianceTileModuleImp(outer: RadianceTile)
|
||||
(0 until nTC).foreach { tc =>
|
||||
val p2 = port(tc, 2)
|
||||
val client = outer.tcSmemNodes(tc).out.head
|
||||
val rawAddress = slice(core.io.tc_a_bits_address, 32, p2)
|
||||
val lineAddress = rawAddress & (~((outer.tcSmemLineSize - 1).U(32.W))).asUInt
|
||||
val adapter = Module(new VortexTLAdapter(
|
||||
outer.smemSourceWidth,
|
||||
new VortexBundleA(tagWidth = outer.tensorTagWidth, dataWidth = tcDataBits),
|
||||
new VortexBundleD(tagWidth = outer.tensorTagWidth, dataWidth = tcDataBits),
|
||||
new VortexBundleA(tagWidth = outer.tensorTagWidth, dataWidth = tcSmemLineBits),
|
||||
new VortexBundleD(tagWidth = outer.tensorTagWidth, dataWidth = tcSmemLineBits),
|
||||
client
|
||||
))
|
||||
adapter.io.inReq.bits <> DontCare
|
||||
adapter.io.inReq.valid := core.io.tc_a_valid(p2)
|
||||
adapter.io.inReq.bits.address := slice(core.io.tc_a_bits_address, 32, p2)
|
||||
adapter.io.inReq.bits.address := lineAddress
|
||||
adapter.io.inReq.bits.source := slice(core.io.tc_a_bits_tag, outer.tensorTagWidth, p2)
|
||||
adapter.io.inReq.bits.size := tcTlSize.U
|
||||
adapter.io.inReq.bits.size := tcSmemLineTlSize.U
|
||||
adapter.io.inReq.bits.opcode := Mux(core.io.tc_a_bits_write(p2).asBool, TLMessages.PutFullData, TLMessages.Get)
|
||||
adapter.io.inReq.bits.mask := slice(core.io.tc_a_bits_mask, 32, p2)
|
||||
adapter.io.inReq.bits.data := slice(core.io.tc_a_bits_data, tcDataBits, p2)
|
||||
adapter.io.inReq.bits.mask := Fill(outer.tcSmemLineSize, 1.U(1.W))
|
||||
adapter.io.inReq.bits.data := slice(core.io.tc_a_bits_data, tcCoreDataBits, p2)(tcSmemLineBits - 1, 0)
|
||||
adapter.io.inResp.ready := core.io.tc_d_ready(p2)
|
||||
client._1.a <> adapter.io.outReq
|
||||
adapter.io.outResp <> client._1.d
|
||||
val lineData = adapter.io.inResp.bits.data
|
||||
val fragmentData = if (outer.tcSmemLineSize == outer.tcSmemSize) {
|
||||
lineData
|
||||
} else {
|
||||
val fragmentsPerLine = outer.tcSmemLineSize / outer.tcSmemSize
|
||||
val fragmentIndex = RegInit(0.U(log2Ceil(fragmentsPerLine).W))
|
||||
val requestFragmentIndex = ((rawAddress & (outer.tcSmemLineSize - 1).U) >>
|
||||
log2Ceil(outer.tcSmemSize)).asUInt
|
||||
val lineFragments = lineData.asTypeOf(Vec(fragmentsPerLine, UInt(tcDataBits.W)))
|
||||
when(adapter.io.inReq.fire) {
|
||||
fragmentIndex := requestFragmentIndex
|
||||
}
|
||||
lineFragments(fragmentIndex)
|
||||
}
|
||||
|
||||
tcAReady(p2) := adapter.io.inReq.ready
|
||||
tcDValid(p2) := adapter.io.inResp.valid
|
||||
tcDData(p2) := adapter.io.inResp.bits.data
|
||||
tcDData(p2) := padToCoreData(fragmentData)
|
||||
tcDTag(p2) := adapter.io.inResp.bits.source
|
||||
}
|
||||
|
||||
@@ -970,15 +999,15 @@ class RadianceTileModuleImp(outer: RadianceTile)
|
||||
gmemAdapter.io.inReq.bits.source := slice(core.io.tc_a_bits_tag, outer.tensorTagWidth, p0)
|
||||
gmemAdapter.io.inReq.bits.size := tcTlSize.U
|
||||
gmemAdapter.io.inReq.bits.opcode := Mux(core.io.tc_a_bits_write(p0).asBool, TLMessages.PutFullData, TLMessages.Get)
|
||||
gmemAdapter.io.inReq.bits.mask := slice(core.io.tc_a_bits_mask, 32, p0)
|
||||
gmemAdapter.io.inReq.bits.data := slice(core.io.tc_a_bits_data, tcDataBits, p0)
|
||||
gmemAdapter.io.inReq.bits.mask := slice(core.io.tc_a_bits_mask, 32, p0)(outer.tcSmemSize - 1, 0)
|
||||
gmemAdapter.io.inReq.bits.data := slice(core.io.tc_a_bits_data, tcCoreDataBits, p0)(tcDataBits - 1, 0)
|
||||
gmemAdapter.io.inResp.ready := core.io.tc_d_ready(p0)
|
||||
gmemClient._1.a <> gmemAdapter.io.outReq
|
||||
gmemAdapter.io.outResp <> gmemClient._1.d
|
||||
|
||||
tcAReady(p0) := gmemAdapter.io.inReq.ready
|
||||
tcDValid(p0) := gmemAdapter.io.inResp.valid
|
||||
tcDData(p0) := gmemAdapter.io.inResp.bits.data
|
||||
tcDData(p0) := padToCoreData(gmemAdapter.io.inResp.bits.data)
|
||||
tcDTag(p0) := gmemAdapter.io.inResp.bits.source
|
||||
}
|
||||
|
||||
@@ -1168,18 +1197,6 @@ class VortexTLAdapter(
|
||||
val outResp = chiselTypeOf(outTL._1.d)
|
||||
})
|
||||
val (bundle, edge) = outTL
|
||||
val sourceGen = Module(
|
||||
new SourceGenerator(
|
||||
newSourceWidth,
|
||||
Some(inReqT.source),
|
||||
ignoreInUse = false
|
||||
)
|
||||
)
|
||||
sourceGen.io.gen := io.outReq.fire // use up a source ID only when request is created
|
||||
sourceGen.io.reclaim.valid := io.outResp.fire
|
||||
sourceGen.io.reclaim.bits := io.outResp.bits.source
|
||||
sourceGen.io.meta := io.inReq.bits.source
|
||||
|
||||
// io passthrough logic
|
||||
// TLBundleA <> VortexBundleA
|
||||
io.outReq.valid := io.inReq.valid
|
||||
@@ -1188,29 +1205,70 @@ class VortexTLAdapter(
|
||||
io.outReq.bits.size := io.inReq.bits.size
|
||||
io.outReq.bits.source := io.inReq.bits.source
|
||||
io.outReq.bits.address := io.inReq.bits.address
|
||||
// Get requires contiguous mask; only copy core's potentially-partial mask
|
||||
// when writing
|
||||
val outMaskWidth = io.outReq.bits.mask.getWidth
|
||||
val inMaskWidth = io.inReq.bits.mask.getWidth
|
||||
val outDataWidth = io.outReq.bits.data.getWidth
|
||||
val inDataWidth = io.inReq.bits.data.getWidth
|
||||
val byteOffset = io.inReq.bits.address(log2Ceil(outMaskWidth) - 1, 0)
|
||||
val responseOffsetWidth = log2Ceil(outMaskWidth)
|
||||
val responseSourceWidth = inReqT.source.getWidth
|
||||
val sourceGen = Module(
|
||||
new SourceGenerator(
|
||||
newSourceWidth,
|
||||
Some(UInt((responseSourceWidth + responseOffsetWidth).W)),
|
||||
ignoreInUse = false
|
||||
)
|
||||
)
|
||||
sourceGen.io.gen := io.outReq.fire // use up a source ID only when request is created
|
||||
sourceGen.io.reclaim.valid := io.outResp.fire
|
||||
sourceGen.io.reclaim.bits := io.outResp.bits.source
|
||||
sourceGen.io.meta := Cat(byteOffset, io.inReq.bits.source)
|
||||
|
||||
val alignedMask = Wire(UInt(outMaskWidth.W))
|
||||
val alignedData = Wire(UInt(outDataWidth.W))
|
||||
if (outMaskWidth == inMaskWidth) {
|
||||
alignedMask := io.inReq.bits.mask
|
||||
} else {
|
||||
val paddedMask = Wire(UInt(outMaskWidth.W))
|
||||
paddedMask := io.inReq.bits.mask
|
||||
alignedMask := (paddedMask << byteOffset)(outMaskWidth - 1, 0)
|
||||
}
|
||||
if (outDataWidth == inDataWidth) {
|
||||
alignedData := io.inReq.bits.data
|
||||
} else {
|
||||
val paddedData = Wire(UInt(outDataWidth.W))
|
||||
paddedData := io.inReq.bits.data
|
||||
alignedData := (paddedData << (byteOffset << 3))(outDataWidth - 1, 0)
|
||||
}
|
||||
|
||||
// PutFull requires the TL-canonical full mask for address+size; PutPartial
|
||||
// can carry the core-provided byte mask.
|
||||
io.outReq.bits.mask := Mux(
|
||||
edge.hasData(io.outReq.bits),
|
||||
io.inReq.bits.mask,
|
||||
// generate TL-correct mask
|
||||
io.outReq.bits.opcode === TLMessages.PutPartialData,
|
||||
alignedMask,
|
||||
edge.mask(io.inReq.bits.address, io.inReq.bits.size)
|
||||
)
|
||||
io.outReq.bits.data := io.inReq.bits.data
|
||||
io.outReq.bits.data := alignedData
|
||||
io.outReq.bits.corrupt := 0.U
|
||||
io.inReq.ready := io.outReq.ready
|
||||
// VortexBundleD <> TLBundleD
|
||||
io.inResp.valid := io.outResp.valid
|
||||
io.inResp.bits.opcode := io.outResp.bits.opcode
|
||||
io.inResp.bits.size := io.outResp.bits.size
|
||||
io.inResp.bits.source := io.outResp.bits.source
|
||||
val responseMeta = sourceGen.io.peek.asUInt
|
||||
val responseSource = responseMeta(responseSourceWidth - 1, 0)
|
||||
val responseByteOffset =
|
||||
responseMeta(responseSourceWidth + responseOffsetWidth - 1, responseSourceWidth)
|
||||
io.inResp.bits.source := responseSource
|
||||
if (outDataWidth == inDataWidth) {
|
||||
io.inResp.bits.data := io.outResp.bits.data
|
||||
} else {
|
||||
io.inResp.bits.data := (io.outResp.bits.data >> (responseByteOffset << 3))(inDataWidth - 1, 0)
|
||||
}
|
||||
io.outResp.ready := io.inResp.ready
|
||||
|
||||
// "man-in-the-middle"
|
||||
io.inReq.ready := io.outReq.ready && sourceGen.io.id.valid
|
||||
io.outReq.valid := io.inReq.valid && sourceGen.io.id.valid
|
||||
io.outReq.bits.source := sourceGen.io.id.bits
|
||||
// translate upstream response back to its old sourceId
|
||||
io.inResp.bits.source := sourceGen.io.peek
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user