tensor: Fix metadata of C req; fix dequeue / req gen timing

This commit is contained in:
Hansung Kim
2024-10-25 19:05:47 -07:00
parent 1a1a4a088d
commit d46a343239

View File

@@ -252,15 +252,15 @@ class TensorCoreDecoupled(
// C access from regfile
//
// since regfile is fixed-latency, respC valid should be determined at the
// request sending side.
val respCValid = RegInit(false.B)
// regfile latency is 1 cycle; don't need a deep response queue
// regfile is fast; don't need a deep response queue
val respQueueCDepth = 2
val respQueueC = Module(new Queue(
new Bundle {
val tag = new TensorMemTag
val tag = new Bundle {
val warp = UInt(numWarpBits.W)
val set = UInt(setBits.W)
val step = UInt(stepBits.W)
}
val data = UInt(io.respC.widthOption.get.W)
},
respQueueCDepth
@@ -276,8 +276,25 @@ class TensorCoreDecoupled(
(respQueueC.entries - 1).U,
(respQueueC.entries - 2).U)
val genReqC = (state === AccessorState.access) && hasSpace && !doneReqC
// 1-cycle delay
respCValid := genReqC
// set/step state of the C accumulator value that will be latched ath the
// next cycle.
val setAccessC = RegInit(0.U(setBits.W))
val stepAccessC = RegInit(0.U(stepBits.W))
val substepAccessC = RegInit(0.U(1.W))
val nextStepAccessC = genReqC && (substepAccessC === 1.U)
when (genReqC) {
substepAccessC := substepAccessC + 1.U
}
dontTouch(stepAccessC)
dontTouch(substepAccessC)
dontTouch(nextStepAccessC)
// give 1-cycle delay to sync valid/metadata with C regfile response
val respCValid = RegNext(genReqC)
val warpAccessCDelayed = RegNext(warpAccess)
val setAccessCDelayed = RegNext(setAccessC)
val stepAccessCDelayed = RegNext(stepAccessC)
// note rd is independent to sets
def rdGen(step: UInt, substep: UInt): UInt = {
@@ -286,25 +303,14 @@ class TensorCoreDecoupled(
(step << 1/*2 substeps*/) + substep
}
io.reqC.valid := genReqC
io.reqC.bits := 5.U // FIXME
// set/index state of the C accumulator value that will be latched ath the
// next cycle.
val stateC = RegInit(stateInit)
when (genReqC) {
when (stateC.index === lastIndex.U) {
stateC.set := stateC.set + 1.U
}
stateC.index := stateC.index + 1.U
}
dontTouch(stateC)
io.reqC.bits := rdGen(stepAccessC, substepAccessC)
// queue the regfile response to buffers
// these strictly belong to the execute stage
respQueueC.io.enq.valid := respCValid
respQueueC.io.enq.bits.tag.warp := warpAccess
respQueueC.io.enq.bits.tag.set := stateC.set
respQueueC.io.enq.bits.tag.index := stateC.index
respQueueC.io.enq.bits.tag.warp := warpAccessCDelayed
respQueueC.io.enq.bits.tag.set := setAccessCDelayed
respQueueC.io.enq.bits.tag.step := stepAccessCDelayed
respQueueC.io.enq.bits.data := io.respC
// serialize every two C responses into one full 4x4 C tile
@@ -315,7 +321,8 @@ class TensorCoreDecoupled(
fullC.io.enq.bits := respQueueC.io.deq.bits.data
respQueueC.io.deq.ready := fullC.io.enq.ready
val fullCTag = Module(new Queue(
new TensorMemTag, entries = 1, pipe = true
chiselTypeOf(respQueueC.io.deq.bits.tag),
entries = 1, pipe = true
))
fullCTag.io.enq.valid := respQueueC.io.deq.valid
fullCTag.io.enq.bits := respQueueC.io.deq.bits.tag
@@ -323,10 +330,10 @@ class TensorCoreDecoupled(
// finalize state when everything has been accessed
val lastReqA = (stateA.set === lastSet.U) && (stateA.index === lastIndex.U)
val lastReqB = (stateB.set === lastSet.U) && (stateB.index === lastIndex.U)
val lastReqC = (stateC.set === lastSet.U) && (stateC.index === lastIndex.U)
val lastReqC = (setAccessC === lastSet.U) && (stepAccessC === lastStep.U)
when (lastReqA && io.reqA.fire) { doneReqA := true.B }
when (lastReqB && io.reqB.fire) { doneReqB := true.B }
when (lastReqC && genReqC) { doneReqC := true.B }
when (lastReqC && nextStepAccessC) { doneReqC := true.B }
when (state === AccessorState.finish) {
doneReqA := false.B
doneReqB := false.B
@@ -335,8 +342,9 @@ class TensorCoreDecoupled(
stateA.index := 0.U
stateB.set := 0.U
stateB.index := 0.U
stateC.set := 0.U
stateC.index := 0.U
setAccessC := 0.U
stepAccessC := 0.U
substepAccessC := 0.U
}
allReqsDone := doneReqA && doneReqB && doneReqC
@@ -432,7 +440,7 @@ class TensorCoreDecoupled(
val fullCBuf = Module(new Queue(
new Bundle {
val tag = new TensorMemTag
val tag = chiselTypeOf(fullCTag.io.deq.bits)
val data = chiselTypeOf(fullC.io.deq.bits)
}, entries = 1, pipe = true
))
@@ -496,9 +504,9 @@ class TensorCoreDecoupled(
((stepCompute & shouldDequeueBMask) === shouldDequeueBMask) &&
(substepCompute === 1.U)
fullBBuf.io.deq.ready := dpuFire && shouldDequeueB
// C deq should be synced with B deq
fullCBuf.io.deq.ready := dpuFire && shouldDequeueB
// C should be dequeued everytime a 4x4 tile is done computing
val shouldDequeueC = (substepCompute === 1.U)
fullCBuf.io.deq.ready := dpuFire && shouldDequeueC
dontTouch(respQueueA)
dontTouch(respQueueB)
@@ -639,6 +647,7 @@ class TensorCoreDecoupled(
}
}
}
sequenceSetStep(setAccessC, stepAccessC, nextStepAccessC)
sequenceSetStep(setCompute, stepCompute, nextStepCompute)
}
@@ -739,7 +748,7 @@ class TensorCoreDecoupledTLImp(outer: TensorCoreDecoupledTL)
tensor.io.respC := 42.U // FIXME bogus
tensor.io.initiate.valid := io.start
tensor.io.initiate.bits.wid := 0.U // FIXME bogus
tensor.io.initiate.bits.wid := 3.U // bogus, static value
tensor.io.writeback.ready := true.B
io.finished := tensor.io.writeback.valid && tensor.io.writeback.bits.last