tensor: SMEM address generation

This commit is contained in:
Hansung Kim
2024-10-17 16:36:18 -07:00
parent 2741af0b2b
commit a2519da58f

View File

@@ -159,6 +159,48 @@ class TensorCoreDecoupled(
tag.step := stepAccess
tag.substep := substepAccess
// @cleanup: generalize in terms of M/N/K-majorness?
def addressGen(baseA: UInt, baseB: UInt, set: UInt, step: UInt, substep: UInt)
: (UInt/*A*/, UInt/*B*/) = {
// note that step iterates along N first, then M
val numComputeTilesM = tilingParams.m / tilingParams.mc
val numComputeTilesN = tilingParams.n / tilingParams.nc
val tileM = step % numComputeTilesM.U
val tileN = step / numComputeTilesM.U
val mcSubstep = tilingParams.mc / 2
val ncSubstep = tilingParams.nc / 2
// note that both A and B are K-major to facilitate bank conflict-free SMEM
// accesses
//
// (row,col) coordinate of the compute tile
val tileRowA = tileM // M
val tileColA = set // K
val tileRowB = tileN // N
val tileColB = set // K
// (row,col) coordinate of the starting element of the compute tile
val elemRowA = (tileRowA << log2Ceil(tilingParams.mc)) +
(substep << log2Ceil(mcSubstep))
val elemColA = tileColA << log2Ceil(tilingParams.kc)
val elemRowB = tileRowB << log2Ceil(tilingParams.nc)
(substep << log2Ceil(ncSubstep))
val elemColB = tileColB << log2Ceil(tilingParams.kc)
val rowStrideA = wordSize * tilingParams.k
val rowStrideABits = log2Ceil(rowStrideA)
val rowStrideB = wordSize * tilingParams.k
val rowStrideBBits = log2Ceil(rowStrideB)
val wordStrideBits = log2Ceil(wordSize)
val tileOffsetA = (elemRowA << rowStrideABits) + (elemColA << wordStrideBits)
val tileOffsetB = (elemRowB << rowStrideBBits) + (elemColB << wordStrideBits)
(baseA + tileOffsetA, baseB + tileOffsetB)
}
// FIXME: bogus base address
val (addressA, addressB) =
addressGen(0.U, 0.U, setAccess, stepAccess, substepAccess)
val respATagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth)))
val respBTagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth)))
Seq((io.reqA, (io.respA, respATagged)),
@@ -172,9 +214,7 @@ class TensorCoreDecoupled(
sourceGen.io.gen := req.fire
sourceGen.io.meta := tag
req.valid := genReq
// FIXME: bogus address
// req.bits.address := (if (i == 0) 0.U else 0x100.U) // avoids bank conflict for A and B
req.bits.address := 0.U
req.bits.address := (if (i == 0) addressA else addressB)
req.bits.source := sourceGen.io.id.bits
sourceGen.io.reclaim.valid := resp.fire
@@ -366,7 +406,7 @@ class TensorCoreDecoupled(
// ----------------
// These queues hold metadata needed for writeback in sync with the DPU.
val queueDepth = 4 // needs to be at least the DPU latency
val queueDepth = 6 // needs to be at least the DPU latency
val tagQueue = Module(new Queue(
chiselTypeOf(operandATag), queueDepth
))
@@ -397,7 +437,8 @@ class TensorCoreDecoupled(
// TODO: decouple wid from frontend
io.writeback.bits.wid := warpReg
io.writeback.bits.rd := rdGen(stepWriteback, substepWriteback)
io.writeback.bits.last := setDone(setWriteback) && stepDone(stepWriteback)
io.writeback.bits.last := setDone(setWriteback) && stepDone(stepWriteback) &&
(substepWriteback === 1.U)
// State transition
// ----------------