tensor: Stage B as well for full throughput
This commit is contained in:
@@ -300,10 +300,13 @@ class TensorCoreDecoupled(
|
|||||||
// background while the tile is being used for compute. This does come with
|
// background while the tile is being used for compute. This does come with
|
||||||
// capacity overhead.
|
// capacity overhead.
|
||||||
val fullABuf = Module(new Queue(
|
val fullABuf = Module(new Queue(
|
||||||
new TensorMemRespWithTag(dataWidth * 2), entries = 1, pipe = true
|
new Bundle {
|
||||||
|
val data = chiselTypeOf(fullA.io.deq.bits)
|
||||||
|
val tag = new TensorMemTag
|
||||||
|
}, entries = 1, pipe = true
|
||||||
))
|
))
|
||||||
fullABuf.io.enq.valid := fullA.io.deq.valid
|
fullABuf.io.enq.valid := fullA.io.deq.valid
|
||||||
fullABuf.io.enq.bits.data := fullA.io.deq.bits.asUInt
|
fullABuf.io.enq.bits.data := fullA.io.deq.bits
|
||||||
fullABuf.io.enq.bits.tag := fullATag.io.deq.bits
|
fullABuf.io.enq.bits.tag := fullATag.io.deq.bits
|
||||||
fullA.io.deq.ready := fullABuf.io.enq.ready
|
fullA.io.deq.ready := fullABuf.io.enq.ready
|
||||||
fullATag.io.deq.ready := fullABuf.io.enq.ready
|
fullATag.io.deq.ready := fullABuf.io.enq.ready
|
||||||
@@ -322,10 +325,22 @@ class TensorCoreDecoupled(
|
|||||||
fullBTag.io.enq.valid := respQueueB.valid
|
fullBTag.io.enq.valid := respQueueB.valid
|
||||||
fullBTag.io.enq.bits := respQueueB.bits.tag
|
fullBTag.io.enq.bits := respQueueB.bits.tag
|
||||||
|
|
||||||
val operandsValid = fullABuf.io.deq.valid && fullB.io.deq.valid
|
val fullBBuf = Module(new Queue(
|
||||||
|
new Bundle {
|
||||||
|
val data = chiselTypeOf(fullB.io.deq.bits)
|
||||||
|
val tag = new TensorMemTag
|
||||||
|
}, entries = 1, pipe = true
|
||||||
|
))
|
||||||
|
fullBBuf.io.enq.valid := fullB.io.deq.valid
|
||||||
|
fullBBuf.io.enq.bits.data := fullB.io.deq.bits
|
||||||
|
fullBBuf.io.enq.bits.tag := fullBTag.io.deq.bits
|
||||||
|
fullB.io.deq.ready := fullBBuf.io.enq.ready
|
||||||
|
fullBTag.io.deq.ready := fullBBuf.io.enq.ready
|
||||||
|
|
||||||
|
val operandsValid = fullABuf.io.deq.valid && fullBBuf.io.deq.valid
|
||||||
val operandA = fullABuf.io.deq.bits.data
|
val operandA = fullABuf.io.deq.bits.data
|
||||||
val operandATag = fullABuf.io.deq.bits.tag
|
val operandATag = fullABuf.io.deq.bits.tag
|
||||||
val operandB = fullB.io.deq.bits
|
val operandB = fullBBuf.io.deq.bits.data
|
||||||
val dpuReady = Wire(Bool())
|
val dpuReady = Wire(Bool())
|
||||||
val dpuFire = operandsValid && dpuReady
|
val dpuFire = operandsValid && dpuReady
|
||||||
val setCompute = fullABuf.io.deq.bits.tag.set
|
val setCompute = fullABuf.io.deq.bits.tag.set
|
||||||
@@ -335,20 +350,19 @@ class TensorCoreDecoupled(
|
|||||||
substepCompute := substepCompute + 1.U
|
substepCompute := substepCompute + 1.U
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// hold full A until two-cycle compute is done
|
||||||
|
fullABuf.io.deq.ready := dpuFire && (substepCompute === 1.U)
|
||||||
// Hold B tile at respQueueB for multiple steps for reuse, only dequeue when
|
// Hold B tile at respQueueB for multiple steps for reuse, only dequeue when
|
||||||
// we fully iterated a column (M-dimension).
|
// we fully iterated a column (M-dimension).
|
||||||
val shouldDequeueBMask = ((1 << numTilesMBits) - 1).U
|
val shouldDequeueBMask = ((1 << numTilesMBits) - 1).U
|
||||||
val shouldDequeueB =
|
val shouldDequeueB =
|
||||||
((stepCompute & shouldDequeueBMask) === shouldDequeueBMask) &&
|
((stepCompute & shouldDequeueBMask) === shouldDequeueBMask) &&
|
||||||
(substepCompute === 1.U)
|
(substepCompute === 1.U)
|
||||||
fullB.io.deq.ready := dpuFire && shouldDequeueB
|
fullBBuf.io.deq.ready := dpuFire && shouldDequeueB
|
||||||
fullBTag.io.deq.ready := dpuFire && shouldDequeueB
|
|
||||||
dontTouch(respQueueA)
|
dontTouch(respQueueA)
|
||||||
dontTouch(respQueueB)
|
dontTouch(respQueueB)
|
||||||
dontTouch(shouldDequeueB)
|
dontTouch(shouldDequeueB)
|
||||||
|
|
||||||
// hold full A until two-cycle compute is done
|
|
||||||
fullABuf.io.deq.ready := dpuFire && (substepCompute === 1.U)
|
|
||||||
// FIXME: this should be nextStepCompute
|
// FIXME: this should be nextStepCompute
|
||||||
val nextStepExecute = dpuFire && (substepCompute === 1.U)
|
val nextStepExecute = dpuFire && (substepCompute === 1.U)
|
||||||
|
|
||||||
@@ -360,9 +374,9 @@ class TensorCoreDecoupled(
|
|||||||
def assertAligned = {
|
def assertAligned = {
|
||||||
val stepMask = (1 << numTilesMBits).U
|
val stepMask = (1 << numTilesMBits).U
|
||||||
when (dpuFire) {
|
when (dpuFire) {
|
||||||
assert((fullABuf.io.deq.bits.tag.set === fullBTag.io.deq.bits.set) &&
|
assert((fullABuf.io.deq.bits.tag.set === fullBBuf.io.deq.bits.tag.set) &&
|
||||||
((fullABuf.io.deq.bits.tag.step & stepMask) ===
|
((fullABuf.io.deq.bits.tag.step & stepMask) ===
|
||||||
(fullBTag.io.deq.bits.step & stepMask)),
|
(fullBBuf.io.deq.bits.tag.step & stepMask)),
|
||||||
"A and B operands are pointing to different set/steps. " ++
|
"A and B operands are pointing to different set/steps. " ++
|
||||||
"This might indicate memory response coming back out-of-order.")
|
"This might indicate memory response coming back out-of-order.")
|
||||||
}
|
}
|
||||||
@@ -378,15 +392,12 @@ class TensorCoreDecoupled(
|
|||||||
))
|
))
|
||||||
// operandA is 4x4 in K-major
|
// operandA is 4x4 in K-major
|
||||||
val operandADimensional =
|
val operandADimensional =
|
||||||
operandA.asBools.grouped(wordSizeInBits).map(VecInit(_).asUInt).toSeq
|
operandA.asUInt.asBools.grouped(wordSizeInBits).map(VecInit(_).asUInt).toSeq
|
||||||
.grouped(4/*k-dim*/).toSeq
|
.grouped(4/*k-dim*/).toSeq
|
||||||
require(operandADimensional.length == tilingParams.mc &&
|
require(operandADimensional.length == tilingParams.mc &&
|
||||||
operandADimensional(0).length == tilingParams.kc,
|
operandADimensional(0).length == tilingParams.kc,
|
||||||
"operand width doesn't agree with tiling parameter")
|
"operand width doesn't agree with tiling parameter")
|
||||||
// operandB is 2x4 in K-major
|
// select 2x4 subtile out of operandB that is 4x4 in K-major
|
||||||
// val operandBDimensional =
|
|
||||||
// operandB.asBools.grouped(wordSizeInBits).map(VecInit(_).asUInt).toSeq
|
|
||||||
// .grouped(4/*k-dim*/).toSeq
|
|
||||||
val operandBDimensional =
|
val operandBDimensional =
|
||||||
operandB(substepCompute).asBools.grouped(wordSizeInBits).map(VecInit(_).asUInt).toSeq
|
operandB(substepCompute).asBools.grouped(wordSizeInBits).map(VecInit(_).asUInt).toSeq
|
||||||
.grouped(4/*k-dim*/).toSeq
|
.grouped(4/*k-dim*/).toSeq
|
||||||
|
|||||||
Reference in New Issue
Block a user