From 77dae3e1f9941d15c213b19a43cd82bd0e00c81c Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Wed, 16 Oct 2024 21:21:48 -0700 Subject: [PATCH] tensor: Write staging pipeline for A tile --- .../radiance/core/TensorCoreDecoupled.scala | 103 ++++++++++++++---- src/main/scala/radiance/core/TensorDPU.scala | 1 + 2 files changed, 83 insertions(+), 21 deletions(-) diff --git a/src/main/scala/radiance/core/TensorCoreDecoupled.scala b/src/main/scala/radiance/core/TensorCoreDecoupled.scala index 92f98b7..69b84f9 100644 --- a/src/main/scala/radiance/core/TensorCoreDecoupled.scala +++ b/src/main/scala/radiance/core/TensorCoreDecoupled.scala @@ -108,8 +108,12 @@ class TensorCoreDecoupled( // set and step being currently accessed in the acc/ex frontend val setAccess = RegInit(0.U(setBits.W)) val stepAccess = RegInit(0.U(stepBits.W)) + // we need full 4x4 A tile to fire DPU, but since the memory width is 8 + // words, we need 2 cycles to read A. `substep` tells which cycle we're at. + val substepAccess = RegInit(0.U(1.W)) dontTouch(setAccess) dontTouch(stepAccess) + dontTouch(substepAccess) when(io.initiate.fire) { val wid = io.initiate.bits.wid @@ -139,16 +143,19 @@ class TensorCoreDecoupled( class TensorMemTag extends Bundle { val set = UInt(setBits.W) val step = UInt(stepBits.W) + val substep = UInt(1.W) } // use concatenation of set/step as the memory request source. This will get // translated to the actual TL sourcewidth in sourceGen. val tag = Wire(new TensorMemTag) tag.set := setAccess tag.step := stepAccess + tag.substep := substepAccess val respATagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth))) val respBTagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth))) - Seq((io.reqA, (io.respA, respATagged)), (io.reqB, (io.respB, respBTagged))).foreach { + Seq((io.reqA, (io.respA, respATagged)), + (io.reqB, (io.respB, respBTagged))).foreach { case (req, (resp, respTagged)) => { val sourceGen = Module(new SourceGenerator( log2Ceil(numSourceIds), @@ -173,18 +180,22 @@ class TensorCoreDecoupled( } // only advance to the next step if we fired mem requests for both A and B + // TODO: @perf: too strict? should be able to have A and B progress + // separately val firedABReg = RegInit(VecInit(false.B, false.B)) val firedABNow = VecInit((Seq(io.reqA, io.reqB) zip firedABReg).map { case (req, fired) => { when (req.fire) { fired := true.B } } req.fire }) val firedAB = (firedABNow.asUInt | firedABReg.asUInt) - val nextStepAccess = firedAB.andR - // clear out firedABReg every step. this will overwrite the previous fired - // write upon the last fire out of A and B - when (nextStepAccess) { + val nextSubstepAccess = firedAB.andR + val nextStepAccess = nextSubstepAccess && (substepAccess === 1.U) + // clear out firedABReg every substep + when (nextSubstepAccess) { firedABReg := Seq(false.B, false.B) + substepAccess := substepAccess + 1.U } + require(substepAccess.widthOption.get == 1, "there should be only two substeps") // Execute stage // ------------- @@ -204,22 +215,72 @@ class TensorCoreDecoupled( io.writeback.bits.data.widthOption.get, "response data width does not match the writeback data width") - val bothQueueValid = (respQueueA.valid && respQueueB.valid) - // assume in-order response and that A/B responses are always aligned; this - // might be too strong an assumption depending on the backing memory - when (bothQueueValid) { - assert((respQueueA.bits.tag.set === respQueueB.bits.tag.set) && - (respQueueA.bits.tag.step === respQueueB.bits.tag.step), - "A and B response queue pointing to different set/steps. " ++ - "This might indicate memory response coming back out-of-order.") - } - // dequeue is synchronized between A and B // FIXME: this need to change to dpu_ready - val deqResp = bothQueueValid && io.writeback.ready - respQueueA.ready := deqResp - respQueueB.ready := deqResp - // FIXME: this need to change to dpu_fire - val nextStepExecute = io.writeback.fire + val dpuReady = io.writeback.ready // FIXME: this need be actual dpu + + val substepExecute = RegInit(0.U(1.W)) + when (respQueueA.fire) { + substepExecute := substepExecute + 1.U + } + dontTouch(substepExecute) + + // note combinationally coupled ready with `pipe` + val halfAQueue = Module(new Queue( + chiselTypeOf(respQueueA.bits.data), entries = 1, pipe = true + )) + halfAQueue.io.enq.valid := respQueueA.valid && (substepExecute === 0.U) + halfAQueue.io.enq.bits := respQueueA.bits.data + + // we need the full data for A because we divide the D tile by half along N; + // for B, the DPU can immediately start computing with a 4x2 tile. + // + // substep == 0 data goes to the LSB + val fullAEnqData = Cat(respQueueA.bits.data, halfAQueue.io.deq.bits) + val fullAQueue = Module(new Queue( + chiselTypeOf(fullAEnqData), entries = 1, pipe = true + )) + // hold first half A data for the first substep + halfAQueue.io.deq.ready := respQueueA.valid && (substepExecute === 1.U) && + fullAQueue.io.enq.ready + + require(fullAEnqData.widthOption.get == dataWidth * 2, + "assumes 2-cycle read for a full compute tile of A") + fullAQueue.io.enq.valid := respQueueA.valid && (substepExecute === 1.U) && + halfAQueue.io.deq.valid + fullAQueue.io.enq.bits := fullAEnqData + + val operandsValid = fullAQueue.io.deq.valid && respQueueB.valid // FIXME? + val dpuFire = operandsValid && dpuReady + fullAQueue.io.deq.ready := dpuFire + val nextStepExecute = dpuFire + + // FIXME: need to hold A for two cycles!! + + // make sure to dequeue from response queues only when both A and B valid + respQueueA.ready := MuxCase(false.B, + Seq((substepExecute === 0.U) -> halfAQueue.io.enq.ready, + (substepExecute === 1.U) -> fullAQueue.io.enq.ready)) + respQueueB.ready := dpuFire + dontTouch(respQueueA) + dontTouch(respQueueB) + + // assert that the A and B response queue heads always point to the same + // set/step/substep + // + // this assumes that memory responses come back in-order. this might be too + // strong an assumption depending on the backing memory + def assertAligned = { + val bothQueueValid = (respQueueA.valid && respQueueB.valid) + when (bothQueueValid && (substepExecute === 0.U)) { + assert((respQueueA.bits.tag.set === respQueueB.bits.tag.set) && + (respQueueA.bits.tag.step === respQueueB.bits.tag.step), + "A and B response queue pointing to different set/steps. " ++ + "This might indicate memory response coming back out-of-order.") + } + dontTouch(respQueueA.bits.tag) + dontTouch(respQueueB.bits.tag) + } + assertAligned def rdGen(set: UInt, step: UInt): UInt = { // each step produces 4x4 output tile, written by 8 threads with 2 regs per @@ -229,7 +290,7 @@ class TensorCoreDecoupled( // FIXME: add substep here } - io.writeback.valid := bothQueueValid + io.writeback.valid := operandsValid // FIXME: bypass logic io.writeback.bits.wid := warpReg io.writeback.bits.rd := rdGen(setExecute, stepExecute) io.writeback.bits.last := setDone(setExecute) && stepDone(stepExecute) diff --git a/src/main/scala/radiance/core/TensorDPU.scala b/src/main/scala/radiance/core/TensorDPU.scala index 4e6cee7..a82bed7 100644 --- a/src/main/scala/radiance/core/TensorDPU.scala +++ b/src/main/scala/radiance/core/TensorDPU.scala @@ -27,6 +27,7 @@ class TensorDotProductUnit(val half: Boolean) extends Module with tile.HasFPUPar val b = Vec(dotProductDim, Bits((inFLen).W)) val c = Bits((outFLen).W) // note C has the out length for accumulation })) + // 'stall' is effectively out.ready, combinationally coupled to in.ready val stall = Input(Bool()) val out = Valid(new Bundle { val data = Bits((outFLen).W)