diff --git a/src/main/scala/radiance/core/TensorCoreDecoupled.scala b/src/main/scala/radiance/core/TensorCoreDecoupled.scala index ff7f94c..ae763c6 100644 --- a/src/main/scala/radiance/core/TensorCoreDecoupled.scala +++ b/src/main/scala/radiance/core/TensorCoreDecoupled.scala @@ -85,7 +85,7 @@ class TensorCoreDecoupled( val busy = RegInit(false.B) // Holds the warp id the core is currently working on. Note that we only // support one outstanding warp request - val warpReg = RegInit(0.U(numWarpBits.W)) + val warpAccess = RegInit(0.U(numWarpBits.W)) // sets: k iteration val numSets = (tilingParams.k / tilingParams.kc) @@ -101,7 +101,7 @@ class TensorCoreDecoupled( when (io.initiate.fire) { val wid = io.initiate.bits.wid busy := true.B - warpReg := wid + warpAccess := wid when(io.writeback.fire) { assert( io.writeback.bits.wid =/= wid, @@ -170,28 +170,35 @@ class TensorCoreDecoupled( val indexBits = log2Ceil(numIndices) val lastIndex = (1 << indexBits) - 1 + class State extends Bundle { + val set = UInt(setBits.W) + val index = UInt(indexBits.W) + } class TensorMemTag extends Bundle { + val warp = UInt(numWarpBits.W) val set = UInt(setBits.W) val index = UInt(indexBits.W) } - val tagInit = Wire(new TensorMemTag) - tagInit.set := 0.U - tagInit.index := 0.U - val tagA = RegInit(tagInit) - val tagB = RegInit(tagInit) + val stateInit = Wire(new State) + stateInit.set := 0.U + stateInit.index := 0.U + val stateA = RegInit(stateInit) + val stateB = RegInit(stateInit) + dontTouch(stateA) + dontTouch(stateB) when (io.reqA.fire) { - when (tagA.index === lastIndex.U) { - tagA.set := tagA.set + 1.U + when (stateA.index === lastIndex.U) { + stateA.set := stateA.set + 1.U } - tagA.index := tagA.index + 1.U + stateA.index := stateA.index + 1.U } when (io.reqB.fire) { - when (tagB.index === lastIndex.U) { - tagB.set := tagB.set + 1.U + when (stateB.index === lastIndex.U) { + stateB.set := stateB.set + 1.U } - tagB.index := tagB.index + 1.U + stateB.index := stateB.index + 1.U } // Address generation @@ -222,12 +229,12 @@ class TensorCoreDecoupled( } // FIXME: bogus base address - val addressA = addressGen(0.U, tagA.set, tagA.index) + val addressA = addressGen(0.U, stateA.set, stateA.index) // SMEM 256KB, 8 banks: 0x8000B(32KB) per bank - val addressB = addressGen(0x400.U, tagB.set, tagB.index) + val addressB = addressGen(0x8000.U, stateB.set, stateB.index) - val lastReqA = (tagA.set === lastSet.U) && (tagA.index === lastIndex.U) - val lastReqB = (tagB.set === lastSet.U) && (tagB.index === lastIndex.U) + val lastReqA = (stateA.set === lastSet.U) && (stateA.index === lastIndex.U) + val lastReqB = (stateB.set === lastSet.U) && (stateB.index === lastIndex.U) val doneReqA = RegInit(false.B) val doneReqB = RegInit(false.B) when (lastReqA && io.reqA.fire) { doneReqA := true.B } @@ -237,16 +244,25 @@ class TensorCoreDecoupled( when (state === AccessorState.finish) { doneReqA := false.B doneReqB := false.B - tagA.set := 0.U - tagA.index := 0.U - tagB.set := 0.U - tagB.index := 0.U + stateA.set := 0.U + stateA.index := 0.U + stateB.set := 0.U + stateB.index := 0.U } allReqsDone := doneReqA && doneReqB // Request generation // + val tagA = Wire(new TensorMemTag) + tagA.warp := warpAccess + tagA.set := stateA.set + tagA.index := stateA.index + val tagB = Wire(new TensorMemTag) + tagB.warp := warpAccess + tagB.set := stateB.set + tagB.index := stateB.index + val respATagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth))) val respBTagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth))) Seq((io.reqA, (io.respA, respATagged)), @@ -422,9 +438,12 @@ class TensorCoreDecoupled( def assertAligned = { val stepMask = (1 << numTilesMBits).U when (dpuFire) { - assert(fullABuf.io.deq.bits.tag.set === fullBBuf.io.deq.bits.tag.set, - "A and B operands are pointing to different sets. " ++ + assert(operandATag.warp === operandBTag.warp && + operandATag.set === operandBTag.set, + "A and B operands are pointing to different warps and sets. " ++ "This might indicate memory response coming back out-of-order.") + assert(operandATag.set === setCompute, + "Operand arrived from memory is pointing at a different set than the FSM.") } } assertAligned @@ -492,6 +511,7 @@ class TensorCoreDecoupled( // These queues hold metadata needed for writeback in sync with the DPU. class TensorComputeTag extends Bundle { + val warp = UInt(numWarpBits.W) val set = UInt(setBits.W) val step = UInt(stepBits.W) val substep = UInt(1.W) @@ -500,6 +520,7 @@ class TensorCoreDecoupled( val queueDepth = 5 // needs to be at least the DPU latency val tagQueue = Module(new Queue(new TensorComputeTag, queueDepth)) tagQueue.io.enq.valid := dpuFire + tagQueue.io.enq.bits.warp := operandATag.warp tagQueue.io.enq.bits.set := setCompute tagQueue.io.enq.bits.step := stepCompute tagQueue.io.enq.bits.substep := substepCompute @@ -518,12 +539,12 @@ class TensorCoreDecoupled( (step << 1/*2 substeps*/) + substep } + val warpWriteback = tagQueue.io.deq.bits.warp val setWriteback = tagQueue.io.deq.bits.set val stepWriteback = tagQueue.io.deq.bits.step val substepWriteback = tagQueue.io.deq.bits.substep io.writeback.valid := dpuValid - // TODO: decouple wid from frontend - io.writeback.bits.wid := warpReg + io.writeback.bits.wid := warpWriteback io.writeback.bits.rd := rdGen(stepWriteback, substepWriteback) io.writeback.bits.last := setDone(setWriteback) && stepDone(stepWriteback) && (substepWriteback === 1.U) @@ -685,7 +706,7 @@ class TensorCoreDecoupledTwoTLRAM(implicit p: Parameters) extends LazyModule { ramA.node := stutter := xbar.node ramB.node := xbar.node - val fuzz = true + val fuzz = false lazy val module = new Impl class Impl extends LazyModuleImp(this) with UnitTestModule {