From 9fb861a8738b183b9e7f8979e374ab3aa137d36b Mon Sep 17 00:00:00 2001 From: Richard Yan Date: Tue, 26 Mar 2024 23:13:30 -0700 Subject: [PATCH] make unified memory node modular --- .../radiance/memory/RWSplitterNode.scala | 164 ++++++++++++++++++ .../scala/radiance/tile/RadianceCluster.scala | 157 ++--------------- 2 files changed, 178 insertions(+), 143 deletions(-) create mode 100644 src/main/scala/radiance/memory/RWSplitterNode.scala diff --git a/src/main/scala/radiance/memory/RWSplitterNode.scala b/src/main/scala/radiance/memory/RWSplitterNode.scala new file mode 100644 index 0000000..b47c641 --- /dev/null +++ b/src/main/scala/radiance/memory/RWSplitterNode.scala @@ -0,0 +1,164 @@ +package radiance.memory + +import chisel3._ +import chisel3.experimental.SourceInfo +import chisel3.util._ +import freechips.rocketchip.diplomacy._ +import freechips.rocketchip.tilelink._ +import freechips.rocketchip.util.BundleField +import org.chipsalliance.cde.config.Parameters + + +class RWSplitterNode(name: String = "rw_splitter")(implicit p: Parameters) extends LazyModule { + // this node accepts both read and write requests, + // splits & arbitrates them into one client node per type of operation; + // there will be N incoming edges, two outgoing edges, with two N:1 muxes; + // it keeps the read and write channels fully separate to allow parallel processing. + val node = TLNexusNode( + clientFn = { seq => + val in_mapping = TLXbar.mapInputIds(seq) + val read_src_range = IdRange(in_mapping.map(_.start).min, in_mapping.map(_.end).max) + assert((read_src_range.start == 0) && isPow2(read_src_range.end)) + val write_src_range = read_src_range.shift(read_src_range.size) + val visibilities = seq.flatMap(_.masters.flatMap(_.visibility)) + val vis_min = visibilities.map(_.base).min + val vis_max = visibilities.map{ x => x.base + x.mask }.max + val vis_mask = vis_max - vis_min + require(isPow2(vis_mask + 1) || vis_mask == -1) + println(f"combined visibilities of splitter memory node clients: ${vis_min}, ${vis_mask}") + + seq(0).v1copy( + echoFields = BundleField.union(seq.flatMap(_.echoFields)), + requestFields = BundleField.union(seq.flatMap(_.requestFields)), + responseKeys = seq.flatMap(_.responseKeys).distinct, + minLatency = seq.map(_.minLatency).min, + clients = Seq( + TLMasterParameters.v1( + name = s"${name}_read_client", + sourceId = read_src_range, + visibility = Seq(AddressSet(vis_min, vis_mask)), + supportsProbe = TransferSizes.mincover(seq.map(_.anyEmitClaims.get)), + supportsGet = TransferSizes.mincover(seq.map(_.anyEmitClaims.get)), + supportsPutFull = TransferSizes.none, + supportsPutPartial = TransferSizes.none + ), + TLMasterParameters.v1( + name = s"${name}_write_client", + sourceId = write_src_range, + visibility = Seq(AddressSet(vis_min, vis_mask)), + supportsProbe = TransferSizes.mincover( + seq.map(_.anyEmitClaims.putFull) ++seq.map(_.anyEmitClaims.putPartial)), + supportsGet = TransferSizes.none, + supportsPutFull = TransferSizes.mincover(seq.map(_.anyEmitClaims.putFull)), + supportsPutPartial = TransferSizes.mincover(seq.map(_.anyEmitClaims.putPartial)) + ) + ) + ) + }, + managerFn = { seq => + println(seq.flatMap(_.slaves.map(_.supports))) + // val fifoIdFactory = TLXbar.relabeler() + seq(0).v1copy( + responseFields = BundleField.union(seq.flatMap(_.responseFields)), + requestKeys = seq.flatMap(_.requestKeys).distinct, + minLatency = seq.map(_.minLatency).min, + endSinkId = TLXbar.mapOutputIds(seq).map(_.end).max, + managers = Seq(TLSlaveParameters.v2( + name = Some(s"${name}_manager"), + address = AddressSet.unify(seq.flatMap(_.slaves.flatMap(_.address))), + supports = TLMasterToSlaveTransferSizes( + get = TransferSizes.mincover(seq.flatMap(_.slaves.map(_.supportsGet))), + putFull = TransferSizes.mincover(seq.flatMap(_.slaves.map(_.supportsPutFull))), + putPartial = TransferSizes.mincover(seq.flatMap(_.slaves.map(_.supportsPutPartial))) + ), + fifoId = Some(0), + )) + ) + } + ) + + lazy val module = new LazyModuleImp(this) { + val u_out = node.out + val u_in = node.in + assert(u_out.length == 2) + println(f"gemmini unified memory node has ${u_in.length} incoming client(s)") + + val r_out = u_out.head + val w_out = u_out.last + + val in_src = TLXbar.mapInputIds(u_in.map(_._2.client)) + val in_src_size = in_src.map(_.end).max + assert(isPow2(in_src_size)) // should be checked already, but just to be sure + + // arbitrate all reads into one read while assigning source prefix, same for write + val a_arbiter_in = (u_in zip in_src).map { case ((in_node, _), src_range) => + val in_r: DecoupledIO[TLBundleA] = + WireDefault(0.U.asTypeOf(Decoupled(new TLBundleA(in_node.a.bits.params.copy( + sourceBits = log2Up(in_src_size) + 1 + ))))) + val in_w: DecoupledIO[TLBundleA] = WireDefault(0.U.asTypeOf(in_r.cloneType)) + + val req_is_read = in_node.a.bits.opcode === TLMessages.Get + + (Seq(in_r.bits.user, in_r.bits.address, in_r.bits.opcode, in_r.bits.size, + in_r.bits.mask, in_r.bits.param, in_r.bits.data) + zip Seq(in_node.a.bits.user, in_node.a.bits.address, in_node.a.bits.opcode, in_node.a.bits.size, + in_node.a.bits.mask, in_node.a.bits.param, in_node.a.bits.data)) + .foreach { case (x, y) => x := y } + in_r.bits.source := in_node.a.bits.source | src_range.start.U | Mux(req_is_read, 0.U, in_src_size.U) + in_w.bits := in_r.bits + + in_r.valid := in_node.a.valid && req_is_read + in_w.valid := in_node.a.valid && !req_is_read + in_node.a.ready := Mux(req_is_read, in_r.ready, in_w.ready) + + (in_r, in_w) + } + // we cannot use round robin because it might reorder requests, even from the same client + val (a_arbiter_in_r_nodes, a_arbiter_in_w_nodes) = a_arbiter_in.unzip + TLArbiter.lowest(r_out._2, r_out._1.a, a_arbiter_in_r_nodes:_*) + TLArbiter.lowest(w_out._2, w_out._1.a, a_arbiter_in_w_nodes:_*) + + def trim(id: UInt, size: Int): UInt = if (size <= 1) 0.U else id(log2Ceil(size)-1, 0) // from Xbar + // for each unified mem node client, arbitrate read/write responses on d channel + (u_in zip in_src).zipWithIndex.foreach { case (((in_node, in_edge), src_range), i) => + // assign d channel back based on source, invalid if source prefix mismatch + val resp = Seq(r_out._1.d, w_out._1.d) + val source_match = resp.zipWithIndex.map { case (r, i) => + (r.bits.source(r.bits.source.getWidth - 1) === i.U(1.W)) && // MSB indicates read(0)/write(1) + src_range.contains(trim(r.bits.source, in_src_size)) + } + val d_arbiter_in = resp.map(r => WireDefault( + 0.U.asTypeOf(Decoupled(new TLBundleD(r.bits.params.copy( + sourceBits = in_node.d.bits.source.getWidth, + sizeBits = in_node.d.bits.size.getWidth + )))) + )) + + (d_arbiter_in lazyZip resp lazyZip source_match).foreach { case (arb_in: DecoupledIO[TLBundleD], r, sm) => + (Seq(arb_in.bits.user, arb_in.bits.opcode, arb_in.bits.data, arb_in.bits.param, + arb_in.bits.sink, arb_in.bits.denied, arb_in.bits.corrupt) + zip Seq(r.bits.user, r.bits.opcode, r.bits.data, r.bits.param, + r.bits.sink, r.bits.denied, r.bits.corrupt)) + .foreach { case (x, y) => x := y } + arb_in.bits.source := trim(r.bits.source, 1 << in_node.d.bits.source.getWidth) // we can trim b/c isPow2(prefix) + arb_in.bits.size := trim(r.bits.size, 1 << in_node.d.bits.size.getWidth) // FIXME: check truncation + + arb_in.valid := r.valid && sm + r.ready := arb_in.ready + } + + TLArbiter.robin(in_edge, in_node.d, d_arbiter_in:_*) + } + } +} + +object RWSplitterNode { + def apply()(implicit p: Parameters, valName: ValName, sourceInfo: SourceInfo): TLNexusNode = { + LazyModule(new RWSplitterNode(name = valName.name)).node + } + + def apply(name: String)(implicit p: Parameters, valName: ValName, sourceInfo: SourceInfo): TLNexusNode = { + LazyModule(new RWSplitterNode(name = name)).node + } +} diff --git a/src/main/scala/radiance/tile/RadianceCluster.scala b/src/main/scala/radiance/tile/RadianceCluster.scala index 180253d..502ae87 100644 --- a/src/main/scala/radiance/tile/RadianceCluster.scala +++ b/src/main/scala/radiance/tile/RadianceCluster.scala @@ -13,6 +13,7 @@ import freechips.rocketchip.tilelink._ import freechips.rocketchip.util.BundleField import gemmini._ import org.chipsalliance.cde.config.Parameters +import radiance.memory.RWSplitterNode case class RadianceClusterParams( val clusterId: Int, @@ -77,78 +78,24 @@ class RadianceCluster ( val smem_banks = gemminiConfig.sp_banks val smem_subbanks = 1 + val splitter_node = RWSplitterNode() + unified_mem_read_node :=* TLWidthWidget(spad_data_len) :=* gemmini.spad_read_nodes - // unified_mem_read_node :=* TLWidthWidget(acc_data_len) :=* acc_read_nodes unified_mem_write_node :=* TLWidthWidget(spad_data_len) :=* gemmini.spad_write_nodes - // unified_mem_write_node :=* TLWidthWidget(acc_data_len) :=* acc_write_nodes unified_mem_write_node := gemmini.spad.spad_writer.node // this is the dma write node + // unified_mem_read_node :=* TLWidthWidget(acc_data_len) :=* acc_read_nodes + // unified_mem_write_node :=* TLWidthWidget(acc_data_len) :=* acc_write_nodes - // this node accepts both read and write requests, - // splits & arbitrates them into one client node per type of operation - // it keeps the read and write channels fully separate to allow parallel processing - val unified_mem_node = TLNexusNode( - clientFn = { seq => - val in_mapping = TLXbar.mapInputIds(seq) - val read_src_range = IdRange(in_mapping.map(_.start).min, in_mapping.map(_.end).max) - assert((read_src_range.start == 0) && isPow2(read_src_range.end)) - val write_src_range = read_src_range.shift(read_src_range.size) - val visibilities = seq.flatMap(_.masters.flatMap(_.visibility)) - val vis_min = visibilities.map(_.base).min - val vis_max = visibilities.map{ x => x.base + x.mask }.max - val vis_mask = vis_max - vis_min - require(isPow2(vis_mask + 1) || vis_mask == -1) - println(f"combined visibilities of unified memory node clients: ${vis_min}, ${vis_mask}") + // assert(splitter_node.in.map(_._2.slave.slaves.flatMap(_.supports.get))) - seq(0).v1copy( - echoFields = BundleField.union(seq.flatMap(_.echoFields)), - requestFields = BundleField.union(seq.flatMap(_.requestFields)), - responseKeys = seq.flatMap(_.responseKeys).distinct, - minLatency = seq.map(_.minLatency).min, - clients = Seq( - TLMasterParameters.v1( - name = "unified_mem_read_client", - sourceId = read_src_range, - visibility = Seq(AddressSet(vis_min, vis_mask)), - supportsProbe = TransferSizes.mincover(seq.map(_.anyEmitClaims.get)), - supportsGet = TransferSizes.mincover(seq.map(_.anyEmitClaims.get)), - supportsPutFull = TransferSizes.none, - supportsPutPartial = TransferSizes.none - ), - TLMasterParameters.v1( - name = "unified_mem_write_client", - sourceId = write_src_range, - visibility = Seq(AddressSet(vis_min, vis_mask)), - supportsProbe = TransferSizes.mincover( - seq.map(_.anyEmitClaims.putFull) ++seq.map(_.anyEmitClaims.putPartial)), - supportsGet = TransferSizes.none, - supportsPutFull = TransferSizes.mincover(seq.map(_.anyEmitClaims.putFull)), - supportsPutPartial = TransferSizes.mincover(seq.map(_.anyEmitClaims.putPartial)) - ) - ) - ) - }, - managerFn = { seq => - // val fifoIdFactory = TLXbar.relabeler() - seq(0).v1copy( - responseFields = BundleField.union(seq.flatMap(_.responseFields)), - requestKeys = seq.flatMap(_.requestKeys).distinct, - minLatency = seq.map(_.minLatency).min, - endSinkId = TLXbar.mapOutputIds(seq).map(_.end).max, - managers = Seq(TLSlaveParameters.v2( - name = Some(f"unified_mem_manager"), - address = Seq(AddressSet(gemmini.spad_base, smem_depth * smem_width * smem_banks - 1)), - supports = TLMasterToSlaveTransferSizes( - get = TransferSizes(1, smem_width), - putFull = TransferSizes(1, smem_width), - putPartial = TransferSizes(1, smem_width)), - fifoId = Some(0) - )) - ) - } - ) + /* address = Seq(AddressSet(gemmini.spad_base, smem_depth * smem_width * smem_banks - 1)), + supports = TLMasterToSlaveTransferSizes( + get = TransferSizes(1, smem_width), + putFull = TransferSizes(1, smem_width), + putPartial = TransferSizes(1, smem_width)),*/ - unified_mem_read_node := TLWidthWidget(spad_data_len) := unified_mem_node - unified_mem_write_node := TLWidthWidget(spad_data_len) := unified_mem_node + unified_mem_read_node := TLWidthWidget(spad_data_len) := splitter_node + unified_mem_write_node := TLWidthWidget(spad_data_len) := splitter_node val stride_by_word = false // collection of read and write managers for each sram (sub)bank @@ -206,7 +153,7 @@ class RadianceCluster ( // connect tile smem nodes to xbar, and xbar to banks // val smem_xbar = TLXbar() - unified_mem_node :=* TLWidthWidget(4) :=* clbus.outwardNode + splitter_node :=* TLWidthWidget(4) :=* clbus.outwardNode gemminiTile.slaveNode :=* TLWidthWidget(4) :=* clbus.outwardNode // printf and perf counter buffer FIXME: make configurable TLRAM(AddressSet(x"ff004000", numCores * 0x200 - 1)) := TLFragmenter(4, 4) := clbus.outwardNode @@ -345,83 +292,7 @@ class RadianceClusterModuleImp(outer: RadianceCluster) extends ClusterModuleImp( } } - def connectUnifiedMemNode: Unit = { - val u_out = outer.unified_mem_node.out - val u_in = outer.unified_mem_node.in - assert(u_out.length == 2) - println(f"gemmini unified memory node has ${u_in.length} incoming client(s)") - - val r_out = u_out.head - val w_out = u_out.last - - val in_src = TLXbar.mapInputIds(u_in.map(_._2.client)) - val in_src_size = in_src.map(_.end).max - assert(isPow2(in_src_size)) // should be checked already, but just to be sure - - // arbitrate all reads into one read while assigning source prefix, same for write - val a_arbiter_in = (u_in zip in_src).map { case ((in_node, _), src_range) => - val in_r: DecoupledIO[TLBundleA] = - WireDefault(0.U.asTypeOf(Decoupled(new TLBundleA(in_node.a.bits.params.copy( - sourceBits = log2Up(in_src_size) + 1 - ))))) - val in_w: DecoupledIO[TLBundleA] = WireDefault(0.U.asTypeOf(in_r.cloneType)) - - val req_is_read = in_node.a.bits.opcode === TLMessages.Get - - (Seq(in_r.bits.user, in_r.bits.address, in_r.bits.opcode, in_r.bits.size, - in_r.bits.mask, in_r.bits.param, in_r.bits.data) - zip Seq(in_node.a.bits.user, in_node.a.bits.address, in_node.a.bits.opcode, in_node.a.bits.size, - in_node.a.bits.mask, in_node.a.bits.param, in_node.a.bits.data)) - .foreach { case (x, y) => x := y } - in_r.bits.source := in_node.a.bits.source | src_range.start.U | Mux(req_is_read, 0.U, in_src_size.U) - in_w.bits := in_r.bits - - in_r.valid := in_node.a.valid && req_is_read - in_w.valid := in_node.a.valid && !req_is_read - in_node.a.ready := Mux(req_is_read, in_r.ready, in_w.ready) - - (in_r, in_w) - } - // we cannot use round robin because it might reorder requests, even from the same client - val (a_arbiter_in_r_nodes, a_arbiter_in_w_nodes) = a_arbiter_in.unzip - TLArbiter.lowest(r_out._2, r_out._1.a, a_arbiter_in_r_nodes:_*) - TLArbiter.lowest(w_out._2, w_out._1.a, a_arbiter_in_w_nodes:_*) - - def trim(id: UInt, size: Int): UInt = if (size <= 1) 0.U else id(log2Ceil(size)-1, 0) // from Xbar - // for each unified mem node client, arbitrate read/write responses on d channel - (u_in zip in_src).zipWithIndex.foreach { case (((in_node, in_edge), src_range), i) => - // assign d channel back based on source, invalid if source prefix mismatch - val resp = Seq(r_out._1.d, w_out._1.d) - val source_match = resp.zipWithIndex.map { case (r, i) => - (r.bits.source(r.bits.source.getWidth - 1) === i.U(1.W)) && // MSB indicates read(0)/write(1) - src_range.contains(trim(r.bits.source, in_src_size)) - } - val d_arbiter_in = resp.map(r => WireDefault( - 0.U.asTypeOf(Decoupled(new TLBundleD(r.bits.params.copy( - sourceBits = in_node.d.bits.source.getWidth, - sizeBits = in_node.d.bits.size.getWidth - )))) - )) - - (d_arbiter_in lazyZip resp lazyZip source_match).foreach { case (arb_in, r, sm) => - (Seq(arb_in.bits.user, arb_in.bits.opcode, arb_in.bits.data, arb_in.bits.param, - arb_in.bits.sink, arb_in.bits.denied, arb_in.bits.corrupt) - zip Seq(r.bits.user, r.bits.opcode, r.bits.data, r.bits.param, - r.bits.sink, r.bits.denied, r.bits.corrupt)) - .foreach { case (x, y) => x := y } - arb_in.bits.source := trim(r.bits.source, 1 << in_node.d.bits.source.getWidth) // we can trim b/c isPow2(prefix) - arb_in.bits.size := trim(r.bits.size, 1 << in_node.d.bits.size.getWidth) // FIXME: check truncation - - arb_in.valid := r.valid && sm - r.ready := arb_in.ready - } - - TLArbiter.robin(in_edge, in_node.d, d_arbiter_in:_*) - } - } - makeSmemBanks - connectUnifiedMemNode println(s"======== barrierSlaveNode: ${outer.barrierSlaveNode.in(0)._2.barrierIdBits}") }