fix ext policy xbar, add rectangular tile support

This commit is contained in:
Richard Yan
2024-09-08 13:21:31 -07:00
parent 378b3531d4
commit afc6ba7eca
6 changed files with 66 additions and 30 deletions

View File

@@ -22,7 +22,11 @@ class XbarWithExtPolicy(nameSuffix: Option[String] = None)
val policySlaveNode = ExtPolicySlaveNode()
class ImplChild extends Impl {
val policy: TLArbiter.Policy = (_, _, _) => policySlaveNode.in.head._1
println(s"policy slave node input width ${policySlaveNode.in.head._1.getWidth}")
val policy: TLArbiter.Policy = (width, _, _) => {
println(s"evaluated policy width: ${width}")
policySlaveNode.in.head._1
}
// val wide_bundle = TLBundleParameters.union((node.in ++ node.out).map(_._2.bundle))
// override def desiredName = (Seq("TLXbar") ++ nameSuffix ++ Seq(s"i${node.in.size}_o${node.out.size}_${wide_bundle.shortName}")).mkString("_")
TLXbar.circuit(policy, node.in, node.out)

View File

@@ -93,7 +93,7 @@ object RadianceGemminiDataType extends Enumeration {
}
class WithRadianceGemmini(location: HierarchicalLocation, crossing: RocketCrossingParams,
dim: Int, accSizeInKB: Int, tileSize: Int,
dim: Int, accSizeInKB: Int, tileSize: Either[(Int, Int, Int), Int],
dataType: RadianceGemminiDataType.Type, dmaBytes: Int) extends Config((site, _, up) => {
case TilesLocated(`location`) => {
val prev = up(TilesLocated(`location`))
@@ -120,7 +120,7 @@ class WithRadianceGemmini(location: HierarchicalLocation, crossing: RocketCrossi
)),
mvin_scale_args = Some(ScaleArguments(
(t: Float, u: Float) => t * u,
1, Float(5, 11), -1, identity = "1.0", c_str="((x) * (scale))"
1, Float(5, 11), -1, identity = "0x3c00", c_str="((x) * (scale))"
)),
mvin_scale_acc_args = None,
has_training_convs = false,
@@ -164,7 +164,7 @@ class WithRadianceGemmini(location: HierarchicalLocation, crossing: RocketCrossi
}
case NumTiles => up(NumTiles) + 1
}) {
def this(location: HierarchicalLocation = InSubsystem, dim: Int, accSizeInKB: Int, tileSize: Int,
def this(location: HierarchicalLocation, dim: Int, accSizeInKB: Int, tileSize: Either[(Int, Int, Int), Int],
dataType: RadianceGemminiDataType.Type = RadianceGemminiDataType.FP32, dmaBytes: Int = 256) =
this(location, RocketCrossingParams(
master = HierarchicalElementMasterPortParams.locationDefault(location),
@@ -174,6 +174,13 @@ class WithRadianceGemmini(location: HierarchicalLocation, crossing: RocketCrossi
case InCluster(clusterId) => CCBUS(clusterId)
}
), dim, accSizeInKB, tileSize, dataType, dmaBytes)
def this(location: HierarchicalLocation, dim: Int, accSizeInKB: Int, tileSize: Int) =
this(location, dim, accSizeInKB, Right(tileSize))
def this(location: HierarchicalLocation, dim: Int, accSizeInKB: Int, tileSize: (Int, Int, Int),
dataType: RadianceGemminiDataType.Type) =
this(location, dim, accSizeInKB, Left(tileSize), dataType)
}
class WithRadianceSharedMem(address: BigInt,

View File

@@ -63,7 +63,7 @@ case class GemminiCoreParams(
case class GemminiTileParams(
tileId: Int = 0,
gemminiConfig: GemminiArrayConfig[Float, Float, Float],
tileSize: Int = 4,
tileSize: Either[(Int, Int, Int), Int] = Right(4),
slaveAddress: BigInt
) extends InstantiableTileParams[GemminiTile] {
def instantiate(crossing: HierarchicalElementCrossingParamsLike, lookup: LookupByHartIdImpl)(
@@ -188,15 +188,27 @@ class GemminiTileModuleImp(outer: GemminiTile) extends BaseTileModuleImp(outer)
ciscInst := 0.U.asTypeOf(ciscInstT)
val tileSize = outer.gemminiParams.tileSize
val (boundsInst, spadQuartile) = (ciscInstT.Lit(_.inst -> 0x1220b07b.U, _.rs1 -> 0.U,
_.rs2 -> (tileSize | (tileSize << 16) | (BigInt(tileSize) << 32)).U),
tileSize * tileSize * outer.gemminiParams.gemminiConfig.DIM)
println(s"gemmini cisc initialized with DIM=${outer.gemminiParams.gemminiConfig.DIM}, tileSize=${tileSize}")
println(f"boundsInst=${boundsInst.litValue}%x, tileSize=${tileSize}, quartile=${spadQuartile}")
val (tileSizeM, tileSizeN, tileSizeK) = outer.gemminiParams.tileSize match {
case Left(v: (Int, Int, Int)) => v
case Right(v: Int) => (v, v, v)
}
val config = outer.gemminiParams.gemminiConfig
val spadQuartile = config.sp_bank_entries * config.sp_banks / 4
// TODO: as a temporary hack, bit 7 of the cisc opcode
// TODO: will force the tile size to be a square base on M.
val rectBoundsInst = ciscInstT.Lit(_.inst -> 0x1220b07b.U, _.rs1 -> 0.U,
_.rs2 -> (tileSizeM | (tileSizeN << 16) | (BigInt(tileSizeK) << 32)).U)
val squareBoundsInst = ciscInstT.Lit(_.inst -> 0x1220b07b.U, _.rs1 -> 0.U,
_.rs2 -> (tileSizeM | (tileSizeM << 16) | (BigInt(tileSizeM) << 32)).U)
val boundsInst = Mux(ciscId(7), squareBoundsInst, rectBoundsInst)
println(s"gemmini cisc initialized with DIM=${config.DIM}, tileSize=${tileSizeM},${tileSizeN},${tileSizeK}")
println(f"boundsInst=${rectBoundsInst.litValue}%x, quartile=${spadQuartile}")
when (ciscValid) {
assert(!accSlave.cmd.valid, "cisc state machine already busy")
switch (ciscId) {
switch (ciscId(6, 0)) {
is (0.U) {
ciscInst := microcodeEntry(Seq(boundsInst,
ciscInstT.Lit(_.inst -> 0x3020b07b.U, _.rs1 -> 0.U, _.rs2 -> (spadQuartile * 3).U), // set A, B address

View File

@@ -301,8 +301,8 @@ class RadianceCluster (
r := subbank_r_xbar.node
w := subbank_w_xbar.node
val ur_xbar = XbarWithExtPolicy(Some("ur"))
val uw_xbar = XbarWithExtPolicy(Some("uw"))
val ur_xbar = XbarWithExtPolicy(Some(s"ur_b${bid}_w${wid}"))
val uw_xbar = XbarWithExtPolicy(Some(s"uw_b${bid}_w${wid}"))
val r_policy_node = ExtPolicyMasterNode(uniform_r_nodes(bid)(wid).length)
val w_policy_node = ExtPolicyMasterNode(uniform_w_nodes(bid)(wid).length)
ur_xbar.policySlaveNode := r_policy_node
@@ -543,9 +543,7 @@ class RadianceClusterModuleImp(outer: RadianceCluster) extends ClusterModuleImp(
dontTouch(smemWriteCounter)
if (outer.stride_by_word) {
val (uniform_r_nodes, uniform_w_nodes) = (outer.uniform_r_nodes.get, outer.uniform_w_nodes.get)
val uniform_fires = Seq.fill(2)(VecInit.fill(outer.smem_banks)(VecInit.fill(outer.smem_subbanks)(false.B)))
val word_selects_1h = Seq.fill(2)(VecInit.fill(outer.smem_banks)(0.U))
outer.smem_bank_mgrs.grouped(outer.smem_subbanks).zipWithIndex.foreach { case (bank_mgrs, bid) =>
// TODO move this loop out
@@ -554,10 +552,13 @@ class RadianceClusterModuleImp(outer: RadianceCluster) extends ClusterModuleImp(
// VecInit(words_with_same_idx.toSeq).asUInt.orR
// }.toSeq).asUInt
// }
val Seq(valid_r_sources, valid_w_sources) = outer.uniform_nodes_in.map { banks =>
banks(bid).map(_.map(_.in.head._1.a.valid)).transpose.map { words_in_idx =>
val word_selects_1h = Seq(
Wire(UInt(outer.uniform_nodes_in.head(bid).head.length.W)).suggestName(s"ws_r_b${bid}"),
Wire(UInt(outer.uniform_nodes_in.last(bid).head.length.W)).suggestName(s"ws_w_b${bid}"))
val Seq(valid_r_sources, valid_w_sources) = outer.uniform_nodes_in.zipWithIndex.map { case (banks, rw) =>
VecInit(banks(bid).map(_.map(_.in.head._1.a.valid)).transpose.map { words_in_idx =>
VecInit(words_in_idx.toSeq).asUInt.orR
}
}.toSeq).asUInt.suggestName(s"valid_sources_rw${rw}_b${bid}")
}
assert(bank_mgrs.flatten.size == 2/* read and write */ * outer.smem_subbanks)
@@ -603,18 +604,29 @@ class RadianceClusterModuleImp(outer: RadianceCluster) extends ClusterModuleImp(
uf(bid)(wid) := n(bid)(wid).in.head._1.a.fire
}
}
println(f"valid r_sources ${valid_r_sources.length}, ${valid_r_sources}")
// use round robin to decide uniform select
(word_selects_1h zip Seq(valid_r_sources, valid_w_sources)).zipWithIndex.foreach { case ((ws, vs), rw) =>
ws(bid) := TLArbiter.roundRobin(vs.length, VecInit(vs.toSeq).asUInt, uniform_fires(rw)(bid).asUInt.orR)
ws := TLArbiter.roundRobin(vs.getWidth, vs, uniform_fires(rw)(bid).asUInt.orR)
}
// mask valid into xbar to prevent triggering assertion
(word_selects_1h zip outer.uniform_nodes_in).foreach { case (ws, ui) =>
ui(bid).foreach { sources =>
val in_valid = sources.map(_.in.head._1.a.valid)
val out_valid = sources.map(_.out.head._1.a.valid)
(in_valid lazyZip out_valid lazyZip ws.asBools).foreach { case (iv, ov, sel) =>
ov := iv && sel // only present output valid if input is selected
}
}
}
(outer.uniform_policy_nodes zip word_selects_1h).zipWithIndex.foreach { case ((nodes_bw, ws), rw) =>
nodes_bw(bid).foreach { policy =>
println(s"policy out ${policy.out.head._1.getWidth}, word select ${ws.getWidth}")
policy.out.head._1 := ws
}
}
}
(outer.uniform_policy_nodes zip word_selects_1h).zipWithIndex.foreach { case ((nodes_bw, ws_b), rw) =>
(nodes_bw zip ws_b).zipWithIndex.foreach { case ((nodes_w, ws), bid) =>
nodes_w.foreach { _.out.head._1 := ws }
}
}
} else {
outer.smem_bank_mgrs.foreach { case Seq(r, w) =>
val mem_depth = outer.smem_depth

View File

@@ -8,6 +8,7 @@ import chisel3.util._
import chisel3.experimental._
import org.chipsalliance.cde.config.Parameters
import freechips.rocketchip.tile._
import radiance.subsystem.RadianceGemminiDataType
class VortexBundleA(
tagWidth: Int,
@@ -332,8 +333,8 @@ class Vortex(tile: RadianceTile)(implicit p: Parameters)
// tensor core
// this module is referenced from inside the Verilog RTL of the core
// pipeline.
addResource("/vsrc/TensorDotProductUnitFP32.sv")
// addResource("/vsrc/TensorDotProductUnit.sv")
// addResource("/vsrc/TensorDotProductUnitFP32.sv")
addResource("/vsrc/TensorDotProductUnit.sv")
// fpnew
// compile order matters; package definitions (ex. fpnew_pkg) should be