tensor: Translate TL response source to set/step tag
This commit is contained in:
@@ -32,8 +32,8 @@ class TensorCoreDecoupled(
|
||||
) extends Module {
|
||||
val numWarpBits = log2Ceil(numWarps)
|
||||
val wordSize = 4 // TODO FP16
|
||||
val dataWidth = numLanes * wordSize * 8/*bits*/ // TODO FP16
|
||||
val sourceWidth = log2Ceil(numSourceIds)
|
||||
val dataWidth = numLanes * wordSize * 8/*bits*/ // TODO FP16
|
||||
|
||||
val io = IO(new Bundle {
|
||||
val initiate = Flipped(Decoupled(new Bundle {
|
||||
@@ -51,6 +51,27 @@ class TensorCoreDecoupled(
|
||||
})
|
||||
dontTouch(io)
|
||||
|
||||
class TensorMemReq(
|
||||
sourceWidth: Int
|
||||
) extends Bundle {
|
||||
val source = UInt(sourceWidth.W)
|
||||
val address = UInt(32.W)
|
||||
}
|
||||
class TensorMemResp(
|
||||
sourceWidth: Int,
|
||||
dataWidth: Int
|
||||
) extends Bundle {
|
||||
val source = UInt(sourceWidth.W)
|
||||
val data = UInt(dataWidth.W)
|
||||
}
|
||||
// mem response after translation from TL source to set/step tag
|
||||
class TensorMemRespWithTag(
|
||||
dataWidth: Int
|
||||
) extends Bundle {
|
||||
val tag = new TensorMemTag
|
||||
val data = UInt(dataWidth.W)
|
||||
}
|
||||
|
||||
// FSM
|
||||
// ---
|
||||
// This drives the overall pipeline of memory requests, dot-product unit
|
||||
@@ -101,18 +122,39 @@ class TensorCoreDecoupled(
|
||||
//
|
||||
val genReq = (state === TensorState.run)
|
||||
|
||||
Seq((io.reqA, io.respA), (io.reqB, io.respB)).foreach {
|
||||
case (req, resp) => {
|
||||
val sourceGen = Module(new SourceGenerator(log2Ceil(numSourceIds)))
|
||||
class TensorMemTag extends Bundle {
|
||||
val set = UInt(setBits.W)
|
||||
val step = UInt(stepBits.W)
|
||||
}
|
||||
// use concatenation of set/step as the memory request source. This will get
|
||||
// translated to the actual TL sourcewidth in sourceGen.
|
||||
val tag = Wire(new TensorMemTag)
|
||||
tag.set := set
|
||||
tag.step := step
|
||||
|
||||
val respATagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth)))
|
||||
val respBTagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth)))
|
||||
Seq((io.reqA, (io.respA, respATagged)), (io.reqB, (io.respB, respBTagged))).foreach {
|
||||
case (req, (resp, respTagged)) => {
|
||||
val sourceGen = Module(new SourceGenerator(
|
||||
log2Ceil(numSourceIds),
|
||||
metadata = Some(tag)
|
||||
))
|
||||
|
||||
sourceGen.io.gen := req.fire
|
||||
sourceGen.io.meta := DontCare
|
||||
sourceGen.io.meta := tag
|
||||
req.valid := genReq
|
||||
req.bits.address := 0.U // FIXME
|
||||
req.bits.source := sourceGen.io.id.bits
|
||||
|
||||
sourceGen.io.reclaim.valid := resp.fire
|
||||
sourceGen.io.reclaim.bits := resp.bits.source
|
||||
|
||||
// translate source
|
||||
respTagged.valid := resp.valid
|
||||
respTagged.bits.tag := sourceGen.io.peek
|
||||
respTagged.bits.data := resp.bits.data
|
||||
resp.ready := respTagged.ready
|
||||
}
|
||||
}
|
||||
|
||||
@@ -130,16 +172,13 @@ class TensorCoreDecoupled(
|
||||
firedABReg := Seq(false.B, false.B)
|
||||
}
|
||||
|
||||
io.respA.ready := true.B // FIXME
|
||||
io.respB.ready := true.B // FIXME
|
||||
|
||||
// Execute stage
|
||||
// -------------
|
||||
// Backend of the decoupled access/execute pipeline.
|
||||
//
|
||||
val respQueueDepth = 4 // FIXME: parameterize
|
||||
val respQueueA = Queue(io.respA, respQueueDepth)
|
||||
val respQueueB = Queue(io.respB, respQueueDepth)
|
||||
val respQueueA = Queue(respATagged, respQueueDepth)
|
||||
val respQueueB = Queue(respBTagged, respQueueDepth)
|
||||
respQueueA.ready := io.writeback.ready // FIXME
|
||||
respQueueB.ready := io.writeback.ready // FIXME
|
||||
|
||||
@@ -149,9 +188,11 @@ class TensorCoreDecoupled(
|
||||
|
||||
// FIXME: debug dummy: pipe A directly to writeback
|
||||
io.writeback.valid := respQueueA.valid
|
||||
val groupedRespA = respQueueA.bits.data.asBools.grouped(wordSize * 8/*bits*/)
|
||||
val groupedRespA = respQueueA.bits.data
|
||||
.asBools.grouped(wordSize * 8/*bits*/)
|
||||
.map(VecInit(_).asUInt)
|
||||
(io.writeback.bits.data zip groupedRespA).foreach { case (wb, data) =>
|
||||
wb := VecInit(data).asUInt
|
||||
wb := data
|
||||
}
|
||||
|
||||
// State transition
|
||||
@@ -204,20 +245,6 @@ class TensorCoreDecoupled(
|
||||
// val rdQueue = Queue(io.initiate, queueDepth, pipe = (queueDepth == 1))
|
||||
}
|
||||
|
||||
class TensorMemReq(
|
||||
sourceWidth: Int
|
||||
) extends Bundle {
|
||||
val source = UInt(sourceWidth.W)
|
||||
val address = UInt(32.W)
|
||||
}
|
||||
class TensorMemResp(
|
||||
sourceWidth: Int,
|
||||
dataWidth: Int
|
||||
) extends Bundle {
|
||||
val source = UInt(sourceWidth.W)
|
||||
val data = UInt(dataWidth.W)
|
||||
}
|
||||
|
||||
// synthesizable unit tests
|
||||
|
||||
// wraps TensorCoreDecoupled with a TileLink client node for use in a Diplomacy
|
||||
|
||||
Reference in New Issue
Block a user