tensor: Add access logic for C from regfile
This commit is contained in:
@@ -51,8 +51,10 @@ class TensorCoreDecoupled(
|
||||
})
|
||||
val respA = Flipped(Decoupled(new TensorMemResp(sourceWidth, dataWidth)))
|
||||
val respB = Flipped(Decoupled(new TensorMemResp(sourceWidth, dataWidth)))
|
||||
val respC = Input(UInt(dataWidth.W))
|
||||
val reqA = Decoupled(new TensorMemReq(sourceWidth))
|
||||
val reqB = Decoupled(new TensorMemReq(sourceWidth))
|
||||
val reqC = Output(Valid(UInt(numFPRegBits.W)))
|
||||
})
|
||||
dontTouch(io)
|
||||
|
||||
@@ -131,9 +133,7 @@ class TensorCoreDecoupled(
|
||||
val stateA = RegInit(stateInit)
|
||||
val stateB = RegInit(stateInit)
|
||||
dontTouch(stateA)
|
||||
dontTouch(stateA.index)
|
||||
dontTouch(stateB)
|
||||
dontTouch(stateB.index)
|
||||
|
||||
io.initiate.ready := (state === AccessorState.idle)
|
||||
when (io.initiate.fire) {
|
||||
@@ -262,6 +262,48 @@ class TensorCoreDecoupled(
|
||||
}
|
||||
}
|
||||
|
||||
// C access from regfile
|
||||
//
|
||||
|
||||
// since regfile is fixed-latency, respC valid should be determined at the
|
||||
// request sending side.
|
||||
val respCValid = RegInit(false.B)
|
||||
|
||||
// regfile latency is 1 cycle; don't need a deep response queue
|
||||
val respQueueCDepth = 1
|
||||
val respQueueC = Module(new Queue(
|
||||
chiselTypeOf(io.respC), respQueueCDepth
|
||||
))
|
||||
respQueueC.io.enq.valid := respCValid
|
||||
respQueueC.io.enq.bits := io.respC
|
||||
|
||||
// serialize every two C responses into one full 4x4 C tile
|
||||
val fullC = Module(new FillBuffer(
|
||||
chiselTypeOf(io.respC), 2/*substeps*/
|
||||
))
|
||||
fullC.io.enq.valid := respQueueC.io.deq.valid
|
||||
fullC.io.enq.bits := respQueueC.io.deq.bits
|
||||
respQueueC.io.deq.ready := fullC.io.enq.ready
|
||||
|
||||
// make sure there's space at the response queue to be latched at the next
|
||||
// cycle
|
||||
val genReqC = (state === AccessorState.access) && respQueueC.io.enq.ready
|
||||
// 1-cycle delay
|
||||
respCValid := genReqC
|
||||
|
||||
io.reqC.valid := genReqC
|
||||
io.reqC.bits := 5.U // FIXME
|
||||
|
||||
// set/index state of the C accumulator value that will be latched ath the
|
||||
// next cycle.
|
||||
val stateRegC = RegInit(stateInit)
|
||||
when (genReqC) {
|
||||
when (stateRegC.index === lastIndex.U) {
|
||||
stateRegC.set := stateRegC.set + 1.U
|
||||
}
|
||||
stateRegC.index := stateRegC.index + 1.U
|
||||
}
|
||||
|
||||
// ===========================================================================
|
||||
// Execute stage
|
||||
// ===========================================================================
|
||||
@@ -349,9 +391,31 @@ class TensorCoreDecoupled(
|
||||
fullB.io.deq.ready := fullBBuf.io.enq.ready
|
||||
fullBTag.io.deq.ready := fullBBuf.io.enq.ready
|
||||
|
||||
// fullC is instiated at the access stage
|
||||
|
||||
val fullCTag = Module(new Queue(
|
||||
new TensorMemTag, entries = 1, pipe = true
|
||||
))
|
||||
fullCTag.io.enq.valid := respQueueB.valid
|
||||
fullCTag.io.enq.bits := respQueueB.bits.tag
|
||||
|
||||
val fullCBuf = Module(new Queue(
|
||||
new Bundle {
|
||||
val data = chiselTypeOf(fullC.io.deq.bits)
|
||||
val tag = new TensorMemTag
|
||||
}, entries = 1, pipe = true
|
||||
))
|
||||
fullCBuf.io.enq.valid := fullC.io.deq.valid
|
||||
fullCBuf.io.enq.bits.data := fullC.io.deq.bits
|
||||
fullCBuf.io.enq.bits.tag := fullCTag.io.deq.bits
|
||||
fullC.io.deq.ready := fullCBuf.io.enq.ready
|
||||
fullCTag.io.deq.ready := fullCBuf.io.enq.ready
|
||||
|
||||
val dpuReady = Wire(Bool())
|
||||
val dpuFire = Wire(Bool())
|
||||
val operandsValid = fullABuf.io.deq.valid && fullBBuf.io.deq.valid
|
||||
val dpuFire = operandsValid && dpuReady
|
||||
dpuFire := operandsValid && dpuReady
|
||||
dontTouch(dpuFire)
|
||||
|
||||
val setCompute = RegInit(0.U(setBits.W))
|
||||
val stepCompute = RegInit(0.U(stepBits.W))
|
||||
@@ -376,11 +440,14 @@ class TensorCoreDecoupled(
|
||||
}
|
||||
val operandA = selectOperandA(fullABuf.io.deq.bits.data)
|
||||
val operandATag = fullABuf.io.deq.bits.tag
|
||||
// select the correct 2x4 tile from B operand buffer
|
||||
// select the correct 2x4 tile from B/C operand buffer
|
||||
val operandB = fullBBuf.io.deq.bits.data(substepCompute)
|
||||
val operandBTag = fullBBuf.io.deq.bits.tag
|
||||
val operandC = fullCBuf.io.deq.bits.data(substepCompute)
|
||||
val operandCTag = fullCBuf.io.deq.bits.tag
|
||||
dontTouch(operandATag)
|
||||
dontTouch(operandBTag)
|
||||
dontTouch(operandCTag)
|
||||
|
||||
// Operand buffer logic
|
||||
//
|
||||
@@ -397,6 +464,10 @@ class TensorCoreDecoupled(
|
||||
((stepCompute & shouldDequeueBMask) === shouldDequeueBMask) &&
|
||||
(substepCompute === 1.U)
|
||||
fullBBuf.io.deq.ready := dpuFire && shouldDequeueB
|
||||
|
||||
// C buf should be synced with B buf
|
||||
fullCBuf.io.deq.ready := dpuFire && shouldDequeueB
|
||||
|
||||
dontTouch(respQueueA)
|
||||
dontTouch(respQueueB)
|
||||
dontTouch(shouldDequeueA)
|
||||
@@ -414,6 +485,10 @@ class TensorCoreDecoupled(
|
||||
operandATag.set === operandBTag.set,
|
||||
"A and B operands are pointing to different warps and sets. " ++
|
||||
"This might indicate memory response coming back out-of-order.")
|
||||
assert(operandATag.warp === operandCTag.warp &&
|
||||
operandATag.set === operandCTag.set,
|
||||
"A and C operands are pointing to different warps and sets. " ++
|
||||
"This might indicate memory response coming back out-of-order.")
|
||||
assert(operandATag.set === setCompute,
|
||||
"Operand arrived from memory is pointing at a different set than the FSM.")
|
||||
}
|
||||
@@ -422,7 +497,7 @@ class TensorCoreDecoupled(
|
||||
|
||||
// Dot-product unit
|
||||
//
|
||||
// 4x2 four-element DPUs summing up to 32 MACs in total
|
||||
// 4x2 four-element DPUs summing up to 32 FP32 MACs in total
|
||||
//
|
||||
val ncSubstep = tilingParams.nc / 2
|
||||
require(tilingParams.mc * ncSubstep == numLanes,
|
||||
@@ -444,13 +519,18 @@ class TensorCoreDecoupled(
|
||||
require(operandBDimensional.length == ncSubstep &&
|
||||
operandBDimensional(0).length == tilingParams.kc,
|
||||
"operand width doesn't agree with tiling parameter")
|
||||
// note operand C is M-major
|
||||
val operandCDimensional = reshapeByFourWords(operandC)
|
||||
require(operandCDimensional.length == ncSubstep &&
|
||||
operandCDimensional(0).length == tilingParams.mc,
|
||||
"operand width doesn't agree with tiling parameter")
|
||||
|
||||
for (m <- 0 until tilingParams.mc) {
|
||||
for (n <- 0 until ncSubstep) {
|
||||
dpus(m)(n).io.in.valid := dpuFire
|
||||
dpus(m)(n).io.in.bits.a := operandADimensional(m)
|
||||
dpus(m)(n).io.in.bits.b := operandBDimensional(n)
|
||||
dpus(m)(n).io.in.bits.c := 0.U // FIXME: bogus accum data
|
||||
dpus(m)(n).io.in.bits.c := operandCDimensional(n)(m)
|
||||
// dpu ready couples with writeback backpressure
|
||||
dpus(m)(n).io.stall := !io.writeback.ready
|
||||
}
|
||||
@@ -631,8 +711,10 @@ class TensorCoreDecoupledTLImp(outer: TensorCoreDecoupledTL)
|
||||
tensor.io.respB.bits.source := tlOutB.d.bits.source
|
||||
tlOutB.d.ready := tensor.io.respB.ready
|
||||
|
||||
tensor.io.respC := 42.U // FIXME bogus
|
||||
|
||||
tensor.io.initiate.valid := io.start
|
||||
tensor.io.initiate.bits.wid := 0.U // TODO
|
||||
tensor.io.initiate.bits.wid := 0.U // FIXME bogus
|
||||
tensor.io.writeback.ready := true.B
|
||||
|
||||
io.finished := tensor.io.writeback.valid && tensor.io.writeback.bits.last
|
||||
|
||||
Reference in New Issue
Block a user