Add state regs and init/writeback transition
This commit is contained in:
@@ -6,50 +6,126 @@ package radiance.core
|
|||||||
import chisel3._
|
import chisel3._
|
||||||
import chisel3.util._
|
import chisel3.util._
|
||||||
|
|
||||||
class TensorCoreDecoupled(val numWarps: Int, val numLanes: Int) extends Module {
|
case class TensorTilingParams(
|
||||||
|
// Dimension of the SMEM tile
|
||||||
|
m: Int = 16,
|
||||||
|
n: Int = 16,
|
||||||
|
k: Int = 16,
|
||||||
|
// Dimension of the compute tile. This is determined by the number of MAC
|
||||||
|
// units
|
||||||
|
mc: Int = 4,
|
||||||
|
nc: Int = 4,
|
||||||
|
kc: Int = 4
|
||||||
|
)
|
||||||
|
|
||||||
|
class TensorCoreDecoupled(
|
||||||
|
val numWarps: Int,
|
||||||
|
val numLanes: Int,
|
||||||
|
val tilingParams: TensorTilingParams
|
||||||
|
) extends Module {
|
||||||
val numWarpBits = log2Ceil(numWarps)
|
val numWarpBits = log2Ceil(numWarps)
|
||||||
val wordSize = 4 // TODO FP16
|
val wordSize = 4 // TODO FP16
|
||||||
val dataWidth = numLanes * wordSize // TODO FP16
|
val dataWidth = numLanes * wordSize // TODO FP16
|
||||||
|
|
||||||
val io = IO(new Bundle{
|
val io = IO(new Bundle {
|
||||||
val initiate = Flipped(Decoupled(new Bundle{
|
val initiate = Flipped(Decoupled(new Bundle {
|
||||||
val wid = UInt(numWarpBits.W)
|
val wid = UInt(numWarpBits.W)
|
||||||
}))
|
}))
|
||||||
val dataA = Flipped(Decoupled(new TensorMemResp(dataWidth)))
|
val writeback = Decoupled(new Bundle {
|
||||||
val dataB = Flipped(Decoupled(new TensorMemResp(dataWidth)))
|
|
||||||
val addressA = Decoupled(new TensorMemReq)
|
|
||||||
val addressB = Decoupled(new TensorMemReq)
|
|
||||||
val writeback = Decoupled(new Bundle{
|
|
||||||
val wid = UInt(numWarpBits.W)
|
val wid = UInt(numWarpBits.W)
|
||||||
val last = Bool()
|
val last = Bool()
|
||||||
})
|
})
|
||||||
|
val respA = Flipped(Decoupled(new TensorMemResp(dataWidth)))
|
||||||
|
val respB = Flipped(Decoupled(new TensorMemResp(dataWidth)))
|
||||||
|
val reqA = Decoupled(new TensorMemReq)
|
||||||
|
val reqB = Decoupled(new TensorMemReq)
|
||||||
})
|
})
|
||||||
|
|
||||||
// FSM
|
// FSM
|
||||||
//
|
// ---
|
||||||
|
// This drives the overall pipeline of memory requests, dot-product unit
|
||||||
|
// operations and regfile writeback.
|
||||||
|
|
||||||
|
object TensorState extends ChiselEnum {
|
||||||
|
val idle = Value(0.U)
|
||||||
|
val run = Value(1.U)
|
||||||
|
// All set/step sequencing is complete and the tensor core is holding the
|
||||||
|
// result data until downstream writeback is ready.
|
||||||
|
// FIXME: is this necessary if writeback is decoupled with queues?
|
||||||
|
val finish = Value(2.U)
|
||||||
|
}
|
||||||
val state = RegInit(TensorState.idle)
|
val state = RegInit(TensorState.idle)
|
||||||
|
val busy = RegInit(false.B)
|
||||||
|
// Holds the warp id the core is currently working on. Note that we only
|
||||||
|
// support one outstanding warp request
|
||||||
|
val warpReg = RegInit(0.U(numWarpBits.W))
|
||||||
|
|
||||||
// TODO: just transition every cycle for now
|
// TODO: just transition every cycle for now
|
||||||
state := (state match {
|
def nextState(state: TensorState.Type) = state match {
|
||||||
case TensorState.idle => Mux(io.initiate.fire, TensorState.smemRead, state)
|
case TensorState.idle => Mux(io.initiate.fire, TensorState.run, state)
|
||||||
case TensorState.smemRead => TensorState.compute
|
case TensorState.run => TensorState.finish
|
||||||
case TensorState.compute => TensorState.writeback
|
case TensorState.finish => {
|
||||||
case TensorState.writeback => {
|
|
||||||
// hold until writeback is cleared
|
// hold until writeback is cleared
|
||||||
Mux(io.writeback.ready, TensorState.idle, state)
|
Mux(io.writeback.ready, TensorState.idle, state)
|
||||||
}
|
}
|
||||||
case _ => TensorState.idle
|
case _ => TensorState.idle
|
||||||
})
|
}
|
||||||
|
state := nextState(state)
|
||||||
|
|
||||||
// TODO
|
// state table for every warp id
|
||||||
io.dataA.ready := true.B
|
// sets: k iteration
|
||||||
io.dataB.ready := true.B
|
val numSets = (tilingParams.k / tilingParams.kc)
|
||||||
io.addressA.valid := false.B
|
val setBits = log2Ceil(numSets)
|
||||||
io.addressB.valid := false.B
|
// steps: i-j iteration
|
||||||
io.addressA.bits := DontCare
|
val numSteps = (tilingParams.m * tilingParams.n) / (tilingParams.mc * tilingParams.nc)
|
||||||
io.addressB.bits := DontCare
|
val stepBits = log2Ceil(numSteps)
|
||||||
io.initiate.ready := true.B
|
val setReg = RegInit(0.U(setBits.W))
|
||||||
io.writeback.valid := true.B
|
val stepReg = RegInit(0.U(setBits.W))
|
||||||
io.writeback.bits := DontCare
|
// val tableRow = Valid(new Bundle {
|
||||||
|
// val set = UInt(setBits.W)
|
||||||
|
// val step = UInt(stepBits.W)
|
||||||
|
// })
|
||||||
|
|
||||||
|
when(io.initiate.fire) {
|
||||||
|
val wid = io.initiate.bits.wid
|
||||||
|
busy := true.B
|
||||||
|
warpReg := wid
|
||||||
|
setReg := 0.U
|
||||||
|
stepReg := 0.U
|
||||||
|
when(io.writeback.fire) {
|
||||||
|
assert(io.writeback.bits.wid =/= wid,
|
||||||
|
"unsupported concurrent initiate and writeback to the same warp")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
when (io.writeback.fire) {
|
||||||
|
busy := false.B
|
||||||
|
}
|
||||||
|
|
||||||
|
io.initiate.ready := !busy
|
||||||
|
|
||||||
|
// Writeback queues
|
||||||
|
// ----------------
|
||||||
|
// These queues hold the metadata necessary for register
|
||||||
|
// writeback.
|
||||||
|
|
||||||
|
// val queueDepth = 2
|
||||||
|
// val widQueue = Queue(io.initiate, queueDepth, pipe = (queueDepth == 1))
|
||||||
|
// val rdQueue = Queue(io.initiate, queueDepth, pipe = (queueDepth == 1))
|
||||||
|
|
||||||
|
// Output logic
|
||||||
|
// ------------
|
||||||
|
|
||||||
|
io.writeback.valid := (state === TensorState.finish)
|
||||||
|
io.writeback.bits.wid := warpReg
|
||||||
|
io.writeback.bits.last := false.B // TODO
|
||||||
|
|
||||||
|
// FIXME
|
||||||
|
io.respA.ready := true.B
|
||||||
|
io.respB.ready := true.B
|
||||||
|
io.reqA.valid := false.B
|
||||||
|
io.reqB.valid := false.B
|
||||||
|
io.reqA.bits := DontCare
|
||||||
|
io.reqB.bits := DontCare
|
||||||
}
|
}
|
||||||
|
|
||||||
class TensorMemReq extends Bundle {
|
class TensorMemReq extends Bundle {
|
||||||
@@ -60,11 +136,3 @@ class TensorMemResp(val dataWidth: Int) extends Bundle {
|
|||||||
// TODO: tag
|
// TODO: tag
|
||||||
val data = UInt(32.W)
|
val data = UInt(32.W)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
object TensorState extends ChiselEnum {
|
|
||||||
val idle = Value(0.U)
|
|
||||||
val smemRead = Value(1.U)
|
|
||||||
val compute = Value(2.U)
|
|
||||||
val writeback = Value(3.U)
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -9,13 +9,16 @@ class TensorCoreDecoupledTest extends AnyFlatSpec with ChiselScalatestTester {
|
|||||||
behavior of "TensorCoreDecoupled"
|
behavior of "TensorCoreDecoupled"
|
||||||
|
|
||||||
it should "do the right thing" in {
|
it should "do the right thing" in {
|
||||||
test(new TensorCoreDecoupled(8, 8))
|
test(new TensorCoreDecoupled(8, 8, tilingParams = TensorTilingParams()))
|
||||||
{ c =>
|
{ c =>
|
||||||
c.io.initiate.valid.poke(true.B)
|
c.io.initiate.valid.poke(true.B)
|
||||||
c.io.dataA.valid.poke(false.B)
|
c.io.initiate.bits.wid.poke(0.U)
|
||||||
c.io.dataA.bits.data.poke(0.U)
|
|
||||||
c.io.dataB.valid.poke(false.B)
|
c.io.respA.valid.poke(false.B)
|
||||||
c.io.dataB.bits.data.poke(0.U)
|
c.io.respA.bits.data.poke(0.U)
|
||||||
|
c.io.respB.valid.poke(false.B)
|
||||||
|
c.io.respB.bits.data.poke(0.U)
|
||||||
|
|
||||||
c.clock.step()
|
c.clock.step()
|
||||||
c.io.writeback.valid.expect(true.B)
|
c.io.writeback.valid.expect(true.B)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user