tensor: Fix metadata of C req; fix dequeue / req gen timing

This commit is contained in:
Hansung Kim
2024-10-25 19:05:47 -07:00
parent 1a1a4a088d
commit d46a343239

View File

@@ -252,15 +252,15 @@ class TensorCoreDecoupled(
// C access from regfile // C access from regfile
// //
// since regfile is fixed-latency, respC valid should be determined at the // regfile is fast; don't need a deep response queue
// request sending side.
val respCValid = RegInit(false.B)
// regfile latency is 1 cycle; don't need a deep response queue
val respQueueCDepth = 2 val respQueueCDepth = 2
val respQueueC = Module(new Queue( val respQueueC = Module(new Queue(
new Bundle { new Bundle {
val tag = new TensorMemTag val tag = new Bundle {
val warp = UInt(numWarpBits.W)
val set = UInt(setBits.W)
val step = UInt(stepBits.W)
}
val data = UInt(io.respC.widthOption.get.W) val data = UInt(io.respC.widthOption.get.W)
}, },
respQueueCDepth respQueueCDepth
@@ -276,8 +276,25 @@ class TensorCoreDecoupled(
(respQueueC.entries - 1).U, (respQueueC.entries - 1).U,
(respQueueC.entries - 2).U) (respQueueC.entries - 2).U)
val genReqC = (state === AccessorState.access) && hasSpace && !doneReqC val genReqC = (state === AccessorState.access) && hasSpace && !doneReqC
// 1-cycle delay
respCValid := genReqC // set/step state of the C accumulator value that will be latched ath the
// next cycle.
val setAccessC = RegInit(0.U(setBits.W))
val stepAccessC = RegInit(0.U(stepBits.W))
val substepAccessC = RegInit(0.U(1.W))
val nextStepAccessC = genReqC && (substepAccessC === 1.U)
when (genReqC) {
substepAccessC := substepAccessC + 1.U
}
dontTouch(stepAccessC)
dontTouch(substepAccessC)
dontTouch(nextStepAccessC)
// give 1-cycle delay to sync valid/metadata with C regfile response
val respCValid = RegNext(genReqC)
val warpAccessCDelayed = RegNext(warpAccess)
val setAccessCDelayed = RegNext(setAccessC)
val stepAccessCDelayed = RegNext(stepAccessC)
// note rd is independent to sets // note rd is independent to sets
def rdGen(step: UInt, substep: UInt): UInt = { def rdGen(step: UInt, substep: UInt): UInt = {
@@ -286,25 +303,14 @@ class TensorCoreDecoupled(
(step << 1/*2 substeps*/) + substep (step << 1/*2 substeps*/) + substep
} }
io.reqC.valid := genReqC io.reqC.valid := genReqC
io.reqC.bits := 5.U // FIXME io.reqC.bits := rdGen(stepAccessC, substepAccessC)
// set/index state of the C accumulator value that will be latched ath the
// next cycle.
val stateC = RegInit(stateInit)
when (genReqC) {
when (stateC.index === lastIndex.U) {
stateC.set := stateC.set + 1.U
}
stateC.index := stateC.index + 1.U
}
dontTouch(stateC)
// queue the regfile response to buffers // queue the regfile response to buffers
// these strictly belong to the execute stage // these strictly belong to the execute stage
respQueueC.io.enq.valid := respCValid respQueueC.io.enq.valid := respCValid
respQueueC.io.enq.bits.tag.warp := warpAccess respQueueC.io.enq.bits.tag.warp := warpAccessCDelayed
respQueueC.io.enq.bits.tag.set := stateC.set respQueueC.io.enq.bits.tag.set := setAccessCDelayed
respQueueC.io.enq.bits.tag.index := stateC.index respQueueC.io.enq.bits.tag.step := stepAccessCDelayed
respQueueC.io.enq.bits.data := io.respC respQueueC.io.enq.bits.data := io.respC
// serialize every two C responses into one full 4x4 C tile // serialize every two C responses into one full 4x4 C tile
@@ -315,7 +321,8 @@ class TensorCoreDecoupled(
fullC.io.enq.bits := respQueueC.io.deq.bits.data fullC.io.enq.bits := respQueueC.io.deq.bits.data
respQueueC.io.deq.ready := fullC.io.enq.ready respQueueC.io.deq.ready := fullC.io.enq.ready
val fullCTag = Module(new Queue( val fullCTag = Module(new Queue(
new TensorMemTag, entries = 1, pipe = true chiselTypeOf(respQueueC.io.deq.bits.tag),
entries = 1, pipe = true
)) ))
fullCTag.io.enq.valid := respQueueC.io.deq.valid fullCTag.io.enq.valid := respQueueC.io.deq.valid
fullCTag.io.enq.bits := respQueueC.io.deq.bits.tag fullCTag.io.enq.bits := respQueueC.io.deq.bits.tag
@@ -323,10 +330,10 @@ class TensorCoreDecoupled(
// finalize state when everything has been accessed // finalize state when everything has been accessed
val lastReqA = (stateA.set === lastSet.U) && (stateA.index === lastIndex.U) val lastReqA = (stateA.set === lastSet.U) && (stateA.index === lastIndex.U)
val lastReqB = (stateB.set === lastSet.U) && (stateB.index === lastIndex.U) val lastReqB = (stateB.set === lastSet.U) && (stateB.index === lastIndex.U)
val lastReqC = (stateC.set === lastSet.U) && (stateC.index === lastIndex.U) val lastReqC = (setAccessC === lastSet.U) && (stepAccessC === lastStep.U)
when (lastReqA && io.reqA.fire) { doneReqA := true.B } when (lastReqA && io.reqA.fire) { doneReqA := true.B }
when (lastReqB && io.reqB.fire) { doneReqB := true.B } when (lastReqB && io.reqB.fire) { doneReqB := true.B }
when (lastReqC && genReqC) { doneReqC := true.B } when (lastReqC && nextStepAccessC) { doneReqC := true.B }
when (state === AccessorState.finish) { when (state === AccessorState.finish) {
doneReqA := false.B doneReqA := false.B
doneReqB := false.B doneReqB := false.B
@@ -335,8 +342,9 @@ class TensorCoreDecoupled(
stateA.index := 0.U stateA.index := 0.U
stateB.set := 0.U stateB.set := 0.U
stateB.index := 0.U stateB.index := 0.U
stateC.set := 0.U setAccessC := 0.U
stateC.index := 0.U stepAccessC := 0.U
substepAccessC := 0.U
} }
allReqsDone := doneReqA && doneReqB && doneReqC allReqsDone := doneReqA && doneReqB && doneReqC
@@ -432,7 +440,7 @@ class TensorCoreDecoupled(
val fullCBuf = Module(new Queue( val fullCBuf = Module(new Queue(
new Bundle { new Bundle {
val tag = new TensorMemTag val tag = chiselTypeOf(fullCTag.io.deq.bits)
val data = chiselTypeOf(fullC.io.deq.bits) val data = chiselTypeOf(fullC.io.deq.bits)
}, entries = 1, pipe = true }, entries = 1, pipe = true
)) ))
@@ -496,9 +504,9 @@ class TensorCoreDecoupled(
((stepCompute & shouldDequeueBMask) === shouldDequeueBMask) && ((stepCompute & shouldDequeueBMask) === shouldDequeueBMask) &&
(substepCompute === 1.U) (substepCompute === 1.U)
fullBBuf.io.deq.ready := dpuFire && shouldDequeueB fullBBuf.io.deq.ready := dpuFire && shouldDequeueB
// C should be dequeued everytime a 4x4 tile is done computing
// C deq should be synced with B deq val shouldDequeueC = (substepCompute === 1.U)
fullCBuf.io.deq.ready := dpuFire && shouldDequeueB fullCBuf.io.deq.ready := dpuFire && shouldDequeueC
dontTouch(respQueueA) dontTouch(respQueueA)
dontTouch(respQueueB) dontTouch(respQueueB)
@@ -639,6 +647,7 @@ class TensorCoreDecoupled(
} }
} }
} }
sequenceSetStep(setAccessC, stepAccessC, nextStepAccessC)
sequenceSetStep(setCompute, stepCompute, nextStepCompute) sequenceSetStep(setCompute, stepCompute, nextStepCompute)
} }
@@ -739,7 +748,7 @@ class TensorCoreDecoupledTLImp(outer: TensorCoreDecoupledTL)
tensor.io.respC := 42.U // FIXME bogus tensor.io.respC := 42.U // FIXME bogus
tensor.io.initiate.valid := io.start tensor.io.initiate.valid := io.start
tensor.io.initiate.bits.wid := 0.U // FIXME bogus tensor.io.initiate.bits.wid := 3.U // bogus, static value
tensor.io.writeback.ready := true.B tensor.io.writeback.ready := true.B
io.finished := tensor.io.writeback.valid && tensor.io.writeback.bits.last io.finished := tensor.io.writeback.valid && tensor.io.writeback.bits.last