tensor: Sequence set/steps in the execute-side

This commit is contained in:
Hansung Kim
2024-10-15 19:12:15 -07:00
parent efaf599fbe
commit e2abe1cffd

View File

@@ -97,9 +97,16 @@ class TensorCoreDecoupled(
// steps: i-j iteration
val numSteps = (tilingParams.m * tilingParams.n) / (tilingParams.mc * tilingParams.nc)
val stepBits = log2Ceil(numSteps)
val lastSet = ((1 << setBits) - 1)
val lastStep = ((1 << stepBits) - 1)
def setDone(set: UInt) = (set === lastSet.U)
def stepDone(step: UInt) = (step === lastStep.U)
// set and step being currently accessed in the acc/ex frontend
val setAccess = RegInit(0.U(setBits.W))
val stepAccess = RegInit(0.U(stepBits.W))
dontTouch(setAccess)
dontTouch(stepAccess)
when(io.initiate.fire) {
val wid = io.initiate.bits.wid
@@ -118,6 +125,9 @@ class TensorCoreDecoupled(
busy := false.B
}
// serialize every HGMMA request
io.initiate.ready := !busy
// Memory traffic generation
// -------------------------
//
@@ -166,10 +176,10 @@ class TensorCoreDecoupled(
req.fire
})
val firedAB = (firedABNow.asUInt | firedABReg.asUInt)
val nextStep = firedAB.andR
val nextStepAccess = firedAB.andR
// clear out firedABReg every step. this will overwrite the previous fired
// write upon the last fire out of A and B
when (nextStep) {
when (nextStepAccess) {
firedABReg := Seq(false.B, false.B)
}
@@ -180,6 +190,8 @@ class TensorCoreDecoupled(
// set and step being currently executed in the acc/ex backend
val setExecute = RegInit(0.U(setBits.W))
val stepExecute = RegInit(0.U(stepBits.W))
dontTouch(setExecute)
dontTouch(stepExecute)
val respQueueDepth = 4 // FIXME: parameterize
val respQueueA = Queue(respATagged, respQueueDepth)
@@ -198,13 +210,19 @@ class TensorCoreDecoupled(
"A and B response queue pointing to different set/steps. " ++
"This might indicate memory response coming back out-of-order.")
}
// synchronized dequeue
// dequeue is synchronized between A and B
// FIXME: this need to change to dpu_ready
val deqResp = bothQueueValid && io.writeback.ready
respQueueA.ready := deqResp
respQueueB.ready := deqResp
// FIXME: this need to change to dpu_fire
val nextStepExecute = io.writeback.fire
io.writeback.valid := bothQueueValid
io.writeback.bits.wid := warpReg
io.writeback.bits.last := setDone(setExecute) && stepDone(stepExecute)
// FIXME: debug dummy: pipe A directly to writeback
io.writeback.valid := respQueueA.valid
val groupedRespA = respQueueA.bits.data
.asBools.grouped(wordSize * 8/*bits*/)
.map(VecInit(_).asUInt)
@@ -216,16 +234,17 @@ class TensorCoreDecoupled(
// ----------------
//
// set/step sequencing logic
val lastSet = ((1 << setBits) - 1)
val lastStep = ((1 << stepBits) - 1)
val setDone = (setAccess === lastSet.U)
val stepDone = (stepAccess === lastStep.U)
when (nextStep) {
stepAccess := (stepAccess + 1.U) & lastStep.U
when (stepDone) {
setAccess := (setAccess + 1.U) & lastSet.U
def sequenceSetStep(set: UInt, step: UInt, nextStep: Bool) = {
when (nextStep) {
step := (step + 1.U) & lastStep.U
when (stepDone(step)) {
set := (set + 1.U) & lastSet.U
}
}
}
sequenceSetStep(setAccess, stepAccess, nextStepAccess)
sequenceSetStep(setExecute, stepExecute, nextStepExecute)
switch(state) {
is(TensorState.idle) {
@@ -234,7 +253,7 @@ class TensorCoreDecoupled(
}
}
is(TensorState.run) {
when (setDone && stepDone && nextStep) {
when (setDone(setAccess) && stepDone(stepAccess) && nextStepAccess) {
when (state === TensorState.run) {
state := TensorState.finish
}
@@ -247,11 +266,6 @@ class TensorCoreDecoupled(
}
}
io.initiate.ready := !busy
io.writeback.valid := (state === TensorState.finish)
io.writeback.bits.wid := warpReg
io.writeback.bits.last := false.B // TODO
// Writeback queues
// ----------------
// These queues hold the metadata necessary for register
@@ -328,7 +342,7 @@ class TensorCoreDecoupledTLImp(outer: TensorCoreDecoupledTL)
tensor.io.initiate.bits.wid := 0.U // FIXME
tensor.io.writeback.ready := true.B
io.finished := tensor.io.writeback.valid
io.finished := tensor.io.writeback.valid && tensor.io.writeback.bits.last
}
// a minimal Diplomacy graph with a tensor core and a TLRAM