fix ext policy xbar, add rectangular tile support

This commit is contained in:
Richard Yan
2024-09-08 13:21:31 -07:00
parent 378b3531d4
commit afc6ba7eca
6 changed files with 66 additions and 30 deletions

View File

@@ -22,7 +22,11 @@ class XbarWithExtPolicy(nameSuffix: Option[String] = None)
val policySlaveNode = ExtPolicySlaveNode() val policySlaveNode = ExtPolicySlaveNode()
class ImplChild extends Impl { class ImplChild extends Impl {
val policy: TLArbiter.Policy = (_, _, _) => policySlaveNode.in.head._1 println(s"policy slave node input width ${policySlaveNode.in.head._1.getWidth}")
val policy: TLArbiter.Policy = (width, _, _) => {
println(s"evaluated policy width: ${width}")
policySlaveNode.in.head._1
}
// val wide_bundle = TLBundleParameters.union((node.in ++ node.out).map(_._2.bundle)) // val wide_bundle = TLBundleParameters.union((node.in ++ node.out).map(_._2.bundle))
// override def desiredName = (Seq("TLXbar") ++ nameSuffix ++ Seq(s"i${node.in.size}_o${node.out.size}_${wide_bundle.shortName}")).mkString("_") // override def desiredName = (Seq("TLXbar") ++ nameSuffix ++ Seq(s"i${node.in.size}_o${node.out.size}_${wide_bundle.shortName}")).mkString("_")
TLXbar.circuit(policy, node.in, node.out) TLXbar.circuit(policy, node.in, node.out)

View File

@@ -93,7 +93,7 @@ object RadianceGemminiDataType extends Enumeration {
} }
class WithRadianceGemmini(location: HierarchicalLocation, crossing: RocketCrossingParams, class WithRadianceGemmini(location: HierarchicalLocation, crossing: RocketCrossingParams,
dim: Int, accSizeInKB: Int, tileSize: Int, dim: Int, accSizeInKB: Int, tileSize: Either[(Int, Int, Int), Int],
dataType: RadianceGemminiDataType.Type, dmaBytes: Int) extends Config((site, _, up) => { dataType: RadianceGemminiDataType.Type, dmaBytes: Int) extends Config((site, _, up) => {
case TilesLocated(`location`) => { case TilesLocated(`location`) => {
val prev = up(TilesLocated(`location`)) val prev = up(TilesLocated(`location`))
@@ -120,7 +120,7 @@ class WithRadianceGemmini(location: HierarchicalLocation, crossing: RocketCrossi
)), )),
mvin_scale_args = Some(ScaleArguments( mvin_scale_args = Some(ScaleArguments(
(t: Float, u: Float) => t * u, (t: Float, u: Float) => t * u,
1, Float(5, 11), -1, identity = "1.0", c_str="((x) * (scale))" 1, Float(5, 11), -1, identity = "0x3c00", c_str="((x) * (scale))"
)), )),
mvin_scale_acc_args = None, mvin_scale_acc_args = None,
has_training_convs = false, has_training_convs = false,
@@ -164,7 +164,7 @@ class WithRadianceGemmini(location: HierarchicalLocation, crossing: RocketCrossi
} }
case NumTiles => up(NumTiles) + 1 case NumTiles => up(NumTiles) + 1
}) { }) {
def this(location: HierarchicalLocation = InSubsystem, dim: Int, accSizeInKB: Int, tileSize: Int, def this(location: HierarchicalLocation, dim: Int, accSizeInKB: Int, tileSize: Either[(Int, Int, Int), Int],
dataType: RadianceGemminiDataType.Type = RadianceGemminiDataType.FP32, dmaBytes: Int = 256) = dataType: RadianceGemminiDataType.Type = RadianceGemminiDataType.FP32, dmaBytes: Int = 256) =
this(location, RocketCrossingParams( this(location, RocketCrossingParams(
master = HierarchicalElementMasterPortParams.locationDefault(location), master = HierarchicalElementMasterPortParams.locationDefault(location),
@@ -174,6 +174,13 @@ class WithRadianceGemmini(location: HierarchicalLocation, crossing: RocketCrossi
case InCluster(clusterId) => CCBUS(clusterId) case InCluster(clusterId) => CCBUS(clusterId)
} }
), dim, accSizeInKB, tileSize, dataType, dmaBytes) ), dim, accSizeInKB, tileSize, dataType, dmaBytes)
def this(location: HierarchicalLocation, dim: Int, accSizeInKB: Int, tileSize: Int) =
this(location, dim, accSizeInKB, Right(tileSize))
def this(location: HierarchicalLocation, dim: Int, accSizeInKB: Int, tileSize: (Int, Int, Int),
dataType: RadianceGemminiDataType.Type) =
this(location, dim, accSizeInKB, Left(tileSize), dataType)
} }
class WithRadianceSharedMem(address: BigInt, class WithRadianceSharedMem(address: BigInt,

View File

@@ -63,7 +63,7 @@ case class GemminiCoreParams(
case class GemminiTileParams( case class GemminiTileParams(
tileId: Int = 0, tileId: Int = 0,
gemminiConfig: GemminiArrayConfig[Float, Float, Float], gemminiConfig: GemminiArrayConfig[Float, Float, Float],
tileSize: Int = 4, tileSize: Either[(Int, Int, Int), Int] = Right(4),
slaveAddress: BigInt slaveAddress: BigInt
) extends InstantiableTileParams[GemminiTile] { ) extends InstantiableTileParams[GemminiTile] {
def instantiate(crossing: HierarchicalElementCrossingParamsLike, lookup: LookupByHartIdImpl)( def instantiate(crossing: HierarchicalElementCrossingParamsLike, lookup: LookupByHartIdImpl)(
@@ -188,15 +188,27 @@ class GemminiTileModuleImp(outer: GemminiTile) extends BaseTileModuleImp(outer)
ciscInst := 0.U.asTypeOf(ciscInstT) ciscInst := 0.U.asTypeOf(ciscInstT)
val tileSize = outer.gemminiParams.tileSize val (tileSizeM, tileSizeN, tileSizeK) = outer.gemminiParams.tileSize match {
val (boundsInst, spadQuartile) = (ciscInstT.Lit(_.inst -> 0x1220b07b.U, _.rs1 -> 0.U, case Left(v: (Int, Int, Int)) => v
_.rs2 -> (tileSize | (tileSize << 16) | (BigInt(tileSize) << 32)).U), case Right(v: Int) => (v, v, v)
tileSize * tileSize * outer.gemminiParams.gemminiConfig.DIM) }
println(s"gemmini cisc initialized with DIM=${outer.gemminiParams.gemminiConfig.DIM}, tileSize=${tileSize}") val config = outer.gemminiParams.gemminiConfig
println(f"boundsInst=${boundsInst.litValue}%x, tileSize=${tileSize}, quartile=${spadQuartile}") val spadQuartile = config.sp_bank_entries * config.sp_banks / 4
// TODO: as a temporary hack, bit 7 of the cisc opcode
// TODO: will force the tile size to be a square base on M.
val rectBoundsInst = ciscInstT.Lit(_.inst -> 0x1220b07b.U, _.rs1 -> 0.U,
_.rs2 -> (tileSizeM | (tileSizeN << 16) | (BigInt(tileSizeK) << 32)).U)
val squareBoundsInst = ciscInstT.Lit(_.inst -> 0x1220b07b.U, _.rs1 -> 0.U,
_.rs2 -> (tileSizeM | (tileSizeM << 16) | (BigInt(tileSizeM) << 32)).U)
val boundsInst = Mux(ciscId(7), squareBoundsInst, rectBoundsInst)
println(s"gemmini cisc initialized with DIM=${config.DIM}, tileSize=${tileSizeM},${tileSizeN},${tileSizeK}")
println(f"boundsInst=${rectBoundsInst.litValue}%x, quartile=${spadQuartile}")
when (ciscValid) { when (ciscValid) {
assert(!accSlave.cmd.valid, "cisc state machine already busy") assert(!accSlave.cmd.valid, "cisc state machine already busy")
switch (ciscId) { switch (ciscId(6, 0)) {
is (0.U) { is (0.U) {
ciscInst := microcodeEntry(Seq(boundsInst, ciscInst := microcodeEntry(Seq(boundsInst,
ciscInstT.Lit(_.inst -> 0x3020b07b.U, _.rs1 -> 0.U, _.rs2 -> (spadQuartile * 3).U), // set A, B address ciscInstT.Lit(_.inst -> 0x3020b07b.U, _.rs1 -> 0.U, _.rs2 -> (spadQuartile * 3).U), // set A, B address

View File

@@ -301,8 +301,8 @@ class RadianceCluster (
r := subbank_r_xbar.node r := subbank_r_xbar.node
w := subbank_w_xbar.node w := subbank_w_xbar.node
val ur_xbar = XbarWithExtPolicy(Some("ur")) val ur_xbar = XbarWithExtPolicy(Some(s"ur_b${bid}_w${wid}"))
val uw_xbar = XbarWithExtPolicy(Some("uw")) val uw_xbar = XbarWithExtPolicy(Some(s"uw_b${bid}_w${wid}"))
val r_policy_node = ExtPolicyMasterNode(uniform_r_nodes(bid)(wid).length) val r_policy_node = ExtPolicyMasterNode(uniform_r_nodes(bid)(wid).length)
val w_policy_node = ExtPolicyMasterNode(uniform_w_nodes(bid)(wid).length) val w_policy_node = ExtPolicyMasterNode(uniform_w_nodes(bid)(wid).length)
ur_xbar.policySlaveNode := r_policy_node ur_xbar.policySlaveNode := r_policy_node
@@ -543,9 +543,7 @@ class RadianceClusterModuleImp(outer: RadianceCluster) extends ClusterModuleImp(
dontTouch(smemWriteCounter) dontTouch(smemWriteCounter)
if (outer.stride_by_word) { if (outer.stride_by_word) {
val (uniform_r_nodes, uniform_w_nodes) = (outer.uniform_r_nodes.get, outer.uniform_w_nodes.get)
val uniform_fires = Seq.fill(2)(VecInit.fill(outer.smem_banks)(VecInit.fill(outer.smem_subbanks)(false.B))) val uniform_fires = Seq.fill(2)(VecInit.fill(outer.smem_banks)(VecInit.fill(outer.smem_subbanks)(false.B)))
val word_selects_1h = Seq.fill(2)(VecInit.fill(outer.smem_banks)(0.U))
outer.smem_bank_mgrs.grouped(outer.smem_subbanks).zipWithIndex.foreach { case (bank_mgrs, bid) => outer.smem_bank_mgrs.grouped(outer.smem_subbanks).zipWithIndex.foreach { case (bank_mgrs, bid) =>
// TODO move this loop out // TODO move this loop out
@@ -554,10 +552,13 @@ class RadianceClusterModuleImp(outer: RadianceCluster) extends ClusterModuleImp(
// VecInit(words_with_same_idx.toSeq).asUInt.orR // VecInit(words_with_same_idx.toSeq).asUInt.orR
// }.toSeq).asUInt // }.toSeq).asUInt
// } // }
val Seq(valid_r_sources, valid_w_sources) = outer.uniform_nodes_in.map { banks => val word_selects_1h = Seq(
banks(bid).map(_.map(_.in.head._1.a.valid)).transpose.map { words_in_idx => Wire(UInt(outer.uniform_nodes_in.head(bid).head.length.W)).suggestName(s"ws_r_b${bid}"),
Wire(UInt(outer.uniform_nodes_in.last(bid).head.length.W)).suggestName(s"ws_w_b${bid}"))
val Seq(valid_r_sources, valid_w_sources) = outer.uniform_nodes_in.zipWithIndex.map { case (banks, rw) =>
VecInit(banks(bid).map(_.map(_.in.head._1.a.valid)).transpose.map { words_in_idx =>
VecInit(words_in_idx.toSeq).asUInt.orR VecInit(words_in_idx.toSeq).asUInt.orR
} }.toSeq).asUInt.suggestName(s"valid_sources_rw${rw}_b${bid}")
} }
assert(bank_mgrs.flatten.size == 2/* read and write */ * outer.smem_subbanks) assert(bank_mgrs.flatten.size == 2/* read and write */ * outer.smem_subbanks)
@@ -603,18 +604,29 @@ class RadianceClusterModuleImp(outer: RadianceCluster) extends ClusterModuleImp(
uf(bid)(wid) := n(bid)(wid).in.head._1.a.fire uf(bid)(wid) := n(bid)(wid).in.head._1.a.fire
} }
} }
// use round robin to decide uniform select
println(f"valid r_sources ${valid_r_sources.length}, ${valid_r_sources}")
(word_selects_1h zip Seq(valid_r_sources, valid_w_sources)).zipWithIndex.foreach { case ((ws, vs), rw) => (word_selects_1h zip Seq(valid_r_sources, valid_w_sources)).zipWithIndex.foreach { case ((ws, vs), rw) =>
ws(bid) := TLArbiter.roundRobin(vs.length, VecInit(vs.toSeq).asUInt, uniform_fires(rw)(bid).asUInt.orR) ws := TLArbiter.roundRobin(vs.getWidth, vs, uniform_fires(rw)(bid).asUInt.orR)
}
// mask valid into xbar to prevent triggering assertion
(word_selects_1h zip outer.uniform_nodes_in).foreach { case (ws, ui) =>
ui(bid).foreach { sources =>
val in_valid = sources.map(_.in.head._1.a.valid)
val out_valid = sources.map(_.out.head._1.a.valid)
(in_valid lazyZip out_valid lazyZip ws.asBools).foreach { case (iv, ov, sel) =>
ov := iv && sel // only present output valid if input is selected
}
}
}
(outer.uniform_policy_nodes zip word_selects_1h).zipWithIndex.foreach { case ((nodes_bw, ws), rw) =>
nodes_bw(bid).foreach { policy =>
println(s"policy out ${policy.out.head._1.getWidth}, word select ${ws.getWidth}")
policy.out.head._1 := ws
}
} }
} }
(outer.uniform_policy_nodes zip word_selects_1h).zipWithIndex.foreach { case ((nodes_bw, ws_b), rw) =>
(nodes_bw zip ws_b).zipWithIndex.foreach { case ((nodes_w, ws), bid) =>
nodes_w.foreach { _.out.head._1 := ws }
}
}
} else { } else {
outer.smem_bank_mgrs.foreach { case Seq(r, w) => outer.smem_bank_mgrs.foreach { case Seq(r, w) =>
val mem_depth = outer.smem_depth val mem_depth = outer.smem_depth

View File

@@ -8,6 +8,7 @@ import chisel3.util._
import chisel3.experimental._ import chisel3.experimental._
import org.chipsalliance.cde.config.Parameters import org.chipsalliance.cde.config.Parameters
import freechips.rocketchip.tile._ import freechips.rocketchip.tile._
import radiance.subsystem.RadianceGemminiDataType
class VortexBundleA( class VortexBundleA(
tagWidth: Int, tagWidth: Int,
@@ -332,8 +333,8 @@ class Vortex(tile: RadianceTile)(implicit p: Parameters)
// tensor core // tensor core
// this module is referenced from inside the Verilog RTL of the core // this module is referenced from inside the Verilog RTL of the core
// pipeline. // pipeline.
addResource("/vsrc/TensorDotProductUnitFP32.sv") // addResource("/vsrc/TensorDotProductUnitFP32.sv")
// addResource("/vsrc/TensorDotProductUnit.sv") addResource("/vsrc/TensorDotProductUnit.sv")
// fpnew // fpnew
// compile order matters; package definitions (ex. fpnew_pkg) should be // compile order matters; package definitions (ex. fpnew_pkg) should be