tensor: Sequence set/steps in the execute-side

This commit is contained in:
Hansung Kim
2024-10-15 19:12:15 -07:00
parent efaf599fbe
commit e2abe1cffd

View File

@@ -97,9 +97,16 @@ class TensorCoreDecoupled(
// steps: i-j iteration // steps: i-j iteration
val numSteps = (tilingParams.m * tilingParams.n) / (tilingParams.mc * tilingParams.nc) val numSteps = (tilingParams.m * tilingParams.n) / (tilingParams.mc * tilingParams.nc)
val stepBits = log2Ceil(numSteps) val stepBits = log2Ceil(numSteps)
val lastSet = ((1 << setBits) - 1)
val lastStep = ((1 << stepBits) - 1)
def setDone(set: UInt) = (set === lastSet.U)
def stepDone(step: UInt) = (step === lastStep.U)
// set and step being currently accessed in the acc/ex frontend // set and step being currently accessed in the acc/ex frontend
val setAccess = RegInit(0.U(setBits.W)) val setAccess = RegInit(0.U(setBits.W))
val stepAccess = RegInit(0.U(stepBits.W)) val stepAccess = RegInit(0.U(stepBits.W))
dontTouch(setAccess)
dontTouch(stepAccess)
when(io.initiate.fire) { when(io.initiate.fire) {
val wid = io.initiate.bits.wid val wid = io.initiate.bits.wid
@@ -118,6 +125,9 @@ class TensorCoreDecoupled(
busy := false.B busy := false.B
} }
// serialize every HGMMA request
io.initiate.ready := !busy
// Memory traffic generation // Memory traffic generation
// ------------------------- // -------------------------
// //
@@ -166,10 +176,10 @@ class TensorCoreDecoupled(
req.fire req.fire
}) })
val firedAB = (firedABNow.asUInt | firedABReg.asUInt) val firedAB = (firedABNow.asUInt | firedABReg.asUInt)
val nextStep = firedAB.andR val nextStepAccess = firedAB.andR
// clear out firedABReg every step. this will overwrite the previous fired // clear out firedABReg every step. this will overwrite the previous fired
// write upon the last fire out of A and B // write upon the last fire out of A and B
when (nextStep) { when (nextStepAccess) {
firedABReg := Seq(false.B, false.B) firedABReg := Seq(false.B, false.B)
} }
@@ -180,6 +190,8 @@ class TensorCoreDecoupled(
// set and step being currently executed in the acc/ex backend // set and step being currently executed in the acc/ex backend
val setExecute = RegInit(0.U(setBits.W)) val setExecute = RegInit(0.U(setBits.W))
val stepExecute = RegInit(0.U(stepBits.W)) val stepExecute = RegInit(0.U(stepBits.W))
dontTouch(setExecute)
dontTouch(stepExecute)
val respQueueDepth = 4 // FIXME: parameterize val respQueueDepth = 4 // FIXME: parameterize
val respQueueA = Queue(respATagged, respQueueDepth) val respQueueA = Queue(respATagged, respQueueDepth)
@@ -198,13 +210,19 @@ class TensorCoreDecoupled(
"A and B response queue pointing to different set/steps. " ++ "A and B response queue pointing to different set/steps. " ++
"This might indicate memory response coming back out-of-order.") "This might indicate memory response coming back out-of-order.")
} }
// synchronized dequeue // dequeue is synchronized between A and B
// FIXME: this need to change to dpu_ready
val deqResp = bothQueueValid && io.writeback.ready val deqResp = bothQueueValid && io.writeback.ready
respQueueA.ready := deqResp respQueueA.ready := deqResp
respQueueB.ready := deqResp respQueueB.ready := deqResp
// FIXME: this need to change to dpu_fire
val nextStepExecute = io.writeback.fire
io.writeback.valid := bothQueueValid
io.writeback.bits.wid := warpReg
io.writeback.bits.last := setDone(setExecute) && stepDone(stepExecute)
// FIXME: debug dummy: pipe A directly to writeback // FIXME: debug dummy: pipe A directly to writeback
io.writeback.valid := respQueueA.valid
val groupedRespA = respQueueA.bits.data val groupedRespA = respQueueA.bits.data
.asBools.grouped(wordSize * 8/*bits*/) .asBools.grouped(wordSize * 8/*bits*/)
.map(VecInit(_).asUInt) .map(VecInit(_).asUInt)
@@ -216,16 +234,17 @@ class TensorCoreDecoupled(
// ---------------- // ----------------
// //
// set/step sequencing logic // set/step sequencing logic
val lastSet = ((1 << setBits) - 1)
val lastStep = ((1 << stepBits) - 1) def sequenceSetStep(set: UInt, step: UInt, nextStep: Bool) = {
val setDone = (setAccess === lastSet.U) when (nextStep) {
val stepDone = (stepAccess === lastStep.U) step := (step + 1.U) & lastStep.U
when (nextStep) { when (stepDone(step)) {
stepAccess := (stepAccess + 1.U) & lastStep.U set := (set + 1.U) & lastSet.U
when (stepDone) { }
setAccess := (setAccess + 1.U) & lastSet.U
} }
} }
sequenceSetStep(setAccess, stepAccess, nextStepAccess)
sequenceSetStep(setExecute, stepExecute, nextStepExecute)
switch(state) { switch(state) {
is(TensorState.idle) { is(TensorState.idle) {
@@ -234,7 +253,7 @@ class TensorCoreDecoupled(
} }
} }
is(TensorState.run) { is(TensorState.run) {
when (setDone && stepDone && nextStep) { when (setDone(setAccess) && stepDone(stepAccess) && nextStepAccess) {
when (state === TensorState.run) { when (state === TensorState.run) {
state := TensorState.finish state := TensorState.finish
} }
@@ -247,11 +266,6 @@ class TensorCoreDecoupled(
} }
} }
io.initiate.ready := !busy
io.writeback.valid := (state === TensorState.finish)
io.writeback.bits.wid := warpReg
io.writeback.bits.last := false.B // TODO
// Writeback queues // Writeback queues
// ---------------- // ----------------
// These queues hold the metadata necessary for register // These queues hold the metadata necessary for register
@@ -328,7 +342,7 @@ class TensorCoreDecoupledTLImp(outer: TensorCoreDecoupledTL)
tensor.io.initiate.bits.wid := 0.U // FIXME tensor.io.initiate.bits.wid := 0.U // FIXME
tensor.io.writeback.ready := true.B tensor.io.writeback.ready := true.B
io.finished := tensor.io.writeback.valid io.finished := tensor.io.writeback.valid && tensor.io.writeback.bits.last
} }
// a minimal Diplomacy graph with a tensor core and a TLRAM // a minimal Diplomacy graph with a tensor core and a TLRAM