tensor: Sequence set/steps in the execute-side
This commit is contained in:
@@ -97,9 +97,16 @@ class TensorCoreDecoupled(
|
||||
// steps: i-j iteration
|
||||
val numSteps = (tilingParams.m * tilingParams.n) / (tilingParams.mc * tilingParams.nc)
|
||||
val stepBits = log2Ceil(numSteps)
|
||||
val lastSet = ((1 << setBits) - 1)
|
||||
val lastStep = ((1 << stepBits) - 1)
|
||||
def setDone(set: UInt) = (set === lastSet.U)
|
||||
def stepDone(step: UInt) = (step === lastStep.U)
|
||||
|
||||
// set and step being currently accessed in the acc/ex frontend
|
||||
val setAccess = RegInit(0.U(setBits.W))
|
||||
val stepAccess = RegInit(0.U(stepBits.W))
|
||||
dontTouch(setAccess)
|
||||
dontTouch(stepAccess)
|
||||
|
||||
when(io.initiate.fire) {
|
||||
val wid = io.initiate.bits.wid
|
||||
@@ -118,6 +125,9 @@ class TensorCoreDecoupled(
|
||||
busy := false.B
|
||||
}
|
||||
|
||||
// serialize every HGMMA request
|
||||
io.initiate.ready := !busy
|
||||
|
||||
// Memory traffic generation
|
||||
// -------------------------
|
||||
//
|
||||
@@ -166,10 +176,10 @@ class TensorCoreDecoupled(
|
||||
req.fire
|
||||
})
|
||||
val firedAB = (firedABNow.asUInt | firedABReg.asUInt)
|
||||
val nextStep = firedAB.andR
|
||||
val nextStepAccess = firedAB.andR
|
||||
// clear out firedABReg every step. this will overwrite the previous fired
|
||||
// write upon the last fire out of A and B
|
||||
when (nextStep) {
|
||||
when (nextStepAccess) {
|
||||
firedABReg := Seq(false.B, false.B)
|
||||
}
|
||||
|
||||
@@ -180,6 +190,8 @@ class TensorCoreDecoupled(
|
||||
// set and step being currently executed in the acc/ex backend
|
||||
val setExecute = RegInit(0.U(setBits.W))
|
||||
val stepExecute = RegInit(0.U(stepBits.W))
|
||||
dontTouch(setExecute)
|
||||
dontTouch(stepExecute)
|
||||
|
||||
val respQueueDepth = 4 // FIXME: parameterize
|
||||
val respQueueA = Queue(respATagged, respQueueDepth)
|
||||
@@ -198,13 +210,19 @@ class TensorCoreDecoupled(
|
||||
"A and B response queue pointing to different set/steps. " ++
|
||||
"This might indicate memory response coming back out-of-order.")
|
||||
}
|
||||
// synchronized dequeue
|
||||
// dequeue is synchronized between A and B
|
||||
// FIXME: this need to change to dpu_ready
|
||||
val deqResp = bothQueueValid && io.writeback.ready
|
||||
respQueueA.ready := deqResp
|
||||
respQueueB.ready := deqResp
|
||||
// FIXME: this need to change to dpu_fire
|
||||
val nextStepExecute = io.writeback.fire
|
||||
|
||||
io.writeback.valid := bothQueueValid
|
||||
io.writeback.bits.wid := warpReg
|
||||
io.writeback.bits.last := setDone(setExecute) && stepDone(stepExecute)
|
||||
|
||||
// FIXME: debug dummy: pipe A directly to writeback
|
||||
io.writeback.valid := respQueueA.valid
|
||||
val groupedRespA = respQueueA.bits.data
|
||||
.asBools.grouped(wordSize * 8/*bits*/)
|
||||
.map(VecInit(_).asUInt)
|
||||
@@ -216,16 +234,17 @@ class TensorCoreDecoupled(
|
||||
// ----------------
|
||||
//
|
||||
// set/step sequencing logic
|
||||
val lastSet = ((1 << setBits) - 1)
|
||||
val lastStep = ((1 << stepBits) - 1)
|
||||
val setDone = (setAccess === lastSet.U)
|
||||
val stepDone = (stepAccess === lastStep.U)
|
||||
when (nextStep) {
|
||||
stepAccess := (stepAccess + 1.U) & lastStep.U
|
||||
when (stepDone) {
|
||||
setAccess := (setAccess + 1.U) & lastSet.U
|
||||
|
||||
def sequenceSetStep(set: UInt, step: UInt, nextStep: Bool) = {
|
||||
when (nextStep) {
|
||||
step := (step + 1.U) & lastStep.U
|
||||
when (stepDone(step)) {
|
||||
set := (set + 1.U) & lastSet.U
|
||||
}
|
||||
}
|
||||
}
|
||||
sequenceSetStep(setAccess, stepAccess, nextStepAccess)
|
||||
sequenceSetStep(setExecute, stepExecute, nextStepExecute)
|
||||
|
||||
switch(state) {
|
||||
is(TensorState.idle) {
|
||||
@@ -234,7 +253,7 @@ class TensorCoreDecoupled(
|
||||
}
|
||||
}
|
||||
is(TensorState.run) {
|
||||
when (setDone && stepDone && nextStep) {
|
||||
when (setDone(setAccess) && stepDone(stepAccess) && nextStepAccess) {
|
||||
when (state === TensorState.run) {
|
||||
state := TensorState.finish
|
||||
}
|
||||
@@ -247,11 +266,6 @@ class TensorCoreDecoupled(
|
||||
}
|
||||
}
|
||||
|
||||
io.initiate.ready := !busy
|
||||
io.writeback.valid := (state === TensorState.finish)
|
||||
io.writeback.bits.wid := warpReg
|
||||
io.writeback.bits.last := false.B // TODO
|
||||
|
||||
// Writeback queues
|
||||
// ----------------
|
||||
// These queues hold the metadata necessary for register
|
||||
@@ -328,7 +342,7 @@ class TensorCoreDecoupledTLImp(outer: TensorCoreDecoupledTL)
|
||||
tensor.io.initiate.bits.wid := 0.U // FIXME
|
||||
tensor.io.writeback.ready := true.B
|
||||
|
||||
io.finished := tensor.io.writeback.valid
|
||||
io.finished := tensor.io.writeback.valid && tensor.io.writeback.bits.last
|
||||
}
|
||||
|
||||
// a minimal Diplomacy graph with a tensor core and a TLRAM
|
||||
|
||||
Reference in New Issue
Block a user