diff --git a/src/main/scala/radiance/tile/RadianceTile.scala b/src/main/scala/radiance/tile/RadianceTile.scala index eefd491..821e154 100644 --- a/src/main/scala/radiance/tile/RadianceTile.scala +++ b/src/main/scala/radiance/tile/RadianceTile.scala @@ -274,6 +274,19 @@ class RadianceTile private ( ) } + val tcSmemSize = 32 + val tcSmemNodes = Seq(TLClientNode(Seq(TLMasterPortParameters.v2( + masters = Seq(TLMasterParameters.v2( + name = s"rad_tc_${radianceParams.coreId}", + sourceId = IdRange(0, 1 << smemSourceWidth), + supports = TLSlaveToMasterTransferSizes( + get = TransferSizes(1, tcSmemSize), + putFull = TransferSizes(1, tcSmemSize), + putPartial = TransferSizes(1, tcSmemSize) + ) + )) + )))) + // combine outgoing per-lane dmemNode into 1 idenity node // // NOTE: We need TLWidthWidget here because there might be a data width diff --git a/src/main/scala/radiance/tile/VirgoSharedMemComponents.scala b/src/main/scala/radiance/tile/VirgoSharedMemComponents.scala index a3fde96..c72fc7f 100644 --- a/src/main/scala/radiance/tile/VirgoSharedMemComponents.scala +++ b/src/main/scala/radiance/tile/VirgoSharedMemComponents.scala @@ -54,6 +54,7 @@ class VirgoSharedMemComponents( smemFanoutXbar.node } } + val tcNodeFanouts = radianceTiles.flatMap(_.tcSmemNodes).map(connectXbarName(_, Some("tc_fanout"))) val clBusClients: Seq[TLNode] = radianceSmemFanout val (uniformRNodes, uniformWNodes, nonuniformRNodes, nonuniformWNodes) = @@ -84,6 +85,12 @@ class VirgoSharedMemComponents( val spadSpWriteNodesSingleBank = distAndDuplicate(gemminis.map(_.spad.spad_writer.node), "ws") val spadSpWriteNodes = Seq.fill(smemBanks)(spadSpWriteNodesSingleBank) // executed only once + // tensor core read nodes + val tcDistNodes = Seq.fill(smemBanks)(tcNodeFanouts.map(connectOne(_, () => DistributorNode(smemWidth, wordSize)))) + val tcNodes = tcDistNodes.map { tcBank => + Seq.fill(smemSubbanks)(tcBank.map(connectXbarName(_, Some("tc_dist_fanout")))) + } // (banks, subbanks, tc client) + if (filterAligned) { val numLsuLanes = radianceTiles.head.numLsuLanes val numLaneDupes = Math.max(1, smemSubbanks / numLsuLanes) @@ -186,8 +193,8 @@ class VirgoSharedMemComponents( } - val uniformRNodes: Seq[Seq[Seq[TLNexusNode]]] = spadReadNodes.map { rb => - (rb zip fAligned.head).map { case (rw, fa) => rw ++ fa } + val uniformRNodes: Seq[Seq[Seq[TLNexusNode]]] = (spadReadNodes zip tcNodes).map { case (rb, tcrb) => + (rb lazyZip tcrb lazyZip fAligned.head).map { case (rw, tcrw, fa) => rw ++ tcrw ++ fa } } val uniformWNodes: Seq[Seq[Seq[TLNexusNode]]] = (spadWriteNodes zip spadSpWriteNodes).map { case (wb, wsb) => (wb lazyZip wsb lazyZip fAligned.last).map {