From 2a8c488d282ebc118bf1476c597dd6e8640d100a Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Tue, 22 Oct 2024 23:10:11 -0700 Subject: [PATCH] tensor: Reassert initiate.ready as soon as access ready --- .../radiance/core/TensorCoreDecoupled.scala | 116 +++++++----------- 1 file changed, 44 insertions(+), 72 deletions(-) diff --git a/src/main/scala/radiance/core/TensorCoreDecoupled.scala b/src/main/scala/radiance/core/TensorCoreDecoupled.scala index ae763c6..c42dc29 100644 --- a/src/main/scala/radiance/core/TensorCoreDecoupled.scala +++ b/src/main/scala/radiance/core/TensorCoreDecoupled.scala @@ -69,6 +69,11 @@ class TensorCoreDecoupled( val source = UInt(sourceWidth.W) val data = UInt(dataWidth.W) } + class TensorMemTag extends Bundle { + val warp = UInt(numWarpBits.W) + val set = UInt(setBits.W) + val index = UInt(indexBits.W) + } // mem response after translation from TL source to set/step tag class TensorMemRespWithTag( dataWidth: Int @@ -77,15 +82,11 @@ class TensorCoreDecoupled( val data = UInt(dataWidth.W) } - // FSM - // --- - // This drives the overall pipeline of memory requests, dot-product unit - // operations and regfile writeback. - - val busy = RegInit(false.B) - // Holds the warp id the core is currently working on. Note that we only - // support one outstanding warp request - val warpAccess = RegInit(0.U(numWarpBits.W)) + // =========================================================================== + // Access stage + // =========================================================================== + // + // Frontend of the decoupled access/execute pipeline. // sets: k iteration val numSets = (tilingParams.k / tilingParams.kc) @@ -97,39 +98,15 @@ class TensorCoreDecoupled( val lastStep = ((1 << stepBits) - 1) def setDone(set: UInt) = (set === lastSet.U) def stepDone(step: UInt) = (step === lastStep.U) + // 'index' is the index of a memory request among the sequence of requests + // needed to read a full M-column of A or N-row of B. Its range is [0,m/2) + // or [0,n/2), where 2 is the stride can be read in a single request size. + require(tilingParams.m == tilingParams.n, + "currently only supports square SMEM tile") + val numIndices = tilingParams.m / 2/*FIXME:hardcoded?*/ + val indexBits = log2Ceil(numIndices) + val lastIndex = (1 << indexBits) - 1 - when (io.initiate.fire) { - val wid = io.initiate.bits.wid - busy := true.B - warpAccess := wid - when(io.writeback.fire) { - assert( - io.writeback.bits.wid =/= wid, - "unsupported concurrent initiate and writeback to the same warp" - ) - } - } - - // TODO: @perf: Instead of waiting until the last writeback, release busy as - // soon as the access frontend is complete so that there's a better chance to - // saturate the backend with back-to-back HGMMAs. This would require sending - // the 'wid' register to backend instead of having it shared with the - // frontend. - when(io.writeback.fire && io.writeback.bits.last) { - busy := false.B - } - - // serialize every HGMMA request - io.initiate.ready := !busy - - // =========================================================================== - // Access stage - // =========================================================================== - // - // Frontend of the decoupled access/execute pipeline. - - // States - // object AccessorState extends ChiselEnum { val idle = Value(0.U) val access = Value(1.U) @@ -142,6 +119,30 @@ class TensorCoreDecoupled( val allReqsDone = WireInit(false.B) dontTouch(allReqsDone) + val warpAccess = RegInit(0.U(numWarpBits.W)) + + class BlockState extends Bundle { + val set = UInt(setBits.W) + val index = UInt(indexBits.W) + } + val stateInit = Wire(new BlockState) + stateInit.set := 0.U + stateInit.index := 0.U + val stateA = RegInit(stateInit) + val stateB = RegInit(stateInit) + dontTouch(stateA) + dontTouch(stateA.index) + dontTouch(stateB) + dontTouch(stateB.index) + + io.initiate.ready := (state === AccessorState.idle) + when (io.initiate.fire) { + warpAccess := io.initiate.bits.wid + assert(stateA.set === 0.U && stateA.index === 0.U && + stateB.set === 0.U && stateB.index === 0.U, + "stateA and stateB not initialized to zero") + } + switch(state) { is(AccessorState.idle) { when(io.initiate.fire) { @@ -154,40 +155,11 @@ class TensorCoreDecoupled( } } is(AccessorState.finish) { - // FIXME: decouple writeback - when(io.writeback.fire) { - state := AccessorState.idle - } + // FIXME: is finish state needed? + state := AccessorState.idle } } - // 'index' is the index of a memory request among the sequence of requests - // needed to read a full M-column of A or N-row of B. Its range is [0,m/2) - // or [0,n/2), where 2 is the stride can be read in a single request size. - require(tilingParams.m == tilingParams.n, - "currently only supports square SMEM tile") - val numIndices = tilingParams.m / 2/*FIXME:hardcoded?*/ - val indexBits = log2Ceil(numIndices) - val lastIndex = (1 << indexBits) - 1 - - class State extends Bundle { - val set = UInt(setBits.W) - val index = UInt(indexBits.W) - } - class TensorMemTag extends Bundle { - val warp = UInt(numWarpBits.W) - val set = UInt(setBits.W) - val index = UInt(indexBits.W) - } - - val stateInit = Wire(new State) - stateInit.set := 0.U - stateInit.index := 0.U - val stateA = RegInit(stateInit) - val stateB = RegInit(stateInit) - dontTouch(stateA) - dontTouch(stateB) - when (io.reqA.fire) { when (stateA.index === lastIndex.U) { stateA.set := stateA.set + 1.U