diff --git a/src/main/scala/radiance/core/TensorCoreDecoupled.scala b/src/main/scala/radiance/core/TensorCoreDecoupled.scala index 4f5ecb3..7c07564 100644 --- a/src/main/scala/radiance/core/TensorCoreDecoupled.scala +++ b/src/main/scala/radiance/core/TensorCoreDecoupled.scala @@ -97,9 +97,16 @@ class TensorCoreDecoupled( // steps: i-j iteration val numSteps = (tilingParams.m * tilingParams.n) / (tilingParams.mc * tilingParams.nc) val stepBits = log2Ceil(numSteps) + val lastSet = ((1 << setBits) - 1) + val lastStep = ((1 << stepBits) - 1) + def setDone(set: UInt) = (set === lastSet.U) + def stepDone(step: UInt) = (step === lastStep.U) + // set and step being currently accessed in the acc/ex frontend val setAccess = RegInit(0.U(setBits.W)) val stepAccess = RegInit(0.U(stepBits.W)) + dontTouch(setAccess) + dontTouch(stepAccess) when(io.initiate.fire) { val wid = io.initiate.bits.wid @@ -118,6 +125,9 @@ class TensorCoreDecoupled( busy := false.B } + // serialize every HGMMA request + io.initiate.ready := !busy + // Memory traffic generation // ------------------------- // @@ -166,10 +176,10 @@ class TensorCoreDecoupled( req.fire }) val firedAB = (firedABNow.asUInt | firedABReg.asUInt) - val nextStep = firedAB.andR + val nextStepAccess = firedAB.andR // clear out firedABReg every step. this will overwrite the previous fired // write upon the last fire out of A and B - when (nextStep) { + when (nextStepAccess) { firedABReg := Seq(false.B, false.B) } @@ -180,6 +190,8 @@ class TensorCoreDecoupled( // set and step being currently executed in the acc/ex backend val setExecute = RegInit(0.U(setBits.W)) val stepExecute = RegInit(0.U(stepBits.W)) + dontTouch(setExecute) + dontTouch(stepExecute) val respQueueDepth = 4 // FIXME: parameterize val respQueueA = Queue(respATagged, respQueueDepth) @@ -198,13 +210,19 @@ class TensorCoreDecoupled( "A and B response queue pointing to different set/steps. " ++ "This might indicate memory response coming back out-of-order.") } - // synchronized dequeue + // dequeue is synchronized between A and B + // FIXME: this need to change to dpu_ready val deqResp = bothQueueValid && io.writeback.ready respQueueA.ready := deqResp respQueueB.ready := deqResp + // FIXME: this need to change to dpu_fire + val nextStepExecute = io.writeback.fire + + io.writeback.valid := bothQueueValid + io.writeback.bits.wid := warpReg + io.writeback.bits.last := setDone(setExecute) && stepDone(stepExecute) // FIXME: debug dummy: pipe A directly to writeback - io.writeback.valid := respQueueA.valid val groupedRespA = respQueueA.bits.data .asBools.grouped(wordSize * 8/*bits*/) .map(VecInit(_).asUInt) @@ -216,16 +234,17 @@ class TensorCoreDecoupled( // ---------------- // // set/step sequencing logic - val lastSet = ((1 << setBits) - 1) - val lastStep = ((1 << stepBits) - 1) - val setDone = (setAccess === lastSet.U) - val stepDone = (stepAccess === lastStep.U) - when (nextStep) { - stepAccess := (stepAccess + 1.U) & lastStep.U - when (stepDone) { - setAccess := (setAccess + 1.U) & lastSet.U + + def sequenceSetStep(set: UInt, step: UInt, nextStep: Bool) = { + when (nextStep) { + step := (step + 1.U) & lastStep.U + when (stepDone(step)) { + set := (set + 1.U) & lastSet.U + } } } + sequenceSetStep(setAccess, stepAccess, nextStepAccess) + sequenceSetStep(setExecute, stepExecute, nextStepExecute) switch(state) { is(TensorState.idle) { @@ -234,7 +253,7 @@ class TensorCoreDecoupled( } } is(TensorState.run) { - when (setDone && stepDone && nextStep) { + when (setDone(setAccess) && stepDone(stepAccess) && nextStepAccess) { when (state === TensorState.run) { state := TensorState.finish } @@ -247,11 +266,6 @@ class TensorCoreDecoupled( } } - io.initiate.ready := !busy - io.writeback.valid := (state === TensorState.finish) - io.writeback.bits.wid := warpReg - io.writeback.bits.last := false.B // TODO - // Writeback queues // ---------------- // These queues hold the metadata necessary for register @@ -328,7 +342,7 @@ class TensorCoreDecoupledTLImp(outer: TensorCoreDecoupledTL) tensor.io.initiate.bits.wid := 0.U // FIXME tensor.io.writeback.ready := true.B - io.finished := tensor.io.writeback.valid + io.finished := tensor.io.writeback.valid && tensor.io.writeback.bits.last } // a minimal Diplomacy graph with a tensor core and a TLRAM