tensor: Sequence through set/steps

This commit is contained in:
Hansung Kim
2024-10-14 20:20:30 -07:00
parent 3165108c8b
commit 01f53a8be1

View File

@@ -42,6 +42,7 @@ class TensorCoreDecoupled(
val reqA = Decoupled(new TensorMemReq)
val reqB = Decoupled(new TensorMemReq)
})
dontTouch(io)
// FSM
// ---
@@ -62,48 +63,70 @@ class TensorCoreDecoupled(
// support one outstanding warp request
val warpReg = RegInit(0.U(numWarpBits.W))
// TODO: just transition every cycle for now
def nextState(state: TensorState.Type) = state match {
case TensorState.idle => Mux(io.initiate.fire, TensorState.run, state)
case TensorState.run => TensorState.finish
case TensorState.finish => {
// hold until writeback is cleared
Mux(io.writeback.ready, TensorState.idle, state)
}
case _ => TensorState.idle
}
state := nextState(state)
// state table for every warp id
// sets: k iteration
val numSets = (tilingParams.k / tilingParams.kc)
val setBits = log2Ceil(numSets)
// steps: i-j iteration
val numSteps = (tilingParams.m * tilingParams.n) / (tilingParams.mc * tilingParams.nc)
val stepBits = log2Ceil(numSteps)
val setReg = RegInit(0.U(setBits.W))
val stepReg = RegInit(0.U(setBits.W))
// val tableRow = Valid(new Bundle {
// val set = UInt(setBits.W)
// val step = UInt(stepBits.W)
// })
val set = RegInit(0.U(setBits.W))
val step = RegInit(0.U(stepBits.W))
when(io.initiate.fire) {
val wid = io.initiate.bits.wid
busy := true.B
warpReg := wid
setReg := 0.U
stepReg := 0.U
set := 0.U
step := 0.U
when(io.writeback.fire) {
assert(io.writeback.bits.wid =/= wid,
"unsupported concurrent initiate and writeback to the same warp")
assert(
io.writeback.bits.wid =/= wid,
"unsupported concurrent initiate and writeback to the same warp"
)
}
}
when(io.writeback.fire) {
busy := false.B
}
// set/step sequencing logic
val nextStep = true.B // TODO
val lastSet = ((1 << setBits) - 1)
val lastStep = ((1 << stepBits) - 1)
val setDone = (set === lastSet.U)
val stepDone = (step === lastStep.U)
when (nextStep) {
step := (step + 1.U) & lastStep.U
when (stepDone) {
set := (set + 1.U) & lastSet.U
}
}
// state transition logic
switch(state) {
is(TensorState.idle) {
when(io.initiate.fire) {
state := TensorState.run
}
}
is(TensorState.run) {
when (setDone && stepDone && nextStep) {
when (state === TensorState.run) {
state := TensorState.finish
}
}
}
is(TensorState.finish) {
when(io.writeback.fire) {
state := TensorState.idle
}
}
}
io.initiate.ready := !busy
io.writeback.valid := (state === TensorState.finish)
io.writeback.bits.wid := warpReg
io.writeback.bits.last := false.B // TODO
// Writeback queues
// ----------------
@@ -114,13 +137,6 @@ class TensorCoreDecoupled(
// val widQueue = Queue(io.initiate, queueDepth, pipe = (queueDepth == 1))
// val rdQueue = Queue(io.initiate, queueDepth, pipe = (queueDepth == 1))
// Output logic
// ------------
io.writeback.valid := (state === TensorState.finish)
io.writeback.bits.wid := warpReg
io.writeback.bits.last := false.B // TODO
// FIXME
io.respA.ready := true.B
io.respB.ready := true.B