tensor: Sequence through set/steps
This commit is contained in:
@@ -42,6 +42,7 @@ class TensorCoreDecoupled(
|
||||
val reqA = Decoupled(new TensorMemReq)
|
||||
val reqB = Decoupled(new TensorMemReq)
|
||||
})
|
||||
dontTouch(io)
|
||||
|
||||
// FSM
|
||||
// ---
|
||||
@@ -62,48 +63,70 @@ class TensorCoreDecoupled(
|
||||
// support one outstanding warp request
|
||||
val warpReg = RegInit(0.U(numWarpBits.W))
|
||||
|
||||
// TODO: just transition every cycle for now
|
||||
def nextState(state: TensorState.Type) = state match {
|
||||
case TensorState.idle => Mux(io.initiate.fire, TensorState.run, state)
|
||||
case TensorState.run => TensorState.finish
|
||||
case TensorState.finish => {
|
||||
// hold until writeback is cleared
|
||||
Mux(io.writeback.ready, TensorState.idle, state)
|
||||
}
|
||||
case _ => TensorState.idle
|
||||
}
|
||||
state := nextState(state)
|
||||
|
||||
// state table for every warp id
|
||||
// sets: k iteration
|
||||
val numSets = (tilingParams.k / tilingParams.kc)
|
||||
val setBits = log2Ceil(numSets)
|
||||
// steps: i-j iteration
|
||||
val numSteps = (tilingParams.m * tilingParams.n) / (tilingParams.mc * tilingParams.nc)
|
||||
val stepBits = log2Ceil(numSteps)
|
||||
val setReg = RegInit(0.U(setBits.W))
|
||||
val stepReg = RegInit(0.U(setBits.W))
|
||||
// val tableRow = Valid(new Bundle {
|
||||
// val set = UInt(setBits.W)
|
||||
// val step = UInt(stepBits.W)
|
||||
// })
|
||||
val set = RegInit(0.U(setBits.W))
|
||||
val step = RegInit(0.U(stepBits.W))
|
||||
|
||||
when(io.initiate.fire) {
|
||||
val wid = io.initiate.bits.wid
|
||||
busy := true.B
|
||||
warpReg := wid
|
||||
setReg := 0.U
|
||||
stepReg := 0.U
|
||||
set := 0.U
|
||||
step := 0.U
|
||||
when(io.writeback.fire) {
|
||||
assert(io.writeback.bits.wid =/= wid,
|
||||
"unsupported concurrent initiate and writeback to the same warp")
|
||||
assert(
|
||||
io.writeback.bits.wid =/= wid,
|
||||
"unsupported concurrent initiate and writeback to the same warp"
|
||||
)
|
||||
}
|
||||
}
|
||||
when(io.writeback.fire) {
|
||||
busy := false.B
|
||||
}
|
||||
|
||||
// set/step sequencing logic
|
||||
val nextStep = true.B // TODO
|
||||
val lastSet = ((1 << setBits) - 1)
|
||||
val lastStep = ((1 << stepBits) - 1)
|
||||
val setDone = (set === lastSet.U)
|
||||
val stepDone = (step === lastStep.U)
|
||||
when (nextStep) {
|
||||
step := (step + 1.U) & lastStep.U
|
||||
when (stepDone) {
|
||||
set := (set + 1.U) & lastSet.U
|
||||
}
|
||||
}
|
||||
|
||||
// state transition logic
|
||||
switch(state) {
|
||||
is(TensorState.idle) {
|
||||
when(io.initiate.fire) {
|
||||
state := TensorState.run
|
||||
}
|
||||
}
|
||||
is(TensorState.run) {
|
||||
when (setDone && stepDone && nextStep) {
|
||||
when (state === TensorState.run) {
|
||||
state := TensorState.finish
|
||||
}
|
||||
}
|
||||
}
|
||||
is(TensorState.finish) {
|
||||
when(io.writeback.fire) {
|
||||
state := TensorState.idle
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
io.initiate.ready := !busy
|
||||
io.writeback.valid := (state === TensorState.finish)
|
||||
io.writeback.bits.wid := warpReg
|
||||
io.writeback.bits.last := false.B // TODO
|
||||
|
||||
// Writeback queues
|
||||
// ----------------
|
||||
@@ -114,13 +137,6 @@ class TensorCoreDecoupled(
|
||||
// val widQueue = Queue(io.initiate, queueDepth, pipe = (queueDepth == 1))
|
||||
// val rdQueue = Queue(io.initiate, queueDepth, pipe = (queueDepth == 1))
|
||||
|
||||
// Output logic
|
||||
// ------------
|
||||
|
||||
io.writeback.valid := (state === TensorState.finish)
|
||||
io.writeback.bits.wid := warpReg
|
||||
io.writeback.bits.last := false.B // TODO
|
||||
|
||||
// FIXME
|
||||
io.respA.ready := true.B
|
||||
io.respB.ready := true.B
|
||||
|
||||
Reference in New Issue
Block a user