tensor: Replace staging logic for A with FillBuffer
This commit is contained in:
@@ -272,46 +272,41 @@ class TensorCoreDecoupled(
|
|||||||
io.writeback.bits.data.widthOption.get,
|
io.writeback.bits.data.widthOption.get,
|
||||||
"response data width does not match the writeback data width")
|
"response data width does not match the writeback data width")
|
||||||
|
|
||||||
|
// FIXME: unnecessary
|
||||||
val substepDeqA = RegInit(0.U(1.W))
|
val substepDeqA = RegInit(0.U(1.W))
|
||||||
when (respQueueA.fire) {
|
when (respQueueA.fire) {
|
||||||
substepDeqA := substepDeqA + 1.U
|
substepDeqA := substepDeqA + 1.U
|
||||||
}
|
}
|
||||||
dontTouch(substepDeqA)
|
dontTouch(substepDeqA)
|
||||||
|
|
||||||
// Do pipelining for the A operand so that we obtain the full 4x4 A tile
|
// Stage the operands in a pipeline so that we obtain the full 4x4 tiles
|
||||||
// ready for compute. The pipeline is two-stage:
|
// ready for compute. Also send the set/step tag along the pipe for
|
||||||
// - stage one (halfAQueue) for assembling the full A tile from half-tiles
|
// alignment check.
|
||||||
// coming from the resp queue, and
|
|
||||||
// - stage two (fullAQueue) for holding the full A tile until it gets
|
|
||||||
// matched with two 4x2 B tiles, and compute is complete.
|
|
||||||
//
|
|
||||||
// Note that the half-tile assembly is unnecessary for B since the B tile is
|
|
||||||
// only 4x2.
|
|
||||||
// Also send the set/step tag along the pipe for alignment check.
|
|
||||||
|
|
||||||
// note combinationally coupled ready with `pipe`
|
val fullA = Module(new FillBuffer(
|
||||||
val halfAQueue = Module(new Queue(
|
chiselTypeOf(respQueueB.bits.data), 2/*substeps*/
|
||||||
chiselTypeOf(respQueueA.bits), entries = 1, pipe = true
|
|
||||||
))
|
))
|
||||||
halfAQueue.io.enq.valid := respQueueA.valid && (substepDeqA === 0.U)
|
fullA.io.enq.valid := respQueueA.valid
|
||||||
halfAQueue.io.enq.bits := respQueueA.bits
|
fullA.io.enq.bits := respQueueA.bits.data
|
||||||
|
respQueueA.ready := fullA.io.enq.ready
|
||||||
|
// `pipe` combinationally couples enq-deq ready
|
||||||
|
val fullATag = Module(new Queue(
|
||||||
|
new TensorMemTag, entries = 1, pipe = true
|
||||||
|
))
|
||||||
|
fullATag.io.enq.valid := respQueueA.valid
|
||||||
|
fullATag.io.enq.bits := respQueueA.bits.tag
|
||||||
|
|
||||||
// substep == 0 data goes to the LSB
|
// stage the full A tile once more so that FillBuffer can be filled up in the
|
||||||
val fullAEnqData = Cat(respQueueA.bits.data, halfAQueue.io.deq.bits.data)
|
// background while the tile is being used for compute. This does come with
|
||||||
require(fullAEnqData.widthOption.get == dataWidth * 2,
|
// capacity overhead.
|
||||||
"assumes 2-cycle read for a full compute tile of A")
|
val fullABuf = Module(new Queue(
|
||||||
// only use the lower halfA's tag. substep will be incorrect.
|
|
||||||
val fullAEnqTag = halfAQueue.io.deq.bits.tag
|
|
||||||
val fullAQueue = Module(new Queue(
|
|
||||||
new TensorMemRespWithTag(dataWidth * 2), entries = 1, pipe = true
|
new TensorMemRespWithTag(dataWidth * 2), entries = 1, pipe = true
|
||||||
))
|
))
|
||||||
// hold first half A data for the first substep
|
fullABuf.io.enq.valid := fullA.io.deq.valid
|
||||||
halfAQueue.io.deq.ready := respQueueA.valid && (substepDeqA === 1.U) &&
|
fullABuf.io.enq.bits.data := fullA.io.deq.bits.asUInt
|
||||||
fullAQueue.io.enq.ready
|
fullABuf.io.enq.bits.tag := fullATag.io.deq.bits
|
||||||
fullAQueue.io.enq.valid := respQueueA.valid && (substepDeqA === 1.U) &&
|
fullA.io.deq.ready := fullABuf.io.enq.ready
|
||||||
halfAQueue.io.deq.valid
|
fullATag.io.deq.ready := fullABuf.io.enq.ready
|
||||||
fullAQueue.io.enq.bits.data := fullAEnqData
|
|
||||||
fullAQueue.io.enq.bits.tag := fullAEnqTag
|
|
||||||
|
|
||||||
// serialize every two B responses into one full 4x4 B tile
|
// serialize every two B responses into one full 4x4 B tile
|
||||||
// FIXME: do the same for A
|
// FIXME: do the same for A
|
||||||
@@ -327,29 +322,24 @@ class TensorCoreDecoupled(
|
|||||||
fullBTag.io.enq.valid := respQueueB.valid
|
fullBTag.io.enq.valid := respQueueB.valid
|
||||||
fullBTag.io.enq.bits := respQueueB.bits.tag
|
fullBTag.io.enq.bits := respQueueB.bits.tag
|
||||||
|
|
||||||
val operandsValid = fullAQueue.io.deq.valid && fullB.io.deq.valid
|
val operandsValid = fullABuf.io.deq.valid && fullB.io.deq.valid
|
||||||
val operandA = fullAQueue.io.deq.bits.data
|
val operandA = fullABuf.io.deq.bits.data
|
||||||
val operandATag = fullAQueue.io.deq.bits.tag
|
val operandATag = fullABuf.io.deq.bits.tag
|
||||||
val operandB = fullB.io.deq.bits
|
val operandB = fullB.io.deq.bits
|
||||||
val dpuReady = Wire(Bool())
|
val dpuReady = Wire(Bool())
|
||||||
val dpuFire = operandsValid && dpuReady
|
val dpuFire = operandsValid && dpuReady
|
||||||
val setCompute = fullAQueue.io.deq.bits.tag.set
|
val setCompute = fullABuf.io.deq.bits.tag.set
|
||||||
val stepCompute = fullAQueue.io.deq.bits.tag.step
|
val stepCompute = fullABuf.io.deq.bits.tag.step
|
||||||
val substepCompute = RegInit(0.U(1.W))
|
val substepCompute = RegInit(0.U(1.W))
|
||||||
when (dpuFire) {
|
when (dpuFire) {
|
||||||
substepCompute := substepCompute + 1.U
|
substepCompute := substepCompute + 1.U
|
||||||
}
|
}
|
||||||
|
|
||||||
// respQueueA output arbitrates to either halfAQueue or fullAQueue depending
|
|
||||||
// on the substep
|
|
||||||
respQueueA.ready := MuxCase(false.B,
|
|
||||||
Seq((substepDeqA === 0.U) -> halfAQueue.io.enq.ready,
|
|
||||||
(substepDeqA === 1.U) -> fullAQueue.io.enq.ready))
|
|
||||||
// Hold B tile at respQueueB for multiple steps for reuse, only dequeue when
|
// Hold B tile at respQueueB for multiple steps for reuse, only dequeue when
|
||||||
// we fully iterated a column (M-dimension).
|
// we fully iterated a column (M-dimension).
|
||||||
val shouldDequeueBMask = ((1 << numTilesMBits) - 1).U
|
val shouldDequeueBMask = ((1 << numTilesMBits) - 1).U
|
||||||
val shouldDequeueB =
|
val shouldDequeueB =
|
||||||
((stepExecute & shouldDequeueBMask) === shouldDequeueBMask) &&
|
((stepCompute & shouldDequeueBMask) === shouldDequeueBMask) &&
|
||||||
(substepCompute === 1.U)
|
(substepCompute === 1.U)
|
||||||
fullB.io.deq.ready := dpuFire && shouldDequeueB
|
fullB.io.deq.ready := dpuFire && shouldDequeueB
|
||||||
fullBTag.io.deq.ready := dpuFire && shouldDequeueB
|
fullBTag.io.deq.ready := dpuFire && shouldDequeueB
|
||||||
@@ -358,7 +348,8 @@ class TensorCoreDecoupled(
|
|||||||
dontTouch(shouldDequeueB)
|
dontTouch(shouldDequeueB)
|
||||||
|
|
||||||
// hold full A until two-cycle compute is done
|
// hold full A until two-cycle compute is done
|
||||||
fullAQueue.io.deq.ready := dpuFire && (substepCompute === 1.U)
|
fullABuf.io.deq.ready := dpuFire && (substepCompute === 1.U)
|
||||||
|
// FIXME: this should be nextStepCompute
|
||||||
val nextStepExecute = dpuFire && (substepCompute === 1.U)
|
val nextStepExecute = dpuFire && (substepCompute === 1.U)
|
||||||
|
|
||||||
// Assert that the DPU is computing with operands of the same set/step. Note
|
// Assert that the DPU is computing with operands of the same set/step. Note
|
||||||
@@ -369,8 +360,8 @@ class TensorCoreDecoupled(
|
|||||||
def assertAligned = {
|
def assertAligned = {
|
||||||
val stepMask = (1 << numTilesMBits).U
|
val stepMask = (1 << numTilesMBits).U
|
||||||
when (dpuFire) {
|
when (dpuFire) {
|
||||||
assert((fullAQueue.io.deq.bits.tag.set === fullBTag.io.deq.bits.set) &&
|
assert((fullABuf.io.deq.bits.tag.set === fullBTag.io.deq.bits.set) &&
|
||||||
((fullAQueue.io.deq.bits.tag.step & stepMask) ===
|
((fullABuf.io.deq.bits.tag.step & stepMask) ===
|
||||||
(fullBTag.io.deq.bits.step & stepMask)),
|
(fullBTag.io.deq.bits.step & stepMask)),
|
||||||
"A and B operands are pointing to different set/steps. " ++
|
"A and B operands are pointing to different set/steps. " ++
|
||||||
"This might indicate memory response coming back out-of-order.")
|
"This might indicate memory response coming back out-of-order.")
|
||||||
|
|||||||
Reference in New Issue
Block a user