tensor: Sequence through set/steps
This commit is contained in:
@@ -42,6 +42,7 @@ class TensorCoreDecoupled(
|
|||||||
val reqA = Decoupled(new TensorMemReq)
|
val reqA = Decoupled(new TensorMemReq)
|
||||||
val reqB = Decoupled(new TensorMemReq)
|
val reqB = Decoupled(new TensorMemReq)
|
||||||
})
|
})
|
||||||
|
dontTouch(io)
|
||||||
|
|
||||||
// FSM
|
// FSM
|
||||||
// ---
|
// ---
|
||||||
@@ -62,48 +63,70 @@ class TensorCoreDecoupled(
|
|||||||
// support one outstanding warp request
|
// support one outstanding warp request
|
||||||
val warpReg = RegInit(0.U(numWarpBits.W))
|
val warpReg = RegInit(0.U(numWarpBits.W))
|
||||||
|
|
||||||
// TODO: just transition every cycle for now
|
|
||||||
def nextState(state: TensorState.Type) = state match {
|
|
||||||
case TensorState.idle => Mux(io.initiate.fire, TensorState.run, state)
|
|
||||||
case TensorState.run => TensorState.finish
|
|
||||||
case TensorState.finish => {
|
|
||||||
// hold until writeback is cleared
|
|
||||||
Mux(io.writeback.ready, TensorState.idle, state)
|
|
||||||
}
|
|
||||||
case _ => TensorState.idle
|
|
||||||
}
|
|
||||||
state := nextState(state)
|
|
||||||
|
|
||||||
// state table for every warp id
|
|
||||||
// sets: k iteration
|
// sets: k iteration
|
||||||
val numSets = (tilingParams.k / tilingParams.kc)
|
val numSets = (tilingParams.k / tilingParams.kc)
|
||||||
val setBits = log2Ceil(numSets)
|
val setBits = log2Ceil(numSets)
|
||||||
// steps: i-j iteration
|
// steps: i-j iteration
|
||||||
val numSteps = (tilingParams.m * tilingParams.n) / (tilingParams.mc * tilingParams.nc)
|
val numSteps = (tilingParams.m * tilingParams.n) / (tilingParams.mc * tilingParams.nc)
|
||||||
val stepBits = log2Ceil(numSteps)
|
val stepBits = log2Ceil(numSteps)
|
||||||
val setReg = RegInit(0.U(setBits.W))
|
val set = RegInit(0.U(setBits.W))
|
||||||
val stepReg = RegInit(0.U(setBits.W))
|
val step = RegInit(0.U(stepBits.W))
|
||||||
// val tableRow = Valid(new Bundle {
|
|
||||||
// val set = UInt(setBits.W)
|
|
||||||
// val step = UInt(stepBits.W)
|
|
||||||
// })
|
|
||||||
|
|
||||||
when(io.initiate.fire) {
|
when(io.initiate.fire) {
|
||||||
val wid = io.initiate.bits.wid
|
val wid = io.initiate.bits.wid
|
||||||
busy := true.B
|
busy := true.B
|
||||||
warpReg := wid
|
warpReg := wid
|
||||||
setReg := 0.U
|
set := 0.U
|
||||||
stepReg := 0.U
|
step := 0.U
|
||||||
when(io.writeback.fire) {
|
when(io.writeback.fire) {
|
||||||
assert(io.writeback.bits.wid =/= wid,
|
assert(
|
||||||
"unsupported concurrent initiate and writeback to the same warp")
|
io.writeback.bits.wid =/= wid,
|
||||||
|
"unsupported concurrent initiate and writeback to the same warp"
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
when (io.writeback.fire) {
|
when(io.writeback.fire) {
|
||||||
busy := false.B
|
busy := false.B
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// set/step sequencing logic
|
||||||
|
val nextStep = true.B // TODO
|
||||||
|
val lastSet = ((1 << setBits) - 1)
|
||||||
|
val lastStep = ((1 << stepBits) - 1)
|
||||||
|
val setDone = (set === lastSet.U)
|
||||||
|
val stepDone = (step === lastStep.U)
|
||||||
|
when (nextStep) {
|
||||||
|
step := (step + 1.U) & lastStep.U
|
||||||
|
when (stepDone) {
|
||||||
|
set := (set + 1.U) & lastSet.U
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// state transition logic
|
||||||
|
switch(state) {
|
||||||
|
is(TensorState.idle) {
|
||||||
|
when(io.initiate.fire) {
|
||||||
|
state := TensorState.run
|
||||||
|
}
|
||||||
|
}
|
||||||
|
is(TensorState.run) {
|
||||||
|
when (setDone && stepDone && nextStep) {
|
||||||
|
when (state === TensorState.run) {
|
||||||
|
state := TensorState.finish
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
is(TensorState.finish) {
|
||||||
|
when(io.writeback.fire) {
|
||||||
|
state := TensorState.idle
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
io.initiate.ready := !busy
|
io.initiate.ready := !busy
|
||||||
|
io.writeback.valid := (state === TensorState.finish)
|
||||||
|
io.writeback.bits.wid := warpReg
|
||||||
|
io.writeback.bits.last := false.B // TODO
|
||||||
|
|
||||||
// Writeback queues
|
// Writeback queues
|
||||||
// ----------------
|
// ----------------
|
||||||
@@ -114,13 +137,6 @@ class TensorCoreDecoupled(
|
|||||||
// val widQueue = Queue(io.initiate, queueDepth, pipe = (queueDepth == 1))
|
// val widQueue = Queue(io.initiate, queueDepth, pipe = (queueDepth == 1))
|
||||||
// val rdQueue = Queue(io.initiate, queueDepth, pipe = (queueDepth == 1))
|
// val rdQueue = Queue(io.initiate, queueDepth, pipe = (queueDepth == 1))
|
||||||
|
|
||||||
// Output logic
|
|
||||||
// ------------
|
|
||||||
|
|
||||||
io.writeback.valid := (state === TensorState.finish)
|
|
||||||
io.writeback.bits.wid := warpReg
|
|
||||||
io.writeback.bits.last := false.B // TODO
|
|
||||||
|
|
||||||
// FIXME
|
// FIXME
|
||||||
io.respA.ready := true.B
|
io.respA.ready := true.B
|
||||||
io.respB.ready := true.B
|
io.respB.ready := true.B
|
||||||
|
|||||||
Reference in New Issue
Block a user