From 378b3531d45cfd3d92d09f229a23b1f3fb16ea1a Mon Sep 17 00:00:00 2001 From: Richard Yan Date: Sat, 7 Sep 2024 20:29:27 -0700 Subject: [PATCH] balanced shared memory across cores --- .../radiance/memory/XbarWithExtPolicy.scala | 40 ++++++++ .../scala/radiance/tile/GemminiTile.scala | 4 +- .../scala/radiance/tile/RadianceCluster.scala | 95 ++++++++++++++++--- 3 files changed, 126 insertions(+), 13 deletions(-) create mode 100644 src/main/scala/radiance/memory/XbarWithExtPolicy.scala diff --git a/src/main/scala/radiance/memory/XbarWithExtPolicy.scala b/src/main/scala/radiance/memory/XbarWithExtPolicy.scala new file mode 100644 index 0000000..4eaea47 --- /dev/null +++ b/src/main/scala/radiance/memory/XbarWithExtPolicy.scala @@ -0,0 +1,40 @@ +package radiance.memory + +import chisel3._ +import chisel3.experimental.SourceInfo +import chisel3.util.Valid +import freechips.rocketchip.tilelink._ +import org.chipsalliance.cde.config.Parameters +import org.chipsalliance.diplomacy.ValName +import org.chipsalliance.diplomacy.lazymodule._ +import org.chipsalliance.diplomacy.nodes.{RenderedEdge, SimpleNodeImp, SinkNode, SourceNode} + +object ExtPolicyNodeImp extends SimpleNodeImp[Int, Int, Int, UInt] { + def bundle(x: Int) = UInt(x.W) + def edge(x: Int, y: Int, p: Parameters, sourceInfo: SourceInfo): Int = x + def render(x: Int): RenderedEdge = RenderedEdge("ffffff") +} +case class ExtPolicyMasterNode(w: Int)(implicit valName: ValName) extends SourceNode(ExtPolicyNodeImp)(Seq(w)) +case class ExtPolicySlaveNode()(implicit valName: ValName) extends SinkNode(ExtPolicyNodeImp)(Seq(0)) + +class XbarWithExtPolicy(nameSuffix: Option[String] = None) + (implicit p: Parameters) extends TLXbar(nameSuffix = nameSuffix) { + val policySlaveNode = ExtPolicySlaveNode() + + class ImplChild extends Impl { + val policy: TLArbiter.Policy = (_, _, _) => policySlaveNode.in.head._1 + // val wide_bundle = TLBundleParameters.union((node.in ++ node.out).map(_._2.bundle)) + // override def desiredName = (Seq("TLXbar") ++ nameSuffix ++ Seq(s"i${node.in.size}_o${node.out.size}_${wide_bundle.shortName}")).mkString("_") + TLXbar.circuit(policy, node.in, node.out) + } + + override lazy val module = new ImplChild +} + +object XbarWithExtPolicy { + def apply(nameSuffix: Option[String] = None) + (implicit p: Parameters): XbarWithExtPolicy = { + val xbar = LazyModule(new XbarWithExtPolicy(nameSuffix)) + xbar + } +} \ No newline at end of file diff --git a/src/main/scala/radiance/tile/GemminiTile.scala b/src/main/scala/radiance/tile/GemminiTile.scala index 9b32470..638c317 100644 --- a/src/main/scala/radiance/tile/GemminiTile.scala +++ b/src/main/scala/radiance/tile/GemminiTile.scala @@ -224,8 +224,8 @@ class GemminiTileModuleImp(outer: GemminiTile) extends BaseTileModuleImp(outer) is (8.U) { val inst = Wire(ciscInstT) inst.inst := 0x1820b07b.U - inst.rs1 := ciscArgs(7, 0) - inst.rs2 := ciscArgs(15, 8) + inst.rs1 := ciscArgs(11, 0) + inst.rs2 := ciscArgs(23, 12) ciscInst := microcodeEntry(Seq(inst)) } is (9.U) { diff --git a/src/main/scala/radiance/tile/RadianceCluster.scala b/src/main/scala/radiance/tile/RadianceCluster.scala index 3ccfeab..7ab2385 100644 --- a/src/main/scala/radiance/tile/RadianceCluster.scala +++ b/src/main/scala/radiance/tile/RadianceCluster.scala @@ -18,6 +18,8 @@ import org.chipsalliance.diplomacy.{DisableMonitors, ValName} import radiance.memory._ import radiance.subsystem.{RadianceFrameBufferKey, RadianceSharedMemKey} +import scala.collection.mutable.ArrayBuffer + case class RadianceClusterParams( val clusterId: Int, val clockSinkParams: ClockSinkParameters = ClockSinkParameters() @@ -98,13 +100,14 @@ class RadianceCluster ( guard_monitors { implicit p => t := from } t } - def connect_xbar_name(from: TLNode, name: Option[String] = None): TLNode = { - val t = LazyModule(new TLXbar(TLArbiter.roundRobin)) + def connect_xbar_name(from: TLNode, name: Option[String] = None, + policy: TLArbiter.Policy = TLArbiter.roundRobin): TLNexusNode = { + val t = LazyModule(new TLXbar(policy)) name.map(t.suggestName) guard_monitors { implicit p => t.node := from } t.node } - def connect_xbar(from: TLNode): TLNode = { + def connect_xbar(from: TLNode): TLNexusNode = { connect_xbar_name(from, None) } @@ -180,8 +183,17 @@ class RadianceCluster ( } } + val uniform_policy_nodes: Seq[ArrayBuffer[ArrayBuffer[ExtPolicyMasterNode]]] = // mutable + Seq.fill(2)(ArrayBuffer.fill(smem_banks)(ArrayBuffer.fill(smem_subbanks)(null))) + val uniform_nodes_in: Seq[ArrayBuffer[ArrayBuffer[Seq[TLIdentityNode]]]] = + Seq.fill(2)(ArrayBuffer.fill(smem_banks)(ArrayBuffer.fill(smem_subbanks)(Seq()))) + val uniform_nodes_out: Seq[ArrayBuffer[ArrayBuffer[TLIdentityNode]]] = + Seq.fill(2)(ArrayBuffer.fill(smem_banks)(ArrayBuffer.fill(smem_subbanks)(null))) + + val (uniform_r_nodes, uniform_w_nodes, _, _) = + if (stride_by_word) { - def dist_and_duplicate(nodes: Seq[TLNode], suffix: String): Seq[Seq[TLNode]] = { + def dist_and_duplicate(nodes: Seq[TLNode], suffix: String): Seq[Seq[TLNexusNode]] = { val word_fanout_nodes = gemminis.zip(nodes).zipWithIndex.map { case ((gemmini, node), gemmini_idx) => val sp_width_bytes = gemmini.config.sp_width / 8 val sp_subbanks = sp_width_bytes / wordSize @@ -205,7 +217,7 @@ class RadianceCluster ( val spad_sp_write_nodes = Seq.fill(smem_banks)(spad_sp_write_nodes_single_bank) // executed only once val (uniform_r_nodes, uniform_w_nodes, nonuniform_r_nodes, nonuniform_w_nodes): - (Seq[Seq[Seq[TLNode]]], Seq[Seq[Seq[TLNode]]], Seq[TLNode], Seq[TLNode]) = if (filter_aligned) { + (Seq[Seq[Seq[TLNexusNode]]], Seq[Seq[Seq[TLNexusNode]]], Seq[TLNode], Seq[TLNode]) = if (filter_aligned) { val num_lanes = radianceTiles.head.numCoreLanes val num_lsu_lanes = radianceTiles.head.numLsuLanes @@ -248,10 +260,10 @@ class RadianceCluster ( Seq.fill(2)(filter_nodes.flatMap(_.map(_._2).map(connect_xbar))) } - val uniform_r_nodes: Seq[Seq[Seq[TLNode]]] = spad_read_nodes.map { rb => + val uniform_r_nodes: Seq[Seq[Seq[TLNexusNode]]] = spad_read_nodes.map { rb => (rb zip f_aligned.head).map { case (rw, fa) => rw ++ fa } } - val uniform_w_nodes: Seq[Seq[Seq[TLNode]]] = (spad_write_nodes zip spad_sp_write_nodes).map { case (wb, wsb) => + val uniform_w_nodes: Seq[Seq[Seq[TLNexusNode]]] = (spad_write_nodes zip spad_sp_write_nodes).map { case (wb, wsb) => (wb lazyZip wsb lazyZip f_aligned.last).map { case (ww, wsw, fa) => ww ++ wsw ++ fa } @@ -264,8 +276,8 @@ class RadianceCluster ( } else { val splitter_nodes = radiance_smem_fanout.map { connect_one(_, RWSplitterNode.apply) } // these nodes access an entire line simultaneously - val uniform_r_nodes: Seq[Seq[Seq[TLNode]]] = spad_read_nodes - val uniform_w_nodes: Seq[Seq[Seq[TLNode]]] = (spad_write_nodes zip spad_sp_write_nodes).map { case (wb, wsb) => + val uniform_r_nodes: Seq[Seq[Seq[TLNexusNode]]] = spad_read_nodes + val uniform_w_nodes: Seq[Seq[Seq[TLNexusNode]]] = (spad_write_nodes zip spad_sp_write_nodes).map { case (wb, wsb) => (wb zip wsb).map { case (ww, wsw) => ww ++ wsw } } // these nodes are random access @@ -288,14 +300,39 @@ class RadianceCluster ( guard_monitors { implicit p => r := subbank_r_xbar.node w := subbank_w_xbar.node - uniform_r_nodes(bid)(wid).foreach( subbank_r_xbar.node := _ ) - uniform_w_nodes(bid)(wid).foreach( subbank_w_xbar.node := _ ) + + val ur_xbar = XbarWithExtPolicy(Some("ur")) + val uw_xbar = XbarWithExtPolicy(Some("uw")) + val r_policy_node = ExtPolicyMasterNode(uniform_r_nodes(bid)(wid).length) + val w_policy_node = ExtPolicyMasterNode(uniform_w_nodes(bid)(wid).length) + ur_xbar.policySlaveNode := r_policy_node + uw_xbar.policySlaveNode := w_policy_node + uniform_policy_nodes.head(bid)(wid) = r_policy_node + uniform_policy_nodes.last(bid)(wid) = w_policy_node + + (Seq(ur_xbar, uw_xbar) lazyZip uniform_nodes_in lazyZip Seq(uniform_r_nodes, uniform_w_nodes)) + .foreach { case (xbar, id_buf, u_nodes) => + + id_buf(bid)(wid) = u_nodes(bid)(wid).map { u => + val id = TLIdentityNode() + xbar.node := id := u + id + } + } + + // uniform_w_nodes(bid)(wid).foreach( uw_xbar.node := _ ) + uniform_nodes_out.head(bid)(wid) = TLIdentityNode() + uniform_nodes_out.last(bid)(wid) = TLIdentityNode() + subbank_r_xbar.node := uniform_nodes_out.head(bid)(wid) := ur_xbar.node + subbank_w_xbar.node := uniform_nodes_out.last(bid)(wid) := uw_xbar.node nonuniform_r_nodes.foreach( subbank_r_xbar.node := _ ) nonuniform_w_nodes.foreach( subbank_w_xbar.node := _ ) } } } + + (Some(uniform_r_nodes), Some(uniform_w_nodes), Some(nonuniform_r_nodes), Some(nonuniform_w_nodes)) } else { gemminis.foreach { gemmini => unified_mem_read_node :=* TLWidthWidget(smem_width) :=* gemmini.spad_read_nodes @@ -322,6 +359,8 @@ class RadianceCluster ( mem.head := smem_r_xbar mem.last := smem_w_xbar } + + (None, None, None, None) } // ******************************************************* @@ -504,7 +543,23 @@ class RadianceClusterModuleImp(outer: RadianceCluster) extends ClusterModuleImp( dontTouch(smemWriteCounter) if (outer.stride_by_word) { + val (uniform_r_nodes, uniform_w_nodes) = (outer.uniform_r_nodes.get, outer.uniform_w_nodes.get) + val uniform_fires = Seq.fill(2)(VecInit.fill(outer.smem_banks)(VecInit.fill(outer.smem_subbanks)(false.B))) + val word_selects_1h = Seq.fill(2)(VecInit.fill(outer.smem_banks)(0.U)) + outer.smem_bank_mgrs.grouped(outer.smem_subbanks).zipWithIndex.foreach { case (bank_mgrs, bid) => + // TODO move this loop out + // val Seq(valid_r_sources, valid_w_sources) = uniform_xbar_nodes.map(_(bid)).map { words => + // VecInit(words.map(_.out.map(_._1.a.valid)).transpose.map { words_with_same_idx => + // VecInit(words_with_same_idx.toSeq).asUInt.orR + // }.toSeq).asUInt + // } + val Seq(valid_r_sources, valid_w_sources) = outer.uniform_nodes_in.map { banks => + banks(bid).map(_.map(_.in.head._1.a.valid)).transpose.map { words_in_idx => + VecInit(words_in_idx.toSeq).asUInt.orR + } + } + assert(bank_mgrs.flatten.size == 2/* read and write */ * outer.smem_subbanks) bank_mgrs.zipWithIndex.foreach { case (Seq(r, w), wid) => assert(!r.portParams.map(_.anySupportPutFull).reduce(_ || _)) @@ -540,6 +595,24 @@ class RadianceClusterModuleImp(outer: RadianceCluster) extends ClusterModuleImp( // add access counters to banks smemReadsPerBankPerCycle(bid)(wid) := (r_node.a.fire === true.B) smemWritesPerBankPerCycle(bid)(wid) := (w_node.a.fire === true.B) + + // (uniform_fires zip Seq(uniform_r_nodes, uniform_w_nodes)).foreach { case (uf, n) => + // uf(bid)(wid) := VecInit(n(bid)(wid).map(_.out.head._1.a.fire)).asUInt.orR + // } + (uniform_fires zip outer.uniform_nodes_out).foreach { case (uf, n) => + uf(bid)(wid) := n(bid)(wid).in.head._1.a.fire + } + } + + println(f"valid r_sources ${valid_r_sources.length}, ${valid_r_sources}") + (word_selects_1h zip Seq(valid_r_sources, valid_w_sources)).zipWithIndex.foreach { case ((ws, vs), rw) => + ws(bid) := TLArbiter.roundRobin(vs.length, VecInit(vs.toSeq).asUInt, uniform_fires(rw)(bid).asUInt.orR) + } + } + + (outer.uniform_policy_nodes zip word_selects_1h).zipWithIndex.foreach { case ((nodes_bw, ws_b), rw) => + (nodes_bw zip ws_b).zipWithIndex.foreach { case ((nodes_w, ws), bid) => + nodes_w.foreach { _.out.head._1 := ws } } } } else {