tensor: SMEM address generation
This commit is contained in:
@@ -159,6 +159,48 @@ class TensorCoreDecoupled(
|
||||
tag.step := stepAccess
|
||||
tag.substep := substepAccess
|
||||
|
||||
// @cleanup: generalize in terms of M/N/K-majorness?
|
||||
def addressGen(baseA: UInt, baseB: UInt, set: UInt, step: UInt, substep: UInt)
|
||||
: (UInt/*A*/, UInt/*B*/) = {
|
||||
// note that step iterates along N first, then M
|
||||
val numComputeTilesM = tilingParams.m / tilingParams.mc
|
||||
val numComputeTilesN = tilingParams.n / tilingParams.nc
|
||||
val tileM = step % numComputeTilesM.U
|
||||
val tileN = step / numComputeTilesM.U
|
||||
val mcSubstep = tilingParams.mc / 2
|
||||
val ncSubstep = tilingParams.nc / 2
|
||||
|
||||
// note that both A and B are K-major to facilitate bank conflict-free SMEM
|
||||
// accesses
|
||||
//
|
||||
// (row,col) coordinate of the compute tile
|
||||
val tileRowA = tileM // M
|
||||
val tileColA = set // K
|
||||
val tileRowB = tileN // N
|
||||
val tileColB = set // K
|
||||
// (row,col) coordinate of the starting element of the compute tile
|
||||
val elemRowA = (tileRowA << log2Ceil(tilingParams.mc)) +
|
||||
(substep << log2Ceil(mcSubstep))
|
||||
val elemColA = tileColA << log2Ceil(tilingParams.kc)
|
||||
val elemRowB = tileRowB << log2Ceil(tilingParams.nc)
|
||||
(substep << log2Ceil(ncSubstep))
|
||||
val elemColB = tileColB << log2Ceil(tilingParams.kc)
|
||||
val rowStrideA = wordSize * tilingParams.k
|
||||
val rowStrideABits = log2Ceil(rowStrideA)
|
||||
val rowStrideB = wordSize * tilingParams.k
|
||||
val rowStrideBBits = log2Ceil(rowStrideB)
|
||||
val wordStrideBits = log2Ceil(wordSize)
|
||||
|
||||
val tileOffsetA = (elemRowA << rowStrideABits) + (elemColA << wordStrideBits)
|
||||
val tileOffsetB = (elemRowB << rowStrideBBits) + (elemColB << wordStrideBits)
|
||||
|
||||
(baseA + tileOffsetA, baseB + tileOffsetB)
|
||||
}
|
||||
|
||||
// FIXME: bogus base address
|
||||
val (addressA, addressB) =
|
||||
addressGen(0.U, 0.U, setAccess, stepAccess, substepAccess)
|
||||
|
||||
val respATagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth)))
|
||||
val respBTagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth)))
|
||||
Seq((io.reqA, (io.respA, respATagged)),
|
||||
@@ -172,9 +214,7 @@ class TensorCoreDecoupled(
|
||||
sourceGen.io.gen := req.fire
|
||||
sourceGen.io.meta := tag
|
||||
req.valid := genReq
|
||||
// FIXME: bogus address
|
||||
// req.bits.address := (if (i == 0) 0.U else 0x100.U) // avoids bank conflict for A and B
|
||||
req.bits.address := 0.U
|
||||
req.bits.address := (if (i == 0) addressA else addressB)
|
||||
req.bits.source := sourceGen.io.id.bits
|
||||
|
||||
sourceGen.io.reclaim.valid := resp.fire
|
||||
@@ -366,7 +406,7 @@ class TensorCoreDecoupled(
|
||||
// ----------------
|
||||
// These queues hold metadata needed for writeback in sync with the DPU.
|
||||
|
||||
val queueDepth = 4 // needs to be at least the DPU latency
|
||||
val queueDepth = 6 // needs to be at least the DPU latency
|
||||
val tagQueue = Module(new Queue(
|
||||
chiselTypeOf(operandATag), queueDepth
|
||||
))
|
||||
@@ -397,7 +437,8 @@ class TensorCoreDecoupled(
|
||||
// TODO: decouple wid from frontend
|
||||
io.writeback.bits.wid := warpReg
|
||||
io.writeback.bits.rd := rdGen(stepWriteback, substepWriteback)
|
||||
io.writeback.bits.last := setDone(setWriteback) && stepDone(stepWriteback)
|
||||
io.writeback.bits.last := setDone(setWriteback) && stepDone(stepWriteback) &&
|
||||
(substepWriteback === 1.U)
|
||||
|
||||
// State transition
|
||||
// ----------------
|
||||
|
||||
Reference in New Issue
Block a user