tensor: Translate TL response source to set/step tag
This commit is contained in:
@@ -32,8 +32,8 @@ class TensorCoreDecoupled(
|
|||||||
) extends Module {
|
) extends Module {
|
||||||
val numWarpBits = log2Ceil(numWarps)
|
val numWarpBits = log2Ceil(numWarps)
|
||||||
val wordSize = 4 // TODO FP16
|
val wordSize = 4 // TODO FP16
|
||||||
val dataWidth = numLanes * wordSize * 8/*bits*/ // TODO FP16
|
|
||||||
val sourceWidth = log2Ceil(numSourceIds)
|
val sourceWidth = log2Ceil(numSourceIds)
|
||||||
|
val dataWidth = numLanes * wordSize * 8/*bits*/ // TODO FP16
|
||||||
|
|
||||||
val io = IO(new Bundle {
|
val io = IO(new Bundle {
|
||||||
val initiate = Flipped(Decoupled(new Bundle {
|
val initiate = Flipped(Decoupled(new Bundle {
|
||||||
@@ -51,6 +51,27 @@ class TensorCoreDecoupled(
|
|||||||
})
|
})
|
||||||
dontTouch(io)
|
dontTouch(io)
|
||||||
|
|
||||||
|
class TensorMemReq(
|
||||||
|
sourceWidth: Int
|
||||||
|
) extends Bundle {
|
||||||
|
val source = UInt(sourceWidth.W)
|
||||||
|
val address = UInt(32.W)
|
||||||
|
}
|
||||||
|
class TensorMemResp(
|
||||||
|
sourceWidth: Int,
|
||||||
|
dataWidth: Int
|
||||||
|
) extends Bundle {
|
||||||
|
val source = UInt(sourceWidth.W)
|
||||||
|
val data = UInt(dataWidth.W)
|
||||||
|
}
|
||||||
|
// mem response after translation from TL source to set/step tag
|
||||||
|
class TensorMemRespWithTag(
|
||||||
|
dataWidth: Int
|
||||||
|
) extends Bundle {
|
||||||
|
val tag = new TensorMemTag
|
||||||
|
val data = UInt(dataWidth.W)
|
||||||
|
}
|
||||||
|
|
||||||
// FSM
|
// FSM
|
||||||
// ---
|
// ---
|
||||||
// This drives the overall pipeline of memory requests, dot-product unit
|
// This drives the overall pipeline of memory requests, dot-product unit
|
||||||
@@ -101,18 +122,39 @@ class TensorCoreDecoupled(
|
|||||||
//
|
//
|
||||||
val genReq = (state === TensorState.run)
|
val genReq = (state === TensorState.run)
|
||||||
|
|
||||||
Seq((io.reqA, io.respA), (io.reqB, io.respB)).foreach {
|
class TensorMemTag extends Bundle {
|
||||||
case (req, resp) => {
|
val set = UInt(setBits.W)
|
||||||
val sourceGen = Module(new SourceGenerator(log2Ceil(numSourceIds)))
|
val step = UInt(stepBits.W)
|
||||||
|
}
|
||||||
|
// use concatenation of set/step as the memory request source. This will get
|
||||||
|
// translated to the actual TL sourcewidth in sourceGen.
|
||||||
|
val tag = Wire(new TensorMemTag)
|
||||||
|
tag.set := set
|
||||||
|
tag.step := step
|
||||||
|
|
||||||
|
val respATagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth)))
|
||||||
|
val respBTagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth)))
|
||||||
|
Seq((io.reqA, (io.respA, respATagged)), (io.reqB, (io.respB, respBTagged))).foreach {
|
||||||
|
case (req, (resp, respTagged)) => {
|
||||||
|
val sourceGen = Module(new SourceGenerator(
|
||||||
|
log2Ceil(numSourceIds),
|
||||||
|
metadata = Some(tag)
|
||||||
|
))
|
||||||
|
|
||||||
sourceGen.io.gen := req.fire
|
sourceGen.io.gen := req.fire
|
||||||
sourceGen.io.meta := DontCare
|
sourceGen.io.meta := tag
|
||||||
req.valid := genReq
|
req.valid := genReq
|
||||||
req.bits.address := 0.U // FIXME
|
req.bits.address := 0.U // FIXME
|
||||||
req.bits.source := sourceGen.io.id.bits
|
req.bits.source := sourceGen.io.id.bits
|
||||||
|
|
||||||
sourceGen.io.reclaim.valid := resp.fire
|
sourceGen.io.reclaim.valid := resp.fire
|
||||||
sourceGen.io.reclaim.bits := resp.bits.source
|
sourceGen.io.reclaim.bits := resp.bits.source
|
||||||
|
|
||||||
|
// translate source
|
||||||
|
respTagged.valid := resp.valid
|
||||||
|
respTagged.bits.tag := sourceGen.io.peek
|
||||||
|
respTagged.bits.data := resp.bits.data
|
||||||
|
resp.ready := respTagged.ready
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -130,16 +172,13 @@ class TensorCoreDecoupled(
|
|||||||
firedABReg := Seq(false.B, false.B)
|
firedABReg := Seq(false.B, false.B)
|
||||||
}
|
}
|
||||||
|
|
||||||
io.respA.ready := true.B // FIXME
|
|
||||||
io.respB.ready := true.B // FIXME
|
|
||||||
|
|
||||||
// Execute stage
|
// Execute stage
|
||||||
// -------------
|
// -------------
|
||||||
// Backend of the decoupled access/execute pipeline.
|
// Backend of the decoupled access/execute pipeline.
|
||||||
//
|
//
|
||||||
val respQueueDepth = 4 // FIXME: parameterize
|
val respQueueDepth = 4 // FIXME: parameterize
|
||||||
val respQueueA = Queue(io.respA, respQueueDepth)
|
val respQueueA = Queue(respATagged, respQueueDepth)
|
||||||
val respQueueB = Queue(io.respB, respQueueDepth)
|
val respQueueB = Queue(respBTagged, respQueueDepth)
|
||||||
respQueueA.ready := io.writeback.ready // FIXME
|
respQueueA.ready := io.writeback.ready // FIXME
|
||||||
respQueueB.ready := io.writeback.ready // FIXME
|
respQueueB.ready := io.writeback.ready // FIXME
|
||||||
|
|
||||||
@@ -149,9 +188,11 @@ class TensorCoreDecoupled(
|
|||||||
|
|
||||||
// FIXME: debug dummy: pipe A directly to writeback
|
// FIXME: debug dummy: pipe A directly to writeback
|
||||||
io.writeback.valid := respQueueA.valid
|
io.writeback.valid := respQueueA.valid
|
||||||
val groupedRespA = respQueueA.bits.data.asBools.grouped(wordSize * 8/*bits*/)
|
val groupedRespA = respQueueA.bits.data
|
||||||
|
.asBools.grouped(wordSize * 8/*bits*/)
|
||||||
|
.map(VecInit(_).asUInt)
|
||||||
(io.writeback.bits.data zip groupedRespA).foreach { case (wb, data) =>
|
(io.writeback.bits.data zip groupedRespA).foreach { case (wb, data) =>
|
||||||
wb := VecInit(data).asUInt
|
wb := data
|
||||||
}
|
}
|
||||||
|
|
||||||
// State transition
|
// State transition
|
||||||
@@ -204,20 +245,6 @@ class TensorCoreDecoupled(
|
|||||||
// val rdQueue = Queue(io.initiate, queueDepth, pipe = (queueDepth == 1))
|
// val rdQueue = Queue(io.initiate, queueDepth, pipe = (queueDepth == 1))
|
||||||
}
|
}
|
||||||
|
|
||||||
class TensorMemReq(
|
|
||||||
sourceWidth: Int
|
|
||||||
) extends Bundle {
|
|
||||||
val source = UInt(sourceWidth.W)
|
|
||||||
val address = UInt(32.W)
|
|
||||||
}
|
|
||||||
class TensorMemResp(
|
|
||||||
sourceWidth: Int,
|
|
||||||
dataWidth: Int
|
|
||||||
) extends Bundle {
|
|
||||||
val source = UInt(sourceWidth.W)
|
|
||||||
val data = UInt(dataWidth.W)
|
|
||||||
}
|
|
||||||
|
|
||||||
// synthesizable unit tests
|
// synthesizable unit tests
|
||||||
|
|
||||||
// wraps TensorCoreDecoupled with a TileLink client node for use in a Diplomacy
|
// wraps TensorCoreDecoupled with a TileLink client node for use in a Diplomacy
|
||||||
|
|||||||
Reference in New Issue
Block a user