tensor: Replace staging logic for A with FillBuffer

This commit is contained in:
Hansung Kim
2024-10-18 19:54:20 -07:00
parent 7fab6f89ad
commit c4b5a11fde

View File

@@ -272,46 +272,41 @@ class TensorCoreDecoupled(
io.writeback.bits.data.widthOption.get, io.writeback.bits.data.widthOption.get,
"response data width does not match the writeback data width") "response data width does not match the writeback data width")
// FIXME: unnecessary
val substepDeqA = RegInit(0.U(1.W)) val substepDeqA = RegInit(0.U(1.W))
when (respQueueA.fire) { when (respQueueA.fire) {
substepDeqA := substepDeqA + 1.U substepDeqA := substepDeqA + 1.U
} }
dontTouch(substepDeqA) dontTouch(substepDeqA)
// Do pipelining for the A operand so that we obtain the full 4x4 A tile // Stage the operands in a pipeline so that we obtain the full 4x4 tiles
// ready for compute. The pipeline is two-stage: // ready for compute. Also send the set/step tag along the pipe for
// - stage one (halfAQueue) for assembling the full A tile from half-tiles // alignment check.
// coming from the resp queue, and
// - stage two (fullAQueue) for holding the full A tile until it gets
// matched with two 4x2 B tiles, and compute is complete.
//
// Note that the half-tile assembly is unnecessary for B since the B tile is
// only 4x2.
// Also send the set/step tag along the pipe for alignment check.
// note combinationally coupled ready with `pipe` val fullA = Module(new FillBuffer(
val halfAQueue = Module(new Queue( chiselTypeOf(respQueueB.bits.data), 2/*substeps*/
chiselTypeOf(respQueueA.bits), entries = 1, pipe = true
)) ))
halfAQueue.io.enq.valid := respQueueA.valid && (substepDeqA === 0.U) fullA.io.enq.valid := respQueueA.valid
halfAQueue.io.enq.bits := respQueueA.bits fullA.io.enq.bits := respQueueA.bits.data
respQueueA.ready := fullA.io.enq.ready
// `pipe` combinationally couples enq-deq ready
val fullATag = Module(new Queue(
new TensorMemTag, entries = 1, pipe = true
))
fullATag.io.enq.valid := respQueueA.valid
fullATag.io.enq.bits := respQueueA.bits.tag
// substep == 0 data goes to the LSB // stage the full A tile once more so that FillBuffer can be filled up in the
val fullAEnqData = Cat(respQueueA.bits.data, halfAQueue.io.deq.bits.data) // background while the tile is being used for compute. This does come with
require(fullAEnqData.widthOption.get == dataWidth * 2, // capacity overhead.
"assumes 2-cycle read for a full compute tile of A") val fullABuf = Module(new Queue(
// only use the lower halfA's tag. substep will be incorrect.
val fullAEnqTag = halfAQueue.io.deq.bits.tag
val fullAQueue = Module(new Queue(
new TensorMemRespWithTag(dataWidth * 2), entries = 1, pipe = true new TensorMemRespWithTag(dataWidth * 2), entries = 1, pipe = true
)) ))
// hold first half A data for the first substep fullABuf.io.enq.valid := fullA.io.deq.valid
halfAQueue.io.deq.ready := respQueueA.valid && (substepDeqA === 1.U) && fullABuf.io.enq.bits.data := fullA.io.deq.bits.asUInt
fullAQueue.io.enq.ready fullABuf.io.enq.bits.tag := fullATag.io.deq.bits
fullAQueue.io.enq.valid := respQueueA.valid && (substepDeqA === 1.U) && fullA.io.deq.ready := fullABuf.io.enq.ready
halfAQueue.io.deq.valid fullATag.io.deq.ready := fullABuf.io.enq.ready
fullAQueue.io.enq.bits.data := fullAEnqData
fullAQueue.io.enq.bits.tag := fullAEnqTag
// serialize every two B responses into one full 4x4 B tile // serialize every two B responses into one full 4x4 B tile
// FIXME: do the same for A // FIXME: do the same for A
@@ -327,29 +322,24 @@ class TensorCoreDecoupled(
fullBTag.io.enq.valid := respQueueB.valid fullBTag.io.enq.valid := respQueueB.valid
fullBTag.io.enq.bits := respQueueB.bits.tag fullBTag.io.enq.bits := respQueueB.bits.tag
val operandsValid = fullAQueue.io.deq.valid && fullB.io.deq.valid val operandsValid = fullABuf.io.deq.valid && fullB.io.deq.valid
val operandA = fullAQueue.io.deq.bits.data val operandA = fullABuf.io.deq.bits.data
val operandATag = fullAQueue.io.deq.bits.tag val operandATag = fullABuf.io.deq.bits.tag
val operandB = fullB.io.deq.bits val operandB = fullB.io.deq.bits
val dpuReady = Wire(Bool()) val dpuReady = Wire(Bool())
val dpuFire = operandsValid && dpuReady val dpuFire = operandsValid && dpuReady
val setCompute = fullAQueue.io.deq.bits.tag.set val setCompute = fullABuf.io.deq.bits.tag.set
val stepCompute = fullAQueue.io.deq.bits.tag.step val stepCompute = fullABuf.io.deq.bits.tag.step
val substepCompute = RegInit(0.U(1.W)) val substepCompute = RegInit(0.U(1.W))
when (dpuFire) { when (dpuFire) {
substepCompute := substepCompute + 1.U substepCompute := substepCompute + 1.U
} }
// respQueueA output arbitrates to either halfAQueue or fullAQueue depending
// on the substep
respQueueA.ready := MuxCase(false.B,
Seq((substepDeqA === 0.U) -> halfAQueue.io.enq.ready,
(substepDeqA === 1.U) -> fullAQueue.io.enq.ready))
// Hold B tile at respQueueB for multiple steps for reuse, only dequeue when // Hold B tile at respQueueB for multiple steps for reuse, only dequeue when
// we fully iterated a column (M-dimension). // we fully iterated a column (M-dimension).
val shouldDequeueBMask = ((1 << numTilesMBits) - 1).U val shouldDequeueBMask = ((1 << numTilesMBits) - 1).U
val shouldDequeueB = val shouldDequeueB =
((stepExecute & shouldDequeueBMask) === shouldDequeueBMask) && ((stepCompute & shouldDequeueBMask) === shouldDequeueBMask) &&
(substepCompute === 1.U) (substepCompute === 1.U)
fullB.io.deq.ready := dpuFire && shouldDequeueB fullB.io.deq.ready := dpuFire && shouldDequeueB
fullBTag.io.deq.ready := dpuFire && shouldDequeueB fullBTag.io.deq.ready := dpuFire && shouldDequeueB
@@ -358,7 +348,8 @@ class TensorCoreDecoupled(
dontTouch(shouldDequeueB) dontTouch(shouldDequeueB)
// hold full A until two-cycle compute is done // hold full A until two-cycle compute is done
fullAQueue.io.deq.ready := dpuFire && (substepCompute === 1.U) fullABuf.io.deq.ready := dpuFire && (substepCompute === 1.U)
// FIXME: this should be nextStepCompute
val nextStepExecute = dpuFire && (substepCompute === 1.U) val nextStepExecute = dpuFire && (substepCompute === 1.U)
// Assert that the DPU is computing with operands of the same set/step. Note // Assert that the DPU is computing with operands of the same set/step. Note
@@ -369,8 +360,8 @@ class TensorCoreDecoupled(
def assertAligned = { def assertAligned = {
val stepMask = (1 << numTilesMBits).U val stepMask = (1 << numTilesMBits).U
when (dpuFire) { when (dpuFire) {
assert((fullAQueue.io.deq.bits.tag.set === fullBTag.io.deq.bits.set) && assert((fullABuf.io.deq.bits.tag.set === fullBTag.io.deq.bits.set) &&
((fullAQueue.io.deq.bits.tag.step & stepMask) === ((fullABuf.io.deq.bits.tag.step & stepMask) ===
(fullBTag.io.deq.bits.step & stepMask)), (fullBTag.io.deq.bits.step & stepMask)),
"A and B operands are pointing to different set/steps. " ++ "A and B operands are pointing to different set/steps. " ++
"This might indicate memory response coming back out-of-order.") "This might indicate memory response coming back out-of-order.")