tensor: Fix metadata of C req; fix dequeue / req gen timing
This commit is contained in:
@@ -252,15 +252,15 @@ class TensorCoreDecoupled(
|
|||||||
// C access from regfile
|
// C access from regfile
|
||||||
//
|
//
|
||||||
|
|
||||||
// since regfile is fixed-latency, respC valid should be determined at the
|
// regfile is fast; don't need a deep response queue
|
||||||
// request sending side.
|
|
||||||
val respCValid = RegInit(false.B)
|
|
||||||
|
|
||||||
// regfile latency is 1 cycle; don't need a deep response queue
|
|
||||||
val respQueueCDepth = 2
|
val respQueueCDepth = 2
|
||||||
val respQueueC = Module(new Queue(
|
val respQueueC = Module(new Queue(
|
||||||
new Bundle {
|
new Bundle {
|
||||||
val tag = new TensorMemTag
|
val tag = new Bundle {
|
||||||
|
val warp = UInt(numWarpBits.W)
|
||||||
|
val set = UInt(setBits.W)
|
||||||
|
val step = UInt(stepBits.W)
|
||||||
|
}
|
||||||
val data = UInt(io.respC.widthOption.get.W)
|
val data = UInt(io.respC.widthOption.get.W)
|
||||||
},
|
},
|
||||||
respQueueCDepth
|
respQueueCDepth
|
||||||
@@ -276,8 +276,25 @@ class TensorCoreDecoupled(
|
|||||||
(respQueueC.entries - 1).U,
|
(respQueueC.entries - 1).U,
|
||||||
(respQueueC.entries - 2).U)
|
(respQueueC.entries - 2).U)
|
||||||
val genReqC = (state === AccessorState.access) && hasSpace && !doneReqC
|
val genReqC = (state === AccessorState.access) && hasSpace && !doneReqC
|
||||||
// 1-cycle delay
|
|
||||||
respCValid := genReqC
|
// set/step state of the C accumulator value that will be latched ath the
|
||||||
|
// next cycle.
|
||||||
|
val setAccessC = RegInit(0.U(setBits.W))
|
||||||
|
val stepAccessC = RegInit(0.U(stepBits.W))
|
||||||
|
val substepAccessC = RegInit(0.U(1.W))
|
||||||
|
val nextStepAccessC = genReqC && (substepAccessC === 1.U)
|
||||||
|
when (genReqC) {
|
||||||
|
substepAccessC := substepAccessC + 1.U
|
||||||
|
}
|
||||||
|
dontTouch(stepAccessC)
|
||||||
|
dontTouch(substepAccessC)
|
||||||
|
dontTouch(nextStepAccessC)
|
||||||
|
|
||||||
|
// give 1-cycle delay to sync valid/metadata with C regfile response
|
||||||
|
val respCValid = RegNext(genReqC)
|
||||||
|
val warpAccessCDelayed = RegNext(warpAccess)
|
||||||
|
val setAccessCDelayed = RegNext(setAccessC)
|
||||||
|
val stepAccessCDelayed = RegNext(stepAccessC)
|
||||||
|
|
||||||
// note rd is independent to sets
|
// note rd is independent to sets
|
||||||
def rdGen(step: UInt, substep: UInt): UInt = {
|
def rdGen(step: UInt, substep: UInt): UInt = {
|
||||||
@@ -286,25 +303,14 @@ class TensorCoreDecoupled(
|
|||||||
(step << 1/*2 substeps*/) + substep
|
(step << 1/*2 substeps*/) + substep
|
||||||
}
|
}
|
||||||
io.reqC.valid := genReqC
|
io.reqC.valid := genReqC
|
||||||
io.reqC.bits := 5.U // FIXME
|
io.reqC.bits := rdGen(stepAccessC, substepAccessC)
|
||||||
|
|
||||||
// set/index state of the C accumulator value that will be latched ath the
|
|
||||||
// next cycle.
|
|
||||||
val stateC = RegInit(stateInit)
|
|
||||||
when (genReqC) {
|
|
||||||
when (stateC.index === lastIndex.U) {
|
|
||||||
stateC.set := stateC.set + 1.U
|
|
||||||
}
|
|
||||||
stateC.index := stateC.index + 1.U
|
|
||||||
}
|
|
||||||
dontTouch(stateC)
|
|
||||||
|
|
||||||
// queue the regfile response to buffers
|
// queue the regfile response to buffers
|
||||||
// these strictly belong to the execute stage
|
// these strictly belong to the execute stage
|
||||||
respQueueC.io.enq.valid := respCValid
|
respQueueC.io.enq.valid := respCValid
|
||||||
respQueueC.io.enq.bits.tag.warp := warpAccess
|
respQueueC.io.enq.bits.tag.warp := warpAccessCDelayed
|
||||||
respQueueC.io.enq.bits.tag.set := stateC.set
|
respQueueC.io.enq.bits.tag.set := setAccessCDelayed
|
||||||
respQueueC.io.enq.bits.tag.index := stateC.index
|
respQueueC.io.enq.bits.tag.step := stepAccessCDelayed
|
||||||
respQueueC.io.enq.bits.data := io.respC
|
respQueueC.io.enq.bits.data := io.respC
|
||||||
|
|
||||||
// serialize every two C responses into one full 4x4 C tile
|
// serialize every two C responses into one full 4x4 C tile
|
||||||
@@ -315,7 +321,8 @@ class TensorCoreDecoupled(
|
|||||||
fullC.io.enq.bits := respQueueC.io.deq.bits.data
|
fullC.io.enq.bits := respQueueC.io.deq.bits.data
|
||||||
respQueueC.io.deq.ready := fullC.io.enq.ready
|
respQueueC.io.deq.ready := fullC.io.enq.ready
|
||||||
val fullCTag = Module(new Queue(
|
val fullCTag = Module(new Queue(
|
||||||
new TensorMemTag, entries = 1, pipe = true
|
chiselTypeOf(respQueueC.io.deq.bits.tag),
|
||||||
|
entries = 1, pipe = true
|
||||||
))
|
))
|
||||||
fullCTag.io.enq.valid := respQueueC.io.deq.valid
|
fullCTag.io.enq.valid := respQueueC.io.deq.valid
|
||||||
fullCTag.io.enq.bits := respQueueC.io.deq.bits.tag
|
fullCTag.io.enq.bits := respQueueC.io.deq.bits.tag
|
||||||
@@ -323,10 +330,10 @@ class TensorCoreDecoupled(
|
|||||||
// finalize state when everything has been accessed
|
// finalize state when everything has been accessed
|
||||||
val lastReqA = (stateA.set === lastSet.U) && (stateA.index === lastIndex.U)
|
val lastReqA = (stateA.set === lastSet.U) && (stateA.index === lastIndex.U)
|
||||||
val lastReqB = (stateB.set === lastSet.U) && (stateB.index === lastIndex.U)
|
val lastReqB = (stateB.set === lastSet.U) && (stateB.index === lastIndex.U)
|
||||||
val lastReqC = (stateC.set === lastSet.U) && (stateC.index === lastIndex.U)
|
val lastReqC = (setAccessC === lastSet.U) && (stepAccessC === lastStep.U)
|
||||||
when (lastReqA && io.reqA.fire) { doneReqA := true.B }
|
when (lastReqA && io.reqA.fire) { doneReqA := true.B }
|
||||||
when (lastReqB && io.reqB.fire) { doneReqB := true.B }
|
when (lastReqB && io.reqB.fire) { doneReqB := true.B }
|
||||||
when (lastReqC && genReqC) { doneReqC := true.B }
|
when (lastReqC && nextStepAccessC) { doneReqC := true.B }
|
||||||
when (state === AccessorState.finish) {
|
when (state === AccessorState.finish) {
|
||||||
doneReqA := false.B
|
doneReqA := false.B
|
||||||
doneReqB := false.B
|
doneReqB := false.B
|
||||||
@@ -335,8 +342,9 @@ class TensorCoreDecoupled(
|
|||||||
stateA.index := 0.U
|
stateA.index := 0.U
|
||||||
stateB.set := 0.U
|
stateB.set := 0.U
|
||||||
stateB.index := 0.U
|
stateB.index := 0.U
|
||||||
stateC.set := 0.U
|
setAccessC := 0.U
|
||||||
stateC.index := 0.U
|
stepAccessC := 0.U
|
||||||
|
substepAccessC := 0.U
|
||||||
}
|
}
|
||||||
|
|
||||||
allReqsDone := doneReqA && doneReqB && doneReqC
|
allReqsDone := doneReqA && doneReqB && doneReqC
|
||||||
@@ -432,7 +440,7 @@ class TensorCoreDecoupled(
|
|||||||
|
|
||||||
val fullCBuf = Module(new Queue(
|
val fullCBuf = Module(new Queue(
|
||||||
new Bundle {
|
new Bundle {
|
||||||
val tag = new TensorMemTag
|
val tag = chiselTypeOf(fullCTag.io.deq.bits)
|
||||||
val data = chiselTypeOf(fullC.io.deq.bits)
|
val data = chiselTypeOf(fullC.io.deq.bits)
|
||||||
}, entries = 1, pipe = true
|
}, entries = 1, pipe = true
|
||||||
))
|
))
|
||||||
@@ -496,9 +504,9 @@ class TensorCoreDecoupled(
|
|||||||
((stepCompute & shouldDequeueBMask) === shouldDequeueBMask) &&
|
((stepCompute & shouldDequeueBMask) === shouldDequeueBMask) &&
|
||||||
(substepCompute === 1.U)
|
(substepCompute === 1.U)
|
||||||
fullBBuf.io.deq.ready := dpuFire && shouldDequeueB
|
fullBBuf.io.deq.ready := dpuFire && shouldDequeueB
|
||||||
|
// C should be dequeued everytime a 4x4 tile is done computing
|
||||||
// C deq should be synced with B deq
|
val shouldDequeueC = (substepCompute === 1.U)
|
||||||
fullCBuf.io.deq.ready := dpuFire && shouldDequeueB
|
fullCBuf.io.deq.ready := dpuFire && shouldDequeueC
|
||||||
|
|
||||||
dontTouch(respQueueA)
|
dontTouch(respQueueA)
|
||||||
dontTouch(respQueueB)
|
dontTouch(respQueueB)
|
||||||
@@ -639,6 +647,7 @@ class TensorCoreDecoupled(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
sequenceSetStep(setAccessC, stepAccessC, nextStepAccessC)
|
||||||
sequenceSetStep(setCompute, stepCompute, nextStepCompute)
|
sequenceSetStep(setCompute, stepCompute, nextStepCompute)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -739,7 +748,7 @@ class TensorCoreDecoupledTLImp(outer: TensorCoreDecoupledTL)
|
|||||||
tensor.io.respC := 42.U // FIXME bogus
|
tensor.io.respC := 42.U // FIXME bogus
|
||||||
|
|
||||||
tensor.io.initiate.valid := io.start
|
tensor.io.initiate.valid := io.start
|
||||||
tensor.io.initiate.bits.wid := 0.U // FIXME bogus
|
tensor.io.initiate.bits.wid := 3.U // bogus, static value
|
||||||
tensor.io.writeback.ready := true.B
|
tensor.io.writeback.ready := true.B
|
||||||
|
|
||||||
io.finished := tensor.io.writeback.valid && tensor.io.writeback.bits.last
|
io.finished := tensor.io.writeback.valid && tensor.io.writeback.bits.last
|
||||||
|
|||||||
Reference in New Issue
Block a user