tensor: Sync rd with DPU using a queue
This commit is contained in:
@@ -270,6 +270,8 @@ class TensorCoreDecoupled(
|
|||||||
val operandB = respQueueB.bits.data
|
val operandB = respQueueB.bits.data
|
||||||
val dpuReady = Wire(Bool())
|
val dpuReady = Wire(Bool())
|
||||||
val dpuFire = operandsValid && dpuReady
|
val dpuFire = operandsValid && dpuReady
|
||||||
|
val setCompute = fullAQueue.io.deq.bits.tag.set
|
||||||
|
val stepCompute = fullAQueue.io.deq.bits.tag.step
|
||||||
val substepCompute = RegInit(0.U(1.W))
|
val substepCompute = RegInit(0.U(1.W))
|
||||||
when (dpuFire) {
|
when (dpuFire) {
|
||||||
substepCompute := substepCompute + 1.U
|
substepCompute := substepCompute + 1.U
|
||||||
@@ -348,9 +350,9 @@ class TensorCoreDecoupled(
|
|||||||
def assertDPU = {
|
def assertDPU = {
|
||||||
val dpuStalls = dpus.flatMap(_.map(_.io.stall))
|
val dpuStalls = dpus.flatMap(_.map(_.io.stall))
|
||||||
assert(dpuStalls.reduce(_ && _) === dpuStalls.reduce(_ || _),
|
assert(dpuStalls.reduce(_ && _) === dpuStalls.reduce(_ || _),
|
||||||
"stall signals of DPUs went unaligned")
|
"stall signals of DPUs went out of sync")
|
||||||
assert(dpuValids.reduce(_ && _) === dpuValids.reduce(_ || _),
|
assert(dpuValids.reduce(_ && _) === dpuValids.reduce(_ || _),
|
||||||
"valid signals of DPUs went unaligned")
|
"valid signals of DPUs went out of sync")
|
||||||
}
|
}
|
||||||
assertDPU
|
assertDPU
|
||||||
|
|
||||||
@@ -362,17 +364,36 @@ class TensorCoreDecoupled(
|
|||||||
}
|
}
|
||||||
io.writeback.bits.data := flattenedDPUOut
|
io.writeback.bits.data := flattenedDPUOut
|
||||||
|
|
||||||
def rdGen(set: UInt, step: UInt): UInt = {
|
// Writeback queues
|
||||||
|
// ----------------
|
||||||
|
// These queues hold metadata needed for writeback in sync with the DPU.
|
||||||
|
|
||||||
|
val queueDepth = 4 // needs to be at least the DPU latency
|
||||||
|
val rdQueue = Module(new Queue(
|
||||||
|
chiselTypeOf(io.writeback.bits.rd), queueDepth
|
||||||
|
))
|
||||||
|
rdQueue.io.enq.valid := dpuFire
|
||||||
|
rdQueue.io.enq.bits := rdGen(stepCompute, substepCompute)
|
||||||
|
rdQueue.io.deq.ready := io.writeback.fire
|
||||||
|
assert(rdQueue.io.enq.ready === true.B,
|
||||||
|
"rd queue full, throttling DPU operation")
|
||||||
|
assert(!dpuValid || rdQueue.io.deq.valid,
|
||||||
|
"rd queue and DPU went out of sync")
|
||||||
|
|
||||||
|
// TODO: decouple wid from frontend
|
||||||
|
// val widQueue = Queue(io.initiate, queueDepth, pipe = (queueDepth == 1))
|
||||||
|
|
||||||
|
// note rd is independent to sets
|
||||||
|
def rdGen(step: UInt, substep: UInt): UInt = {
|
||||||
// each step produces 4x4 output tile, written by 8 threads with 2 regs per
|
// each step produces 4x4 output tile, written by 8 threads with 2 regs per
|
||||||
// thread
|
// thread
|
||||||
require(numLanes == 8, "currently assumes 8-wide warps")
|
(step << 1/*2 substeps*/) + substep
|
||||||
(Cat(set, step) >> 1/*2 regs/thread*/)
|
|
||||||
// FIXME: add substep here
|
|
||||||
}
|
}
|
||||||
|
|
||||||
io.writeback.valid := dpuValid
|
io.writeback.valid := dpuValid
|
||||||
io.writeback.bits.wid := warpReg
|
io.writeback.bits.wid := warpReg
|
||||||
io.writeback.bits.rd := rdGen(setExecute, stepExecute)
|
io.writeback.bits.rd := rdQueue.io.deq.bits
|
||||||
|
// FIXME: look at set/step of dpu output not setExecute
|
||||||
io.writeback.bits.last := setDone(setExecute) && stepDone(stepExecute)
|
io.writeback.bits.last := setDone(setExecute) && stepDone(stepExecute)
|
||||||
|
|
||||||
// State transition
|
// State transition
|
||||||
@@ -410,15 +431,6 @@ class TensorCoreDecoupled(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Writeback queues
|
|
||||||
// ----------------
|
|
||||||
// These queues hold the metadata necessary for register
|
|
||||||
// writeback.
|
|
||||||
|
|
||||||
// val queueDepth = 2
|
|
||||||
// val widQueue = Queue(io.initiate, queueDepth, pipe = (queueDepth == 1))
|
|
||||||
// val rdQueue = Queue(io.initiate, queueDepth, pipe = (queueDepth == 1))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// synthesizable unit tests
|
// synthesizable unit tests
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ class TensorDotProductUnit(val half: Boolean) extends Module with tile.HasFPUPar
|
|||||||
io.out.bits.data := ieee(box(dpu.io.out.bits.data, S))
|
io.out.bits.data := ieee(box(dpu.io.out.bits.data, S))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Copied from chisel3.util.Pipe.
|
// An implementation of chisel3.util.Pipe that supports stalls.
|
||||||
class StallingPipe[T <: Data](val gen: T, val latency: Int = 1) extends Module {
|
class StallingPipe[T <: Data](val gen: T, val latency: Int = 1) extends Module {
|
||||||
/** A non-ambiguous name of this `StallingPipe` for use in generated Verilog
|
/** A non-ambiguous name of this `StallingPipe` for use in generated Verilog
|
||||||
* names. Includes the latency cycle count in the name as well as the
|
* names. Includes the latency cycle count in the name as well as the
|
||||||
|
|||||||
Reference in New Issue
Block a user