actually support large smem subbanks

This commit is contained in:
Richard Yan
2024-09-10 23:24:02 -07:00
parent 13142ab0b9
commit f1a1b77828
2 changed files with 103 additions and 43 deletions

View File

@@ -20,20 +20,23 @@ class AlignFilterNode(filters: Seq[AddressSet])(implicit p: Parameters) extends
val master = seq.head.masters.head val master = seq.head.masters.head
// TODO: to implement multiple filters, source Id mapping needs to be redone // TODO: to implement multiple filters, source Id mapping needs to be redone
assert(filters.length == 1, "multiple filters currently not supported") // assert(filters.length == 1, "multiple filters currently not supported")
val in_mapping = TLXbar.mapInputIds(Seq.fill(filters.length + 1)(seq.head))
val unaligned_src_range = in_mapping.last
seq.head.v1copy( seq.head.v1copy(
clients = filters.map { filter => clients = filters.zipWithIndex.map { case (filter, i) =>
master.v2copy( master.v2copy(
name = s"${name}_filter_aligned", name = s"${name}_filter_aligned",
sourceId = master.sourceId, sourceId = in_mapping(i),
visibility = Seq(filter), visibility = Seq(filter),
emits = seq.map(_.anyEmitClaims).reduce(_ mincover _) emits = seq.map(_.anyEmitClaims).reduce(_ mincover _)
) )
} ++ Seq( } ++ Seq(
master.v2copy( master.v2copy(
name = s"${name}_filter_unaligned", name = s"${name}_filter_unaligned",
sourceId = master.sourceId.shift(master.sourceId.size), sourceId = unaligned_src_range,
visibility = Seq(AddressSet.everything), visibility = Seq(AddressSet.everything),
emits = seq.map(_.anyEmitClaims).reduce(_ mincover _) emits = seq.map(_.anyEmitClaims).reduce(_ mincover _)
), ),
@@ -81,14 +84,18 @@ class AlignFilterNode(filters: Seq[AddressSet])(implicit p: Parameters) extends
val a = node.out.init.map(_._1) val a = node.out.init.map(_._1)
val ua = node.out.last._1 val ua = node.out.last._1
val in_mapping = TLXbar.mapInputIds(Seq.fill(filters.length + 1)(node.in.head._2.client))
val unaligned_src = in_mapping.last
val a_aligned = filters.map(_.contains(c.a.bits.address)) val a_aligned = filters.map(_.contains(c.a.bits.address))
(a zip a_aligned).foreach { case (a, aligned) => (a zip a_aligned).zipWithIndex.foreach { case ((a, aligned), idx) =>
a.a.bits := c.a.bits a.a.bits := c.a.bits
a.a.bits.source := in_mapping(idx).start.U + c.a.bits.source
a.a.valid := c.a.valid && aligned a.a.valid := c.a.valid && aligned
} }
ua.a.bits := c.a.bits ua.a.bits := c.a.bits
ua.a.bits.source := c.a.bits.source + (1.U << c.a.bits.source.getWidth) ua.a.bits.source := unaligned_src.start.U + c.a.bits.source // + (1.U << c.a.bits.source.getWidth)
ua.a.valid := c.a.valid && !a_aligned.reduce(_ || _) ua.a.valid := c.a.valid && !a_aligned.reduce(_ || _)
c.a.ready := MuxCase(ua.a.ready, (a zip a_aligned).map { case (a, aligned) => aligned -> a.a.ready }) c.a.ready := MuxCase(ua.a.ready, (a zip a_aligned).map { case (a, aligned) => aligned -> a.a.ready })

View File

@@ -223,52 +223,105 @@ class RadianceCluster (
val num_lsu_lanes = radianceTiles.head.numLsuLanes val num_lsu_lanes = radianceTiles.head.numLsuLanes
val num_lane_dupes = Math.max(1, smem_subbanks / num_lsu_lanes) val num_lane_dupes = Math.max(1, smem_subbanks / num_lsu_lanes)
val filter_range = smem_subbanks / num_lane_dupes val filter_range = Math.min(smem_subbanks, num_lsu_lanes)
println(s"num_lsu_lanes ${num_lsu_lanes} num_lane_dupes ${num_lane_dupes} filter_range ${filter_range}")
// (subbank, source, rw) // (subbank, sources, aligned) = rw node
val filter_nodes: Seq[Seq[(TLNode, TLNode)]] = Seq.tabulate(num_lane_dupes) { did => val (f_aligned, f_unaligned) = if (num_lsu_lanes >= smem_subbanks) {
Seq.tabulate(filter_range) { wid => val filter_nodes: Seq[Seq[(TLNode, TLNode)]] = Seq.tabulate(num_lane_dupes) { did =>
val true_wid = did * filter_range + wid Seq.tabulate(filter_range) { wid =>
val address = AddressSet(smem_base + wordSize * true_wid, (smem_size - 1) - (smem_subbanks - 1) * wordSize) val true_wid = did * filter_range + wid
val address = AddressSet(smem_base + wordSize * true_wid, (smem_size - 1) - (smem_subbanks - 1) * wordSize)
radiance_smem_fanout.grouped(num_lsu_lanes).toList.zipWithIndex.flatMap { case (lanes, cid) => radiance_smem_fanout.grouped(num_lsu_lanes).toList.zipWithIndex.flatMap { case (lanes, cid) =>
lanes.zipWithIndex.flatMap { case (lane, lid) => lanes.zipWithIndex.flatMap { case (lane, lid) =>
if ((lid % filter_range) == wid) { if ((lid % filter_range) == wid) {
println(f"c${cid}_l${lid} connected to d${did}w${wid}") println(f"c${cid}_l${lid} connected to d${did}w${wid}")
val filter_node = AlignFilterNode(Seq(address))(p, ValName(s"filter_l${lid}_w${true_wid}"), info) val filter_node = AlignFilterNode(Seq(address))(p, ValName(s"filter_l${lid}_w${true_wid}"), info)
DisableMonitors { implicit p => filter_node := lane } DisableMonitors { implicit p => filter_node := lane }
// Seq((aligned splitter, unaligned splitter)) // Seq((aligned splitter, unaligned splitter))
Seq(( Seq((
connect_one(filter_node, () => connect_one(filter_node, () =>
RWSplitterNode(address, s"aligned_splitter_c${cid}_l${lid}_w${true_wid}")), RWSplitterNode(address, s"aligned_splitter_c${cid}_l${lid}_w${true_wid}")),
connect_one(filter_node, () => connect_one(filter_node, () =>
RWSplitterNode(AddressSet.everything, s"unaligned_splitter_c${cid}_l${lid}_w${true_wid}")) RWSplitterNode(AddressSet.everything, s"unaligned_splitter_c${cid}_l${lid}"))
)) ))
} else Seq() } else Seq()
}
} }
} }
} }.flatten
}.flatten
val f_aligned = Seq.fill(2)(filter_nodes.map(_.map(_._1).map(connect_xbar_name(_, Some("rad_aligned")))))
val f_unaligned = if (serialize_unaligned) { val f_aligned = Seq.fill(2)(filter_nodes.map(_.map(_._1).map(connect_xbar_name(_, Some("rad_aligned")))))
Seq.fill(2) { val f_unaligned = if (serialize_unaligned) {
val serialized_node = TLEphemeralNode() Seq.fill(2) {
val serialized_in_xbar = LazyModule(new TLXbar()) val serialized_node = TLEphemeralNode()
val serialized_out_xbar = LazyModule(new TLXbar()) val serialized_in_xbar = LazyModule(new TLXbar())
serialized_in_xbar.suggestName("unaligned_serialized_in_xbar") val serialized_out_xbar = LazyModule(new TLXbar())
serialized_out_xbar.suggestName("unaligned_serialized_out_xbar") serialized_in_xbar.suggestName("unaligned_serialized_in_xbar")
guard_monitors { implicit p => serialized_out_xbar.suggestName("unaligned_serialized_out_xbar")
filter_nodes.foreach(_.map(_._2).foreach(serialized_in_xbar.node := _)) guard_monitors { implicit p =>
serialized_node := serialized_in_xbar.node filter_nodes.foreach(_.map(_._2).foreach(serialized_in_xbar.node := _))
serialized_out_xbar.node := serialized_node serialized_node := serialized_in_xbar.node
serialized_out_xbar.node := serialized_node
}
Seq(serialized_out_xbar.node)
} }
Seq(serialized_out_xbar.node) } else {
Seq.fill(2)(filter_nodes.flatMap(_.map(_._2).map(connect_xbar)))
} }
} else { (f_aligned, f_unaligned)
Seq.fill(2)(filter_nodes.flatMap(_.map(_._2).map(connect_xbar))) } else { // aligned: (subbanks, cores) = rw node
// (lanes, cores) = filter_node
val filter_nodes = Seq.tabulate(filter_range) { wid =>
val addresses = Seq.tabulate(num_lane_dupes) { did =>
AddressSet(smem_base + (did * filter_range + wid) * wordSize,
(smem_size - 1) - (smem_subbanks - 1) * wordSize)
}
radiance_smem_fanout.grouped(num_lsu_lanes).toSeq.zipWithIndex.map { case (lanes, cid) =>
val lane = lanes(wid)
val filter_node = AlignFilterNode(addresses)(p, ValName(s"filter_c${cid}_w${wid}"), info)
guard_monitors { implicit p =>
filter_node := lane
}
filter_node
}
}
val f_aligned_rw = Seq.tabulate(num_lane_dupes) { did =>
filter_nodes.zipWithIndex.map { case (cores, lid) =>
cores.zipWithIndex.map { case (fn, cid) =>
val address = AddressSet(smem_base + (did * filter_range + lid) * wordSize,
(smem_size - 1) - (smem_subbanks - 1) * wordSize)
connect_one(fn, () => RWSplitterNode(address, s"aligned_split_c${cid}_l${lid}_d${did}"))
}
}
}.flatten
val f_unaligned_rw = filter_nodes.zipWithIndex.flatMap { case (cores, lid) =>
cores.zipWithIndex.map { case (fn, cid) =>
connect_one(fn, () => RWSplitterNode(AddressSet.everything, s"unaligned_split_c${cid}_l${lid}"))
}
}
val f_aligned = Seq.fill(2)(f_aligned_rw.map(_.map(connect_xbar_name(_, Some("rad_aligned")))))
val f_unaligned = if (serialize_unaligned) {
Seq.fill(2) {
val serialized_node = TLEphemeralNode()
val serialized_in_xbar = TLXbar(nameSuffix = Some("unaligned_ser_in"))
val serialized_out_xbar = TLXbar(nameSuffix = Some("unaligned_ser_out"))
guard_monitors { implicit p =>
f_unaligned_rw.foreach(serialized_in_xbar := _)
serialized_node := serialized_in_xbar
serialized_out_xbar := serialized_node
}
Seq(serialized_out_xbar)
}
} else {
Seq.fill(2)(f_unaligned_rw.map(connect_xbar))
}
(f_aligned, f_unaligned)
} }
val uniform_r_nodes: Seq[Seq[Seq[TLNexusNode]]] = spad_read_nodes.map { rb => val uniform_r_nodes: Seq[Seq[Seq[TLNexusNode]]] = spad_read_nodes.map { rb =>
(rb zip f_aligned.head).map { case (rw, fa) => rw ++ fa } (rb zip f_aligned.head).map { case (rw, fa) => rw ++ fa }
} }