diff --git a/src/main/scala/radiance/core/TensorCoreDecoupled.scala b/src/main/scala/radiance/core/TensorCoreDecoupled.scala index 4744266..935ed40 100644 --- a/src/main/scala/radiance/core/TensorCoreDecoupled.scala +++ b/src/main/scala/radiance/core/TensorCoreDecoupled.scala @@ -42,6 +42,7 @@ class TensorCoreDecoupled( val reqA = Decoupled(new TensorMemReq) val reqB = Decoupled(new TensorMemReq) }) + dontTouch(io) // FSM // --- @@ -62,48 +63,70 @@ class TensorCoreDecoupled( // support one outstanding warp request val warpReg = RegInit(0.U(numWarpBits.W)) - // TODO: just transition every cycle for now - def nextState(state: TensorState.Type) = state match { - case TensorState.idle => Mux(io.initiate.fire, TensorState.run, state) - case TensorState.run => TensorState.finish - case TensorState.finish => { - // hold until writeback is cleared - Mux(io.writeback.ready, TensorState.idle, state) - } - case _ => TensorState.idle - } - state := nextState(state) - - // state table for every warp id // sets: k iteration val numSets = (tilingParams.k / tilingParams.kc) val setBits = log2Ceil(numSets) // steps: i-j iteration val numSteps = (tilingParams.m * tilingParams.n) / (tilingParams.mc * tilingParams.nc) val stepBits = log2Ceil(numSteps) - val setReg = RegInit(0.U(setBits.W)) - val stepReg = RegInit(0.U(setBits.W)) - // val tableRow = Valid(new Bundle { - // val set = UInt(setBits.W) - // val step = UInt(stepBits.W) - // }) + val set = RegInit(0.U(setBits.W)) + val step = RegInit(0.U(stepBits.W)) when(io.initiate.fire) { val wid = io.initiate.bits.wid busy := true.B warpReg := wid - setReg := 0.U - stepReg := 0.U + set := 0.U + step := 0.U when(io.writeback.fire) { - assert(io.writeback.bits.wid =/= wid, - "unsupported concurrent initiate and writeback to the same warp") + assert( + io.writeback.bits.wid =/= wid, + "unsupported concurrent initiate and writeback to the same warp" + ) } } - when (io.writeback.fire) { + when(io.writeback.fire) { busy := false.B } + // set/step sequencing logic + val nextStep = true.B // TODO + val lastSet = ((1 << setBits) - 1) + val lastStep = ((1 << stepBits) - 1) + val setDone = (set === lastSet.U) + val stepDone = (step === lastStep.U) + when (nextStep) { + step := (step + 1.U) & lastStep.U + when (stepDone) { + set := (set + 1.U) & lastSet.U + } + } + + // state transition logic + switch(state) { + is(TensorState.idle) { + when(io.initiate.fire) { + state := TensorState.run + } + } + is(TensorState.run) { + when (setDone && stepDone && nextStep) { + when (state === TensorState.run) { + state := TensorState.finish + } + } + } + is(TensorState.finish) { + when(io.writeback.fire) { + state := TensorState.idle + } + } + } + io.initiate.ready := !busy + io.writeback.valid := (state === TensorState.finish) + io.writeback.bits.wid := warpReg + io.writeback.bits.last := false.B // TODO // Writeback queues // ---------------- @@ -114,13 +137,6 @@ class TensorCoreDecoupled( // val widQueue = Queue(io.initiate, queueDepth, pipe = (queueDepth == 1)) // val rdQueue = Queue(io.initiate, queueDepth, pipe = (queueDepth == 1)) - // Output logic - // ------------ - - io.writeback.valid := (state === TensorState.finish) - io.writeback.bits.wid := warpReg - io.writeback.bits.last := false.B // TODO - // FIXME io.respA.ready := true.B io.respB.ready := true.B