diff --git a/src/main/scala/radiance/core/TensorCoreDecoupled.scala b/src/main/scala/radiance/core/TensorCoreDecoupled.scala index d91f49f..398740e 100644 --- a/src/main/scala/radiance/core/TensorCoreDecoupled.scala +++ b/src/main/scala/radiance/core/TensorCoreDecoupled.scala @@ -252,15 +252,15 @@ class TensorCoreDecoupled( // C access from regfile // - // since regfile is fixed-latency, respC valid should be determined at the - // request sending side. - val respCValid = RegInit(false.B) - - // regfile latency is 1 cycle; don't need a deep response queue + // regfile is fast; don't need a deep response queue val respQueueCDepth = 2 val respQueueC = Module(new Queue( new Bundle { - val tag = new TensorMemTag + val tag = new Bundle { + val warp = UInt(numWarpBits.W) + val set = UInt(setBits.W) + val step = UInt(stepBits.W) + } val data = UInt(io.respC.widthOption.get.W) }, respQueueCDepth @@ -276,8 +276,25 @@ class TensorCoreDecoupled( (respQueueC.entries - 1).U, (respQueueC.entries - 2).U) val genReqC = (state === AccessorState.access) && hasSpace && !doneReqC - // 1-cycle delay - respCValid := genReqC + + // set/step state of the C accumulator value that will be latched ath the + // next cycle. + val setAccessC = RegInit(0.U(setBits.W)) + val stepAccessC = RegInit(0.U(stepBits.W)) + val substepAccessC = RegInit(0.U(1.W)) + val nextStepAccessC = genReqC && (substepAccessC === 1.U) + when (genReqC) { + substepAccessC := substepAccessC + 1.U + } + dontTouch(stepAccessC) + dontTouch(substepAccessC) + dontTouch(nextStepAccessC) + + // give 1-cycle delay to sync valid/metadata with C regfile response + val respCValid = RegNext(genReqC) + val warpAccessCDelayed = RegNext(warpAccess) + val setAccessCDelayed = RegNext(setAccessC) + val stepAccessCDelayed = RegNext(stepAccessC) // note rd is independent to sets def rdGen(step: UInt, substep: UInt): UInt = { @@ -286,25 +303,14 @@ class TensorCoreDecoupled( (step << 1/*2 substeps*/) + substep } io.reqC.valid := genReqC - io.reqC.bits := 5.U // FIXME - - // set/index state of the C accumulator value that will be latched ath the - // next cycle. - val stateC = RegInit(stateInit) - when (genReqC) { - when (stateC.index === lastIndex.U) { - stateC.set := stateC.set + 1.U - } - stateC.index := stateC.index + 1.U - } - dontTouch(stateC) + io.reqC.bits := rdGen(stepAccessC, substepAccessC) // queue the regfile response to buffers // these strictly belong to the execute stage respQueueC.io.enq.valid := respCValid - respQueueC.io.enq.bits.tag.warp := warpAccess - respQueueC.io.enq.bits.tag.set := stateC.set - respQueueC.io.enq.bits.tag.index := stateC.index + respQueueC.io.enq.bits.tag.warp := warpAccessCDelayed + respQueueC.io.enq.bits.tag.set := setAccessCDelayed + respQueueC.io.enq.bits.tag.step := stepAccessCDelayed respQueueC.io.enq.bits.data := io.respC // serialize every two C responses into one full 4x4 C tile @@ -315,7 +321,8 @@ class TensorCoreDecoupled( fullC.io.enq.bits := respQueueC.io.deq.bits.data respQueueC.io.deq.ready := fullC.io.enq.ready val fullCTag = Module(new Queue( - new TensorMemTag, entries = 1, pipe = true + chiselTypeOf(respQueueC.io.deq.bits.tag), + entries = 1, pipe = true )) fullCTag.io.enq.valid := respQueueC.io.deq.valid fullCTag.io.enq.bits := respQueueC.io.deq.bits.tag @@ -323,10 +330,10 @@ class TensorCoreDecoupled( // finalize state when everything has been accessed val lastReqA = (stateA.set === lastSet.U) && (stateA.index === lastIndex.U) val lastReqB = (stateB.set === lastSet.U) && (stateB.index === lastIndex.U) - val lastReqC = (stateC.set === lastSet.U) && (stateC.index === lastIndex.U) + val lastReqC = (setAccessC === lastSet.U) && (stepAccessC === lastStep.U) when (lastReqA && io.reqA.fire) { doneReqA := true.B } when (lastReqB && io.reqB.fire) { doneReqB := true.B } - when (lastReqC && genReqC) { doneReqC := true.B } + when (lastReqC && nextStepAccessC) { doneReqC := true.B } when (state === AccessorState.finish) { doneReqA := false.B doneReqB := false.B @@ -335,8 +342,9 @@ class TensorCoreDecoupled( stateA.index := 0.U stateB.set := 0.U stateB.index := 0.U - stateC.set := 0.U - stateC.index := 0.U + setAccessC := 0.U + stepAccessC := 0.U + substepAccessC := 0.U } allReqsDone := doneReqA && doneReqB && doneReqC @@ -432,7 +440,7 @@ class TensorCoreDecoupled( val fullCBuf = Module(new Queue( new Bundle { - val tag = new TensorMemTag + val tag = chiselTypeOf(fullCTag.io.deq.bits) val data = chiselTypeOf(fullC.io.deq.bits) }, entries = 1, pipe = true )) @@ -496,9 +504,9 @@ class TensorCoreDecoupled( ((stepCompute & shouldDequeueBMask) === shouldDequeueBMask) && (substepCompute === 1.U) fullBBuf.io.deq.ready := dpuFire && shouldDequeueB - - // C deq should be synced with B deq - fullCBuf.io.deq.ready := dpuFire && shouldDequeueB + // C should be dequeued everytime a 4x4 tile is done computing + val shouldDequeueC = (substepCompute === 1.U) + fullCBuf.io.deq.ready := dpuFire && shouldDequeueC dontTouch(respQueueA) dontTouch(respQueueB) @@ -639,6 +647,7 @@ class TensorCoreDecoupled( } } } + sequenceSetStep(setAccessC, stepAccessC, nextStepAccessC) sequenceSetStep(setCompute, stepCompute, nextStepCompute) } @@ -739,7 +748,7 @@ class TensorCoreDecoupledTLImp(outer: TensorCoreDecoupledTL) tensor.io.respC := 42.U // FIXME bogus tensor.io.initiate.valid := io.start - tensor.io.initiate.bits.wid := 0.U // FIXME bogus + tensor.io.initiate.bits.wid := 3.U // bogus, static value tensor.io.writeback.ready := true.B io.finished := tensor.io.writeback.valid && tensor.io.writeback.bits.last