tensor: Decouple warp in execute from access
This allows the access stage to accept new initiate back-to-back without waiting for the previous writeback to finish.
This commit is contained in:
@@ -85,7 +85,7 @@ class TensorCoreDecoupled(
|
|||||||
val busy = RegInit(false.B)
|
val busy = RegInit(false.B)
|
||||||
// Holds the warp id the core is currently working on. Note that we only
|
// Holds the warp id the core is currently working on. Note that we only
|
||||||
// support one outstanding warp request
|
// support one outstanding warp request
|
||||||
val warpReg = RegInit(0.U(numWarpBits.W))
|
val warpAccess = RegInit(0.U(numWarpBits.W))
|
||||||
|
|
||||||
// sets: k iteration
|
// sets: k iteration
|
||||||
val numSets = (tilingParams.k / tilingParams.kc)
|
val numSets = (tilingParams.k / tilingParams.kc)
|
||||||
@@ -101,7 +101,7 @@ class TensorCoreDecoupled(
|
|||||||
when (io.initiate.fire) {
|
when (io.initiate.fire) {
|
||||||
val wid = io.initiate.bits.wid
|
val wid = io.initiate.bits.wid
|
||||||
busy := true.B
|
busy := true.B
|
||||||
warpReg := wid
|
warpAccess := wid
|
||||||
when(io.writeback.fire) {
|
when(io.writeback.fire) {
|
||||||
assert(
|
assert(
|
||||||
io.writeback.bits.wid =/= wid,
|
io.writeback.bits.wid =/= wid,
|
||||||
@@ -170,28 +170,35 @@ class TensorCoreDecoupled(
|
|||||||
val indexBits = log2Ceil(numIndices)
|
val indexBits = log2Ceil(numIndices)
|
||||||
val lastIndex = (1 << indexBits) - 1
|
val lastIndex = (1 << indexBits) - 1
|
||||||
|
|
||||||
|
class State extends Bundle {
|
||||||
|
val set = UInt(setBits.W)
|
||||||
|
val index = UInt(indexBits.W)
|
||||||
|
}
|
||||||
class TensorMemTag extends Bundle {
|
class TensorMemTag extends Bundle {
|
||||||
|
val warp = UInt(numWarpBits.W)
|
||||||
val set = UInt(setBits.W)
|
val set = UInt(setBits.W)
|
||||||
val index = UInt(indexBits.W)
|
val index = UInt(indexBits.W)
|
||||||
}
|
}
|
||||||
|
|
||||||
val tagInit = Wire(new TensorMemTag)
|
val stateInit = Wire(new State)
|
||||||
tagInit.set := 0.U
|
stateInit.set := 0.U
|
||||||
tagInit.index := 0.U
|
stateInit.index := 0.U
|
||||||
val tagA = RegInit(tagInit)
|
val stateA = RegInit(stateInit)
|
||||||
val tagB = RegInit(tagInit)
|
val stateB = RegInit(stateInit)
|
||||||
|
dontTouch(stateA)
|
||||||
|
dontTouch(stateB)
|
||||||
|
|
||||||
when (io.reqA.fire) {
|
when (io.reqA.fire) {
|
||||||
when (tagA.index === lastIndex.U) {
|
when (stateA.index === lastIndex.U) {
|
||||||
tagA.set := tagA.set + 1.U
|
stateA.set := stateA.set + 1.U
|
||||||
}
|
}
|
||||||
tagA.index := tagA.index + 1.U
|
stateA.index := stateA.index + 1.U
|
||||||
}
|
}
|
||||||
when (io.reqB.fire) {
|
when (io.reqB.fire) {
|
||||||
when (tagB.index === lastIndex.U) {
|
when (stateB.index === lastIndex.U) {
|
||||||
tagB.set := tagB.set + 1.U
|
stateB.set := stateB.set + 1.U
|
||||||
}
|
}
|
||||||
tagB.index := tagB.index + 1.U
|
stateB.index := stateB.index + 1.U
|
||||||
}
|
}
|
||||||
|
|
||||||
// Address generation
|
// Address generation
|
||||||
@@ -222,12 +229,12 @@ class TensorCoreDecoupled(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// FIXME: bogus base address
|
// FIXME: bogus base address
|
||||||
val addressA = addressGen(0.U, tagA.set, tagA.index)
|
val addressA = addressGen(0.U, stateA.set, stateA.index)
|
||||||
// SMEM 256KB, 8 banks: 0x8000B(32KB) per bank
|
// SMEM 256KB, 8 banks: 0x8000B(32KB) per bank
|
||||||
val addressB = addressGen(0x400.U, tagB.set, tagB.index)
|
val addressB = addressGen(0x8000.U, stateB.set, stateB.index)
|
||||||
|
|
||||||
val lastReqA = (tagA.set === lastSet.U) && (tagA.index === lastIndex.U)
|
val lastReqA = (stateA.set === lastSet.U) && (stateA.index === lastIndex.U)
|
||||||
val lastReqB = (tagB.set === lastSet.U) && (tagB.index === lastIndex.U)
|
val lastReqB = (stateB.set === lastSet.U) && (stateB.index === lastIndex.U)
|
||||||
val doneReqA = RegInit(false.B)
|
val doneReqA = RegInit(false.B)
|
||||||
val doneReqB = RegInit(false.B)
|
val doneReqB = RegInit(false.B)
|
||||||
when (lastReqA && io.reqA.fire) { doneReqA := true.B }
|
when (lastReqA && io.reqA.fire) { doneReqA := true.B }
|
||||||
@@ -237,16 +244,25 @@ class TensorCoreDecoupled(
|
|||||||
when (state === AccessorState.finish) {
|
when (state === AccessorState.finish) {
|
||||||
doneReqA := false.B
|
doneReqA := false.B
|
||||||
doneReqB := false.B
|
doneReqB := false.B
|
||||||
tagA.set := 0.U
|
stateA.set := 0.U
|
||||||
tagA.index := 0.U
|
stateA.index := 0.U
|
||||||
tagB.set := 0.U
|
stateB.set := 0.U
|
||||||
tagB.index := 0.U
|
stateB.index := 0.U
|
||||||
}
|
}
|
||||||
|
|
||||||
allReqsDone := doneReqA && doneReqB
|
allReqsDone := doneReqA && doneReqB
|
||||||
|
|
||||||
// Request generation
|
// Request generation
|
||||||
//
|
//
|
||||||
|
val tagA = Wire(new TensorMemTag)
|
||||||
|
tagA.warp := warpAccess
|
||||||
|
tagA.set := stateA.set
|
||||||
|
tagA.index := stateA.index
|
||||||
|
val tagB = Wire(new TensorMemTag)
|
||||||
|
tagB.warp := warpAccess
|
||||||
|
tagB.set := stateB.set
|
||||||
|
tagB.index := stateB.index
|
||||||
|
|
||||||
val respATagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth)))
|
val respATagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth)))
|
||||||
val respBTagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth)))
|
val respBTagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth)))
|
||||||
Seq((io.reqA, (io.respA, respATagged)),
|
Seq((io.reqA, (io.respA, respATagged)),
|
||||||
@@ -422,9 +438,12 @@ class TensorCoreDecoupled(
|
|||||||
def assertAligned = {
|
def assertAligned = {
|
||||||
val stepMask = (1 << numTilesMBits).U
|
val stepMask = (1 << numTilesMBits).U
|
||||||
when (dpuFire) {
|
when (dpuFire) {
|
||||||
assert(fullABuf.io.deq.bits.tag.set === fullBBuf.io.deq.bits.tag.set,
|
assert(operandATag.warp === operandBTag.warp &&
|
||||||
"A and B operands are pointing to different sets. " ++
|
operandATag.set === operandBTag.set,
|
||||||
|
"A and B operands are pointing to different warps and sets. " ++
|
||||||
"This might indicate memory response coming back out-of-order.")
|
"This might indicate memory response coming back out-of-order.")
|
||||||
|
assert(operandATag.set === setCompute,
|
||||||
|
"Operand arrived from memory is pointing at a different set than the FSM.")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
assertAligned
|
assertAligned
|
||||||
@@ -492,6 +511,7 @@ class TensorCoreDecoupled(
|
|||||||
// These queues hold metadata needed for writeback in sync with the DPU.
|
// These queues hold metadata needed for writeback in sync with the DPU.
|
||||||
|
|
||||||
class TensorComputeTag extends Bundle {
|
class TensorComputeTag extends Bundle {
|
||||||
|
val warp = UInt(numWarpBits.W)
|
||||||
val set = UInt(setBits.W)
|
val set = UInt(setBits.W)
|
||||||
val step = UInt(stepBits.W)
|
val step = UInt(stepBits.W)
|
||||||
val substep = UInt(1.W)
|
val substep = UInt(1.W)
|
||||||
@@ -500,6 +520,7 @@ class TensorCoreDecoupled(
|
|||||||
val queueDepth = 5 // needs to be at least the DPU latency
|
val queueDepth = 5 // needs to be at least the DPU latency
|
||||||
val tagQueue = Module(new Queue(new TensorComputeTag, queueDepth))
|
val tagQueue = Module(new Queue(new TensorComputeTag, queueDepth))
|
||||||
tagQueue.io.enq.valid := dpuFire
|
tagQueue.io.enq.valid := dpuFire
|
||||||
|
tagQueue.io.enq.bits.warp := operandATag.warp
|
||||||
tagQueue.io.enq.bits.set := setCompute
|
tagQueue.io.enq.bits.set := setCompute
|
||||||
tagQueue.io.enq.bits.step := stepCompute
|
tagQueue.io.enq.bits.step := stepCompute
|
||||||
tagQueue.io.enq.bits.substep := substepCompute
|
tagQueue.io.enq.bits.substep := substepCompute
|
||||||
@@ -518,12 +539,12 @@ class TensorCoreDecoupled(
|
|||||||
(step << 1/*2 substeps*/) + substep
|
(step << 1/*2 substeps*/) + substep
|
||||||
}
|
}
|
||||||
|
|
||||||
|
val warpWriteback = tagQueue.io.deq.bits.warp
|
||||||
val setWriteback = tagQueue.io.deq.bits.set
|
val setWriteback = tagQueue.io.deq.bits.set
|
||||||
val stepWriteback = tagQueue.io.deq.bits.step
|
val stepWriteback = tagQueue.io.deq.bits.step
|
||||||
val substepWriteback = tagQueue.io.deq.bits.substep
|
val substepWriteback = tagQueue.io.deq.bits.substep
|
||||||
io.writeback.valid := dpuValid
|
io.writeback.valid := dpuValid
|
||||||
// TODO: decouple wid from frontend
|
io.writeback.bits.wid := warpWriteback
|
||||||
io.writeback.bits.wid := warpReg
|
|
||||||
io.writeback.bits.rd := rdGen(stepWriteback, substepWriteback)
|
io.writeback.bits.rd := rdGen(stepWriteback, substepWriteback)
|
||||||
io.writeback.bits.last := setDone(setWriteback) && stepDone(stepWriteback) &&
|
io.writeback.bits.last := setDone(setWriteback) && stepDone(stepWriteback) &&
|
||||||
(substepWriteback === 1.U)
|
(substepWriteback === 1.U)
|
||||||
@@ -685,7 +706,7 @@ class TensorCoreDecoupledTwoTLRAM(implicit p: Parameters) extends LazyModule {
|
|||||||
ramA.node := stutter := xbar.node
|
ramA.node := stutter := xbar.node
|
||||||
ramB.node := xbar.node
|
ramB.node := xbar.node
|
||||||
|
|
||||||
val fuzz = true
|
val fuzz = false
|
||||||
|
|
||||||
lazy val module = new Impl
|
lazy val module = new Impl
|
||||||
class Impl extends LazyModuleImp(this) with UnitTestModule {
|
class Impl extends LazyModuleImp(this) with UnitTestModule {
|
||||||
|
|||||||
Reference in New Issue
Block a user