tensor: Sequence set/steps in the execute-side
This commit is contained in:
@@ -97,9 +97,16 @@ class TensorCoreDecoupled(
|
|||||||
// steps: i-j iteration
|
// steps: i-j iteration
|
||||||
val numSteps = (tilingParams.m * tilingParams.n) / (tilingParams.mc * tilingParams.nc)
|
val numSteps = (tilingParams.m * tilingParams.n) / (tilingParams.mc * tilingParams.nc)
|
||||||
val stepBits = log2Ceil(numSteps)
|
val stepBits = log2Ceil(numSteps)
|
||||||
|
val lastSet = ((1 << setBits) - 1)
|
||||||
|
val lastStep = ((1 << stepBits) - 1)
|
||||||
|
def setDone(set: UInt) = (set === lastSet.U)
|
||||||
|
def stepDone(step: UInt) = (step === lastStep.U)
|
||||||
|
|
||||||
// set and step being currently accessed in the acc/ex frontend
|
// set and step being currently accessed in the acc/ex frontend
|
||||||
val setAccess = RegInit(0.U(setBits.W))
|
val setAccess = RegInit(0.U(setBits.W))
|
||||||
val stepAccess = RegInit(0.U(stepBits.W))
|
val stepAccess = RegInit(0.U(stepBits.W))
|
||||||
|
dontTouch(setAccess)
|
||||||
|
dontTouch(stepAccess)
|
||||||
|
|
||||||
when(io.initiate.fire) {
|
when(io.initiate.fire) {
|
||||||
val wid = io.initiate.bits.wid
|
val wid = io.initiate.bits.wid
|
||||||
@@ -118,6 +125,9 @@ class TensorCoreDecoupled(
|
|||||||
busy := false.B
|
busy := false.B
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// serialize every HGMMA request
|
||||||
|
io.initiate.ready := !busy
|
||||||
|
|
||||||
// Memory traffic generation
|
// Memory traffic generation
|
||||||
// -------------------------
|
// -------------------------
|
||||||
//
|
//
|
||||||
@@ -166,10 +176,10 @@ class TensorCoreDecoupled(
|
|||||||
req.fire
|
req.fire
|
||||||
})
|
})
|
||||||
val firedAB = (firedABNow.asUInt | firedABReg.asUInt)
|
val firedAB = (firedABNow.asUInt | firedABReg.asUInt)
|
||||||
val nextStep = firedAB.andR
|
val nextStepAccess = firedAB.andR
|
||||||
// clear out firedABReg every step. this will overwrite the previous fired
|
// clear out firedABReg every step. this will overwrite the previous fired
|
||||||
// write upon the last fire out of A and B
|
// write upon the last fire out of A and B
|
||||||
when (nextStep) {
|
when (nextStepAccess) {
|
||||||
firedABReg := Seq(false.B, false.B)
|
firedABReg := Seq(false.B, false.B)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -180,6 +190,8 @@ class TensorCoreDecoupled(
|
|||||||
// set and step being currently executed in the acc/ex backend
|
// set and step being currently executed in the acc/ex backend
|
||||||
val setExecute = RegInit(0.U(setBits.W))
|
val setExecute = RegInit(0.U(setBits.W))
|
||||||
val stepExecute = RegInit(0.U(stepBits.W))
|
val stepExecute = RegInit(0.U(stepBits.W))
|
||||||
|
dontTouch(setExecute)
|
||||||
|
dontTouch(stepExecute)
|
||||||
|
|
||||||
val respQueueDepth = 4 // FIXME: parameterize
|
val respQueueDepth = 4 // FIXME: parameterize
|
||||||
val respQueueA = Queue(respATagged, respQueueDepth)
|
val respQueueA = Queue(respATagged, respQueueDepth)
|
||||||
@@ -198,13 +210,19 @@ class TensorCoreDecoupled(
|
|||||||
"A and B response queue pointing to different set/steps. " ++
|
"A and B response queue pointing to different set/steps. " ++
|
||||||
"This might indicate memory response coming back out-of-order.")
|
"This might indicate memory response coming back out-of-order.")
|
||||||
}
|
}
|
||||||
// synchronized dequeue
|
// dequeue is synchronized between A and B
|
||||||
|
// FIXME: this need to change to dpu_ready
|
||||||
val deqResp = bothQueueValid && io.writeback.ready
|
val deqResp = bothQueueValid && io.writeback.ready
|
||||||
respQueueA.ready := deqResp
|
respQueueA.ready := deqResp
|
||||||
respQueueB.ready := deqResp
|
respQueueB.ready := deqResp
|
||||||
|
// FIXME: this need to change to dpu_fire
|
||||||
|
val nextStepExecute = io.writeback.fire
|
||||||
|
|
||||||
|
io.writeback.valid := bothQueueValid
|
||||||
|
io.writeback.bits.wid := warpReg
|
||||||
|
io.writeback.bits.last := setDone(setExecute) && stepDone(stepExecute)
|
||||||
|
|
||||||
// FIXME: debug dummy: pipe A directly to writeback
|
// FIXME: debug dummy: pipe A directly to writeback
|
||||||
io.writeback.valid := respQueueA.valid
|
|
||||||
val groupedRespA = respQueueA.bits.data
|
val groupedRespA = respQueueA.bits.data
|
||||||
.asBools.grouped(wordSize * 8/*bits*/)
|
.asBools.grouped(wordSize * 8/*bits*/)
|
||||||
.map(VecInit(_).asUInt)
|
.map(VecInit(_).asUInt)
|
||||||
@@ -216,16 +234,17 @@ class TensorCoreDecoupled(
|
|||||||
// ----------------
|
// ----------------
|
||||||
//
|
//
|
||||||
// set/step sequencing logic
|
// set/step sequencing logic
|
||||||
val lastSet = ((1 << setBits) - 1)
|
|
||||||
val lastStep = ((1 << stepBits) - 1)
|
def sequenceSetStep(set: UInt, step: UInt, nextStep: Bool) = {
|
||||||
val setDone = (setAccess === lastSet.U)
|
when (nextStep) {
|
||||||
val stepDone = (stepAccess === lastStep.U)
|
step := (step + 1.U) & lastStep.U
|
||||||
when (nextStep) {
|
when (stepDone(step)) {
|
||||||
stepAccess := (stepAccess + 1.U) & lastStep.U
|
set := (set + 1.U) & lastSet.U
|
||||||
when (stepDone) {
|
}
|
||||||
setAccess := (setAccess + 1.U) & lastSet.U
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
sequenceSetStep(setAccess, stepAccess, nextStepAccess)
|
||||||
|
sequenceSetStep(setExecute, stepExecute, nextStepExecute)
|
||||||
|
|
||||||
switch(state) {
|
switch(state) {
|
||||||
is(TensorState.idle) {
|
is(TensorState.idle) {
|
||||||
@@ -234,7 +253,7 @@ class TensorCoreDecoupled(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
is(TensorState.run) {
|
is(TensorState.run) {
|
||||||
when (setDone && stepDone && nextStep) {
|
when (setDone(setAccess) && stepDone(stepAccess) && nextStepAccess) {
|
||||||
when (state === TensorState.run) {
|
when (state === TensorState.run) {
|
||||||
state := TensorState.finish
|
state := TensorState.finish
|
||||||
}
|
}
|
||||||
@@ -247,11 +266,6 @@ class TensorCoreDecoupled(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
io.initiate.ready := !busy
|
|
||||||
io.writeback.valid := (state === TensorState.finish)
|
|
||||||
io.writeback.bits.wid := warpReg
|
|
||||||
io.writeback.bits.last := false.B // TODO
|
|
||||||
|
|
||||||
// Writeback queues
|
// Writeback queues
|
||||||
// ----------------
|
// ----------------
|
||||||
// These queues hold the metadata necessary for register
|
// These queues hold the metadata necessary for register
|
||||||
@@ -328,7 +342,7 @@ class TensorCoreDecoupledTLImp(outer: TensorCoreDecoupledTL)
|
|||||||
tensor.io.initiate.bits.wid := 0.U // FIXME
|
tensor.io.initiate.bits.wid := 0.U // FIXME
|
||||||
tensor.io.writeback.ready := true.B
|
tensor.io.writeback.ready := true.B
|
||||||
|
|
||||||
io.finished := tensor.io.writeback.valid
|
io.finished := tensor.io.writeback.valid && tensor.io.writeback.bits.last
|
||||||
}
|
}
|
||||||
|
|
||||||
// a minimal Diplomacy graph with a tensor core and a TLRAM
|
// a minimal Diplomacy graph with a tensor core and a TLRAM
|
||||||
|
|||||||
Reference in New Issue
Block a user