Fix Blackwell SMEM fragment alignment

This commit is contained in:
Zhongdi LUO
2026-05-27 08:43:36 +00:00
parent c6c30ec0dc
commit 1e78574113

View File

@@ -21,7 +21,7 @@ import midas.targetutils.SynthesizePrintf
import org.chipsalliance.cde.config._ import org.chipsalliance.cde.config._
import radiance.core._ import radiance.core._
import radiance.memory._ import radiance.memory._
import radiance.subsystem.{GPUMemParams, GPUMemory, RadianceSimArgs} import radiance.subsystem.{GPUMemParams, GPUMemory, RadianceSharedMemKey, RadianceSimArgs}
/** For determining radiance core id. This may be different from /** For determining radiance core id. This may be different from
* RadianceTileParams.tileId, when a cluster contains non-core tiles */ * RadianceTileParams.tileId, when a cluster contains non-core tiles */
@@ -289,6 +289,11 @@ class RadianceTile private (
} }
val tcSmemSize = numLsuLanes * 4 val tcSmemSize = numLsuLanes * 4
val tcSmemLineSize = p(RadianceSharedMemKey)
.map(k => k.numWords * k.wordSize)
.getOrElse(tcSmemSize)
val tcSmemClientMaxSize =
if (radianceParams.core.tensorCoreBlackwell) math.max(tcSmemSize, tcSmemLineSize) else tcSmemSize
val numTensorWarps = radianceParams.core.numTensorWarps val numTensorWarps = radianceParams.core.numTensorWarps
val numScalarWarps = numWarps - numTensorWarps val numScalarWarps = numWarps - numTensorWarps
require(numTensorWarps > 0 && numTensorWarps < numWarps, require(numTensorWarps > 0 && numTensorWarps < numWarps,
@@ -300,6 +305,8 @@ class RadianceTile private (
require(numLsuLanes == 4 || numLsuLanes == 8, require(numLsuLanes == 4 || numLsuLanes == 8,
s"Wu Blackwell binding supports 4 or 8 lanes, got ${numLsuLanes}") s"Wu Blackwell binding supports 4 or 8 lanes, got ${numLsuLanes}")
require(numTensorCores == numTensorWarps, "Wu Blackwell binding requires one Tensor Core per Tensor warp") require(numTensorCores == numTensorWarps, "Wu Blackwell binding requires one Tensor Core per Tensor warp")
require(isPow2(tcSmemLineSize) && tcSmemLineSize >= tcSmemSize && (tcSmemLineSize % tcSmemSize) == 0,
s"Wu Blackwell SMEM line size (${tcSmemLineSize}) must be a power-of-two multiple of TC fragment size (${tcSmemSize})")
} }
val tensorUsesAsyncMem = radianceParams.core.tensorCoreDecoupled || radianceParams.core.tensorCoreBlackwell val tensorUsesAsyncMem = radianceParams.core.tensorCoreDecoupled || radianceParams.core.tensorCoreBlackwell
val tcSmemNodeCount = if (radianceParams.core.tensorCoreDecoupled) 2 else if (radianceParams.core.tensorCoreBlackwell) numTensorCores else 0 val tcSmemNodeCount = if (radianceParams.core.tensorCoreDecoupled) 2 else if (radianceParams.core.tensorCoreBlackwell) numTensorCores else 0
@@ -309,9 +316,9 @@ class RadianceTile private (
name = s"rad_tc_${radianceParams.coreId}_$i", name = s"rad_tc_${radianceParams.coreId}_$i",
sourceId = IdRange(0, 1 << smemSourceWidth), sourceId = IdRange(0, 1 << smemSourceWidth),
supports = TLSlaveToMasterTransferSizes( supports = TLSlaveToMasterTransferSizes(
probe = TransferSizes(1, tcSmemSize), probe = TransferSizes(1, tcSmemClientMaxSize),
get = TransferSizes(1, tcSmemSize), get = TransferSizes(1, tcSmemClientMaxSize),
putFull = TransferSizes(1, tcSmemSize), putFull = TransferSizes(1, tcSmemClientMaxSize),
), ),
requestFifo = true requestFifo = true
)) ))
@@ -854,18 +861,24 @@ class RadianceTileModuleImp(outer: RadianceTile)
val nTC = outer.numTensorCores val nTC = outer.numTensorCores
val tcPorts = 3 val tcPorts = 3
val tcCoreDataBits = 32 * 8
val tcDataBits = outer.tcSmemSize * 8 val tcDataBits = outer.tcSmemSize * 8
val tcSmemLineBits = outer.tcSmemLineSize * 8
val tmemAddrBits = 9 val tmemAddrBits = 9
val tmemDataBits = tcDataBits val tmemDataBits = tcDataBits
val tmemMaskBits = outer.tcSmemSize val tmemMaskBits = outer.tcSmemSize
val tcTlSize = log2Ceil(outer.tcSmemSize) val tcTlSize = log2Ceil(outer.tcSmemSize)
val tcSmemLineTlSize = log2Ceil(outer.tcSmemLineSize)
def slice(u: UInt, width: Int, idx: Int): UInt = u(width * (idx + 1) - 1, width * idx) def slice(u: UInt, width: Int, idx: Int): UInt = u(width * (idx + 1) - 1, width * idx)
def port(tc: Int, p: Int): Int = tc * tcPorts + p def port(tc: Int, p: Int): Int = tc * tcPorts + p
def padToCoreData(u: UInt): UInt = {
if (u.getWidth == tcCoreDataBits) u else Cat(0.U((tcCoreDataBits - u.getWidth).W), u)
}
val tcAReady = Wire(Vec(nTC * tcPorts, Bool())) val tcAReady = Wire(Vec(nTC * tcPorts, Bool()))
val tcDValid = Wire(Vec(nTC * tcPorts, Bool())) val tcDValid = Wire(Vec(nTC * tcPorts, Bool()))
val tcDData = Wire(Vec(nTC * tcPorts, UInt(tcDataBits.W))) val tcDData = Wire(Vec(nTC * tcPorts, UInt(tcCoreDataBits.W)))
val tcDTag = Wire(Vec(nTC * tcPorts, UInt(outer.tensorTagWidth.W))) val tcDTag = Wire(Vec(nTC * tcPorts, UInt(outer.tensorTagWidth.W)))
tcAReady.foreach(_ := false.B) tcAReady.foreach(_ := false.B)
tcDValid.foreach(_ := false.B) tcDValid.foreach(_ := false.B)
@@ -930,27 +943,43 @@ class RadianceTileModuleImp(outer: RadianceTile)
(0 until nTC).foreach { tc => (0 until nTC).foreach { tc =>
val p2 = port(tc, 2) val p2 = port(tc, 2)
val client = outer.tcSmemNodes(tc).out.head val client = outer.tcSmemNodes(tc).out.head
val rawAddress = slice(core.io.tc_a_bits_address, 32, p2)
val lineAddress = rawAddress & (~((outer.tcSmemLineSize - 1).U(32.W))).asUInt
val adapter = Module(new VortexTLAdapter( val adapter = Module(new VortexTLAdapter(
outer.smemSourceWidth, outer.smemSourceWidth,
new VortexBundleA(tagWidth = outer.tensorTagWidth, dataWidth = tcDataBits), new VortexBundleA(tagWidth = outer.tensorTagWidth, dataWidth = tcSmemLineBits),
new VortexBundleD(tagWidth = outer.tensorTagWidth, dataWidth = tcDataBits), new VortexBundleD(tagWidth = outer.tensorTagWidth, dataWidth = tcSmemLineBits),
client client
)) ))
adapter.io.inReq.bits <> DontCare adapter.io.inReq.bits <> DontCare
adapter.io.inReq.valid := core.io.tc_a_valid(p2) adapter.io.inReq.valid := core.io.tc_a_valid(p2)
adapter.io.inReq.bits.address := slice(core.io.tc_a_bits_address, 32, p2) adapter.io.inReq.bits.address := lineAddress
adapter.io.inReq.bits.source := slice(core.io.tc_a_bits_tag, outer.tensorTagWidth, p2) adapter.io.inReq.bits.source := slice(core.io.tc_a_bits_tag, outer.tensorTagWidth, p2)
adapter.io.inReq.bits.size := tcTlSize.U adapter.io.inReq.bits.size := tcSmemLineTlSize.U
adapter.io.inReq.bits.opcode := Mux(core.io.tc_a_bits_write(p2).asBool, TLMessages.PutFullData, TLMessages.Get) adapter.io.inReq.bits.opcode := Mux(core.io.tc_a_bits_write(p2).asBool, TLMessages.PutFullData, TLMessages.Get)
adapter.io.inReq.bits.mask := slice(core.io.tc_a_bits_mask, 32, p2) adapter.io.inReq.bits.mask := Fill(outer.tcSmemLineSize, 1.U(1.W))
adapter.io.inReq.bits.data := slice(core.io.tc_a_bits_data, tcDataBits, p2) adapter.io.inReq.bits.data := slice(core.io.tc_a_bits_data, tcCoreDataBits, p2)(tcSmemLineBits - 1, 0)
adapter.io.inResp.ready := core.io.tc_d_ready(p2) adapter.io.inResp.ready := core.io.tc_d_ready(p2)
client._1.a <> adapter.io.outReq client._1.a <> adapter.io.outReq
adapter.io.outResp <> client._1.d adapter.io.outResp <> client._1.d
val lineData = adapter.io.inResp.bits.data
val fragmentData = if (outer.tcSmemLineSize == outer.tcSmemSize) {
lineData
} else {
val fragmentsPerLine = outer.tcSmemLineSize / outer.tcSmemSize
val fragmentIndex = RegInit(0.U(log2Ceil(fragmentsPerLine).W))
val requestFragmentIndex = ((rawAddress & (outer.tcSmemLineSize - 1).U) >>
log2Ceil(outer.tcSmemSize)).asUInt
val lineFragments = lineData.asTypeOf(Vec(fragmentsPerLine, UInt(tcDataBits.W)))
when(adapter.io.inReq.fire) {
fragmentIndex := requestFragmentIndex
}
lineFragments(fragmentIndex)
}
tcAReady(p2) := adapter.io.inReq.ready tcAReady(p2) := adapter.io.inReq.ready
tcDValid(p2) := adapter.io.inResp.valid tcDValid(p2) := adapter.io.inResp.valid
tcDData(p2) := adapter.io.inResp.bits.data tcDData(p2) := padToCoreData(fragmentData)
tcDTag(p2) := adapter.io.inResp.bits.source tcDTag(p2) := adapter.io.inResp.bits.source
} }
@@ -970,15 +999,15 @@ class RadianceTileModuleImp(outer: RadianceTile)
gmemAdapter.io.inReq.bits.source := slice(core.io.tc_a_bits_tag, outer.tensorTagWidth, p0) gmemAdapter.io.inReq.bits.source := slice(core.io.tc_a_bits_tag, outer.tensorTagWidth, p0)
gmemAdapter.io.inReq.bits.size := tcTlSize.U gmemAdapter.io.inReq.bits.size := tcTlSize.U
gmemAdapter.io.inReq.bits.opcode := Mux(core.io.tc_a_bits_write(p0).asBool, TLMessages.PutFullData, TLMessages.Get) gmemAdapter.io.inReq.bits.opcode := Mux(core.io.tc_a_bits_write(p0).asBool, TLMessages.PutFullData, TLMessages.Get)
gmemAdapter.io.inReq.bits.mask := slice(core.io.tc_a_bits_mask, 32, p0) gmemAdapter.io.inReq.bits.mask := slice(core.io.tc_a_bits_mask, 32, p0)(outer.tcSmemSize - 1, 0)
gmemAdapter.io.inReq.bits.data := slice(core.io.tc_a_bits_data, tcDataBits, p0) gmemAdapter.io.inReq.bits.data := slice(core.io.tc_a_bits_data, tcCoreDataBits, p0)(tcDataBits - 1, 0)
gmemAdapter.io.inResp.ready := core.io.tc_d_ready(p0) gmemAdapter.io.inResp.ready := core.io.tc_d_ready(p0)
gmemClient._1.a <> gmemAdapter.io.outReq gmemClient._1.a <> gmemAdapter.io.outReq
gmemAdapter.io.outResp <> gmemClient._1.d gmemAdapter.io.outResp <> gmemClient._1.d
tcAReady(p0) := gmemAdapter.io.inReq.ready tcAReady(p0) := gmemAdapter.io.inReq.ready
tcDValid(p0) := gmemAdapter.io.inResp.valid tcDValid(p0) := gmemAdapter.io.inResp.valid
tcDData(p0) := gmemAdapter.io.inResp.bits.data tcDData(p0) := padToCoreData(gmemAdapter.io.inResp.bits.data)
tcDTag(p0) := gmemAdapter.io.inResp.bits.source tcDTag(p0) := gmemAdapter.io.inResp.bits.source
} }