diff --git a/src/main/scala/radiance/tile/RadianceTile.scala b/src/main/scala/radiance/tile/RadianceTile.scala index 317ade6..6bfbff6 100644 --- a/src/main/scala/radiance/tile/RadianceTile.scala +++ b/src/main/scala/radiance/tile/RadianceTile.scala @@ -21,7 +21,7 @@ import midas.targetutils.SynthesizePrintf import org.chipsalliance.cde.config._ import radiance.core._ import radiance.memory._ -import radiance.subsystem.{GPUMemParams, GPUMemory, RadianceSimArgs} +import radiance.subsystem.{GPUMemParams, GPUMemory, RadianceSharedMemKey, RadianceSimArgs} /** For determining radiance core id. This may be different from * RadianceTileParams.tileId, when a cluster contains non-core tiles */ @@ -289,6 +289,11 @@ class RadianceTile private ( } val tcSmemSize = numLsuLanes * 4 + val tcSmemLineSize = p(RadianceSharedMemKey) + .map(k => k.numWords * k.wordSize) + .getOrElse(tcSmemSize) + val tcSmemClientMaxSize = + if (radianceParams.core.tensorCoreBlackwell) math.max(tcSmemSize, tcSmemLineSize) else tcSmemSize val numTensorWarps = radianceParams.core.numTensorWarps val numScalarWarps = numWarps - numTensorWarps require(numTensorWarps > 0 && numTensorWarps < numWarps, @@ -300,6 +305,8 @@ class RadianceTile private ( require(numLsuLanes == 4 || numLsuLanes == 8, s"Wu Blackwell binding supports 4 or 8 lanes, got ${numLsuLanes}") require(numTensorCores == numTensorWarps, "Wu Blackwell binding requires one Tensor Core per Tensor warp") + require(isPow2(tcSmemLineSize) && tcSmemLineSize >= tcSmemSize && (tcSmemLineSize % tcSmemSize) == 0, + s"Wu Blackwell SMEM line size (${tcSmemLineSize}) must be a power-of-two multiple of TC fragment size (${tcSmemSize})") } val tensorUsesAsyncMem = radianceParams.core.tensorCoreDecoupled || radianceParams.core.tensorCoreBlackwell val tcSmemNodeCount = if (radianceParams.core.tensorCoreDecoupled) 2 else if (radianceParams.core.tensorCoreBlackwell) numTensorCores else 0 @@ -309,9 +316,9 @@ class RadianceTile private ( name = s"rad_tc_${radianceParams.coreId}_$i", sourceId = IdRange(0, 1 << smemSourceWidth), supports = TLSlaveToMasterTransferSizes( - probe = TransferSizes(1, tcSmemSize), - get = TransferSizes(1, tcSmemSize), - putFull = TransferSizes(1, tcSmemSize), + probe = TransferSizes(1, tcSmemClientMaxSize), + get = TransferSizes(1, tcSmemClientMaxSize), + putFull = TransferSizes(1, tcSmemClientMaxSize), ), requestFifo = true )) @@ -854,18 +861,24 @@ class RadianceTileModuleImp(outer: RadianceTile) val nTC = outer.numTensorCores val tcPorts = 3 + val tcCoreDataBits = 32 * 8 val tcDataBits = outer.tcSmemSize * 8 + val tcSmemLineBits = outer.tcSmemLineSize * 8 val tmemAddrBits = 9 val tmemDataBits = tcDataBits val tmemMaskBits = outer.tcSmemSize val tcTlSize = log2Ceil(outer.tcSmemSize) + val tcSmemLineTlSize = log2Ceil(outer.tcSmemLineSize) def slice(u: UInt, width: Int, idx: Int): UInt = u(width * (idx + 1) - 1, width * idx) def port(tc: Int, p: Int): Int = tc * tcPorts + p + def padToCoreData(u: UInt): UInt = { + if (u.getWidth == tcCoreDataBits) u else Cat(0.U((tcCoreDataBits - u.getWidth).W), u) + } val tcAReady = Wire(Vec(nTC * tcPorts, Bool())) val tcDValid = Wire(Vec(nTC * tcPorts, Bool())) - val tcDData = Wire(Vec(nTC * tcPorts, UInt(tcDataBits.W))) + val tcDData = Wire(Vec(nTC * tcPorts, UInt(tcCoreDataBits.W))) val tcDTag = Wire(Vec(nTC * tcPorts, UInt(outer.tensorTagWidth.W))) tcAReady.foreach(_ := false.B) tcDValid.foreach(_ := false.B) @@ -930,27 +943,43 @@ class RadianceTileModuleImp(outer: RadianceTile) (0 until nTC).foreach { tc => val p2 = port(tc, 2) val client = outer.tcSmemNodes(tc).out.head + val rawAddress = slice(core.io.tc_a_bits_address, 32, p2) + val lineAddress = rawAddress & (~((outer.tcSmemLineSize - 1).U(32.W))).asUInt val adapter = Module(new VortexTLAdapter( outer.smemSourceWidth, - new VortexBundleA(tagWidth = outer.tensorTagWidth, dataWidth = tcDataBits), - new VortexBundleD(tagWidth = outer.tensorTagWidth, dataWidth = tcDataBits), + new VortexBundleA(tagWidth = outer.tensorTagWidth, dataWidth = tcSmemLineBits), + new VortexBundleD(tagWidth = outer.tensorTagWidth, dataWidth = tcSmemLineBits), client )) adapter.io.inReq.bits <> DontCare adapter.io.inReq.valid := core.io.tc_a_valid(p2) - adapter.io.inReq.bits.address := slice(core.io.tc_a_bits_address, 32, p2) + adapter.io.inReq.bits.address := lineAddress adapter.io.inReq.bits.source := slice(core.io.tc_a_bits_tag, outer.tensorTagWidth, p2) - adapter.io.inReq.bits.size := tcTlSize.U + adapter.io.inReq.bits.size := tcSmemLineTlSize.U adapter.io.inReq.bits.opcode := Mux(core.io.tc_a_bits_write(p2).asBool, TLMessages.PutFullData, TLMessages.Get) - adapter.io.inReq.bits.mask := slice(core.io.tc_a_bits_mask, 32, p2) - adapter.io.inReq.bits.data := slice(core.io.tc_a_bits_data, tcDataBits, p2) + adapter.io.inReq.bits.mask := Fill(outer.tcSmemLineSize, 1.U(1.W)) + adapter.io.inReq.bits.data := slice(core.io.tc_a_bits_data, tcCoreDataBits, p2)(tcSmemLineBits - 1, 0) adapter.io.inResp.ready := core.io.tc_d_ready(p2) client._1.a <> adapter.io.outReq adapter.io.outResp <> client._1.d + val lineData = adapter.io.inResp.bits.data + val fragmentData = if (outer.tcSmemLineSize == outer.tcSmemSize) { + lineData + } else { + val fragmentsPerLine = outer.tcSmemLineSize / outer.tcSmemSize + val fragmentIndex = RegInit(0.U(log2Ceil(fragmentsPerLine).W)) + val requestFragmentIndex = ((rawAddress & (outer.tcSmemLineSize - 1).U) >> + log2Ceil(outer.tcSmemSize)).asUInt + val lineFragments = lineData.asTypeOf(Vec(fragmentsPerLine, UInt(tcDataBits.W))) + when(adapter.io.inReq.fire) { + fragmentIndex := requestFragmentIndex + } + lineFragments(fragmentIndex) + } tcAReady(p2) := adapter.io.inReq.ready tcDValid(p2) := adapter.io.inResp.valid - tcDData(p2) := adapter.io.inResp.bits.data + tcDData(p2) := padToCoreData(fragmentData) tcDTag(p2) := adapter.io.inResp.bits.source } @@ -970,15 +999,15 @@ class RadianceTileModuleImp(outer: RadianceTile) gmemAdapter.io.inReq.bits.source := slice(core.io.tc_a_bits_tag, outer.tensorTagWidth, p0) gmemAdapter.io.inReq.bits.size := tcTlSize.U gmemAdapter.io.inReq.bits.opcode := Mux(core.io.tc_a_bits_write(p0).asBool, TLMessages.PutFullData, TLMessages.Get) - gmemAdapter.io.inReq.bits.mask := slice(core.io.tc_a_bits_mask, 32, p0) - gmemAdapter.io.inReq.bits.data := slice(core.io.tc_a_bits_data, tcDataBits, p0) + gmemAdapter.io.inReq.bits.mask := slice(core.io.tc_a_bits_mask, 32, p0)(outer.tcSmemSize - 1, 0) + gmemAdapter.io.inReq.bits.data := slice(core.io.tc_a_bits_data, tcCoreDataBits, p0)(tcDataBits - 1, 0) gmemAdapter.io.inResp.ready := core.io.tc_d_ready(p0) gmemClient._1.a <> gmemAdapter.io.outReq gmemAdapter.io.outResp <> gmemClient._1.d tcAReady(p0) := gmemAdapter.io.inReq.ready tcDValid(p0) := gmemAdapter.io.inResp.valid - tcDData(p0) := gmemAdapter.io.inResp.bits.data + tcDData(p0) := padToCoreData(gmemAdapter.io.inResp.bits.data) tcDTag(p0) := gmemAdapter.io.inResp.bits.source }