tensor: Address gen for block-wise contiguous layout

Necessary to meet 32B-alignment requirement for SMEM.
This commit is contained in:
Hansung Kim
2024-10-22 17:09:21 -07:00
parent 54ce0f7c34
commit b566748bcb

View File

@@ -200,22 +200,30 @@ class TensorCoreDecoupled(
// note that both A and B are K-major to facilitate bank conflict-free SMEM // note that both A and B are K-major to facilitate bank conflict-free SMEM
// accesses, so that below code applies to both. // accesses, so that below code applies to both.
// //
// (row,col) coordinate of the compute tile // a "block" is the 4*8 byte-sized contiguous memory that can be read in
val tileRow = index // one SMEM request. The A and B matrix is assumed to be stored in
val tileCol = set // block-wise "index"-major order (M-major for A, N-major for B)
// (row,col) coordinate of the starting element of the compute tile val blockRow = set
val elemRow = index << 1 val blockCol = index
val elemCol = tileCol << log2Ceil(tilingParams.kc) val blockIndex = (blockRow << indexBits) + blockCol
val rowStride = tilingParams.k * wordSize val blockSize = numLanes * wordSize
val rowStrideBits = log2Ceil(rowStride) val blockSizeBits = log2Ceil(blockSize)
val wordStrideBits = log2Ceil(wordSize) val byteOffset = blockIndex << blockSizeBits
val tileOffset = (elemRow << rowStrideBits) + (elemCol << wordStrideBits) base + byteOffset
base + tileOffset // address generation for byte-wise K-major A and B layout
// val elemRow = blockRow << 1
// val elemCol = blockCol << log2Ceil(tilingParams.kc)
// val rowStride = tilingParams.k * wordSize
// val rowStrideBits = log2Ceil(rowStride)
// val wordStrideBits = log2Ceil(wordSize)
// val tileOffset = (elemRow << rowStrideBits) + (elemCol << wordStrideBits)
// base + tileOffset
} }
// FIXME: bogus base address // FIXME: bogus base address
val addressA = addressGen(0.U, tagA.set, tagA.index) val addressA = addressGen(0.U, tagA.set, tagA.index)
// SMEM 256KB, 8 banks: 0x8000B(32KB) per bank
val addressB = addressGen(0x400.U, tagB.set, tagB.index) val addressB = addressGen(0x400.U, tagB.set, tagB.index)
val lastReqA = (tagA.set === lastSet.U) && (tagA.index === lastIndex.U) val lastReqA = (tagA.set === lastSet.U) && (tagA.index === lastIndex.U)