tensor: Address gen for block-wise contiguous layout
Necessary to meet 32B-alignment requirement for SMEM.
This commit is contained in:
@@ -200,22 +200,30 @@ class TensorCoreDecoupled(
|
|||||||
// note that both A and B are K-major to facilitate bank conflict-free SMEM
|
// note that both A and B are K-major to facilitate bank conflict-free SMEM
|
||||||
// accesses, so that below code applies to both.
|
// accesses, so that below code applies to both.
|
||||||
//
|
//
|
||||||
// (row,col) coordinate of the compute tile
|
// a "block" is the 4*8 byte-sized contiguous memory that can be read in
|
||||||
val tileRow = index
|
// one SMEM request. The A and B matrix is assumed to be stored in
|
||||||
val tileCol = set
|
// block-wise "index"-major order (M-major for A, N-major for B)
|
||||||
// (row,col) coordinate of the starting element of the compute tile
|
val blockRow = set
|
||||||
val elemRow = index << 1
|
val blockCol = index
|
||||||
val elemCol = tileCol << log2Ceil(tilingParams.kc)
|
val blockIndex = (blockRow << indexBits) + blockCol
|
||||||
val rowStride = tilingParams.k * wordSize
|
val blockSize = numLanes * wordSize
|
||||||
val rowStrideBits = log2Ceil(rowStride)
|
val blockSizeBits = log2Ceil(blockSize)
|
||||||
val wordStrideBits = log2Ceil(wordSize)
|
val byteOffset = blockIndex << blockSizeBits
|
||||||
val tileOffset = (elemRow << rowStrideBits) + (elemCol << wordStrideBits)
|
base + byteOffset
|
||||||
|
|
||||||
base + tileOffset
|
// address generation for byte-wise K-major A and B layout
|
||||||
|
// val elemRow = blockRow << 1
|
||||||
|
// val elemCol = blockCol << log2Ceil(tilingParams.kc)
|
||||||
|
// val rowStride = tilingParams.k * wordSize
|
||||||
|
// val rowStrideBits = log2Ceil(rowStride)
|
||||||
|
// val wordStrideBits = log2Ceil(wordSize)
|
||||||
|
// val tileOffset = (elemRow << rowStrideBits) + (elemCol << wordStrideBits)
|
||||||
|
// base + tileOffset
|
||||||
}
|
}
|
||||||
|
|
||||||
// FIXME: bogus base address
|
// FIXME: bogus base address
|
||||||
val addressA = addressGen(0.U, tagA.set, tagA.index)
|
val addressA = addressGen(0.U, tagA.set, tagA.index)
|
||||||
|
// SMEM 256KB, 8 banks: 0x8000B(32KB) per bank
|
||||||
val addressB = addressGen(0x400.U, tagB.set, tagB.index)
|
val addressB = addressGen(0x400.U, tagB.set, tagB.index)
|
||||||
|
|
||||||
val lastReqA = (tagA.set === lastSet.U) && (tagA.index === lastIndex.U)
|
val lastReqA = (tagA.set === lastSet.U) && (tagA.index === lastIndex.U)
|
||||||
|
|||||||
Reference in New Issue
Block a user