From 6cad8edd1838642cbbb61ef6998c8318d96864e1 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Wed, 16 Oct 2024 22:01:02 -0700 Subject: [PATCH] tensor: Fix operand alignment in pipelining --- .../radiance/core/TensorCoreDecoupled.scala | 56 +++++++++++-------- 1 file changed, 33 insertions(+), 23 deletions(-) diff --git a/src/main/scala/radiance/core/TensorCoreDecoupled.scala b/src/main/scala/radiance/core/TensorCoreDecoupled.scala index 69b84f9..0654df3 100644 --- a/src/main/scala/radiance/core/TensorCoreDecoupled.scala +++ b/src/main/scala/radiance/core/TensorCoreDecoupled.scala @@ -224,37 +224,51 @@ class TensorCoreDecoupled( } dontTouch(substepExecute) + // Do pipelining for the A operand so that we obtain the full 4x4 A tile + // ready for compute. The pipeline is two-stage: + // - stage one (halfAQueue) for assembling the full A tile from half-tiles + // coming from the resp queue, and + // - stage two (fullAQueue) for holding the full A tile until it gets + // matched with two 4x2 B tiles, and compute is complete. + // + // Note that the half-tile assembly is unnecessary for B since the B tile is + // only 4x2. + // Also send the set/step tag along the pipe for alignment check. + // note combinationally coupled ready with `pipe` val halfAQueue = Module(new Queue( - chiselTypeOf(respQueueA.bits.data), entries = 1, pipe = true + chiselTypeOf(respQueueA.bits), entries = 1, pipe = true )) halfAQueue.io.enq.valid := respQueueA.valid && (substepExecute === 0.U) - halfAQueue.io.enq.bits := respQueueA.bits.data + halfAQueue.io.enq.bits := respQueueA.bits - // we need the full data for A because we divide the D tile by half along N; - // for B, the DPU can immediately start computing with a 4x2 tile. - // // substep == 0 data goes to the LSB - val fullAEnqData = Cat(respQueueA.bits.data, halfAQueue.io.deq.bits) + val fullAEnqData = Cat(respQueueA.bits.data, halfAQueue.io.deq.bits.data) + require(fullAEnqData.widthOption.get == dataWidth * 2, + "assumes 2-cycle read for a full compute tile of A") + // only use the lower halfA's tag. substep will be incorrect. + val fullAEnqTag = halfAQueue.io.deq.bits.tag val fullAQueue = Module(new Queue( - chiselTypeOf(fullAEnqData), entries = 1, pipe = true + new TensorMemRespWithTag(dataWidth * 2), entries = 1, pipe = true )) // hold first half A data for the first substep halfAQueue.io.deq.ready := respQueueA.valid && (substepExecute === 1.U) && fullAQueue.io.enq.ready - - require(fullAEnqData.widthOption.get == dataWidth * 2, - "assumes 2-cycle read for a full compute tile of A") fullAQueue.io.enq.valid := respQueueA.valid && (substepExecute === 1.U) && halfAQueue.io.deq.valid - fullAQueue.io.enq.bits := fullAEnqData + fullAQueue.io.enq.bits.data := fullAEnqData + fullAQueue.io.enq.bits.tag := fullAEnqTag val operandsValid = fullAQueue.io.deq.valid && respQueueB.valid // FIXME? val dpuFire = operandsValid && dpuReady - fullAQueue.io.deq.ready := dpuFire - val nextStepExecute = dpuFire + val substepCompute = RegInit(0.U(1.W)) + when (dpuFire) { + substepCompute := substepCompute + 1.U + } - // FIXME: need to hold A for two cycles!! + // hold full A until two-cycle compute is done + fullAQueue.io.deq.ready := dpuFire && (substepCompute === 1.U) + val nextStepExecute = dpuFire && (substepCompute === 1.U) // make sure to dequeue from response queues only when both A and B valid respQueueA.ready := MuxCase(false.B, @@ -264,21 +278,17 @@ class TensorCoreDecoupled( dontTouch(respQueueA) dontTouch(respQueueB) - // assert that the A and B response queue heads always point to the same - // set/step/substep + // assert that the DPU is computing with operands of the same set/step // // this assumes that memory responses come back in-order. this might be too // strong an assumption depending on the backing memory def assertAligned = { - val bothQueueValid = (respQueueA.valid && respQueueB.valid) - when (bothQueueValid && (substepExecute === 0.U)) { - assert((respQueueA.bits.tag.set === respQueueB.bits.tag.set) && - (respQueueA.bits.tag.step === respQueueB.bits.tag.step), - "A and B response queue pointing to different set/steps. " ++ + when (dpuFire) { + assert((fullAQueue.io.deq.bits.tag.set === respQueueB.bits.tag.set) && + (fullAQueue.io.deq.bits.tag.step === respQueueB.bits.tag.step), + "A and B operands are pointing to different set/steps. " ++ "This might indicate memory response coming back out-of-order.") } - dontTouch(respQueueA.bits.tag) - dontTouch(respQueueB.bits.tag) } assertAligned