diff --git a/src/main/scala/radiance/core/TensorCoreDecoupled.scala b/src/main/scala/radiance/core/TensorCoreDecoupled.scala index 040106c..d91f49f 100644 --- a/src/main/scala/radiance/core/TensorCoreDecoupled.scala +++ b/src/main/scala/radiance/core/TensorCoreDecoupled.scala @@ -205,24 +205,11 @@ class TensorCoreDecoupled( // SMEM 256KB, 8 banks: 0x8000B(32KB) per bank val addressB = addressGen(0x8000.U, stateB.set, stateB.index) - val lastReqA = (stateA.set === lastSet.U) && (stateA.index === lastIndex.U) - val lastReqB = (stateB.set === lastSet.U) && (stateB.index === lastIndex.U) val doneReqA = RegInit(false.B) val doneReqB = RegInit(false.B) - when (lastReqA && io.reqA.fire) { doneReqA := true.B } - when (lastReqB && io.reqB.fire) { doneReqB := true.B } + val doneReqC = RegInit(false.B) val genReqA = (state === AccessorState.access) && !doneReqA val genReqB = (state === AccessorState.access) && !doneReqB - when (state === AccessorState.finish) { - doneReqA := false.B - doneReqB := false.B - stateA.set := 0.U - stateA.index := 0.U - stateB.set := 0.U - stateB.index := 0.U - } - - allReqsDone := doneReqA && doneReqB // Request generation // @@ -288,7 +275,7 @@ class TensorCoreDecoupled( respQueueC.io.count <= Mux(respQueueC.io.deq.fire, (respQueueC.entries - 1).U, (respQueueC.entries - 2).U) - val genReqC = (state === AccessorState.access) && hasSpace + val genReqC = (state === AccessorState.access) && hasSpace && !doneReqC // 1-cycle delay respCValid := genReqC @@ -298,27 +285,26 @@ class TensorCoreDecoupled( // thread (step << 1/*2 substeps*/) + substep } - io.reqC.valid := genReqC io.reqC.bits := 5.U // FIXME // set/index state of the C accumulator value that will be latched ath the // next cycle. - val stateRegC = RegInit(stateInit) + val stateC = RegInit(stateInit) when (genReqC) { - when (stateRegC.index === lastIndex.U) { - stateRegC.set := stateRegC.set + 1.U + when (stateC.index === lastIndex.U) { + stateC.set := stateC.set + 1.U } - stateRegC.index := stateRegC.index + 1.U + stateC.index := stateC.index + 1.U } - dontTouch(stateRegC) + dontTouch(stateC) // queue the regfile response to buffers // these strictly belong to the execute stage respQueueC.io.enq.valid := respCValid respQueueC.io.enq.bits.tag.warp := warpAccess - respQueueC.io.enq.bits.tag.set := stateRegC.set - respQueueC.io.enq.bits.tag.index := stateRegC.index + respQueueC.io.enq.bits.tag.set := stateC.set + respQueueC.io.enq.bits.tag.index := stateC.index respQueueC.io.enq.bits.data := io.respC // serialize every two C responses into one full 4x4 C tile @@ -334,6 +320,27 @@ class TensorCoreDecoupled( fullCTag.io.enq.valid := respQueueC.io.deq.valid fullCTag.io.enq.bits := respQueueC.io.deq.bits.tag + // finalize state when everything has been accessed + val lastReqA = (stateA.set === lastSet.U) && (stateA.index === lastIndex.U) + val lastReqB = (stateB.set === lastSet.U) && (stateB.index === lastIndex.U) + val lastReqC = (stateC.set === lastSet.U) && (stateC.index === lastIndex.U) + when (lastReqA && io.reqA.fire) { doneReqA := true.B } + when (lastReqB && io.reqB.fire) { doneReqB := true.B } + when (lastReqC && genReqC) { doneReqC := true.B } + when (state === AccessorState.finish) { + doneReqA := false.B + doneReqB := false.B + doneReqC := false.B + stateA.set := 0.U + stateA.index := 0.U + stateB.set := 0.U + stateB.index := 0.U + stateC.set := 0.U + stateC.index := 0.U + } + + allReqsDone := doneReqA && doneReqB && doneReqC + // =========================================================================== // Execute stage // ===========================================================================