tensor: Reassert initiate.ready as soon as access ready
This commit is contained in:
@@ -69,6 +69,11 @@ class TensorCoreDecoupled(
|
|||||||
val source = UInt(sourceWidth.W)
|
val source = UInt(sourceWidth.W)
|
||||||
val data = UInt(dataWidth.W)
|
val data = UInt(dataWidth.W)
|
||||||
}
|
}
|
||||||
|
class TensorMemTag extends Bundle {
|
||||||
|
val warp = UInt(numWarpBits.W)
|
||||||
|
val set = UInt(setBits.W)
|
||||||
|
val index = UInt(indexBits.W)
|
||||||
|
}
|
||||||
// mem response after translation from TL source to set/step tag
|
// mem response after translation from TL source to set/step tag
|
||||||
class TensorMemRespWithTag(
|
class TensorMemRespWithTag(
|
||||||
dataWidth: Int
|
dataWidth: Int
|
||||||
@@ -77,15 +82,11 @@ class TensorCoreDecoupled(
|
|||||||
val data = UInt(dataWidth.W)
|
val data = UInt(dataWidth.W)
|
||||||
}
|
}
|
||||||
|
|
||||||
// FSM
|
// ===========================================================================
|
||||||
// ---
|
// Access stage
|
||||||
// This drives the overall pipeline of memory requests, dot-product unit
|
// ===========================================================================
|
||||||
// operations and regfile writeback.
|
//
|
||||||
|
// Frontend of the decoupled access/execute pipeline.
|
||||||
val busy = RegInit(false.B)
|
|
||||||
// Holds the warp id the core is currently working on. Note that we only
|
|
||||||
// support one outstanding warp request
|
|
||||||
val warpAccess = RegInit(0.U(numWarpBits.W))
|
|
||||||
|
|
||||||
// sets: k iteration
|
// sets: k iteration
|
||||||
val numSets = (tilingParams.k / tilingParams.kc)
|
val numSets = (tilingParams.k / tilingParams.kc)
|
||||||
@@ -97,39 +98,15 @@ class TensorCoreDecoupled(
|
|||||||
val lastStep = ((1 << stepBits) - 1)
|
val lastStep = ((1 << stepBits) - 1)
|
||||||
def setDone(set: UInt) = (set === lastSet.U)
|
def setDone(set: UInt) = (set === lastSet.U)
|
||||||
def stepDone(step: UInt) = (step === lastStep.U)
|
def stepDone(step: UInt) = (step === lastStep.U)
|
||||||
|
// 'index' is the index of a memory request among the sequence of requests
|
||||||
|
// needed to read a full M-column of A or N-row of B. Its range is [0,m/2)
|
||||||
|
// or [0,n/2), where 2 is the stride can be read in a single request size.
|
||||||
|
require(tilingParams.m == tilingParams.n,
|
||||||
|
"currently only supports square SMEM tile")
|
||||||
|
val numIndices = tilingParams.m / 2/*FIXME:hardcoded?*/
|
||||||
|
val indexBits = log2Ceil(numIndices)
|
||||||
|
val lastIndex = (1 << indexBits) - 1
|
||||||
|
|
||||||
when (io.initiate.fire) {
|
|
||||||
val wid = io.initiate.bits.wid
|
|
||||||
busy := true.B
|
|
||||||
warpAccess := wid
|
|
||||||
when(io.writeback.fire) {
|
|
||||||
assert(
|
|
||||||
io.writeback.bits.wid =/= wid,
|
|
||||||
"unsupported concurrent initiate and writeback to the same warp"
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: @perf: Instead of waiting until the last writeback, release busy as
|
|
||||||
// soon as the access frontend is complete so that there's a better chance to
|
|
||||||
// saturate the backend with back-to-back HGMMAs. This would require sending
|
|
||||||
// the 'wid' register to backend instead of having it shared with the
|
|
||||||
// frontend.
|
|
||||||
when(io.writeback.fire && io.writeback.bits.last) {
|
|
||||||
busy := false.B
|
|
||||||
}
|
|
||||||
|
|
||||||
// serialize every HGMMA request
|
|
||||||
io.initiate.ready := !busy
|
|
||||||
|
|
||||||
// ===========================================================================
|
|
||||||
// Access stage
|
|
||||||
// ===========================================================================
|
|
||||||
//
|
|
||||||
// Frontend of the decoupled access/execute pipeline.
|
|
||||||
|
|
||||||
// States
|
|
||||||
//
|
|
||||||
object AccessorState extends ChiselEnum {
|
object AccessorState extends ChiselEnum {
|
||||||
val idle = Value(0.U)
|
val idle = Value(0.U)
|
||||||
val access = Value(1.U)
|
val access = Value(1.U)
|
||||||
@@ -142,6 +119,30 @@ class TensorCoreDecoupled(
|
|||||||
val allReqsDone = WireInit(false.B)
|
val allReqsDone = WireInit(false.B)
|
||||||
dontTouch(allReqsDone)
|
dontTouch(allReqsDone)
|
||||||
|
|
||||||
|
val warpAccess = RegInit(0.U(numWarpBits.W))
|
||||||
|
|
||||||
|
class BlockState extends Bundle {
|
||||||
|
val set = UInt(setBits.W)
|
||||||
|
val index = UInt(indexBits.W)
|
||||||
|
}
|
||||||
|
val stateInit = Wire(new BlockState)
|
||||||
|
stateInit.set := 0.U
|
||||||
|
stateInit.index := 0.U
|
||||||
|
val stateA = RegInit(stateInit)
|
||||||
|
val stateB = RegInit(stateInit)
|
||||||
|
dontTouch(stateA)
|
||||||
|
dontTouch(stateA.index)
|
||||||
|
dontTouch(stateB)
|
||||||
|
dontTouch(stateB.index)
|
||||||
|
|
||||||
|
io.initiate.ready := (state === AccessorState.idle)
|
||||||
|
when (io.initiate.fire) {
|
||||||
|
warpAccess := io.initiate.bits.wid
|
||||||
|
assert(stateA.set === 0.U && stateA.index === 0.U &&
|
||||||
|
stateB.set === 0.U && stateB.index === 0.U,
|
||||||
|
"stateA and stateB not initialized to zero")
|
||||||
|
}
|
||||||
|
|
||||||
switch(state) {
|
switch(state) {
|
||||||
is(AccessorState.idle) {
|
is(AccessorState.idle) {
|
||||||
when(io.initiate.fire) {
|
when(io.initiate.fire) {
|
||||||
@@ -154,40 +155,11 @@ class TensorCoreDecoupled(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
is(AccessorState.finish) {
|
is(AccessorState.finish) {
|
||||||
// FIXME: decouple writeback
|
// FIXME: is finish state needed?
|
||||||
when(io.writeback.fire) {
|
state := AccessorState.idle
|
||||||
state := AccessorState.idle
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 'index' is the index of a memory request among the sequence of requests
|
|
||||||
// needed to read a full M-column of A or N-row of B. Its range is [0,m/2)
|
|
||||||
// or [0,n/2), where 2 is the stride can be read in a single request size.
|
|
||||||
require(tilingParams.m == tilingParams.n,
|
|
||||||
"currently only supports square SMEM tile")
|
|
||||||
val numIndices = tilingParams.m / 2/*FIXME:hardcoded?*/
|
|
||||||
val indexBits = log2Ceil(numIndices)
|
|
||||||
val lastIndex = (1 << indexBits) - 1
|
|
||||||
|
|
||||||
class State extends Bundle {
|
|
||||||
val set = UInt(setBits.W)
|
|
||||||
val index = UInt(indexBits.W)
|
|
||||||
}
|
|
||||||
class TensorMemTag extends Bundle {
|
|
||||||
val warp = UInt(numWarpBits.W)
|
|
||||||
val set = UInt(setBits.W)
|
|
||||||
val index = UInt(indexBits.W)
|
|
||||||
}
|
|
||||||
|
|
||||||
val stateInit = Wire(new State)
|
|
||||||
stateInit.set := 0.U
|
|
||||||
stateInit.index := 0.U
|
|
||||||
val stateA = RegInit(stateInit)
|
|
||||||
val stateB = RegInit(stateInit)
|
|
||||||
dontTouch(stateA)
|
|
||||||
dontTouch(stateB)
|
|
||||||
|
|
||||||
when (io.reqA.fire) {
|
when (io.reqA.fire) {
|
||||||
when (stateA.index === lastIndex.U) {
|
when (stateA.index === lastIndex.U) {
|
||||||
stateA.set := stateA.set + 1.U
|
stateA.set := stateA.set + 1.U
|
||||||
|
|||||||
Reference in New Issue
Block a user