From 7de8e86d4f04712f90c4457940c02a341b721f76 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Thu, 17 Oct 2024 15:18:47 -0700 Subject: [PATCH] tensor: Sync rd with DPU using a queue --- .../radiance/core/TensorCoreDecoupled.scala | 44 ++++++++++++------- src/main/scala/radiance/core/TensorDPU.scala | 2 +- 2 files changed, 29 insertions(+), 17 deletions(-) diff --git a/src/main/scala/radiance/core/TensorCoreDecoupled.scala b/src/main/scala/radiance/core/TensorCoreDecoupled.scala index b9695ad..92a6596 100644 --- a/src/main/scala/radiance/core/TensorCoreDecoupled.scala +++ b/src/main/scala/radiance/core/TensorCoreDecoupled.scala @@ -270,6 +270,8 @@ class TensorCoreDecoupled( val operandB = respQueueB.bits.data val dpuReady = Wire(Bool()) val dpuFire = operandsValid && dpuReady + val setCompute = fullAQueue.io.deq.bits.tag.set + val stepCompute = fullAQueue.io.deq.bits.tag.step val substepCompute = RegInit(0.U(1.W)) when (dpuFire) { substepCompute := substepCompute + 1.U @@ -348,9 +350,9 @@ class TensorCoreDecoupled( def assertDPU = { val dpuStalls = dpus.flatMap(_.map(_.io.stall)) assert(dpuStalls.reduce(_ && _) === dpuStalls.reduce(_ || _), - "stall signals of DPUs went unaligned") + "stall signals of DPUs went out of sync") assert(dpuValids.reduce(_ && _) === dpuValids.reduce(_ || _), - "valid signals of DPUs went unaligned") + "valid signals of DPUs went out of sync") } assertDPU @@ -362,17 +364,36 @@ class TensorCoreDecoupled( } io.writeback.bits.data := flattenedDPUOut - def rdGen(set: UInt, step: UInt): UInt = { + // Writeback queues + // ---------------- + // These queues hold metadata needed for writeback in sync with the DPU. + + val queueDepth = 4 // needs to be at least the DPU latency + val rdQueue = Module(new Queue( + chiselTypeOf(io.writeback.bits.rd), queueDepth + )) + rdQueue.io.enq.valid := dpuFire + rdQueue.io.enq.bits := rdGen(stepCompute, substepCompute) + rdQueue.io.deq.ready := io.writeback.fire + assert(rdQueue.io.enq.ready === true.B, + "rd queue full, throttling DPU operation") + assert(!dpuValid || rdQueue.io.deq.valid, + "rd queue and DPU went out of sync") + + // TODO: decouple wid from frontend + // val widQueue = Queue(io.initiate, queueDepth, pipe = (queueDepth == 1)) + + // note rd is independent to sets + def rdGen(step: UInt, substep: UInt): UInt = { // each step produces 4x4 output tile, written by 8 threads with 2 regs per // thread - require(numLanes == 8, "currently assumes 8-wide warps") - (Cat(set, step) >> 1/*2 regs/thread*/) - // FIXME: add substep here + (step << 1/*2 substeps*/) + substep } io.writeback.valid := dpuValid io.writeback.bits.wid := warpReg - io.writeback.bits.rd := rdGen(setExecute, stepExecute) + io.writeback.bits.rd := rdQueue.io.deq.bits + // FIXME: look at set/step of dpu output not setExecute io.writeback.bits.last := setDone(setExecute) && stepDone(stepExecute) // State transition @@ -410,15 +431,6 @@ class TensorCoreDecoupled( } } } - - // Writeback queues - // ---------------- - // These queues hold the metadata necessary for register - // writeback. - - // val queueDepth = 2 - // val widQueue = Queue(io.initiate, queueDepth, pipe = (queueDepth == 1)) - // val rdQueue = Queue(io.initiate, queueDepth, pipe = (queueDepth == 1)) } // synthesizable unit tests diff --git a/src/main/scala/radiance/core/TensorDPU.scala b/src/main/scala/radiance/core/TensorDPU.scala index a82bed7..515b1bf 100644 --- a/src/main/scala/radiance/core/TensorDPU.scala +++ b/src/main/scala/radiance/core/TensorDPU.scala @@ -53,7 +53,7 @@ class TensorDotProductUnit(val half: Boolean) extends Module with tile.HasFPUPar io.out.bits.data := ieee(box(dpu.io.out.bits.data, S)) } -// Copied from chisel3.util.Pipe. +// An implementation of chisel3.util.Pipe that supports stalls. class StallingPipe[T <: Data](val gen: T, val latency: Int = 1) extends Module { /** A non-ambiguous name of this `StallingPipe` for use in generated Verilog * names. Includes the latency cycle count in the name as well as the