tensor: Decouple warp in execute from access

This allows the access stage to accept new initiate back-to-back without
waiting for the previous writeback to finish.
This commit is contained in:
Hansung Kim
2024-10-22 22:44:33 -07:00
parent 072904a82b
commit 95ecc5180f

View File

@@ -85,7 +85,7 @@ class TensorCoreDecoupled(
val busy = RegInit(false.B) val busy = RegInit(false.B)
// Holds the warp id the core is currently working on. Note that we only // Holds the warp id the core is currently working on. Note that we only
// support one outstanding warp request // support one outstanding warp request
val warpReg = RegInit(0.U(numWarpBits.W)) val warpAccess = RegInit(0.U(numWarpBits.W))
// sets: k iteration // sets: k iteration
val numSets = (tilingParams.k / tilingParams.kc) val numSets = (tilingParams.k / tilingParams.kc)
@@ -101,7 +101,7 @@ class TensorCoreDecoupled(
when (io.initiate.fire) { when (io.initiate.fire) {
val wid = io.initiate.bits.wid val wid = io.initiate.bits.wid
busy := true.B busy := true.B
warpReg := wid warpAccess := wid
when(io.writeback.fire) { when(io.writeback.fire) {
assert( assert(
io.writeback.bits.wid =/= wid, io.writeback.bits.wid =/= wid,
@@ -170,28 +170,35 @@ class TensorCoreDecoupled(
val indexBits = log2Ceil(numIndices) val indexBits = log2Ceil(numIndices)
val lastIndex = (1 << indexBits) - 1 val lastIndex = (1 << indexBits) - 1
class State extends Bundle {
val set = UInt(setBits.W)
val index = UInt(indexBits.W)
}
class TensorMemTag extends Bundle { class TensorMemTag extends Bundle {
val warp = UInt(numWarpBits.W)
val set = UInt(setBits.W) val set = UInt(setBits.W)
val index = UInt(indexBits.W) val index = UInt(indexBits.W)
} }
val tagInit = Wire(new TensorMemTag) val stateInit = Wire(new State)
tagInit.set := 0.U stateInit.set := 0.U
tagInit.index := 0.U stateInit.index := 0.U
val tagA = RegInit(tagInit) val stateA = RegInit(stateInit)
val tagB = RegInit(tagInit) val stateB = RegInit(stateInit)
dontTouch(stateA)
dontTouch(stateB)
when (io.reqA.fire) { when (io.reqA.fire) {
when (tagA.index === lastIndex.U) { when (stateA.index === lastIndex.U) {
tagA.set := tagA.set + 1.U stateA.set := stateA.set + 1.U
} }
tagA.index := tagA.index + 1.U stateA.index := stateA.index + 1.U
} }
when (io.reqB.fire) { when (io.reqB.fire) {
when (tagB.index === lastIndex.U) { when (stateB.index === lastIndex.U) {
tagB.set := tagB.set + 1.U stateB.set := stateB.set + 1.U
} }
tagB.index := tagB.index + 1.U stateB.index := stateB.index + 1.U
} }
// Address generation // Address generation
@@ -222,12 +229,12 @@ class TensorCoreDecoupled(
} }
// FIXME: bogus base address // FIXME: bogus base address
val addressA = addressGen(0.U, tagA.set, tagA.index) val addressA = addressGen(0.U, stateA.set, stateA.index)
// SMEM 256KB, 8 banks: 0x8000B(32KB) per bank // SMEM 256KB, 8 banks: 0x8000B(32KB) per bank
val addressB = addressGen(0x400.U, tagB.set, tagB.index) val addressB = addressGen(0x8000.U, stateB.set, stateB.index)
val lastReqA = (tagA.set === lastSet.U) && (tagA.index === lastIndex.U) val lastReqA = (stateA.set === lastSet.U) && (stateA.index === lastIndex.U)
val lastReqB = (tagB.set === lastSet.U) && (tagB.index === lastIndex.U) val lastReqB = (stateB.set === lastSet.U) && (stateB.index === lastIndex.U)
val doneReqA = RegInit(false.B) val doneReqA = RegInit(false.B)
val doneReqB = RegInit(false.B) val doneReqB = RegInit(false.B)
when (lastReqA && io.reqA.fire) { doneReqA := true.B } when (lastReqA && io.reqA.fire) { doneReqA := true.B }
@@ -237,16 +244,25 @@ class TensorCoreDecoupled(
when (state === AccessorState.finish) { when (state === AccessorState.finish) {
doneReqA := false.B doneReqA := false.B
doneReqB := false.B doneReqB := false.B
tagA.set := 0.U stateA.set := 0.U
tagA.index := 0.U stateA.index := 0.U
tagB.set := 0.U stateB.set := 0.U
tagB.index := 0.U stateB.index := 0.U
} }
allReqsDone := doneReqA && doneReqB allReqsDone := doneReqA && doneReqB
// Request generation // Request generation
// //
val tagA = Wire(new TensorMemTag)
tagA.warp := warpAccess
tagA.set := stateA.set
tagA.index := stateA.index
val tagB = Wire(new TensorMemTag)
tagB.warp := warpAccess
tagB.set := stateB.set
tagB.index := stateB.index
val respATagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth))) val respATagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth)))
val respBTagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth))) val respBTagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth)))
Seq((io.reqA, (io.respA, respATagged)), Seq((io.reqA, (io.respA, respATagged)),
@@ -422,9 +438,12 @@ class TensorCoreDecoupled(
def assertAligned = { def assertAligned = {
val stepMask = (1 << numTilesMBits).U val stepMask = (1 << numTilesMBits).U
when (dpuFire) { when (dpuFire) {
assert(fullABuf.io.deq.bits.tag.set === fullBBuf.io.deq.bits.tag.set, assert(operandATag.warp === operandBTag.warp &&
"A and B operands are pointing to different sets. " ++ operandATag.set === operandBTag.set,
"A and B operands are pointing to different warps and sets. " ++
"This might indicate memory response coming back out-of-order.") "This might indicate memory response coming back out-of-order.")
assert(operandATag.set === setCompute,
"Operand arrived from memory is pointing at a different set than the FSM.")
} }
} }
assertAligned assertAligned
@@ -492,6 +511,7 @@ class TensorCoreDecoupled(
// These queues hold metadata needed for writeback in sync with the DPU. // These queues hold metadata needed for writeback in sync with the DPU.
class TensorComputeTag extends Bundle { class TensorComputeTag extends Bundle {
val warp = UInt(numWarpBits.W)
val set = UInt(setBits.W) val set = UInt(setBits.W)
val step = UInt(stepBits.W) val step = UInt(stepBits.W)
val substep = UInt(1.W) val substep = UInt(1.W)
@@ -500,6 +520,7 @@ class TensorCoreDecoupled(
val queueDepth = 5 // needs to be at least the DPU latency val queueDepth = 5 // needs to be at least the DPU latency
val tagQueue = Module(new Queue(new TensorComputeTag, queueDepth)) val tagQueue = Module(new Queue(new TensorComputeTag, queueDepth))
tagQueue.io.enq.valid := dpuFire tagQueue.io.enq.valid := dpuFire
tagQueue.io.enq.bits.warp := operandATag.warp
tagQueue.io.enq.bits.set := setCompute tagQueue.io.enq.bits.set := setCompute
tagQueue.io.enq.bits.step := stepCompute tagQueue.io.enq.bits.step := stepCompute
tagQueue.io.enq.bits.substep := substepCompute tagQueue.io.enq.bits.substep := substepCompute
@@ -518,12 +539,12 @@ class TensorCoreDecoupled(
(step << 1/*2 substeps*/) + substep (step << 1/*2 substeps*/) + substep
} }
val warpWriteback = tagQueue.io.deq.bits.warp
val setWriteback = tagQueue.io.deq.bits.set val setWriteback = tagQueue.io.deq.bits.set
val stepWriteback = tagQueue.io.deq.bits.step val stepWriteback = tagQueue.io.deq.bits.step
val substepWriteback = tagQueue.io.deq.bits.substep val substepWriteback = tagQueue.io.deq.bits.substep
io.writeback.valid := dpuValid io.writeback.valid := dpuValid
// TODO: decouple wid from frontend io.writeback.bits.wid := warpWriteback
io.writeback.bits.wid := warpReg
io.writeback.bits.rd := rdGen(stepWriteback, substepWriteback) io.writeback.bits.rd := rdGen(stepWriteback, substepWriteback)
io.writeback.bits.last := setDone(setWriteback) && stepDone(stepWriteback) && io.writeback.bits.last := setDone(setWriteback) && stepDone(stepWriteback) &&
(substepWriteback === 1.U) (substepWriteback === 1.U)
@@ -685,7 +706,7 @@ class TensorCoreDecoupledTwoTLRAM(implicit p: Parameters) extends LazyModule {
ramA.node := stutter := xbar.node ramA.node := stutter := xbar.node
ramB.node := xbar.node ramB.node := xbar.node
val fuzz = true val fuzz = false
lazy val module = new Impl lazy val module = new Impl
class Impl extends LazyModuleImp(this) with UnitTestModule { class Impl extends LazyModuleImp(this) with UnitTestModule {