tensor: Decouple A and B access states
Get rid of set/stepAccess states and let A and B access progress independently.
This commit is contained in:
@@ -82,15 +82,6 @@ class TensorCoreDecoupled(
|
|||||||
// This drives the overall pipeline of memory requests, dot-product unit
|
// This drives the overall pipeline of memory requests, dot-product unit
|
||||||
// operations and regfile writeback.
|
// operations and regfile writeback.
|
||||||
|
|
||||||
object TensorState extends ChiselEnum {
|
|
||||||
val idle = Value(0.U)
|
|
||||||
val run = Value(1.U)
|
|
||||||
// All set/step sequencing is complete and the tensor core is holding the
|
|
||||||
// result data until downstream writeback is ready.
|
|
||||||
// FIXME: is this necessary if writeback is decoupled with queues?
|
|
||||||
val finish = Value(2.U)
|
|
||||||
}
|
|
||||||
val state = RegInit(TensorState.idle)
|
|
||||||
val busy = RegInit(false.B)
|
val busy = RegInit(false.B)
|
||||||
// Holds the warp id the core is currently working on. Note that we only
|
// Holds the warp id the core is currently working on. Note that we only
|
||||||
// support one outstanding warp request
|
// support one outstanding warp request
|
||||||
@@ -107,22 +98,10 @@ class TensorCoreDecoupled(
|
|||||||
def setDone(set: UInt) = (set === lastSet.U)
|
def setDone(set: UInt) = (set === lastSet.U)
|
||||||
def stepDone(step: UInt) = (step === lastStep.U)
|
def stepDone(step: UInt) = (step === lastStep.U)
|
||||||
|
|
||||||
// set and step being currently accessed in the acc/ex frontend
|
when (io.initiate.fire) {
|
||||||
val setAccess = RegInit(0.U(setBits.W))
|
|
||||||
val stepAccess = RegInit(0.U(stepBits.W))
|
|
||||||
// we need full 4x4 A tile to fire DPU, but since the memory width is 8
|
|
||||||
// words, we need 2 cycles to read A. `substep` tells which cycle we're at.
|
|
||||||
val substepAccess = RegInit(0.U(1.W))
|
|
||||||
dontTouch(setAccess)
|
|
||||||
dontTouch(stepAccess)
|
|
||||||
dontTouch(substepAccess)
|
|
||||||
|
|
||||||
when(io.initiate.fire) {
|
|
||||||
val wid = io.initiate.bits.wid
|
val wid = io.initiate.bits.wid
|
||||||
busy := true.B
|
busy := true.B
|
||||||
warpReg := wid
|
warpReg := wid
|
||||||
setAccess := 0.U
|
|
||||||
stepAccess := 0.U
|
|
||||||
when(io.writeback.fire) {
|
when(io.writeback.fire) {
|
||||||
assert(
|
assert(
|
||||||
io.writeback.bits.wid =/= wid,
|
io.writeback.bits.wid =/= wid,
|
||||||
@@ -143,55 +122,51 @@ class TensorCoreDecoupled(
|
|||||||
// serialize every HGMMA request
|
// serialize every HGMMA request
|
||||||
io.initiate.ready := !busy
|
io.initiate.ready := !busy
|
||||||
|
|
||||||
// Memory traffic generation
|
// ===========================================================================
|
||||||
// -------------------------
|
// Access stage
|
||||||
|
// ===========================================================================
|
||||||
//
|
//
|
||||||
val numTilesM = tilingParams.m / tilingParams.mc
|
// Frontend of the decoupled access/execute pipeline.
|
||||||
val numTilesN = tilingParams.n / tilingParams.nc
|
|
||||||
// @cleanup: generalize in terms of M/N/K-majorness?
|
|
||||||
def addressGen(baseA: UInt, baseB: UInt, set: UInt, step: UInt, substep: UInt)
|
|
||||||
: (UInt/*A*/, UInt/*B*/) = {
|
|
||||||
// note that step iterates along N first, then M
|
|
||||||
val tileM = step % numTilesM.U
|
|
||||||
val tileN = step / numTilesM.U
|
|
||||||
|
|
||||||
// note that both A and B are K-major to facilitate bank conflict-free SMEM
|
// States
|
||||||
// accesses
|
//
|
||||||
//
|
object AccessorState extends ChiselEnum {
|
||||||
// (row,col) coordinate of the compute tile
|
val idle = Value(0.U)
|
||||||
val tileRowA = tileM // M
|
val access = Value(1.U)
|
||||||
val tileColA = set // K
|
// All set/step sequencing is complete and the tensor core is holding the
|
||||||
val tileRowB = tileN // N
|
// result data until downstream writeback is ready.
|
||||||
val tileColB = set // K
|
// FIXME: is this necessary if writeback is decoupled with queues?
|
||||||
// (row,col) coordinate of the starting element of the compute tile
|
val finish = Value(2.U)
|
||||||
val elemRowA = (tileRowA << log2Ceil(tilingParams.mc)) +
|
|
||||||
(substep << log2Ceil(tilingParams.mc / 2))
|
|
||||||
val elemColA = tileColA << log2Ceil(tilingParams.kc)
|
|
||||||
val elemRowB = (tileRowB << log2Ceil(tilingParams.nc)) +
|
|
||||||
(substep << log2Ceil(tilingParams.nc / 2))
|
|
||||||
val elemColB = tileColB << log2Ceil(tilingParams.kc)
|
|
||||||
val rowStrideA = wordSize * tilingParams.k
|
|
||||||
val rowStrideABits = log2Ceil(rowStrideA)
|
|
||||||
val rowStrideB = wordSize * tilingParams.k
|
|
||||||
val rowStrideBBits = log2Ceil(rowStrideB)
|
|
||||||
val wordStrideBits = log2Ceil(wordSize)
|
|
||||||
|
|
||||||
val tileOffsetA = (elemRowA << rowStrideABits) + (elemColA << wordStrideBits)
|
|
||||||
val tileOffsetB = (elemRowB << rowStrideBBits) + (elemColB << wordStrideBits)
|
|
||||||
|
|
||||||
(baseA + tileOffsetA, baseB + tileOffsetB)
|
|
||||||
}
|
}
|
||||||
|
val state = RegInit(AccessorState.idle)
|
||||||
|
val allReqsDone = WireInit(false.B)
|
||||||
|
dontTouch(allReqsDone)
|
||||||
|
|
||||||
// FIXME: bogus base address
|
switch(state) {
|
||||||
val (addressA, addressB) =
|
is(AccessorState.idle) {
|
||||||
addressGen(0.U, 0.U, setAccess, stepAccess, substepAccess)
|
when(io.initiate.fire) {
|
||||||
|
state := AccessorState.access
|
||||||
|
}
|
||||||
|
}
|
||||||
|
is(AccessorState.access) {
|
||||||
|
when (allReqsDone) {
|
||||||
|
state := AccessorState.finish
|
||||||
|
}
|
||||||
|
}
|
||||||
|
is(AccessorState.finish) {
|
||||||
|
// FIXME: decouple writeback
|
||||||
|
when(io.writeback.fire) {
|
||||||
|
state := AccessorState.idle
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 'index' is the index of a memory request among the sequence of requests
|
// 'index' is the index of a memory request among the sequence of requests
|
||||||
// needed to read a full M-column of A or N-row of B. Its range is [0,m/2)
|
// needed to read a full M-column of A or N-row of B. Its range is [0,m/2)
|
||||||
// or [0,n/2), where 2 is the stride can be read in a single request size.
|
// or [0,n/2), where 2 is the stride can be read in a single request size.
|
||||||
require(tilingParams.m == tilingParams.n,
|
require(tilingParams.m == tilingParams.n,
|
||||||
"currently only supports square SMEM tile")
|
"currently only supports square SMEM tile")
|
||||||
val numIndices = tilingParams.m / 2
|
val numIndices = tilingParams.m / 2/*FIXME:hardcoded?*/
|
||||||
val indexBits = log2Ceil(numIndices)
|
val indexBits = log2Ceil(numIndices)
|
||||||
val lastIndex = (1 << indexBits) - 1
|
val lastIndex = (1 << indexBits) - 1
|
||||||
|
|
||||||
@@ -219,9 +194,51 @@ class TensorCoreDecoupled(
|
|||||||
tagB.index := tagB.index + 1.U
|
tagB.index := tagB.index + 1.U
|
||||||
}
|
}
|
||||||
|
|
||||||
val genReqA = (state === TensorState.run)
|
// Address generation
|
||||||
val genReqB = (state === TensorState.run)
|
//
|
||||||
|
def addressGen(base: UInt, set: UInt, index: UInt): UInt = {
|
||||||
|
// note that both A and B are K-major to facilitate bank conflict-free SMEM
|
||||||
|
// accesses, so that below code applies to both.
|
||||||
|
//
|
||||||
|
// (row,col) coordinate of the compute tile
|
||||||
|
val tileRow = index
|
||||||
|
val tileCol = set
|
||||||
|
// (row,col) coordinate of the starting element of the compute tile
|
||||||
|
val elemRow = index << 1
|
||||||
|
val elemCol = tileCol << log2Ceil(tilingParams.kc)
|
||||||
|
val rowStride = tilingParams.k * wordSize
|
||||||
|
val rowStrideBits = log2Ceil(rowStride)
|
||||||
|
val wordStrideBits = log2Ceil(wordSize)
|
||||||
|
val tileOffset = (elemRow << rowStrideBits) + (elemCol << wordStrideBits)
|
||||||
|
|
||||||
|
base + tileOffset
|
||||||
|
}
|
||||||
|
|
||||||
|
// FIXME: bogus base address
|
||||||
|
val addressA = addressGen(0.U, tagA.set, tagA.index)
|
||||||
|
val addressB = addressGen(0.U, tagB.set, tagB.index)
|
||||||
|
|
||||||
|
val lastReqA = (tagA.set === lastSet.U) && (tagA.index === lastIndex.U)
|
||||||
|
val lastReqB = (tagB.set === lastSet.U) && (tagB.index === lastIndex.U)
|
||||||
|
val doneReqA = RegInit(false.B)
|
||||||
|
val doneReqB = RegInit(false.B)
|
||||||
|
when (lastReqA && io.reqA.fire) { doneReqA := true.B }
|
||||||
|
when (lastReqB && io.reqB.fire) { doneReqB := true.B }
|
||||||
|
val genReqA = (state === AccessorState.access) && !doneReqA
|
||||||
|
val genReqB = (state === AccessorState.access) && !doneReqA
|
||||||
|
when (state === AccessorState.finish) {
|
||||||
|
doneReqA := false.B
|
||||||
|
doneReqB := false.B
|
||||||
|
tagA.set := 0.U
|
||||||
|
tagA.index := 0.U
|
||||||
|
tagB.set := 0.U
|
||||||
|
tagB.index := 0.U
|
||||||
|
}
|
||||||
|
|
||||||
|
allReqsDone := doneReqA && doneReqB
|
||||||
|
|
||||||
|
// Request generation
|
||||||
|
//
|
||||||
val respATagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth)))
|
val respATagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth)))
|
||||||
val respBTagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth)))
|
val respBTagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth)))
|
||||||
Seq((io.reqA, (io.respA, respATagged)),
|
Seq((io.reqA, (io.respA, respATagged)),
|
||||||
@@ -249,34 +266,13 @@ class TensorCoreDecoupled(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// only advance to the next step if we fired mem requests for both A and B.
|
// ===========================================================================
|
||||||
// also consider that B doesn't have to be fired every time due to reuse.
|
|
||||||
// @perf: too strict? should be able to have A and B progress separately
|
|
||||||
val firedAReg = RegInit(false.B)
|
|
||||||
val firedBReg = RegInit(false.B)
|
|
||||||
when (io.reqA.fire) { firedAReg := true.B }
|
|
||||||
when (io.reqB.fire) { firedBReg := true.B }
|
|
||||||
val firedANow = io.reqA.fire
|
|
||||||
val firedBNow = io.reqB.fire
|
|
||||||
val firedA = firedAReg || firedANow
|
|
||||||
val firedB = firedBReg || firedBNow
|
|
||||||
val nextSubstepAccess = firedA && firedB
|
|
||||||
val nextStepAccess = nextSubstepAccess && (substepAccess === 1.U)
|
|
||||||
// clear out firedABReg every substep
|
|
||||||
when (nextSubstepAccess) {
|
|
||||||
firedAReg := false.B
|
|
||||||
firedBReg := false.B
|
|
||||||
substepAccess := substepAccess + 1.U
|
|
||||||
}
|
|
||||||
require(substepAccess.widthOption.get == 1, "there should be only two substeps")
|
|
||||||
|
|
||||||
// Execute stage
|
// Execute stage
|
||||||
// -------------
|
// ===========================================================================
|
||||||
|
//
|
||||||
// Backend of the decoupled access/execute pipeline.
|
// Backend of the decoupled access/execute pipeline.
|
||||||
//
|
//
|
||||||
// set and step being currently executed in the acc/ex backend
|
val respQueueDepth = 8 // FIXME: parameterize
|
||||||
|
|
||||||
val respQueueDepth = 4 // FIXME: parameterize
|
|
||||||
val respQueueA = Queue(respATagged, respQueueDepth)
|
val respQueueA = Queue(respATagged, respQueueDepth)
|
||||||
val respQueueB = Queue(respBTagged, respQueueDepth)
|
val respQueueB = Queue(respBTagged, respQueueDepth)
|
||||||
|
|
||||||
@@ -369,6 +365,7 @@ class TensorCoreDecoupled(
|
|||||||
// Operand selection
|
// Operand selection
|
||||||
//
|
//
|
||||||
// select the correct 4x4 tile from A operand buffer
|
// select the correct 4x4 tile from A operand buffer
|
||||||
|
val numTilesM = tilingParams.m / tilingParams.mc
|
||||||
val numTilesMBits = log2Ceil(numTilesM)
|
val numTilesMBits = log2Ceil(numTilesM)
|
||||||
def selectOperandA(buf: Vec[UInt]): UInt = {
|
def selectOperandA(buf: Vec[UInt]): UInt = {
|
||||||
require(buf.length == numIndices)
|
require(buf.length == numIndices)
|
||||||
@@ -383,7 +380,7 @@ class TensorCoreDecoupled(
|
|||||||
dontTouch(operandATag)
|
dontTouch(operandATag)
|
||||||
dontTouch(operandBTag)
|
dontTouch(operandBTag)
|
||||||
|
|
||||||
// Operand buffer dequeue logic
|
// Operand buffer logic
|
||||||
//
|
//
|
||||||
// hold A data until the entire set is done
|
// hold A data until the entire set is done
|
||||||
val shouldDequeueAMask = ((1 << stepBits) - 1).U
|
val shouldDequeueAMask = ((1 << stepBits) - 1).U
|
||||||
@@ -476,8 +473,8 @@ class TensorCoreDecoupled(
|
|||||||
}
|
}
|
||||||
io.writeback.bits.data := flattenedDPUOut
|
io.writeback.bits.data := flattenedDPUOut
|
||||||
|
|
||||||
// Writeback queues
|
// Writeback logic
|
||||||
// ----------------
|
//
|
||||||
// These queues hold metadata needed for writeback in sync with the DPU.
|
// These queues hold metadata needed for writeback in sync with the DPU.
|
||||||
|
|
||||||
class TensorComputeTag extends Bundle {
|
class TensorComputeTag extends Bundle {
|
||||||
@@ -530,28 +527,7 @@ class TensorCoreDecoupled(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
sequenceSetStep(setAccess, stepAccess, nextStepAccess)
|
|
||||||
sequenceSetStep(setCompute, stepCompute, nextStepCompute)
|
sequenceSetStep(setCompute, stepCompute, nextStepCompute)
|
||||||
|
|
||||||
switch(state) {
|
|
||||||
is(TensorState.idle) {
|
|
||||||
when(io.initiate.fire) {
|
|
||||||
state := TensorState.run
|
|
||||||
}
|
|
||||||
}
|
|
||||||
is(TensorState.run) {
|
|
||||||
when (setDone(setAccess) && stepDone(stepAccess) && nextStepAccess) {
|
|
||||||
when (state === TensorState.run) {
|
|
||||||
state := TensorState.finish
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
is(TensorState.finish) {
|
|
||||||
when(io.writeback.fire) {
|
|
||||||
state := TensorState.idle
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// A buffer that collects multiple entries of input data and exposes the
|
// A buffer that collects multiple entries of input data and exposes the
|
||||||
|
|||||||
Reference in New Issue
Block a user