tensor: Fix operand alignment in pipelining
This commit is contained in:
@@ -224,37 +224,51 @@ class TensorCoreDecoupled(
|
|||||||
}
|
}
|
||||||
dontTouch(substepExecute)
|
dontTouch(substepExecute)
|
||||||
|
|
||||||
|
// Do pipelining for the A operand so that we obtain the full 4x4 A tile
|
||||||
|
// ready for compute. The pipeline is two-stage:
|
||||||
|
// - stage one (halfAQueue) for assembling the full A tile from half-tiles
|
||||||
|
// coming from the resp queue, and
|
||||||
|
// - stage two (fullAQueue) for holding the full A tile until it gets
|
||||||
|
// matched with two 4x2 B tiles, and compute is complete.
|
||||||
|
//
|
||||||
|
// Note that the half-tile assembly is unnecessary for B since the B tile is
|
||||||
|
// only 4x2.
|
||||||
|
// Also send the set/step tag along the pipe for alignment check.
|
||||||
|
|
||||||
// note combinationally coupled ready with `pipe`
|
// note combinationally coupled ready with `pipe`
|
||||||
val halfAQueue = Module(new Queue(
|
val halfAQueue = Module(new Queue(
|
||||||
chiselTypeOf(respQueueA.bits.data), entries = 1, pipe = true
|
chiselTypeOf(respQueueA.bits), entries = 1, pipe = true
|
||||||
))
|
))
|
||||||
halfAQueue.io.enq.valid := respQueueA.valid && (substepExecute === 0.U)
|
halfAQueue.io.enq.valid := respQueueA.valid && (substepExecute === 0.U)
|
||||||
halfAQueue.io.enq.bits := respQueueA.bits.data
|
halfAQueue.io.enq.bits := respQueueA.bits
|
||||||
|
|
||||||
// we need the full data for A because we divide the D tile by half along N;
|
|
||||||
// for B, the DPU can immediately start computing with a 4x2 tile.
|
|
||||||
//
|
|
||||||
// substep == 0 data goes to the LSB
|
// substep == 0 data goes to the LSB
|
||||||
val fullAEnqData = Cat(respQueueA.bits.data, halfAQueue.io.deq.bits)
|
val fullAEnqData = Cat(respQueueA.bits.data, halfAQueue.io.deq.bits.data)
|
||||||
|
require(fullAEnqData.widthOption.get == dataWidth * 2,
|
||||||
|
"assumes 2-cycle read for a full compute tile of A")
|
||||||
|
// only use the lower halfA's tag. substep will be incorrect.
|
||||||
|
val fullAEnqTag = halfAQueue.io.deq.bits.tag
|
||||||
val fullAQueue = Module(new Queue(
|
val fullAQueue = Module(new Queue(
|
||||||
chiselTypeOf(fullAEnqData), entries = 1, pipe = true
|
new TensorMemRespWithTag(dataWidth * 2), entries = 1, pipe = true
|
||||||
))
|
))
|
||||||
// hold first half A data for the first substep
|
// hold first half A data for the first substep
|
||||||
halfAQueue.io.deq.ready := respQueueA.valid && (substepExecute === 1.U) &&
|
halfAQueue.io.deq.ready := respQueueA.valid && (substepExecute === 1.U) &&
|
||||||
fullAQueue.io.enq.ready
|
fullAQueue.io.enq.ready
|
||||||
|
|
||||||
require(fullAEnqData.widthOption.get == dataWidth * 2,
|
|
||||||
"assumes 2-cycle read for a full compute tile of A")
|
|
||||||
fullAQueue.io.enq.valid := respQueueA.valid && (substepExecute === 1.U) &&
|
fullAQueue.io.enq.valid := respQueueA.valid && (substepExecute === 1.U) &&
|
||||||
halfAQueue.io.deq.valid
|
halfAQueue.io.deq.valid
|
||||||
fullAQueue.io.enq.bits := fullAEnqData
|
fullAQueue.io.enq.bits.data := fullAEnqData
|
||||||
|
fullAQueue.io.enq.bits.tag := fullAEnqTag
|
||||||
|
|
||||||
val operandsValid = fullAQueue.io.deq.valid && respQueueB.valid // FIXME?
|
val operandsValid = fullAQueue.io.deq.valid && respQueueB.valid // FIXME?
|
||||||
val dpuFire = operandsValid && dpuReady
|
val dpuFire = operandsValid && dpuReady
|
||||||
fullAQueue.io.deq.ready := dpuFire
|
val substepCompute = RegInit(0.U(1.W))
|
||||||
val nextStepExecute = dpuFire
|
when (dpuFire) {
|
||||||
|
substepCompute := substepCompute + 1.U
|
||||||
|
}
|
||||||
|
|
||||||
// FIXME: need to hold A for two cycles!!
|
// hold full A until two-cycle compute is done
|
||||||
|
fullAQueue.io.deq.ready := dpuFire && (substepCompute === 1.U)
|
||||||
|
val nextStepExecute = dpuFire && (substepCompute === 1.U)
|
||||||
|
|
||||||
// make sure to dequeue from response queues only when both A and B valid
|
// make sure to dequeue from response queues only when both A and B valid
|
||||||
respQueueA.ready := MuxCase(false.B,
|
respQueueA.ready := MuxCase(false.B,
|
||||||
@@ -264,21 +278,17 @@ class TensorCoreDecoupled(
|
|||||||
dontTouch(respQueueA)
|
dontTouch(respQueueA)
|
||||||
dontTouch(respQueueB)
|
dontTouch(respQueueB)
|
||||||
|
|
||||||
// assert that the A and B response queue heads always point to the same
|
// assert that the DPU is computing with operands of the same set/step
|
||||||
// set/step/substep
|
|
||||||
//
|
//
|
||||||
// this assumes that memory responses come back in-order. this might be too
|
// this assumes that memory responses come back in-order. this might be too
|
||||||
// strong an assumption depending on the backing memory
|
// strong an assumption depending on the backing memory
|
||||||
def assertAligned = {
|
def assertAligned = {
|
||||||
val bothQueueValid = (respQueueA.valid && respQueueB.valid)
|
when (dpuFire) {
|
||||||
when (bothQueueValid && (substepExecute === 0.U)) {
|
assert((fullAQueue.io.deq.bits.tag.set === respQueueB.bits.tag.set) &&
|
||||||
assert((respQueueA.bits.tag.set === respQueueB.bits.tag.set) &&
|
(fullAQueue.io.deq.bits.tag.step === respQueueB.bits.tag.step),
|
||||||
(respQueueA.bits.tag.step === respQueueB.bits.tag.step),
|
"A and B operands are pointing to different set/steps. " ++
|
||||||
"A and B response queue pointing to different set/steps. " ++
|
|
||||||
"This might indicate memory response coming back out-of-order.")
|
"This might indicate memory response coming back out-of-order.")
|
||||||
}
|
}
|
||||||
dontTouch(respQueueA.bits.tag)
|
|
||||||
dontTouch(respQueueB.bits.tag)
|
|
||||||
}
|
}
|
||||||
assertAligned
|
assertAligned
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user