tensor: Fix metadata of C req; fix dequeue / req gen timing
This commit is contained in:
@@ -252,15 +252,15 @@ class TensorCoreDecoupled(
|
||||
// C access from regfile
|
||||
//
|
||||
|
||||
// since regfile is fixed-latency, respC valid should be determined at the
|
||||
// request sending side.
|
||||
val respCValid = RegInit(false.B)
|
||||
|
||||
// regfile latency is 1 cycle; don't need a deep response queue
|
||||
// regfile is fast; don't need a deep response queue
|
||||
val respQueueCDepth = 2
|
||||
val respQueueC = Module(new Queue(
|
||||
new Bundle {
|
||||
val tag = new TensorMemTag
|
||||
val tag = new Bundle {
|
||||
val warp = UInt(numWarpBits.W)
|
||||
val set = UInt(setBits.W)
|
||||
val step = UInt(stepBits.W)
|
||||
}
|
||||
val data = UInt(io.respC.widthOption.get.W)
|
||||
},
|
||||
respQueueCDepth
|
||||
@@ -276,8 +276,25 @@ class TensorCoreDecoupled(
|
||||
(respQueueC.entries - 1).U,
|
||||
(respQueueC.entries - 2).U)
|
||||
val genReqC = (state === AccessorState.access) && hasSpace && !doneReqC
|
||||
// 1-cycle delay
|
||||
respCValid := genReqC
|
||||
|
||||
// set/step state of the C accumulator value that will be latched ath the
|
||||
// next cycle.
|
||||
val setAccessC = RegInit(0.U(setBits.W))
|
||||
val stepAccessC = RegInit(0.U(stepBits.W))
|
||||
val substepAccessC = RegInit(0.U(1.W))
|
||||
val nextStepAccessC = genReqC && (substepAccessC === 1.U)
|
||||
when (genReqC) {
|
||||
substepAccessC := substepAccessC + 1.U
|
||||
}
|
||||
dontTouch(stepAccessC)
|
||||
dontTouch(substepAccessC)
|
||||
dontTouch(nextStepAccessC)
|
||||
|
||||
// give 1-cycle delay to sync valid/metadata with C regfile response
|
||||
val respCValid = RegNext(genReqC)
|
||||
val warpAccessCDelayed = RegNext(warpAccess)
|
||||
val setAccessCDelayed = RegNext(setAccessC)
|
||||
val stepAccessCDelayed = RegNext(stepAccessC)
|
||||
|
||||
// note rd is independent to sets
|
||||
def rdGen(step: UInt, substep: UInt): UInt = {
|
||||
@@ -286,25 +303,14 @@ class TensorCoreDecoupled(
|
||||
(step << 1/*2 substeps*/) + substep
|
||||
}
|
||||
io.reqC.valid := genReqC
|
||||
io.reqC.bits := 5.U // FIXME
|
||||
|
||||
// set/index state of the C accumulator value that will be latched ath the
|
||||
// next cycle.
|
||||
val stateC = RegInit(stateInit)
|
||||
when (genReqC) {
|
||||
when (stateC.index === lastIndex.U) {
|
||||
stateC.set := stateC.set + 1.U
|
||||
}
|
||||
stateC.index := stateC.index + 1.U
|
||||
}
|
||||
dontTouch(stateC)
|
||||
io.reqC.bits := rdGen(stepAccessC, substepAccessC)
|
||||
|
||||
// queue the regfile response to buffers
|
||||
// these strictly belong to the execute stage
|
||||
respQueueC.io.enq.valid := respCValid
|
||||
respQueueC.io.enq.bits.tag.warp := warpAccess
|
||||
respQueueC.io.enq.bits.tag.set := stateC.set
|
||||
respQueueC.io.enq.bits.tag.index := stateC.index
|
||||
respQueueC.io.enq.bits.tag.warp := warpAccessCDelayed
|
||||
respQueueC.io.enq.bits.tag.set := setAccessCDelayed
|
||||
respQueueC.io.enq.bits.tag.step := stepAccessCDelayed
|
||||
respQueueC.io.enq.bits.data := io.respC
|
||||
|
||||
// serialize every two C responses into one full 4x4 C tile
|
||||
@@ -315,7 +321,8 @@ class TensorCoreDecoupled(
|
||||
fullC.io.enq.bits := respQueueC.io.deq.bits.data
|
||||
respQueueC.io.deq.ready := fullC.io.enq.ready
|
||||
val fullCTag = Module(new Queue(
|
||||
new TensorMemTag, entries = 1, pipe = true
|
||||
chiselTypeOf(respQueueC.io.deq.bits.tag),
|
||||
entries = 1, pipe = true
|
||||
))
|
||||
fullCTag.io.enq.valid := respQueueC.io.deq.valid
|
||||
fullCTag.io.enq.bits := respQueueC.io.deq.bits.tag
|
||||
@@ -323,10 +330,10 @@ class TensorCoreDecoupled(
|
||||
// finalize state when everything has been accessed
|
||||
val lastReqA = (stateA.set === lastSet.U) && (stateA.index === lastIndex.U)
|
||||
val lastReqB = (stateB.set === lastSet.U) && (stateB.index === lastIndex.U)
|
||||
val lastReqC = (stateC.set === lastSet.U) && (stateC.index === lastIndex.U)
|
||||
val lastReqC = (setAccessC === lastSet.U) && (stepAccessC === lastStep.U)
|
||||
when (lastReqA && io.reqA.fire) { doneReqA := true.B }
|
||||
when (lastReqB && io.reqB.fire) { doneReqB := true.B }
|
||||
when (lastReqC && genReqC) { doneReqC := true.B }
|
||||
when (lastReqC && nextStepAccessC) { doneReqC := true.B }
|
||||
when (state === AccessorState.finish) {
|
||||
doneReqA := false.B
|
||||
doneReqB := false.B
|
||||
@@ -335,8 +342,9 @@ class TensorCoreDecoupled(
|
||||
stateA.index := 0.U
|
||||
stateB.set := 0.U
|
||||
stateB.index := 0.U
|
||||
stateC.set := 0.U
|
||||
stateC.index := 0.U
|
||||
setAccessC := 0.U
|
||||
stepAccessC := 0.U
|
||||
substepAccessC := 0.U
|
||||
}
|
||||
|
||||
allReqsDone := doneReqA && doneReqB && doneReqC
|
||||
@@ -432,7 +440,7 @@ class TensorCoreDecoupled(
|
||||
|
||||
val fullCBuf = Module(new Queue(
|
||||
new Bundle {
|
||||
val tag = new TensorMemTag
|
||||
val tag = chiselTypeOf(fullCTag.io.deq.bits)
|
||||
val data = chiselTypeOf(fullC.io.deq.bits)
|
||||
}, entries = 1, pipe = true
|
||||
))
|
||||
@@ -496,9 +504,9 @@ class TensorCoreDecoupled(
|
||||
((stepCompute & shouldDequeueBMask) === shouldDequeueBMask) &&
|
||||
(substepCompute === 1.U)
|
||||
fullBBuf.io.deq.ready := dpuFire && shouldDequeueB
|
||||
|
||||
// C deq should be synced with B deq
|
||||
fullCBuf.io.deq.ready := dpuFire && shouldDequeueB
|
||||
// C should be dequeued everytime a 4x4 tile is done computing
|
||||
val shouldDequeueC = (substepCompute === 1.U)
|
||||
fullCBuf.io.deq.ready := dpuFire && shouldDequeueC
|
||||
|
||||
dontTouch(respQueueA)
|
||||
dontTouch(respQueueB)
|
||||
@@ -639,6 +647,7 @@ class TensorCoreDecoupled(
|
||||
}
|
||||
}
|
||||
}
|
||||
sequenceSetStep(setAccessC, stepAccessC, nextStepAccessC)
|
||||
sequenceSetStep(setCompute, stepCompute, nextStepCompute)
|
||||
}
|
||||
|
||||
@@ -739,7 +748,7 @@ class TensorCoreDecoupledTLImp(outer: TensorCoreDecoupledTL)
|
||||
tensor.io.respC := 42.U // FIXME bogus
|
||||
|
||||
tensor.io.initiate.valid := io.start
|
||||
tensor.io.initiate.bits.wid := 0.U // FIXME bogus
|
||||
tensor.io.initiate.bits.wid := 3.U // bogus, static value
|
||||
tensor.io.writeback.ready := true.B
|
||||
|
||||
io.finished := tensor.io.writeback.valid && tensor.io.writeback.bits.last
|
||||
|
||||
Reference in New Issue
Block a user