tensor: Add access logic for C from regfile

This commit is contained in:
Hansung Kim
2024-10-25 15:22:52 -07:00
parent fc5b864b86
commit 43e064fe82

View File

@@ -51,8 +51,10 @@ class TensorCoreDecoupled(
})
val respA = Flipped(Decoupled(new TensorMemResp(sourceWidth, dataWidth)))
val respB = Flipped(Decoupled(new TensorMemResp(sourceWidth, dataWidth)))
val respC = Input(UInt(dataWidth.W))
val reqA = Decoupled(new TensorMemReq(sourceWidth))
val reqB = Decoupled(new TensorMemReq(sourceWidth))
val reqC = Output(Valid(UInt(numFPRegBits.W)))
})
dontTouch(io)
@@ -131,9 +133,7 @@ class TensorCoreDecoupled(
val stateA = RegInit(stateInit)
val stateB = RegInit(stateInit)
dontTouch(stateA)
dontTouch(stateA.index)
dontTouch(stateB)
dontTouch(stateB.index)
io.initiate.ready := (state === AccessorState.idle)
when (io.initiate.fire) {
@@ -262,6 +262,48 @@ class TensorCoreDecoupled(
}
}
// C access from regfile
//
// since regfile is fixed-latency, respC valid should be determined at the
// request sending side.
val respCValid = RegInit(false.B)
// regfile latency is 1 cycle; don't need a deep response queue
val respQueueCDepth = 1
val respQueueC = Module(new Queue(
chiselTypeOf(io.respC), respQueueCDepth
))
respQueueC.io.enq.valid := respCValid
respQueueC.io.enq.bits := io.respC
// serialize every two C responses into one full 4x4 C tile
val fullC = Module(new FillBuffer(
chiselTypeOf(io.respC), 2/*substeps*/
))
fullC.io.enq.valid := respQueueC.io.deq.valid
fullC.io.enq.bits := respQueueC.io.deq.bits
respQueueC.io.deq.ready := fullC.io.enq.ready
// make sure there's space at the response queue to be latched at the next
// cycle
val genReqC = (state === AccessorState.access) && respQueueC.io.enq.ready
// 1-cycle delay
respCValid := genReqC
io.reqC.valid := genReqC
io.reqC.bits := 5.U // FIXME
// set/index state of the C accumulator value that will be latched ath the
// next cycle.
val stateRegC = RegInit(stateInit)
when (genReqC) {
when (stateRegC.index === lastIndex.U) {
stateRegC.set := stateRegC.set + 1.U
}
stateRegC.index := stateRegC.index + 1.U
}
// ===========================================================================
// Execute stage
// ===========================================================================
@@ -349,9 +391,31 @@ class TensorCoreDecoupled(
fullB.io.deq.ready := fullBBuf.io.enq.ready
fullBTag.io.deq.ready := fullBBuf.io.enq.ready
// fullC is instiated at the access stage
val fullCTag = Module(new Queue(
new TensorMemTag, entries = 1, pipe = true
))
fullCTag.io.enq.valid := respQueueB.valid
fullCTag.io.enq.bits := respQueueB.bits.tag
val fullCBuf = Module(new Queue(
new Bundle {
val data = chiselTypeOf(fullC.io.deq.bits)
val tag = new TensorMemTag
}, entries = 1, pipe = true
))
fullCBuf.io.enq.valid := fullC.io.deq.valid
fullCBuf.io.enq.bits.data := fullC.io.deq.bits
fullCBuf.io.enq.bits.tag := fullCTag.io.deq.bits
fullC.io.deq.ready := fullCBuf.io.enq.ready
fullCTag.io.deq.ready := fullCBuf.io.enq.ready
val dpuReady = Wire(Bool())
val dpuFire = Wire(Bool())
val operandsValid = fullABuf.io.deq.valid && fullBBuf.io.deq.valid
val dpuFire = operandsValid && dpuReady
dpuFire := operandsValid && dpuReady
dontTouch(dpuFire)
val setCompute = RegInit(0.U(setBits.W))
val stepCompute = RegInit(0.U(stepBits.W))
@@ -376,11 +440,14 @@ class TensorCoreDecoupled(
}
val operandA = selectOperandA(fullABuf.io.deq.bits.data)
val operandATag = fullABuf.io.deq.bits.tag
// select the correct 2x4 tile from B operand buffer
// select the correct 2x4 tile from B/C operand buffer
val operandB = fullBBuf.io.deq.bits.data(substepCompute)
val operandBTag = fullBBuf.io.deq.bits.tag
val operandC = fullCBuf.io.deq.bits.data(substepCompute)
val operandCTag = fullCBuf.io.deq.bits.tag
dontTouch(operandATag)
dontTouch(operandBTag)
dontTouch(operandCTag)
// Operand buffer logic
//
@@ -397,6 +464,10 @@ class TensorCoreDecoupled(
((stepCompute & shouldDequeueBMask) === shouldDequeueBMask) &&
(substepCompute === 1.U)
fullBBuf.io.deq.ready := dpuFire && shouldDequeueB
// C buf should be synced with B buf
fullCBuf.io.deq.ready := dpuFire && shouldDequeueB
dontTouch(respQueueA)
dontTouch(respQueueB)
dontTouch(shouldDequeueA)
@@ -414,6 +485,10 @@ class TensorCoreDecoupled(
operandATag.set === operandBTag.set,
"A and B operands are pointing to different warps and sets. " ++
"This might indicate memory response coming back out-of-order.")
assert(operandATag.warp === operandCTag.warp &&
operandATag.set === operandCTag.set,
"A and C operands are pointing to different warps and sets. " ++
"This might indicate memory response coming back out-of-order.")
assert(operandATag.set === setCompute,
"Operand arrived from memory is pointing at a different set than the FSM.")
}
@@ -422,7 +497,7 @@ class TensorCoreDecoupled(
// Dot-product unit
//
// 4x2 four-element DPUs summing up to 32 MACs in total
// 4x2 four-element DPUs summing up to 32 FP32 MACs in total
//
val ncSubstep = tilingParams.nc / 2
require(tilingParams.mc * ncSubstep == numLanes,
@@ -444,13 +519,18 @@ class TensorCoreDecoupled(
require(operandBDimensional.length == ncSubstep &&
operandBDimensional(0).length == tilingParams.kc,
"operand width doesn't agree with tiling parameter")
// note operand C is M-major
val operandCDimensional = reshapeByFourWords(operandC)
require(operandCDimensional.length == ncSubstep &&
operandCDimensional(0).length == tilingParams.mc,
"operand width doesn't agree with tiling parameter")
for (m <- 0 until tilingParams.mc) {
for (n <- 0 until ncSubstep) {
dpus(m)(n).io.in.valid := dpuFire
dpus(m)(n).io.in.bits.a := operandADimensional(m)
dpus(m)(n).io.in.bits.b := operandBDimensional(n)
dpus(m)(n).io.in.bits.c := 0.U // FIXME: bogus accum data
dpus(m)(n).io.in.bits.c := operandCDimensional(n)(m)
// dpu ready couples with writeback backpressure
dpus(m)(n).io.stall := !io.writeback.ready
}
@@ -631,8 +711,10 @@ class TensorCoreDecoupledTLImp(outer: TensorCoreDecoupledTL)
tensor.io.respB.bits.source := tlOutB.d.bits.source
tlOutB.d.ready := tensor.io.respB.ready
tensor.io.respC := 42.U // FIXME bogus
tensor.io.initiate.valid := io.start
tensor.io.initiate.bits.wid := 0.U // TODO
tensor.io.initiate.bits.wid := 0.U // FIXME bogus
tensor.io.writeback.ready := true.B
io.finished := tensor.io.writeback.valid && tensor.io.writeback.bits.last