tensor: Keep set/step in the tag writeback queue

This commit is contained in:
Hansung Kim
2024-10-17 15:43:44 -07:00
parent 7de8e86d4f
commit 2741af0b2b

View File

@@ -267,6 +267,7 @@ class TensorCoreDecoupled(
val operandsValid = fullAQueue.io.deq.valid && respQueueB.valid val operandsValid = fullAQueue.io.deq.valid && respQueueB.valid
val operandA = fullAQueue.io.deq.bits.data val operandA = fullAQueue.io.deq.bits.data
val operandATag = fullAQueue.io.deq.bits.tag
val operandB = respQueueB.bits.data val operandB = respQueueB.bits.data
val dpuReady = Wire(Bool()) val dpuReady = Wire(Bool())
val dpuFire = operandsValid && dpuReady val dpuFire = operandsValid && dpuReady
@@ -314,8 +315,6 @@ class TensorCoreDecoupled(
val operandADimensional = val operandADimensional =
operandA.asBools.grouped(wordSizeInBits).map(VecInit(_).asUInt).toSeq operandA.asBools.grouped(wordSizeInBits).map(VecInit(_).asUInt).toSeq
.grouped(4).toSeq .grouped(4).toSeq
println(s"operandA: ${fullAQueue.io.deq.bits.data.widthOption.get} bits")
println(s"A: ${operandADimensional.length}, ${operandADimensional(0).length}")
assert(operandADimensional.length == tilingParams.mc && assert(operandADimensional.length == tilingParams.mc &&
operandADimensional(0).length == tilingParams.kc, operandADimensional(0).length == tilingParams.kc,
"operand width doesn't agree with tiling parameter") "operand width doesn't agree with tiling parameter")
@@ -323,7 +322,6 @@ class TensorCoreDecoupled(
val operandBDimensional = val operandBDimensional =
operandB.asBools.grouped(wordSizeInBits).map(VecInit(_).asUInt).toSeq operandB.asBools.grouped(wordSizeInBits).map(VecInit(_).asUInt).toSeq
.grouped(4).toSeq .grouped(4).toSeq
println(s"B: ${operandBDimensional.length}, ${operandBDimensional(0).length}")
val ncSubstep = tilingParams.nc / 2 val ncSubstep = tilingParams.nc / 2
assert(tilingParams.mc * ncSubstep == numLanes, assert(tilingParams.mc * ncSubstep == numLanes,
"substep tile size doesn't match writeback throughput") "substep tile size doesn't match writeback throughput")
@@ -369,18 +367,20 @@ class TensorCoreDecoupled(
// These queues hold metadata needed for writeback in sync with the DPU. // These queues hold metadata needed for writeback in sync with the DPU.
val queueDepth = 4 // needs to be at least the DPU latency val queueDepth = 4 // needs to be at least the DPU latency
val rdQueue = Module(new Queue( val tagQueue = Module(new Queue(
chiselTypeOf(io.writeback.bits.rd), queueDepth chiselTypeOf(operandATag), queueDepth
)) ))
rdQueue.io.enq.valid := dpuFire tagQueue.io.enq.valid := dpuFire
rdQueue.io.enq.bits := rdGen(stepCompute, substepCompute) // A and B should have the same tags
rdQueue.io.deq.ready := io.writeback.fire tagQueue.io.enq.bits := operandATag
assert(rdQueue.io.enq.ready === true.B, // @cleanup: awkward
"rd queue full, throttling DPU operation") tagQueue.io.enq.bits.substep := substepCompute
assert(!dpuValid || rdQueue.io.deq.valid, tagQueue.io.deq.ready := io.writeback.fire
"rd queue and DPU went out of sync") assert(tagQueue.io.enq.ready === true.B,
"tag queue full, DPU operation might be throttled")
assert(!dpuValid || tagQueue.io.deq.valid,
"tag queue and DPU went out of sync")
// TODO: decouple wid from frontend
// val widQueue = Queue(io.initiate, queueDepth, pipe = (queueDepth == 1)) // val widQueue = Queue(io.initiate, queueDepth, pipe = (queueDepth == 1))
// note rd is independent to sets // note rd is independent to sets
@@ -390,11 +390,14 @@ class TensorCoreDecoupled(
(step << 1/*2 substeps*/) + substep (step << 1/*2 substeps*/) + substep
} }
val setWriteback = tagQueue.io.deq.bits.set
val stepWriteback = tagQueue.io.deq.bits.step
val substepWriteback = tagQueue.io.deq.bits.substep
io.writeback.valid := dpuValid io.writeback.valid := dpuValid
// TODO: decouple wid from frontend
io.writeback.bits.wid := warpReg io.writeback.bits.wid := warpReg
io.writeback.bits.rd := rdQueue.io.deq.bits io.writeback.bits.rd := rdGen(stepWriteback, substepWriteback)
// FIXME: look at set/step of dpu output not setExecute io.writeback.bits.last := setDone(setWriteback) && stepDone(stepWriteback)
io.writeback.bits.last := setDone(setExecute) && stepDone(stepExecute)
// State transition // State transition
// ---------------- // ----------------
@@ -500,6 +503,10 @@ class TensorCoreDecoupledTLImp(outer: TensorCoreDecoupledTL)
tensor.io.writeback.ready := true.B tensor.io.writeback.ready := true.B
io.finished := tensor.io.writeback.valid && tensor.io.writeback.bits.last io.finished := tensor.io.writeback.valid && tensor.io.writeback.bits.last
when (io.finished) {
// might be too strong
assert(tensor.io.writeback.bits.rd === 31.U)
}
} }
// a minimal Diplomacy graph with a tensor core and a TLRAM // a minimal Diplomacy graph with a tensor core and a TLRAM