make unified memory node modular
This commit is contained in:
164
src/main/scala/radiance/memory/RWSplitterNode.scala
Normal file
164
src/main/scala/radiance/memory/RWSplitterNode.scala
Normal file
@@ -0,0 +1,164 @@
|
||||
package radiance.memory
|
||||
|
||||
import chisel3._
|
||||
import chisel3.experimental.SourceInfo
|
||||
import chisel3.util._
|
||||
import freechips.rocketchip.diplomacy._
|
||||
import freechips.rocketchip.tilelink._
|
||||
import freechips.rocketchip.util.BundleField
|
||||
import org.chipsalliance.cde.config.Parameters
|
||||
|
||||
|
||||
class RWSplitterNode(name: String = "rw_splitter")(implicit p: Parameters) extends LazyModule {
|
||||
// this node accepts both read and write requests,
|
||||
// splits & arbitrates them into one client node per type of operation;
|
||||
// there will be N incoming edges, two outgoing edges, with two N:1 muxes;
|
||||
// it keeps the read and write channels fully separate to allow parallel processing.
|
||||
val node = TLNexusNode(
|
||||
clientFn = { seq =>
|
||||
val in_mapping = TLXbar.mapInputIds(seq)
|
||||
val read_src_range = IdRange(in_mapping.map(_.start).min, in_mapping.map(_.end).max)
|
||||
assert((read_src_range.start == 0) && isPow2(read_src_range.end))
|
||||
val write_src_range = read_src_range.shift(read_src_range.size)
|
||||
val visibilities = seq.flatMap(_.masters.flatMap(_.visibility))
|
||||
val vis_min = visibilities.map(_.base).min
|
||||
val vis_max = visibilities.map{ x => x.base + x.mask }.max
|
||||
val vis_mask = vis_max - vis_min
|
||||
require(isPow2(vis_mask + 1) || vis_mask == -1)
|
||||
println(f"combined visibilities of splitter memory node clients: ${vis_min}, ${vis_mask}")
|
||||
|
||||
seq(0).v1copy(
|
||||
echoFields = BundleField.union(seq.flatMap(_.echoFields)),
|
||||
requestFields = BundleField.union(seq.flatMap(_.requestFields)),
|
||||
responseKeys = seq.flatMap(_.responseKeys).distinct,
|
||||
minLatency = seq.map(_.minLatency).min,
|
||||
clients = Seq(
|
||||
TLMasterParameters.v1(
|
||||
name = s"${name}_read_client",
|
||||
sourceId = read_src_range,
|
||||
visibility = Seq(AddressSet(vis_min, vis_mask)),
|
||||
supportsProbe = TransferSizes.mincover(seq.map(_.anyEmitClaims.get)),
|
||||
supportsGet = TransferSizes.mincover(seq.map(_.anyEmitClaims.get)),
|
||||
supportsPutFull = TransferSizes.none,
|
||||
supportsPutPartial = TransferSizes.none
|
||||
),
|
||||
TLMasterParameters.v1(
|
||||
name = s"${name}_write_client",
|
||||
sourceId = write_src_range,
|
||||
visibility = Seq(AddressSet(vis_min, vis_mask)),
|
||||
supportsProbe = TransferSizes.mincover(
|
||||
seq.map(_.anyEmitClaims.putFull) ++seq.map(_.anyEmitClaims.putPartial)),
|
||||
supportsGet = TransferSizes.none,
|
||||
supportsPutFull = TransferSizes.mincover(seq.map(_.anyEmitClaims.putFull)),
|
||||
supportsPutPartial = TransferSizes.mincover(seq.map(_.anyEmitClaims.putPartial))
|
||||
)
|
||||
)
|
||||
)
|
||||
},
|
||||
managerFn = { seq =>
|
||||
println(seq.flatMap(_.slaves.map(_.supports)))
|
||||
// val fifoIdFactory = TLXbar.relabeler()
|
||||
seq(0).v1copy(
|
||||
responseFields = BundleField.union(seq.flatMap(_.responseFields)),
|
||||
requestKeys = seq.flatMap(_.requestKeys).distinct,
|
||||
minLatency = seq.map(_.minLatency).min,
|
||||
endSinkId = TLXbar.mapOutputIds(seq).map(_.end).max,
|
||||
managers = Seq(TLSlaveParameters.v2(
|
||||
name = Some(s"${name}_manager"),
|
||||
address = AddressSet.unify(seq.flatMap(_.slaves.flatMap(_.address))),
|
||||
supports = TLMasterToSlaveTransferSizes(
|
||||
get = TransferSizes.mincover(seq.flatMap(_.slaves.map(_.supportsGet))),
|
||||
putFull = TransferSizes.mincover(seq.flatMap(_.slaves.map(_.supportsPutFull))),
|
||||
putPartial = TransferSizes.mincover(seq.flatMap(_.slaves.map(_.supportsPutPartial)))
|
||||
),
|
||||
fifoId = Some(0),
|
||||
))
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
lazy val module = new LazyModuleImp(this) {
|
||||
val u_out = node.out
|
||||
val u_in = node.in
|
||||
assert(u_out.length == 2)
|
||||
println(f"gemmini unified memory node has ${u_in.length} incoming client(s)")
|
||||
|
||||
val r_out = u_out.head
|
||||
val w_out = u_out.last
|
||||
|
||||
val in_src = TLXbar.mapInputIds(u_in.map(_._2.client))
|
||||
val in_src_size = in_src.map(_.end).max
|
||||
assert(isPow2(in_src_size)) // should be checked already, but just to be sure
|
||||
|
||||
// arbitrate all reads into one read while assigning source prefix, same for write
|
||||
val a_arbiter_in = (u_in zip in_src).map { case ((in_node, _), src_range) =>
|
||||
val in_r: DecoupledIO[TLBundleA] =
|
||||
WireDefault(0.U.asTypeOf(Decoupled(new TLBundleA(in_node.a.bits.params.copy(
|
||||
sourceBits = log2Up(in_src_size) + 1
|
||||
)))))
|
||||
val in_w: DecoupledIO[TLBundleA] = WireDefault(0.U.asTypeOf(in_r.cloneType))
|
||||
|
||||
val req_is_read = in_node.a.bits.opcode === TLMessages.Get
|
||||
|
||||
(Seq(in_r.bits.user, in_r.bits.address, in_r.bits.opcode, in_r.bits.size,
|
||||
in_r.bits.mask, in_r.bits.param, in_r.bits.data)
|
||||
zip Seq(in_node.a.bits.user, in_node.a.bits.address, in_node.a.bits.opcode, in_node.a.bits.size,
|
||||
in_node.a.bits.mask, in_node.a.bits.param, in_node.a.bits.data))
|
||||
.foreach { case (x, y) => x := y }
|
||||
in_r.bits.source := in_node.a.bits.source | src_range.start.U | Mux(req_is_read, 0.U, in_src_size.U)
|
||||
in_w.bits := in_r.bits
|
||||
|
||||
in_r.valid := in_node.a.valid && req_is_read
|
||||
in_w.valid := in_node.a.valid && !req_is_read
|
||||
in_node.a.ready := Mux(req_is_read, in_r.ready, in_w.ready)
|
||||
|
||||
(in_r, in_w)
|
||||
}
|
||||
// we cannot use round robin because it might reorder requests, even from the same client
|
||||
val (a_arbiter_in_r_nodes, a_arbiter_in_w_nodes) = a_arbiter_in.unzip
|
||||
TLArbiter.lowest(r_out._2, r_out._1.a, a_arbiter_in_r_nodes:_*)
|
||||
TLArbiter.lowest(w_out._2, w_out._1.a, a_arbiter_in_w_nodes:_*)
|
||||
|
||||
def trim(id: UInt, size: Int): UInt = if (size <= 1) 0.U else id(log2Ceil(size)-1, 0) // from Xbar
|
||||
// for each unified mem node client, arbitrate read/write responses on d channel
|
||||
(u_in zip in_src).zipWithIndex.foreach { case (((in_node, in_edge), src_range), i) =>
|
||||
// assign d channel back based on source, invalid if source prefix mismatch
|
||||
val resp = Seq(r_out._1.d, w_out._1.d)
|
||||
val source_match = resp.zipWithIndex.map { case (r, i) =>
|
||||
(r.bits.source(r.bits.source.getWidth - 1) === i.U(1.W)) && // MSB indicates read(0)/write(1)
|
||||
src_range.contains(trim(r.bits.source, in_src_size))
|
||||
}
|
||||
val d_arbiter_in = resp.map(r => WireDefault(
|
||||
0.U.asTypeOf(Decoupled(new TLBundleD(r.bits.params.copy(
|
||||
sourceBits = in_node.d.bits.source.getWidth,
|
||||
sizeBits = in_node.d.bits.size.getWidth
|
||||
))))
|
||||
))
|
||||
|
||||
(d_arbiter_in lazyZip resp lazyZip source_match).foreach { case (arb_in: DecoupledIO[TLBundleD], r, sm) =>
|
||||
(Seq(arb_in.bits.user, arb_in.bits.opcode, arb_in.bits.data, arb_in.bits.param,
|
||||
arb_in.bits.sink, arb_in.bits.denied, arb_in.bits.corrupt)
|
||||
zip Seq(r.bits.user, r.bits.opcode, r.bits.data, r.bits.param,
|
||||
r.bits.sink, r.bits.denied, r.bits.corrupt))
|
||||
.foreach { case (x, y) => x := y }
|
||||
arb_in.bits.source := trim(r.bits.source, 1 << in_node.d.bits.source.getWidth) // we can trim b/c isPow2(prefix)
|
||||
arb_in.bits.size := trim(r.bits.size, 1 << in_node.d.bits.size.getWidth) // FIXME: check truncation
|
||||
|
||||
arb_in.valid := r.valid && sm
|
||||
r.ready := arb_in.ready
|
||||
}
|
||||
|
||||
TLArbiter.robin(in_edge, in_node.d, d_arbiter_in:_*)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
object RWSplitterNode {
|
||||
def apply()(implicit p: Parameters, valName: ValName, sourceInfo: SourceInfo): TLNexusNode = {
|
||||
LazyModule(new RWSplitterNode(name = valName.name)).node
|
||||
}
|
||||
|
||||
def apply(name: String)(implicit p: Parameters, valName: ValName, sourceInfo: SourceInfo): TLNexusNode = {
|
||||
LazyModule(new RWSplitterNode(name = name)).node
|
||||
}
|
||||
}
|
||||
@@ -13,6 +13,7 @@ import freechips.rocketchip.tilelink._
|
||||
import freechips.rocketchip.util.BundleField
|
||||
import gemmini._
|
||||
import org.chipsalliance.cde.config.Parameters
|
||||
import radiance.memory.RWSplitterNode
|
||||
|
||||
case class RadianceClusterParams(
|
||||
val clusterId: Int,
|
||||
@@ -77,78 +78,24 @@ class RadianceCluster (
|
||||
val smem_banks = gemminiConfig.sp_banks
|
||||
val smem_subbanks = 1
|
||||
|
||||
val splitter_node = RWSplitterNode()
|
||||
|
||||
unified_mem_read_node :=* TLWidthWidget(spad_data_len) :=* gemmini.spad_read_nodes
|
||||
// unified_mem_read_node :=* TLWidthWidget(acc_data_len) :=* acc_read_nodes
|
||||
unified_mem_write_node :=* TLWidthWidget(spad_data_len) :=* gemmini.spad_write_nodes
|
||||
// unified_mem_write_node :=* TLWidthWidget(acc_data_len) :=* acc_write_nodes
|
||||
unified_mem_write_node := gemmini.spad.spad_writer.node // this is the dma write node
|
||||
// unified_mem_read_node :=* TLWidthWidget(acc_data_len) :=* acc_read_nodes
|
||||
// unified_mem_write_node :=* TLWidthWidget(acc_data_len) :=* acc_write_nodes
|
||||
|
||||
// this node accepts both read and write requests,
|
||||
// splits & arbitrates them into one client node per type of operation
|
||||
// it keeps the read and write channels fully separate to allow parallel processing
|
||||
val unified_mem_node = TLNexusNode(
|
||||
clientFn = { seq =>
|
||||
val in_mapping = TLXbar.mapInputIds(seq)
|
||||
val read_src_range = IdRange(in_mapping.map(_.start).min, in_mapping.map(_.end).max)
|
||||
assert((read_src_range.start == 0) && isPow2(read_src_range.end))
|
||||
val write_src_range = read_src_range.shift(read_src_range.size)
|
||||
val visibilities = seq.flatMap(_.masters.flatMap(_.visibility))
|
||||
val vis_min = visibilities.map(_.base).min
|
||||
val vis_max = visibilities.map{ x => x.base + x.mask }.max
|
||||
val vis_mask = vis_max - vis_min
|
||||
require(isPow2(vis_mask + 1) || vis_mask == -1)
|
||||
println(f"combined visibilities of unified memory node clients: ${vis_min}, ${vis_mask}")
|
||||
// assert(splitter_node.in.map(_._2.slave.slaves.flatMap(_.supports.get)))
|
||||
|
||||
seq(0).v1copy(
|
||||
echoFields = BundleField.union(seq.flatMap(_.echoFields)),
|
||||
requestFields = BundleField.union(seq.flatMap(_.requestFields)),
|
||||
responseKeys = seq.flatMap(_.responseKeys).distinct,
|
||||
minLatency = seq.map(_.minLatency).min,
|
||||
clients = Seq(
|
||||
TLMasterParameters.v1(
|
||||
name = "unified_mem_read_client",
|
||||
sourceId = read_src_range,
|
||||
visibility = Seq(AddressSet(vis_min, vis_mask)),
|
||||
supportsProbe = TransferSizes.mincover(seq.map(_.anyEmitClaims.get)),
|
||||
supportsGet = TransferSizes.mincover(seq.map(_.anyEmitClaims.get)),
|
||||
supportsPutFull = TransferSizes.none,
|
||||
supportsPutPartial = TransferSizes.none
|
||||
),
|
||||
TLMasterParameters.v1(
|
||||
name = "unified_mem_write_client",
|
||||
sourceId = write_src_range,
|
||||
visibility = Seq(AddressSet(vis_min, vis_mask)),
|
||||
supportsProbe = TransferSizes.mincover(
|
||||
seq.map(_.anyEmitClaims.putFull) ++seq.map(_.anyEmitClaims.putPartial)),
|
||||
supportsGet = TransferSizes.none,
|
||||
supportsPutFull = TransferSizes.mincover(seq.map(_.anyEmitClaims.putFull)),
|
||||
supportsPutPartial = TransferSizes.mincover(seq.map(_.anyEmitClaims.putPartial))
|
||||
)
|
||||
)
|
||||
)
|
||||
},
|
||||
managerFn = { seq =>
|
||||
// val fifoIdFactory = TLXbar.relabeler()
|
||||
seq(0).v1copy(
|
||||
responseFields = BundleField.union(seq.flatMap(_.responseFields)),
|
||||
requestKeys = seq.flatMap(_.requestKeys).distinct,
|
||||
minLatency = seq.map(_.minLatency).min,
|
||||
endSinkId = TLXbar.mapOutputIds(seq).map(_.end).max,
|
||||
managers = Seq(TLSlaveParameters.v2(
|
||||
name = Some(f"unified_mem_manager"),
|
||||
address = Seq(AddressSet(gemmini.spad_base, smem_depth * smem_width * smem_banks - 1)),
|
||||
supports = TLMasterToSlaveTransferSizes(
|
||||
get = TransferSizes(1, smem_width),
|
||||
putFull = TransferSizes(1, smem_width),
|
||||
putPartial = TransferSizes(1, smem_width)),
|
||||
fifoId = Some(0)
|
||||
))
|
||||
)
|
||||
}
|
||||
)
|
||||
/* address = Seq(AddressSet(gemmini.spad_base, smem_depth * smem_width * smem_banks - 1)),
|
||||
supports = TLMasterToSlaveTransferSizes(
|
||||
get = TransferSizes(1, smem_width),
|
||||
putFull = TransferSizes(1, smem_width),
|
||||
putPartial = TransferSizes(1, smem_width)),*/
|
||||
|
||||
unified_mem_read_node := TLWidthWidget(spad_data_len) := unified_mem_node
|
||||
unified_mem_write_node := TLWidthWidget(spad_data_len) := unified_mem_node
|
||||
unified_mem_read_node := TLWidthWidget(spad_data_len) := splitter_node
|
||||
unified_mem_write_node := TLWidthWidget(spad_data_len) := splitter_node
|
||||
|
||||
val stride_by_word = false
|
||||
// collection of read and write managers for each sram (sub)bank
|
||||
@@ -206,7 +153,7 @@ class RadianceCluster (
|
||||
|
||||
// connect tile smem nodes to xbar, and xbar to banks
|
||||
// val smem_xbar = TLXbar()
|
||||
unified_mem_node :=* TLWidthWidget(4) :=* clbus.outwardNode
|
||||
splitter_node :=* TLWidthWidget(4) :=* clbus.outwardNode
|
||||
gemminiTile.slaveNode :=* TLWidthWidget(4) :=* clbus.outwardNode
|
||||
// printf and perf counter buffer FIXME: make configurable
|
||||
TLRAM(AddressSet(x"ff004000", numCores * 0x200 - 1)) := TLFragmenter(4, 4) := clbus.outwardNode
|
||||
@@ -345,83 +292,7 @@ class RadianceClusterModuleImp(outer: RadianceCluster) extends ClusterModuleImp(
|
||||
}
|
||||
}
|
||||
|
||||
def connectUnifiedMemNode: Unit = {
|
||||
val u_out = outer.unified_mem_node.out
|
||||
val u_in = outer.unified_mem_node.in
|
||||
assert(u_out.length == 2)
|
||||
println(f"gemmini unified memory node has ${u_in.length} incoming client(s)")
|
||||
|
||||
val r_out = u_out.head
|
||||
val w_out = u_out.last
|
||||
|
||||
val in_src = TLXbar.mapInputIds(u_in.map(_._2.client))
|
||||
val in_src_size = in_src.map(_.end).max
|
||||
assert(isPow2(in_src_size)) // should be checked already, but just to be sure
|
||||
|
||||
// arbitrate all reads into one read while assigning source prefix, same for write
|
||||
val a_arbiter_in = (u_in zip in_src).map { case ((in_node, _), src_range) =>
|
||||
val in_r: DecoupledIO[TLBundleA] =
|
||||
WireDefault(0.U.asTypeOf(Decoupled(new TLBundleA(in_node.a.bits.params.copy(
|
||||
sourceBits = log2Up(in_src_size) + 1
|
||||
)))))
|
||||
val in_w: DecoupledIO[TLBundleA] = WireDefault(0.U.asTypeOf(in_r.cloneType))
|
||||
|
||||
val req_is_read = in_node.a.bits.opcode === TLMessages.Get
|
||||
|
||||
(Seq(in_r.bits.user, in_r.bits.address, in_r.bits.opcode, in_r.bits.size,
|
||||
in_r.bits.mask, in_r.bits.param, in_r.bits.data)
|
||||
zip Seq(in_node.a.bits.user, in_node.a.bits.address, in_node.a.bits.opcode, in_node.a.bits.size,
|
||||
in_node.a.bits.mask, in_node.a.bits.param, in_node.a.bits.data))
|
||||
.foreach { case (x, y) => x := y }
|
||||
in_r.bits.source := in_node.a.bits.source | src_range.start.U | Mux(req_is_read, 0.U, in_src_size.U)
|
||||
in_w.bits := in_r.bits
|
||||
|
||||
in_r.valid := in_node.a.valid && req_is_read
|
||||
in_w.valid := in_node.a.valid && !req_is_read
|
||||
in_node.a.ready := Mux(req_is_read, in_r.ready, in_w.ready)
|
||||
|
||||
(in_r, in_w)
|
||||
}
|
||||
// we cannot use round robin because it might reorder requests, even from the same client
|
||||
val (a_arbiter_in_r_nodes, a_arbiter_in_w_nodes) = a_arbiter_in.unzip
|
||||
TLArbiter.lowest(r_out._2, r_out._1.a, a_arbiter_in_r_nodes:_*)
|
||||
TLArbiter.lowest(w_out._2, w_out._1.a, a_arbiter_in_w_nodes:_*)
|
||||
|
||||
def trim(id: UInt, size: Int): UInt = if (size <= 1) 0.U else id(log2Ceil(size)-1, 0) // from Xbar
|
||||
// for each unified mem node client, arbitrate read/write responses on d channel
|
||||
(u_in zip in_src).zipWithIndex.foreach { case (((in_node, in_edge), src_range), i) =>
|
||||
// assign d channel back based on source, invalid if source prefix mismatch
|
||||
val resp = Seq(r_out._1.d, w_out._1.d)
|
||||
val source_match = resp.zipWithIndex.map { case (r, i) =>
|
||||
(r.bits.source(r.bits.source.getWidth - 1) === i.U(1.W)) && // MSB indicates read(0)/write(1)
|
||||
src_range.contains(trim(r.bits.source, in_src_size))
|
||||
}
|
||||
val d_arbiter_in = resp.map(r => WireDefault(
|
||||
0.U.asTypeOf(Decoupled(new TLBundleD(r.bits.params.copy(
|
||||
sourceBits = in_node.d.bits.source.getWidth,
|
||||
sizeBits = in_node.d.bits.size.getWidth
|
||||
))))
|
||||
))
|
||||
|
||||
(d_arbiter_in lazyZip resp lazyZip source_match).foreach { case (arb_in, r, sm) =>
|
||||
(Seq(arb_in.bits.user, arb_in.bits.opcode, arb_in.bits.data, arb_in.bits.param,
|
||||
arb_in.bits.sink, arb_in.bits.denied, arb_in.bits.corrupt)
|
||||
zip Seq(r.bits.user, r.bits.opcode, r.bits.data, r.bits.param,
|
||||
r.bits.sink, r.bits.denied, r.bits.corrupt))
|
||||
.foreach { case (x, y) => x := y }
|
||||
arb_in.bits.source := trim(r.bits.source, 1 << in_node.d.bits.source.getWidth) // we can trim b/c isPow2(prefix)
|
||||
arb_in.bits.size := trim(r.bits.size, 1 << in_node.d.bits.size.getWidth) // FIXME: check truncation
|
||||
|
||||
arb_in.valid := r.valid && sm
|
||||
r.ready := arb_in.ready
|
||||
}
|
||||
|
||||
TLArbiter.robin(in_edge, in_node.d, d_arbiter_in:_*)
|
||||
}
|
||||
}
|
||||
|
||||
makeSmemBanks
|
||||
connectUnifiedMemNode
|
||||
|
||||
println(s"======== barrierSlaveNode: ${outer.barrierSlaveNode.in(0)._2.barrierIdBits}")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user