tensor: Decouple warp in execute from access

This allows the access stage to accept new initiate back-to-back without
waiting for the previous writeback to finish.
This commit is contained in:
Hansung Kim
2024-10-22 22:44:33 -07:00
parent 072904a82b
commit 95ecc5180f

View File

@@ -85,7 +85,7 @@ class TensorCoreDecoupled(
val busy = RegInit(false.B)
// Holds the warp id the core is currently working on. Note that we only
// support one outstanding warp request
val warpReg = RegInit(0.U(numWarpBits.W))
val warpAccess = RegInit(0.U(numWarpBits.W))
// sets: k iteration
val numSets = (tilingParams.k / tilingParams.kc)
@@ -101,7 +101,7 @@ class TensorCoreDecoupled(
when (io.initiate.fire) {
val wid = io.initiate.bits.wid
busy := true.B
warpReg := wid
warpAccess := wid
when(io.writeback.fire) {
assert(
io.writeback.bits.wid =/= wid,
@@ -170,28 +170,35 @@ class TensorCoreDecoupled(
val indexBits = log2Ceil(numIndices)
val lastIndex = (1 << indexBits) - 1
class State extends Bundle {
val set = UInt(setBits.W)
val index = UInt(indexBits.W)
}
class TensorMemTag extends Bundle {
val warp = UInt(numWarpBits.W)
val set = UInt(setBits.W)
val index = UInt(indexBits.W)
}
val tagInit = Wire(new TensorMemTag)
tagInit.set := 0.U
tagInit.index := 0.U
val tagA = RegInit(tagInit)
val tagB = RegInit(tagInit)
val stateInit = Wire(new State)
stateInit.set := 0.U
stateInit.index := 0.U
val stateA = RegInit(stateInit)
val stateB = RegInit(stateInit)
dontTouch(stateA)
dontTouch(stateB)
when (io.reqA.fire) {
when (tagA.index === lastIndex.U) {
tagA.set := tagA.set + 1.U
when (stateA.index === lastIndex.U) {
stateA.set := stateA.set + 1.U
}
tagA.index := tagA.index + 1.U
stateA.index := stateA.index + 1.U
}
when (io.reqB.fire) {
when (tagB.index === lastIndex.U) {
tagB.set := tagB.set + 1.U
when (stateB.index === lastIndex.U) {
stateB.set := stateB.set + 1.U
}
tagB.index := tagB.index + 1.U
stateB.index := stateB.index + 1.U
}
// Address generation
@@ -222,12 +229,12 @@ class TensorCoreDecoupled(
}
// FIXME: bogus base address
val addressA = addressGen(0.U, tagA.set, tagA.index)
val addressA = addressGen(0.U, stateA.set, stateA.index)
// SMEM 256KB, 8 banks: 0x8000B(32KB) per bank
val addressB = addressGen(0x400.U, tagB.set, tagB.index)
val addressB = addressGen(0x8000.U, stateB.set, stateB.index)
val lastReqA = (tagA.set === lastSet.U) && (tagA.index === lastIndex.U)
val lastReqB = (tagB.set === lastSet.U) && (tagB.index === lastIndex.U)
val lastReqA = (stateA.set === lastSet.U) && (stateA.index === lastIndex.U)
val lastReqB = (stateB.set === lastSet.U) && (stateB.index === lastIndex.U)
val doneReqA = RegInit(false.B)
val doneReqB = RegInit(false.B)
when (lastReqA && io.reqA.fire) { doneReqA := true.B }
@@ -237,16 +244,25 @@ class TensorCoreDecoupled(
when (state === AccessorState.finish) {
doneReqA := false.B
doneReqB := false.B
tagA.set := 0.U
tagA.index := 0.U
tagB.set := 0.U
tagB.index := 0.U
stateA.set := 0.U
stateA.index := 0.U
stateB.set := 0.U
stateB.index := 0.U
}
allReqsDone := doneReqA && doneReqB
// Request generation
//
val tagA = Wire(new TensorMemTag)
tagA.warp := warpAccess
tagA.set := stateA.set
tagA.index := stateA.index
val tagB = Wire(new TensorMemTag)
tagB.warp := warpAccess
tagB.set := stateB.set
tagB.index := stateB.index
val respATagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth)))
val respBTagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth)))
Seq((io.reqA, (io.respA, respATagged)),
@@ -422,9 +438,12 @@ class TensorCoreDecoupled(
def assertAligned = {
val stepMask = (1 << numTilesMBits).U
when (dpuFire) {
assert(fullABuf.io.deq.bits.tag.set === fullBBuf.io.deq.bits.tag.set,
"A and B operands are pointing to different sets. " ++
assert(operandATag.warp === operandBTag.warp &&
operandATag.set === operandBTag.set,
"A and B operands are pointing to different warps and sets. " ++
"This might indicate memory response coming back out-of-order.")
assert(operandATag.set === setCompute,
"Operand arrived from memory is pointing at a different set than the FSM.")
}
}
assertAligned
@@ -492,6 +511,7 @@ class TensorCoreDecoupled(
// These queues hold metadata needed for writeback in sync with the DPU.
class TensorComputeTag extends Bundle {
val warp = UInt(numWarpBits.W)
val set = UInt(setBits.W)
val step = UInt(stepBits.W)
val substep = UInt(1.W)
@@ -500,6 +520,7 @@ class TensorCoreDecoupled(
val queueDepth = 5 // needs to be at least the DPU latency
val tagQueue = Module(new Queue(new TensorComputeTag, queueDepth))
tagQueue.io.enq.valid := dpuFire
tagQueue.io.enq.bits.warp := operandATag.warp
tagQueue.io.enq.bits.set := setCompute
tagQueue.io.enq.bits.step := stepCompute
tagQueue.io.enq.bits.substep := substepCompute
@@ -518,12 +539,12 @@ class TensorCoreDecoupled(
(step << 1/*2 substeps*/) + substep
}
val warpWriteback = tagQueue.io.deq.bits.warp
val setWriteback = tagQueue.io.deq.bits.set
val stepWriteback = tagQueue.io.deq.bits.step
val substepWriteback = tagQueue.io.deq.bits.substep
io.writeback.valid := dpuValid
// TODO: decouple wid from frontend
io.writeback.bits.wid := warpReg
io.writeback.bits.wid := warpWriteback
io.writeback.bits.rd := rdGen(stepWriteback, substepWriteback)
io.writeback.bits.last := setDone(setWriteback) && stepDone(stepWriteback) &&
(substepWriteback === 1.U)
@@ -685,7 +706,7 @@ class TensorCoreDecoupledTwoTLRAM(implicit p: Parameters) extends LazyModule {
ramA.node := stutter := xbar.node
ramB.node := xbar.node
val fuzz = true
val fuzz = false
lazy val module = new Impl
class Impl extends LazyModuleImp(this) with UnitTestModule {