tensor: Decouple A and B access states

Get rid of set/stepAccess states and let A and B access progress
independently.
This commit is contained in:
Hansung Kim
2024-10-18 22:42:41 -07:00
parent c0292dd0aa
commit 0aadc6074a

View File

@@ -82,15 +82,6 @@ class TensorCoreDecoupled(
// This drives the overall pipeline of memory requests, dot-product unit // This drives the overall pipeline of memory requests, dot-product unit
// operations and regfile writeback. // operations and regfile writeback.
object TensorState extends ChiselEnum {
val idle = Value(0.U)
val run = Value(1.U)
// All set/step sequencing is complete and the tensor core is holding the
// result data until downstream writeback is ready.
// FIXME: is this necessary if writeback is decoupled with queues?
val finish = Value(2.U)
}
val state = RegInit(TensorState.idle)
val busy = RegInit(false.B) val busy = RegInit(false.B)
// Holds the warp id the core is currently working on. Note that we only // Holds the warp id the core is currently working on. Note that we only
// support one outstanding warp request // support one outstanding warp request
@@ -107,22 +98,10 @@ class TensorCoreDecoupled(
def setDone(set: UInt) = (set === lastSet.U) def setDone(set: UInt) = (set === lastSet.U)
def stepDone(step: UInt) = (step === lastStep.U) def stepDone(step: UInt) = (step === lastStep.U)
// set and step being currently accessed in the acc/ex frontend when (io.initiate.fire) {
val setAccess = RegInit(0.U(setBits.W))
val stepAccess = RegInit(0.U(stepBits.W))
// we need full 4x4 A tile to fire DPU, but since the memory width is 8
// words, we need 2 cycles to read A. `substep` tells which cycle we're at.
val substepAccess = RegInit(0.U(1.W))
dontTouch(setAccess)
dontTouch(stepAccess)
dontTouch(substepAccess)
when(io.initiate.fire) {
val wid = io.initiate.bits.wid val wid = io.initiate.bits.wid
busy := true.B busy := true.B
warpReg := wid warpReg := wid
setAccess := 0.U
stepAccess := 0.U
when(io.writeback.fire) { when(io.writeback.fire) {
assert( assert(
io.writeback.bits.wid =/= wid, io.writeback.bits.wid =/= wid,
@@ -143,55 +122,51 @@ class TensorCoreDecoupled(
// serialize every HGMMA request // serialize every HGMMA request
io.initiate.ready := !busy io.initiate.ready := !busy
// Memory traffic generation // ===========================================================================
// ------------------------- // Access stage
// ===========================================================================
// //
val numTilesM = tilingParams.m / tilingParams.mc // Frontend of the decoupled access/execute pipeline.
val numTilesN = tilingParams.n / tilingParams.nc
// @cleanup: generalize in terms of M/N/K-majorness?
def addressGen(baseA: UInt, baseB: UInt, set: UInt, step: UInt, substep: UInt)
: (UInt/*A*/, UInt/*B*/) = {
// note that step iterates along N first, then M
val tileM = step % numTilesM.U
val tileN = step / numTilesM.U
// note that both A and B are K-major to facilitate bank conflict-free SMEM // States
// accesses //
// object AccessorState extends ChiselEnum {
// (row,col) coordinate of the compute tile val idle = Value(0.U)
val tileRowA = tileM // M val access = Value(1.U)
val tileColA = set // K // All set/step sequencing is complete and the tensor core is holding the
val tileRowB = tileN // N // result data until downstream writeback is ready.
val tileColB = set // K // FIXME: is this necessary if writeback is decoupled with queues?
// (row,col) coordinate of the starting element of the compute tile val finish = Value(2.U)
val elemRowA = (tileRowA << log2Ceil(tilingParams.mc)) +
(substep << log2Ceil(tilingParams.mc / 2))
val elemColA = tileColA << log2Ceil(tilingParams.kc)
val elemRowB = (tileRowB << log2Ceil(tilingParams.nc)) +
(substep << log2Ceil(tilingParams.nc / 2))
val elemColB = tileColB << log2Ceil(tilingParams.kc)
val rowStrideA = wordSize * tilingParams.k
val rowStrideABits = log2Ceil(rowStrideA)
val rowStrideB = wordSize * tilingParams.k
val rowStrideBBits = log2Ceil(rowStrideB)
val wordStrideBits = log2Ceil(wordSize)
val tileOffsetA = (elemRowA << rowStrideABits) + (elemColA << wordStrideBits)
val tileOffsetB = (elemRowB << rowStrideBBits) + (elemColB << wordStrideBits)
(baseA + tileOffsetA, baseB + tileOffsetB)
} }
val state = RegInit(AccessorState.idle)
val allReqsDone = WireInit(false.B)
dontTouch(allReqsDone)
// FIXME: bogus base address switch(state) {
val (addressA, addressB) = is(AccessorState.idle) {
addressGen(0.U, 0.U, setAccess, stepAccess, substepAccess) when(io.initiate.fire) {
state := AccessorState.access
}
}
is(AccessorState.access) {
when (allReqsDone) {
state := AccessorState.finish
}
}
is(AccessorState.finish) {
// FIXME: decouple writeback
when(io.writeback.fire) {
state := AccessorState.idle
}
}
}
// 'index' is the index of a memory request among the sequence of requests // 'index' is the index of a memory request among the sequence of requests
// needed to read a full M-column of A or N-row of B. Its range is [0,m/2) // needed to read a full M-column of A or N-row of B. Its range is [0,m/2)
// or [0,n/2), where 2 is the stride can be read in a single request size. // or [0,n/2), where 2 is the stride can be read in a single request size.
require(tilingParams.m == tilingParams.n, require(tilingParams.m == tilingParams.n,
"currently only supports square SMEM tile") "currently only supports square SMEM tile")
val numIndices = tilingParams.m / 2 val numIndices = tilingParams.m / 2/*FIXME:hardcoded?*/
val indexBits = log2Ceil(numIndices) val indexBits = log2Ceil(numIndices)
val lastIndex = (1 << indexBits) - 1 val lastIndex = (1 << indexBits) - 1
@@ -219,9 +194,51 @@ class TensorCoreDecoupled(
tagB.index := tagB.index + 1.U tagB.index := tagB.index + 1.U
} }
val genReqA = (state === TensorState.run) // Address generation
val genReqB = (state === TensorState.run) //
def addressGen(base: UInt, set: UInt, index: UInt): UInt = {
// note that both A and B are K-major to facilitate bank conflict-free SMEM
// accesses, so that below code applies to both.
//
// (row,col) coordinate of the compute tile
val tileRow = index
val tileCol = set
// (row,col) coordinate of the starting element of the compute tile
val elemRow = index << 1
val elemCol = tileCol << log2Ceil(tilingParams.kc)
val rowStride = tilingParams.k * wordSize
val rowStrideBits = log2Ceil(rowStride)
val wordStrideBits = log2Ceil(wordSize)
val tileOffset = (elemRow << rowStrideBits) + (elemCol << wordStrideBits)
base + tileOffset
}
// FIXME: bogus base address
val addressA = addressGen(0.U, tagA.set, tagA.index)
val addressB = addressGen(0.U, tagB.set, tagB.index)
val lastReqA = (tagA.set === lastSet.U) && (tagA.index === lastIndex.U)
val lastReqB = (tagB.set === lastSet.U) && (tagB.index === lastIndex.U)
val doneReqA = RegInit(false.B)
val doneReqB = RegInit(false.B)
when (lastReqA && io.reqA.fire) { doneReqA := true.B }
when (lastReqB && io.reqB.fire) { doneReqB := true.B }
val genReqA = (state === AccessorState.access) && !doneReqA
val genReqB = (state === AccessorState.access) && !doneReqA
when (state === AccessorState.finish) {
doneReqA := false.B
doneReqB := false.B
tagA.set := 0.U
tagA.index := 0.U
tagB.set := 0.U
tagB.index := 0.U
}
allReqsDone := doneReqA && doneReqB
// Request generation
//
val respATagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth))) val respATagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth)))
val respBTagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth))) val respBTagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth)))
Seq((io.reqA, (io.respA, respATagged)), Seq((io.reqA, (io.respA, respATagged)),
@@ -249,34 +266,13 @@ class TensorCoreDecoupled(
} }
} }
// only advance to the next step if we fired mem requests for both A and B. // ===========================================================================
// also consider that B doesn't have to be fired every time due to reuse.
// @perf: too strict? should be able to have A and B progress separately
val firedAReg = RegInit(false.B)
val firedBReg = RegInit(false.B)
when (io.reqA.fire) { firedAReg := true.B }
when (io.reqB.fire) { firedBReg := true.B }
val firedANow = io.reqA.fire
val firedBNow = io.reqB.fire
val firedA = firedAReg || firedANow
val firedB = firedBReg || firedBNow
val nextSubstepAccess = firedA && firedB
val nextStepAccess = nextSubstepAccess && (substepAccess === 1.U)
// clear out firedABReg every substep
when (nextSubstepAccess) {
firedAReg := false.B
firedBReg := false.B
substepAccess := substepAccess + 1.U
}
require(substepAccess.widthOption.get == 1, "there should be only two substeps")
// Execute stage // Execute stage
// ------------- // ===========================================================================
//
// Backend of the decoupled access/execute pipeline. // Backend of the decoupled access/execute pipeline.
// //
// set and step being currently executed in the acc/ex backend val respQueueDepth = 8 // FIXME: parameterize
val respQueueDepth = 4 // FIXME: parameterize
val respQueueA = Queue(respATagged, respQueueDepth) val respQueueA = Queue(respATagged, respQueueDepth)
val respQueueB = Queue(respBTagged, respQueueDepth) val respQueueB = Queue(respBTagged, respQueueDepth)
@@ -369,6 +365,7 @@ class TensorCoreDecoupled(
// Operand selection // Operand selection
// //
// select the correct 4x4 tile from A operand buffer // select the correct 4x4 tile from A operand buffer
val numTilesM = tilingParams.m / tilingParams.mc
val numTilesMBits = log2Ceil(numTilesM) val numTilesMBits = log2Ceil(numTilesM)
def selectOperandA(buf: Vec[UInt]): UInt = { def selectOperandA(buf: Vec[UInt]): UInt = {
require(buf.length == numIndices) require(buf.length == numIndices)
@@ -383,7 +380,7 @@ class TensorCoreDecoupled(
dontTouch(operandATag) dontTouch(operandATag)
dontTouch(operandBTag) dontTouch(operandBTag)
// Operand buffer dequeue logic // Operand buffer logic
// //
// hold A data until the entire set is done // hold A data until the entire set is done
val shouldDequeueAMask = ((1 << stepBits) - 1).U val shouldDequeueAMask = ((1 << stepBits) - 1).U
@@ -476,8 +473,8 @@ class TensorCoreDecoupled(
} }
io.writeback.bits.data := flattenedDPUOut io.writeback.bits.data := flattenedDPUOut
// Writeback queues // Writeback logic
// ---------------- //
// These queues hold metadata needed for writeback in sync with the DPU. // These queues hold metadata needed for writeback in sync with the DPU.
class TensorComputeTag extends Bundle { class TensorComputeTag extends Bundle {
@@ -530,28 +527,7 @@ class TensorCoreDecoupled(
} }
} }
} }
sequenceSetStep(setAccess, stepAccess, nextStepAccess)
sequenceSetStep(setCompute, stepCompute, nextStepCompute) sequenceSetStep(setCompute, stepCompute, nextStepCompute)
switch(state) {
is(TensorState.idle) {
when(io.initiate.fire) {
state := TensorState.run
}
}
is(TensorState.run) {
when (setDone(setAccess) && stepDone(stepAccess) && nextStepAccess) {
when (state === TensorState.run) {
state := TensorState.finish
}
}
}
is(TensorState.finish) {
when(io.writeback.fire) {
state := TensorState.idle
}
}
}
} }
// A buffer that collects multiple entries of input data and exposes the // A buffer that collects multiple entries of input data and exposes the