diff --git a/src/main/scala/radiance/core/TensorCoreDecoupled.scala b/src/main/scala/radiance/core/TensorCoreDecoupled.scala index f7c8547..897edb2 100644 --- a/src/main/scala/radiance/core/TensorCoreDecoupled.scala +++ b/src/main/scala/radiance/core/TensorCoreDecoupled.scala @@ -145,8 +145,6 @@ class TensorCoreDecoupled( // Memory traffic generation // ------------------------- // - val genReq = (state === TensorState.run) - class TensorMemTag extends Bundle { val set = UInt(setBits.W) val step = UInt(stepBits.W) @@ -159,16 +157,14 @@ class TensorCoreDecoupled( tag.step := stepAccess tag.substep := substepAccess + val numTilesM = tilingParams.m / tilingParams.mc + val numTilesN = tilingParams.n / tilingParams.nc // @cleanup: generalize in terms of M/N/K-majorness? def addressGen(baseA: UInt, baseB: UInt, set: UInt, step: UInt, substep: UInt) : (UInt/*A*/, UInt/*B*/) = { // note that step iterates along N first, then M - val numComputeTilesM = tilingParams.m / tilingParams.mc - val numComputeTilesN = tilingParams.n / tilingParams.nc - val tileM = step % numComputeTilesM.U - val tileN = step / numComputeTilesM.U - val mcSubstep = tilingParams.mc / 2 - val ncSubstep = tilingParams.nc / 2 + val tileM = step % numTilesM.U + val tileN = step / numTilesM.U // note that both A and B are K-major to facilitate bank conflict-free SMEM // accesses @@ -180,11 +176,11 @@ class TensorCoreDecoupled( val tileColB = set // K // (row,col) coordinate of the starting element of the compute tile val elemRowA = (tileRowA << log2Ceil(tilingParams.mc)) + - (substep << log2Ceil(mcSubstep)) - val elemColA = tileColA << log2Ceil(tilingParams.kc) - val elemRowB = tileRowB << log2Ceil(tilingParams.nc) - (substep << log2Ceil(ncSubstep)) - val elemColB = tileColB << log2Ceil(tilingParams.kc) + (substep << log2Ceil(tilingParams.mc / 2)) + val elemColA = tileColA << log2Ceil(tilingParams.kc) + val elemRowB = (tileRowB << log2Ceil(tilingParams.nc)) + + (substep << log2Ceil(tilingParams.nc / 2)) + val elemColB = tileColB << log2Ceil(tilingParams.kc) val rowStrideA = wordSize * tilingParams.k val rowStrideABits = log2Ceil(rowStrideA) val rowStrideB = wordSize * tilingParams.k @@ -201,6 +197,13 @@ class TensorCoreDecoupled( val (addressA, addressB) = addressGen(0.U, 0.U, setAccess, stepAccess, substepAccess) + val genReqA = (state === TensorState.run) + val numTilesMBits = log2Ceil(numTilesM) + // generate B request at every 4 steps. B achieves reuse through outer + // product so it doesn't require access at every step + val shouldFireB = (stepAccess & ((1 << numTilesMBits) - 1).U) === 0.U + val genReqB = (state === TensorState.run) && shouldFireB + val respATagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth))) val respBTagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth))) Seq((io.reqA, (io.respA, respATagged)), @@ -213,7 +216,7 @@ class TensorCoreDecoupled( sourceGen.io.gen := req.fire sourceGen.io.meta := tag - req.valid := genReq + req.valid := (if (i == 0) genReqA else genReqB) req.bits.address := (if (i == 0) addressA else addressB) req.bits.source := sourceGen.io.id.bits @@ -228,23 +231,27 @@ class TensorCoreDecoupled( } } - // only advance to the next step if we fired mem requests for both A and B - // TODO: @perf: too strict? should be able to have A and B progress - // separately - val firedABReg = RegInit(VecInit(false.B, false.B)) - val firedABNow = VecInit((Seq(io.reqA, io.reqB) zip firedABReg).map { - case (req, fired) => { when (req.fire) { fired := true.B } } - req.fire - }) - val firedAB = (firedABNow.asUInt | firedABReg.asUInt) - val nextSubstepAccess = firedAB.andR + // only advance to the next step if we fired mem requests for both A and B. + // also consider that B doesn't have to be fired every time due to reuse. + // @perf: too strict? should be able to have A and B progress separately + val firedAReg = RegInit(false.B) + val firedBReg = RegInit(false.B) + when (io.reqA.fire) { firedAReg := true.B } + when (io.reqB.fire) { firedBReg := true.B } + val firedANow = io.reqA.fire + val firedBNow = io.reqB.fire + val firedA = firedAReg || firedANow + val firedB = firedBReg || firedBNow + val nextSubstepAccess = firedA && (!shouldFireB || firedB) val nextStepAccess = nextSubstepAccess && (substepAccess === 1.U) // clear out firedABReg every substep when (nextSubstepAccess) { - firedABReg := Seq(false.B, false.B) + firedAReg := false.B + firedBReg := false.B substepAccess := substepAccess + 1.U } require(substepAccess.widthOption.get == 1, "there should be only two substeps") + dontTouch(shouldFireB) // Execute stage // ------------- @@ -327,18 +334,26 @@ class TensorCoreDecoupled( respQueueA.ready := MuxCase(false.B, Seq((substepExecute === 0.U) -> halfAQueue.io.enq.ready, (substepExecute === 1.U) -> fullAQueue.io.enq.ready)) - respQueueB.ready := dpuFire + // Hold B tile at respQueueB for multiple steps for reuse, only dequeue when + // we fully iterated a column (M-dimension). + val shouldDequeueBMask = ((1 << numTilesMBits) - 1).U + val shouldDequeueB = (stepExecute & shouldDequeueBMask) === shouldDequeueBMask + respQueueB.ready := dpuFire && shouldDequeueB dontTouch(respQueueA) dontTouch(respQueueB) + dontTouch(shouldDequeueB) - // assert that the DPU is computing with operands of the same set/step + // Assert that the DPU is computing with operands of the same set/step. Note + // that the B resp will only have step values multiple of 4 due to reuse. // - // this assumes that memory responses come back in-order. this might be too - // strong an assumption depending on the backing memory + // This check assumes that memory responses come back in-order. Might be too + // strong of an assumption depending on the backing memory. def assertAligned = { + val stepMask = (1 << numTilesMBits).U when (dpuFire) { assert((fullAQueue.io.deq.bits.tag.set === respQueueB.bits.tag.set) && - (fullAQueue.io.deq.bits.tag.step === respQueueB.bits.tag.step), + ((fullAQueue.io.deq.bits.tag.step & stepMask) === + (respQueueB.bits.tag.step & stepMask)), "A and B operands are pointing to different set/steps. " ++ "This might indicate memory response coming back out-of-order.") } @@ -348,26 +363,26 @@ class TensorCoreDecoupled( // Dot-product unit // // 4x2 four-element DPUs summing up to 32 MACs in total - val dpus = Seq.fill(4)(Seq.fill(2)( + val ncSubstep = tilingParams.nc / 2 + val dpus = Seq.fill(tilingParams.mc)(Seq.fill(ncSubstep)( Module(new TensorDotProductUnit(half = false)) )) // operandA is 4x4 in K-major val operandADimensional = operandA.asBools.grouped(wordSizeInBits).map(VecInit(_).asUInt).toSeq - .grouped(4).toSeq - assert(operandADimensional.length == tilingParams.mc && - operandADimensional(0).length == tilingParams.kc, - "operand width doesn't agree with tiling parameter") - // operandB is 2x4, i.e. 4x2 in N-major + .grouped(4/*k-dim*/).toSeq + require(operandADimensional.length == tilingParams.mc && + operandADimensional(0).length == tilingParams.kc, + "operand width doesn't agree with tiling parameter") + // operandB is 2x4 in K-major val operandBDimensional = operandB.asBools.grouped(wordSizeInBits).map(VecInit(_).asUInt).toSeq - .grouped(4).toSeq - val ncSubstep = tilingParams.nc / 2 - assert(tilingParams.mc * ncSubstep == numLanes, - "substep tile size doesn't match writeback throughput") - assert(operandBDimensional.length == ncSubstep && - operandBDimensional(0).length == tilingParams.kc, - "operand width doesn't agree with tiling parameter") + .grouped(4/*k-dim*/).toSeq + require(tilingParams.mc * ncSubstep == numLanes, + "substep tile size doesn't match writeback throughput") + require(operandBDimensional.length == ncSubstep && + operandBDimensional(0).length == tilingParams.kc, + "operand width doesn't agree with tiling parameter") for (m <- 0 until tilingParams.mc) { for (n <- 0 until ncSubstep) { @@ -406,10 +421,8 @@ class TensorCoreDecoupled( // ---------------- // These queues hold metadata needed for writeback in sync with the DPU. - val queueDepth = 6 // needs to be at least the DPU latency - val tagQueue = Module(new Queue( - chiselTypeOf(operandATag), queueDepth - )) + val queueDepth = 5 // needs to be at least the DPU latency + val tagQueue = Module(new Queue(chiselTypeOf(operandATag), queueDepth)) tagQueue.io.enq.valid := dpuFire // A and B should have the same tags tagQueue.io.enq.bits := operandATag @@ -573,11 +586,11 @@ class TensorCoreDecoupledTwoTLRAM(implicit p: Parameters) extends LazyModule { val tensor = LazyModule(new TensorCoreDecoupledTL) val xbar = LazyModule(new TLXbar) val ramA = LazyModule(new TLRAM( - address = AddressSet(0x000, 0xfffeff), + address = AddressSet(0x000, 0xfffbff), beatBytes = 32 // @cleanup: hardcoded )) val ramB = LazyModule(new TLRAM( - address = AddressSet(0x100, 0xfffeff), + address = AddressSet(0x400, 0xfffbff), beatBytes = 32 // @cleanup: hardcoded ))