make unified memory node modular

This commit is contained in:
Richard Yan
2024-03-26 23:13:30 -07:00
parent cb0a4c526e
commit 9fb861a873
2 changed files with 178 additions and 143 deletions

View File

@@ -0,0 +1,164 @@
package radiance.memory
import chisel3._
import chisel3.experimental.SourceInfo
import chisel3.util._
import freechips.rocketchip.diplomacy._
import freechips.rocketchip.tilelink._
import freechips.rocketchip.util.BundleField
import org.chipsalliance.cde.config.Parameters
class RWSplitterNode(name: String = "rw_splitter")(implicit p: Parameters) extends LazyModule {
// this node accepts both read and write requests,
// splits & arbitrates them into one client node per type of operation;
// there will be N incoming edges, two outgoing edges, with two N:1 muxes;
// it keeps the read and write channels fully separate to allow parallel processing.
val node = TLNexusNode(
clientFn = { seq =>
val in_mapping = TLXbar.mapInputIds(seq)
val read_src_range = IdRange(in_mapping.map(_.start).min, in_mapping.map(_.end).max)
assert((read_src_range.start == 0) && isPow2(read_src_range.end))
val write_src_range = read_src_range.shift(read_src_range.size)
val visibilities = seq.flatMap(_.masters.flatMap(_.visibility))
val vis_min = visibilities.map(_.base).min
val vis_max = visibilities.map{ x => x.base + x.mask }.max
val vis_mask = vis_max - vis_min
require(isPow2(vis_mask + 1) || vis_mask == -1)
println(f"combined visibilities of splitter memory node clients: ${vis_min}, ${vis_mask}")
seq(0).v1copy(
echoFields = BundleField.union(seq.flatMap(_.echoFields)),
requestFields = BundleField.union(seq.flatMap(_.requestFields)),
responseKeys = seq.flatMap(_.responseKeys).distinct,
minLatency = seq.map(_.minLatency).min,
clients = Seq(
TLMasterParameters.v1(
name = s"${name}_read_client",
sourceId = read_src_range,
visibility = Seq(AddressSet(vis_min, vis_mask)),
supportsProbe = TransferSizes.mincover(seq.map(_.anyEmitClaims.get)),
supportsGet = TransferSizes.mincover(seq.map(_.anyEmitClaims.get)),
supportsPutFull = TransferSizes.none,
supportsPutPartial = TransferSizes.none
),
TLMasterParameters.v1(
name = s"${name}_write_client",
sourceId = write_src_range,
visibility = Seq(AddressSet(vis_min, vis_mask)),
supportsProbe = TransferSizes.mincover(
seq.map(_.anyEmitClaims.putFull) ++seq.map(_.anyEmitClaims.putPartial)),
supportsGet = TransferSizes.none,
supportsPutFull = TransferSizes.mincover(seq.map(_.anyEmitClaims.putFull)),
supportsPutPartial = TransferSizes.mincover(seq.map(_.anyEmitClaims.putPartial))
)
)
)
},
managerFn = { seq =>
println(seq.flatMap(_.slaves.map(_.supports)))
// val fifoIdFactory = TLXbar.relabeler()
seq(0).v1copy(
responseFields = BundleField.union(seq.flatMap(_.responseFields)),
requestKeys = seq.flatMap(_.requestKeys).distinct,
minLatency = seq.map(_.minLatency).min,
endSinkId = TLXbar.mapOutputIds(seq).map(_.end).max,
managers = Seq(TLSlaveParameters.v2(
name = Some(s"${name}_manager"),
address = AddressSet.unify(seq.flatMap(_.slaves.flatMap(_.address))),
supports = TLMasterToSlaveTransferSizes(
get = TransferSizes.mincover(seq.flatMap(_.slaves.map(_.supportsGet))),
putFull = TransferSizes.mincover(seq.flatMap(_.slaves.map(_.supportsPutFull))),
putPartial = TransferSizes.mincover(seq.flatMap(_.slaves.map(_.supportsPutPartial)))
),
fifoId = Some(0),
))
)
}
)
lazy val module = new LazyModuleImp(this) {
val u_out = node.out
val u_in = node.in
assert(u_out.length == 2)
println(f"gemmini unified memory node has ${u_in.length} incoming client(s)")
val r_out = u_out.head
val w_out = u_out.last
val in_src = TLXbar.mapInputIds(u_in.map(_._2.client))
val in_src_size = in_src.map(_.end).max
assert(isPow2(in_src_size)) // should be checked already, but just to be sure
// arbitrate all reads into one read while assigning source prefix, same for write
val a_arbiter_in = (u_in zip in_src).map { case ((in_node, _), src_range) =>
val in_r: DecoupledIO[TLBundleA] =
WireDefault(0.U.asTypeOf(Decoupled(new TLBundleA(in_node.a.bits.params.copy(
sourceBits = log2Up(in_src_size) + 1
)))))
val in_w: DecoupledIO[TLBundleA] = WireDefault(0.U.asTypeOf(in_r.cloneType))
val req_is_read = in_node.a.bits.opcode === TLMessages.Get
(Seq(in_r.bits.user, in_r.bits.address, in_r.bits.opcode, in_r.bits.size,
in_r.bits.mask, in_r.bits.param, in_r.bits.data)
zip Seq(in_node.a.bits.user, in_node.a.bits.address, in_node.a.bits.opcode, in_node.a.bits.size,
in_node.a.bits.mask, in_node.a.bits.param, in_node.a.bits.data))
.foreach { case (x, y) => x := y }
in_r.bits.source := in_node.a.bits.source | src_range.start.U | Mux(req_is_read, 0.U, in_src_size.U)
in_w.bits := in_r.bits
in_r.valid := in_node.a.valid && req_is_read
in_w.valid := in_node.a.valid && !req_is_read
in_node.a.ready := Mux(req_is_read, in_r.ready, in_w.ready)
(in_r, in_w)
}
// we cannot use round robin because it might reorder requests, even from the same client
val (a_arbiter_in_r_nodes, a_arbiter_in_w_nodes) = a_arbiter_in.unzip
TLArbiter.lowest(r_out._2, r_out._1.a, a_arbiter_in_r_nodes:_*)
TLArbiter.lowest(w_out._2, w_out._1.a, a_arbiter_in_w_nodes:_*)
def trim(id: UInt, size: Int): UInt = if (size <= 1) 0.U else id(log2Ceil(size)-1, 0) // from Xbar
// for each unified mem node client, arbitrate read/write responses on d channel
(u_in zip in_src).zipWithIndex.foreach { case (((in_node, in_edge), src_range), i) =>
// assign d channel back based on source, invalid if source prefix mismatch
val resp = Seq(r_out._1.d, w_out._1.d)
val source_match = resp.zipWithIndex.map { case (r, i) =>
(r.bits.source(r.bits.source.getWidth - 1) === i.U(1.W)) && // MSB indicates read(0)/write(1)
src_range.contains(trim(r.bits.source, in_src_size))
}
val d_arbiter_in = resp.map(r => WireDefault(
0.U.asTypeOf(Decoupled(new TLBundleD(r.bits.params.copy(
sourceBits = in_node.d.bits.source.getWidth,
sizeBits = in_node.d.bits.size.getWidth
))))
))
(d_arbiter_in lazyZip resp lazyZip source_match).foreach { case (arb_in: DecoupledIO[TLBundleD], r, sm) =>
(Seq(arb_in.bits.user, arb_in.bits.opcode, arb_in.bits.data, arb_in.bits.param,
arb_in.bits.sink, arb_in.bits.denied, arb_in.bits.corrupt)
zip Seq(r.bits.user, r.bits.opcode, r.bits.data, r.bits.param,
r.bits.sink, r.bits.denied, r.bits.corrupt))
.foreach { case (x, y) => x := y }
arb_in.bits.source := trim(r.bits.source, 1 << in_node.d.bits.source.getWidth) // we can trim b/c isPow2(prefix)
arb_in.bits.size := trim(r.bits.size, 1 << in_node.d.bits.size.getWidth) // FIXME: check truncation
arb_in.valid := r.valid && sm
r.ready := arb_in.ready
}
TLArbiter.robin(in_edge, in_node.d, d_arbiter_in:_*)
}
}
}
object RWSplitterNode {
def apply()(implicit p: Parameters, valName: ValName, sourceInfo: SourceInfo): TLNexusNode = {
LazyModule(new RWSplitterNode(name = valName.name)).node
}
def apply(name: String)(implicit p: Parameters, valName: ValName, sourceInfo: SourceInfo): TLNexusNode = {
LazyModule(new RWSplitterNode(name = name)).node
}
}

View File

@@ -13,6 +13,7 @@ import freechips.rocketchip.tilelink._
import freechips.rocketchip.util.BundleField
import gemmini._
import org.chipsalliance.cde.config.Parameters
import radiance.memory.RWSplitterNode
case class RadianceClusterParams(
val clusterId: Int,
@@ -77,78 +78,24 @@ class RadianceCluster (
val smem_banks = gemminiConfig.sp_banks
val smem_subbanks = 1
val splitter_node = RWSplitterNode()
unified_mem_read_node :=* TLWidthWidget(spad_data_len) :=* gemmini.spad_read_nodes
// unified_mem_read_node :=* TLWidthWidget(acc_data_len) :=* acc_read_nodes
unified_mem_write_node :=* TLWidthWidget(spad_data_len) :=* gemmini.spad_write_nodes
// unified_mem_write_node :=* TLWidthWidget(acc_data_len) :=* acc_write_nodes
unified_mem_write_node := gemmini.spad.spad_writer.node // this is the dma write node
// unified_mem_read_node :=* TLWidthWidget(acc_data_len) :=* acc_read_nodes
// unified_mem_write_node :=* TLWidthWidget(acc_data_len) :=* acc_write_nodes
// this node accepts both read and write requests,
// splits & arbitrates them into one client node per type of operation
// it keeps the read and write channels fully separate to allow parallel processing
val unified_mem_node = TLNexusNode(
clientFn = { seq =>
val in_mapping = TLXbar.mapInputIds(seq)
val read_src_range = IdRange(in_mapping.map(_.start).min, in_mapping.map(_.end).max)
assert((read_src_range.start == 0) && isPow2(read_src_range.end))
val write_src_range = read_src_range.shift(read_src_range.size)
val visibilities = seq.flatMap(_.masters.flatMap(_.visibility))
val vis_min = visibilities.map(_.base).min
val vis_max = visibilities.map{ x => x.base + x.mask }.max
val vis_mask = vis_max - vis_min
require(isPow2(vis_mask + 1) || vis_mask == -1)
println(f"combined visibilities of unified memory node clients: ${vis_min}, ${vis_mask}")
// assert(splitter_node.in.map(_._2.slave.slaves.flatMap(_.supports.get)))
seq(0).v1copy(
echoFields = BundleField.union(seq.flatMap(_.echoFields)),
requestFields = BundleField.union(seq.flatMap(_.requestFields)),
responseKeys = seq.flatMap(_.responseKeys).distinct,
minLatency = seq.map(_.minLatency).min,
clients = Seq(
TLMasterParameters.v1(
name = "unified_mem_read_client",
sourceId = read_src_range,
visibility = Seq(AddressSet(vis_min, vis_mask)),
supportsProbe = TransferSizes.mincover(seq.map(_.anyEmitClaims.get)),
supportsGet = TransferSizes.mincover(seq.map(_.anyEmitClaims.get)),
supportsPutFull = TransferSizes.none,
supportsPutPartial = TransferSizes.none
),
TLMasterParameters.v1(
name = "unified_mem_write_client",
sourceId = write_src_range,
visibility = Seq(AddressSet(vis_min, vis_mask)),
supportsProbe = TransferSizes.mincover(
seq.map(_.anyEmitClaims.putFull) ++seq.map(_.anyEmitClaims.putPartial)),
supportsGet = TransferSizes.none,
supportsPutFull = TransferSizes.mincover(seq.map(_.anyEmitClaims.putFull)),
supportsPutPartial = TransferSizes.mincover(seq.map(_.anyEmitClaims.putPartial))
)
)
)
},
managerFn = { seq =>
// val fifoIdFactory = TLXbar.relabeler()
seq(0).v1copy(
responseFields = BundleField.union(seq.flatMap(_.responseFields)),
requestKeys = seq.flatMap(_.requestKeys).distinct,
minLatency = seq.map(_.minLatency).min,
endSinkId = TLXbar.mapOutputIds(seq).map(_.end).max,
managers = Seq(TLSlaveParameters.v2(
name = Some(f"unified_mem_manager"),
address = Seq(AddressSet(gemmini.spad_base, smem_depth * smem_width * smem_banks - 1)),
supports = TLMasterToSlaveTransferSizes(
get = TransferSizes(1, smem_width),
putFull = TransferSizes(1, smem_width),
putPartial = TransferSizes(1, smem_width)),
fifoId = Some(0)
))
)
}
)
/* address = Seq(AddressSet(gemmini.spad_base, smem_depth * smem_width * smem_banks - 1)),
supports = TLMasterToSlaveTransferSizes(
get = TransferSizes(1, smem_width),
putFull = TransferSizes(1, smem_width),
putPartial = TransferSizes(1, smem_width)),*/
unified_mem_read_node := TLWidthWidget(spad_data_len) := unified_mem_node
unified_mem_write_node := TLWidthWidget(spad_data_len) := unified_mem_node
unified_mem_read_node := TLWidthWidget(spad_data_len) := splitter_node
unified_mem_write_node := TLWidthWidget(spad_data_len) := splitter_node
val stride_by_word = false
// collection of read and write managers for each sram (sub)bank
@@ -206,7 +153,7 @@ class RadianceCluster (
// connect tile smem nodes to xbar, and xbar to banks
// val smem_xbar = TLXbar()
unified_mem_node :=* TLWidthWidget(4) :=* clbus.outwardNode
splitter_node :=* TLWidthWidget(4) :=* clbus.outwardNode
gemminiTile.slaveNode :=* TLWidthWidget(4) :=* clbus.outwardNode
// printf and perf counter buffer FIXME: make configurable
TLRAM(AddressSet(x"ff004000", numCores * 0x200 - 1)) := TLFragmenter(4, 4) := clbus.outwardNode
@@ -345,83 +292,7 @@ class RadianceClusterModuleImp(outer: RadianceCluster) extends ClusterModuleImp(
}
}
def connectUnifiedMemNode: Unit = {
val u_out = outer.unified_mem_node.out
val u_in = outer.unified_mem_node.in
assert(u_out.length == 2)
println(f"gemmini unified memory node has ${u_in.length} incoming client(s)")
val r_out = u_out.head
val w_out = u_out.last
val in_src = TLXbar.mapInputIds(u_in.map(_._2.client))
val in_src_size = in_src.map(_.end).max
assert(isPow2(in_src_size)) // should be checked already, but just to be sure
// arbitrate all reads into one read while assigning source prefix, same for write
val a_arbiter_in = (u_in zip in_src).map { case ((in_node, _), src_range) =>
val in_r: DecoupledIO[TLBundleA] =
WireDefault(0.U.asTypeOf(Decoupled(new TLBundleA(in_node.a.bits.params.copy(
sourceBits = log2Up(in_src_size) + 1
)))))
val in_w: DecoupledIO[TLBundleA] = WireDefault(0.U.asTypeOf(in_r.cloneType))
val req_is_read = in_node.a.bits.opcode === TLMessages.Get
(Seq(in_r.bits.user, in_r.bits.address, in_r.bits.opcode, in_r.bits.size,
in_r.bits.mask, in_r.bits.param, in_r.bits.data)
zip Seq(in_node.a.bits.user, in_node.a.bits.address, in_node.a.bits.opcode, in_node.a.bits.size,
in_node.a.bits.mask, in_node.a.bits.param, in_node.a.bits.data))
.foreach { case (x, y) => x := y }
in_r.bits.source := in_node.a.bits.source | src_range.start.U | Mux(req_is_read, 0.U, in_src_size.U)
in_w.bits := in_r.bits
in_r.valid := in_node.a.valid && req_is_read
in_w.valid := in_node.a.valid && !req_is_read
in_node.a.ready := Mux(req_is_read, in_r.ready, in_w.ready)
(in_r, in_w)
}
// we cannot use round robin because it might reorder requests, even from the same client
val (a_arbiter_in_r_nodes, a_arbiter_in_w_nodes) = a_arbiter_in.unzip
TLArbiter.lowest(r_out._2, r_out._1.a, a_arbiter_in_r_nodes:_*)
TLArbiter.lowest(w_out._2, w_out._1.a, a_arbiter_in_w_nodes:_*)
def trim(id: UInt, size: Int): UInt = if (size <= 1) 0.U else id(log2Ceil(size)-1, 0) // from Xbar
// for each unified mem node client, arbitrate read/write responses on d channel
(u_in zip in_src).zipWithIndex.foreach { case (((in_node, in_edge), src_range), i) =>
// assign d channel back based on source, invalid if source prefix mismatch
val resp = Seq(r_out._1.d, w_out._1.d)
val source_match = resp.zipWithIndex.map { case (r, i) =>
(r.bits.source(r.bits.source.getWidth - 1) === i.U(1.W)) && // MSB indicates read(0)/write(1)
src_range.contains(trim(r.bits.source, in_src_size))
}
val d_arbiter_in = resp.map(r => WireDefault(
0.U.asTypeOf(Decoupled(new TLBundleD(r.bits.params.copy(
sourceBits = in_node.d.bits.source.getWidth,
sizeBits = in_node.d.bits.size.getWidth
))))
))
(d_arbiter_in lazyZip resp lazyZip source_match).foreach { case (arb_in, r, sm) =>
(Seq(arb_in.bits.user, arb_in.bits.opcode, arb_in.bits.data, arb_in.bits.param,
arb_in.bits.sink, arb_in.bits.denied, arb_in.bits.corrupt)
zip Seq(r.bits.user, r.bits.opcode, r.bits.data, r.bits.param,
r.bits.sink, r.bits.denied, r.bits.corrupt))
.foreach { case (x, y) => x := y }
arb_in.bits.source := trim(r.bits.source, 1 << in_node.d.bits.source.getWidth) // we can trim b/c isPow2(prefix)
arb_in.bits.size := trim(r.bits.size, 1 << in_node.d.bits.size.getWidth) // FIXME: check truncation
arb_in.valid := r.valid && sm
r.ready := arb_in.ready
}
TLArbiter.robin(in_edge, in_node.d, d_arbiter_in:_*)
}
}
makeSmemBanks
connectUnifiedMemNode
println(s"======== barrierSlaveNode: ${outer.barrierSlaveNode.in(0)._2.barrierIdBits}")
}