tensor: Sync rd with DPU using a queue
This commit is contained in:
@@ -270,6 +270,8 @@ class TensorCoreDecoupled(
|
||||
val operandB = respQueueB.bits.data
|
||||
val dpuReady = Wire(Bool())
|
||||
val dpuFire = operandsValid && dpuReady
|
||||
val setCompute = fullAQueue.io.deq.bits.tag.set
|
||||
val stepCompute = fullAQueue.io.deq.bits.tag.step
|
||||
val substepCompute = RegInit(0.U(1.W))
|
||||
when (dpuFire) {
|
||||
substepCompute := substepCompute + 1.U
|
||||
@@ -348,9 +350,9 @@ class TensorCoreDecoupled(
|
||||
def assertDPU = {
|
||||
val dpuStalls = dpus.flatMap(_.map(_.io.stall))
|
||||
assert(dpuStalls.reduce(_ && _) === dpuStalls.reduce(_ || _),
|
||||
"stall signals of DPUs went unaligned")
|
||||
"stall signals of DPUs went out of sync")
|
||||
assert(dpuValids.reduce(_ && _) === dpuValids.reduce(_ || _),
|
||||
"valid signals of DPUs went unaligned")
|
||||
"valid signals of DPUs went out of sync")
|
||||
}
|
||||
assertDPU
|
||||
|
||||
@@ -362,17 +364,36 @@ class TensorCoreDecoupled(
|
||||
}
|
||||
io.writeback.bits.data := flattenedDPUOut
|
||||
|
||||
def rdGen(set: UInt, step: UInt): UInt = {
|
||||
// Writeback queues
|
||||
// ----------------
|
||||
// These queues hold metadata needed for writeback in sync with the DPU.
|
||||
|
||||
val queueDepth = 4 // needs to be at least the DPU latency
|
||||
val rdQueue = Module(new Queue(
|
||||
chiselTypeOf(io.writeback.bits.rd), queueDepth
|
||||
))
|
||||
rdQueue.io.enq.valid := dpuFire
|
||||
rdQueue.io.enq.bits := rdGen(stepCompute, substepCompute)
|
||||
rdQueue.io.deq.ready := io.writeback.fire
|
||||
assert(rdQueue.io.enq.ready === true.B,
|
||||
"rd queue full, throttling DPU operation")
|
||||
assert(!dpuValid || rdQueue.io.deq.valid,
|
||||
"rd queue and DPU went out of sync")
|
||||
|
||||
// TODO: decouple wid from frontend
|
||||
// val widQueue = Queue(io.initiate, queueDepth, pipe = (queueDepth == 1))
|
||||
|
||||
// note rd is independent to sets
|
||||
def rdGen(step: UInt, substep: UInt): UInt = {
|
||||
// each step produces 4x4 output tile, written by 8 threads with 2 regs per
|
||||
// thread
|
||||
require(numLanes == 8, "currently assumes 8-wide warps")
|
||||
(Cat(set, step) >> 1/*2 regs/thread*/)
|
||||
// FIXME: add substep here
|
||||
(step << 1/*2 substeps*/) + substep
|
||||
}
|
||||
|
||||
io.writeback.valid := dpuValid
|
||||
io.writeback.bits.wid := warpReg
|
||||
io.writeback.bits.rd := rdGen(setExecute, stepExecute)
|
||||
io.writeback.bits.rd := rdQueue.io.deq.bits
|
||||
// FIXME: look at set/step of dpu output not setExecute
|
||||
io.writeback.bits.last := setDone(setExecute) && stepDone(stepExecute)
|
||||
|
||||
// State transition
|
||||
@@ -410,15 +431,6 @@ class TensorCoreDecoupled(
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Writeback queues
|
||||
// ----------------
|
||||
// These queues hold the metadata necessary for register
|
||||
// writeback.
|
||||
|
||||
// val queueDepth = 2
|
||||
// val widQueue = Queue(io.initiate, queueDepth, pipe = (queueDepth == 1))
|
||||
// val rdQueue = Queue(io.initiate, queueDepth, pipe = (queueDepth == 1))
|
||||
}
|
||||
|
||||
// synthesizable unit tests
|
||||
|
||||
@@ -53,7 +53,7 @@ class TensorDotProductUnit(val half: Boolean) extends Module with tile.HasFPUPar
|
||||
io.out.bits.data := ieee(box(dpu.io.out.bits.data, S))
|
||||
}
|
||||
|
||||
// Copied from chisel3.util.Pipe.
|
||||
// An implementation of chisel3.util.Pipe that supports stalls.
|
||||
class StallingPipe[T <: Data](val gen: T, val latency: Int = 1) extends Module {
|
||||
/** A non-ambiguous name of this `StallingPipe` for use in generated Verilog
|
||||
* names. Includes the latency cycle count in the name as well as the
|
||||
|
||||
Reference in New Issue
Block a user