tensor: Decouple warp in execute from access
This allows the access stage to accept new initiate back-to-back without waiting for the previous writeback to finish.
This commit is contained in:
@@ -85,7 +85,7 @@ class TensorCoreDecoupled(
|
||||
val busy = RegInit(false.B)
|
||||
// Holds the warp id the core is currently working on. Note that we only
|
||||
// support one outstanding warp request
|
||||
val warpReg = RegInit(0.U(numWarpBits.W))
|
||||
val warpAccess = RegInit(0.U(numWarpBits.W))
|
||||
|
||||
// sets: k iteration
|
||||
val numSets = (tilingParams.k / tilingParams.kc)
|
||||
@@ -101,7 +101,7 @@ class TensorCoreDecoupled(
|
||||
when (io.initiate.fire) {
|
||||
val wid = io.initiate.bits.wid
|
||||
busy := true.B
|
||||
warpReg := wid
|
||||
warpAccess := wid
|
||||
when(io.writeback.fire) {
|
||||
assert(
|
||||
io.writeback.bits.wid =/= wid,
|
||||
@@ -170,28 +170,35 @@ class TensorCoreDecoupled(
|
||||
val indexBits = log2Ceil(numIndices)
|
||||
val lastIndex = (1 << indexBits) - 1
|
||||
|
||||
class State extends Bundle {
|
||||
val set = UInt(setBits.W)
|
||||
val index = UInt(indexBits.W)
|
||||
}
|
||||
class TensorMemTag extends Bundle {
|
||||
val warp = UInt(numWarpBits.W)
|
||||
val set = UInt(setBits.W)
|
||||
val index = UInt(indexBits.W)
|
||||
}
|
||||
|
||||
val tagInit = Wire(new TensorMemTag)
|
||||
tagInit.set := 0.U
|
||||
tagInit.index := 0.U
|
||||
val tagA = RegInit(tagInit)
|
||||
val tagB = RegInit(tagInit)
|
||||
val stateInit = Wire(new State)
|
||||
stateInit.set := 0.U
|
||||
stateInit.index := 0.U
|
||||
val stateA = RegInit(stateInit)
|
||||
val stateB = RegInit(stateInit)
|
||||
dontTouch(stateA)
|
||||
dontTouch(stateB)
|
||||
|
||||
when (io.reqA.fire) {
|
||||
when (tagA.index === lastIndex.U) {
|
||||
tagA.set := tagA.set + 1.U
|
||||
when (stateA.index === lastIndex.U) {
|
||||
stateA.set := stateA.set + 1.U
|
||||
}
|
||||
tagA.index := tagA.index + 1.U
|
||||
stateA.index := stateA.index + 1.U
|
||||
}
|
||||
when (io.reqB.fire) {
|
||||
when (tagB.index === lastIndex.U) {
|
||||
tagB.set := tagB.set + 1.U
|
||||
when (stateB.index === lastIndex.U) {
|
||||
stateB.set := stateB.set + 1.U
|
||||
}
|
||||
tagB.index := tagB.index + 1.U
|
||||
stateB.index := stateB.index + 1.U
|
||||
}
|
||||
|
||||
// Address generation
|
||||
@@ -222,12 +229,12 @@ class TensorCoreDecoupled(
|
||||
}
|
||||
|
||||
// FIXME: bogus base address
|
||||
val addressA = addressGen(0.U, tagA.set, tagA.index)
|
||||
val addressA = addressGen(0.U, stateA.set, stateA.index)
|
||||
// SMEM 256KB, 8 banks: 0x8000B(32KB) per bank
|
||||
val addressB = addressGen(0x400.U, tagB.set, tagB.index)
|
||||
val addressB = addressGen(0x8000.U, stateB.set, stateB.index)
|
||||
|
||||
val lastReqA = (tagA.set === lastSet.U) && (tagA.index === lastIndex.U)
|
||||
val lastReqB = (tagB.set === lastSet.U) && (tagB.index === lastIndex.U)
|
||||
val lastReqA = (stateA.set === lastSet.U) && (stateA.index === lastIndex.U)
|
||||
val lastReqB = (stateB.set === lastSet.U) && (stateB.index === lastIndex.U)
|
||||
val doneReqA = RegInit(false.B)
|
||||
val doneReqB = RegInit(false.B)
|
||||
when (lastReqA && io.reqA.fire) { doneReqA := true.B }
|
||||
@@ -237,16 +244,25 @@ class TensorCoreDecoupled(
|
||||
when (state === AccessorState.finish) {
|
||||
doneReqA := false.B
|
||||
doneReqB := false.B
|
||||
tagA.set := 0.U
|
||||
tagA.index := 0.U
|
||||
tagB.set := 0.U
|
||||
tagB.index := 0.U
|
||||
stateA.set := 0.U
|
||||
stateA.index := 0.U
|
||||
stateB.set := 0.U
|
||||
stateB.index := 0.U
|
||||
}
|
||||
|
||||
allReqsDone := doneReqA && doneReqB
|
||||
|
||||
// Request generation
|
||||
//
|
||||
val tagA = Wire(new TensorMemTag)
|
||||
tagA.warp := warpAccess
|
||||
tagA.set := stateA.set
|
||||
tagA.index := stateA.index
|
||||
val tagB = Wire(new TensorMemTag)
|
||||
tagB.warp := warpAccess
|
||||
tagB.set := stateB.set
|
||||
tagB.index := stateB.index
|
||||
|
||||
val respATagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth)))
|
||||
val respBTagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth)))
|
||||
Seq((io.reqA, (io.respA, respATagged)),
|
||||
@@ -422,9 +438,12 @@ class TensorCoreDecoupled(
|
||||
def assertAligned = {
|
||||
val stepMask = (1 << numTilesMBits).U
|
||||
when (dpuFire) {
|
||||
assert(fullABuf.io.deq.bits.tag.set === fullBBuf.io.deq.bits.tag.set,
|
||||
"A and B operands are pointing to different sets. " ++
|
||||
assert(operandATag.warp === operandBTag.warp &&
|
||||
operandATag.set === operandBTag.set,
|
||||
"A and B operands are pointing to different warps and sets. " ++
|
||||
"This might indicate memory response coming back out-of-order.")
|
||||
assert(operandATag.set === setCompute,
|
||||
"Operand arrived from memory is pointing at a different set than the FSM.")
|
||||
}
|
||||
}
|
||||
assertAligned
|
||||
@@ -492,6 +511,7 @@ class TensorCoreDecoupled(
|
||||
// These queues hold metadata needed for writeback in sync with the DPU.
|
||||
|
||||
class TensorComputeTag extends Bundle {
|
||||
val warp = UInt(numWarpBits.W)
|
||||
val set = UInt(setBits.W)
|
||||
val step = UInt(stepBits.W)
|
||||
val substep = UInt(1.W)
|
||||
@@ -500,6 +520,7 @@ class TensorCoreDecoupled(
|
||||
val queueDepth = 5 // needs to be at least the DPU latency
|
||||
val tagQueue = Module(new Queue(new TensorComputeTag, queueDepth))
|
||||
tagQueue.io.enq.valid := dpuFire
|
||||
tagQueue.io.enq.bits.warp := operandATag.warp
|
||||
tagQueue.io.enq.bits.set := setCompute
|
||||
tagQueue.io.enq.bits.step := stepCompute
|
||||
tagQueue.io.enq.bits.substep := substepCompute
|
||||
@@ -518,12 +539,12 @@ class TensorCoreDecoupled(
|
||||
(step << 1/*2 substeps*/) + substep
|
||||
}
|
||||
|
||||
val warpWriteback = tagQueue.io.deq.bits.warp
|
||||
val setWriteback = tagQueue.io.deq.bits.set
|
||||
val stepWriteback = tagQueue.io.deq.bits.step
|
||||
val substepWriteback = tagQueue.io.deq.bits.substep
|
||||
io.writeback.valid := dpuValid
|
||||
// TODO: decouple wid from frontend
|
||||
io.writeback.bits.wid := warpReg
|
||||
io.writeback.bits.wid := warpWriteback
|
||||
io.writeback.bits.rd := rdGen(stepWriteback, substepWriteback)
|
||||
io.writeback.bits.last := setDone(setWriteback) && stepDone(stepWriteback) &&
|
||||
(substepWriteback === 1.U)
|
||||
@@ -685,7 +706,7 @@ class TensorCoreDecoupledTwoTLRAM(implicit p: Parameters) extends LazyModule {
|
||||
ramA.node := stutter := xbar.node
|
||||
ramB.node := xbar.node
|
||||
|
||||
val fuzz = true
|
||||
val fuzz = false
|
||||
|
||||
lazy val module = new Impl
|
||||
class Impl extends LazyModuleImp(this) with UnitTestModule {
|
||||
|
||||
Reference in New Issue
Block a user