diff --git a/src/main/scala/radiance/core/TensorCoreDecoupled.scala b/src/main/scala/radiance/core/TensorCoreDecoupled.scala index 3d00c35..f7c8547 100644 --- a/src/main/scala/radiance/core/TensorCoreDecoupled.scala +++ b/src/main/scala/radiance/core/TensorCoreDecoupled.scala @@ -159,6 +159,48 @@ class TensorCoreDecoupled( tag.step := stepAccess tag.substep := substepAccess + // @cleanup: generalize in terms of M/N/K-majorness? + def addressGen(baseA: UInt, baseB: UInt, set: UInt, step: UInt, substep: UInt) + : (UInt/*A*/, UInt/*B*/) = { + // note that step iterates along N first, then M + val numComputeTilesM = tilingParams.m / tilingParams.mc + val numComputeTilesN = tilingParams.n / tilingParams.nc + val tileM = step % numComputeTilesM.U + val tileN = step / numComputeTilesM.U + val mcSubstep = tilingParams.mc / 2 + val ncSubstep = tilingParams.nc / 2 + + // note that both A and B are K-major to facilitate bank conflict-free SMEM + // accesses + // + // (row,col) coordinate of the compute tile + val tileRowA = tileM // M + val tileColA = set // K + val tileRowB = tileN // N + val tileColB = set // K + // (row,col) coordinate of the starting element of the compute tile + val elemRowA = (tileRowA << log2Ceil(tilingParams.mc)) + + (substep << log2Ceil(mcSubstep)) + val elemColA = tileColA << log2Ceil(tilingParams.kc) + val elemRowB = tileRowB << log2Ceil(tilingParams.nc) + (substep << log2Ceil(ncSubstep)) + val elemColB = tileColB << log2Ceil(tilingParams.kc) + val rowStrideA = wordSize * tilingParams.k + val rowStrideABits = log2Ceil(rowStrideA) + val rowStrideB = wordSize * tilingParams.k + val rowStrideBBits = log2Ceil(rowStrideB) + val wordStrideBits = log2Ceil(wordSize) + + val tileOffsetA = (elemRowA << rowStrideABits) + (elemColA << wordStrideBits) + val tileOffsetB = (elemRowB << rowStrideBBits) + (elemColB << wordStrideBits) + + (baseA + tileOffsetA, baseB + tileOffsetB) + } + + // FIXME: bogus base address + val (addressA, addressB) = + addressGen(0.U, 0.U, setAccess, stepAccess, substepAccess) + val respATagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth))) val respBTagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth))) Seq((io.reqA, (io.respA, respATagged)), @@ -172,9 +214,7 @@ class TensorCoreDecoupled( sourceGen.io.gen := req.fire sourceGen.io.meta := tag req.valid := genReq - // FIXME: bogus address - // req.bits.address := (if (i == 0) 0.U else 0x100.U) // avoids bank conflict for A and B - req.bits.address := 0.U + req.bits.address := (if (i == 0) addressA else addressB) req.bits.source := sourceGen.io.id.bits sourceGen.io.reclaim.valid := resp.fire @@ -366,7 +406,7 @@ class TensorCoreDecoupled( // ---------------- // These queues hold metadata needed for writeback in sync with the DPU. - val queueDepth = 4 // needs to be at least the DPU latency + val queueDepth = 6 // needs to be at least the DPU latency val tagQueue = Module(new Queue( chiselTypeOf(operandATag), queueDepth )) @@ -397,7 +437,8 @@ class TensorCoreDecoupled( // TODO: decouple wid from frontend io.writeback.bits.wid := warpReg io.writeback.bits.rd := rdGen(stepWriteback, substepWriteback) - io.writeback.bits.last := setDone(setWriteback) && stepDone(stepWriteback) + io.writeback.bits.last := setDone(setWriteback) && stepDone(stepWriteback) && + (substepWriteback === 1.U) // State transition // ----------------