tensor: Translate TL response source to set/step tag

This commit is contained in:
Hansung Kim
2024-10-15 16:48:39 -07:00
parent 2ca2ee37b0
commit de393115cd

View File

@@ -32,8 +32,8 @@ class TensorCoreDecoupled(
) extends Module {
val numWarpBits = log2Ceil(numWarps)
val wordSize = 4 // TODO FP16
val dataWidth = numLanes * wordSize * 8/*bits*/ // TODO FP16
val sourceWidth = log2Ceil(numSourceIds)
val dataWidth = numLanes * wordSize * 8/*bits*/ // TODO FP16
val io = IO(new Bundle {
val initiate = Flipped(Decoupled(new Bundle {
@@ -51,6 +51,27 @@ class TensorCoreDecoupled(
})
dontTouch(io)
class TensorMemReq(
sourceWidth: Int
) extends Bundle {
val source = UInt(sourceWidth.W)
val address = UInt(32.W)
}
class TensorMemResp(
sourceWidth: Int,
dataWidth: Int
) extends Bundle {
val source = UInt(sourceWidth.W)
val data = UInt(dataWidth.W)
}
// mem response after translation from TL source to set/step tag
class TensorMemRespWithTag(
dataWidth: Int
) extends Bundle {
val tag = new TensorMemTag
val data = UInt(dataWidth.W)
}
// FSM
// ---
// This drives the overall pipeline of memory requests, dot-product unit
@@ -101,18 +122,39 @@ class TensorCoreDecoupled(
//
val genReq = (state === TensorState.run)
Seq((io.reqA, io.respA), (io.reqB, io.respB)).foreach {
case (req, resp) => {
val sourceGen = Module(new SourceGenerator(log2Ceil(numSourceIds)))
class TensorMemTag extends Bundle {
val set = UInt(setBits.W)
val step = UInt(stepBits.W)
}
// use concatenation of set/step as the memory request source. This will get
// translated to the actual TL sourcewidth in sourceGen.
val tag = Wire(new TensorMemTag)
tag.set := set
tag.step := step
val respATagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth)))
val respBTagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth)))
Seq((io.reqA, (io.respA, respATagged)), (io.reqB, (io.respB, respBTagged))).foreach {
case (req, (resp, respTagged)) => {
val sourceGen = Module(new SourceGenerator(
log2Ceil(numSourceIds),
metadata = Some(tag)
))
sourceGen.io.gen := req.fire
sourceGen.io.meta := DontCare
sourceGen.io.meta := tag
req.valid := genReq
req.bits.address := 0.U // FIXME
req.bits.source := sourceGen.io.id.bits
sourceGen.io.reclaim.valid := resp.fire
sourceGen.io.reclaim.bits := resp.bits.source
// translate source
respTagged.valid := resp.valid
respTagged.bits.tag := sourceGen.io.peek
respTagged.bits.data := resp.bits.data
resp.ready := respTagged.ready
}
}
@@ -130,16 +172,13 @@ class TensorCoreDecoupled(
firedABReg := Seq(false.B, false.B)
}
io.respA.ready := true.B // FIXME
io.respB.ready := true.B // FIXME
// Execute stage
// -------------
// Backend of the decoupled access/execute pipeline.
//
val respQueueDepth = 4 // FIXME: parameterize
val respQueueA = Queue(io.respA, respQueueDepth)
val respQueueB = Queue(io.respB, respQueueDepth)
val respQueueA = Queue(respATagged, respQueueDepth)
val respQueueB = Queue(respBTagged, respQueueDepth)
respQueueA.ready := io.writeback.ready // FIXME
respQueueB.ready := io.writeback.ready // FIXME
@@ -149,9 +188,11 @@ class TensorCoreDecoupled(
// FIXME: debug dummy: pipe A directly to writeback
io.writeback.valid := respQueueA.valid
val groupedRespA = respQueueA.bits.data.asBools.grouped(wordSize * 8/*bits*/)
val groupedRespA = respQueueA.bits.data
.asBools.grouped(wordSize * 8/*bits*/)
.map(VecInit(_).asUInt)
(io.writeback.bits.data zip groupedRespA).foreach { case (wb, data) =>
wb := VecInit(data).asUInt
wb := data
}
// State transition
@@ -204,20 +245,6 @@ class TensorCoreDecoupled(
// val rdQueue = Queue(io.initiate, queueDepth, pipe = (queueDepth == 1))
}
class TensorMemReq(
sourceWidth: Int
) extends Bundle {
val source = UInt(sourceWidth.W)
val address = UInt(32.W)
}
class TensorMemResp(
sourceWidth: Int,
dataWidth: Int
) extends Bundle {
val source = UInt(sourceWidth.W)
val data = UInt(dataWidth.W)
}
// synthesizable unit tests
// wraps TensorCoreDecoupled with a TileLink client node for use in a Diplomacy