tensor: Add destination reg to IO
This commit is contained in:
@@ -28,12 +28,14 @@ class TensorCoreDecoupled(
|
|||||||
val numWarps: Int,
|
val numWarps: Int,
|
||||||
val numLanes: Int,
|
val numLanes: Int,
|
||||||
val numSourceIds: Int,
|
val numSourceIds: Int,
|
||||||
val tilingParams: TensorTilingParams
|
val tilingParams: TensorTilingParams,
|
||||||
|
val numFPRegs: Int = 32
|
||||||
) extends Module {
|
) extends Module {
|
||||||
val numWarpBits = log2Ceil(numWarps)
|
val numWarpBits = log2Ceil(numWarps)
|
||||||
val wordSize = 4 // TODO FP16
|
val wordSize = 4 // TODO FP16
|
||||||
val sourceWidth = log2Ceil(numSourceIds)
|
val sourceWidth = log2Ceil(numSourceIds)
|
||||||
val dataWidth = numLanes * wordSize * 8/*bits*/ // TODO FP16
|
val dataWidth = numLanes * wordSize * 8/*bits*/ // TODO FP16
|
||||||
|
val numFPRegBits = log2Ceil(numFPRegs)
|
||||||
|
|
||||||
val io = IO(new Bundle {
|
val io = IO(new Bundle {
|
||||||
val initiate = Flipped(Decoupled(new Bundle {
|
val initiate = Flipped(Decoupled(new Bundle {
|
||||||
@@ -42,6 +44,7 @@ class TensorCoreDecoupled(
|
|||||||
val writeback = Decoupled(new Bundle {
|
val writeback = Decoupled(new Bundle {
|
||||||
val last = Bool()
|
val last = Bool()
|
||||||
val wid = UInt(numWarpBits.W)
|
val wid = UInt(numWarpBits.W)
|
||||||
|
val rd = UInt(numFPRegBits.W)
|
||||||
val data = Vec(numLanes, UInt((wordSize * 8/*bits*/).W))
|
val data = Vec(numLanes, UInt((wordSize * 8/*bits*/).W))
|
||||||
})
|
})
|
||||||
val respA = Flipped(Decoupled(new TensorMemResp(sourceWidth, dataWidth)))
|
val respA = Flipped(Decoupled(new TensorMemResp(sourceWidth, dataWidth)))
|
||||||
@@ -218,8 +221,17 @@ class TensorCoreDecoupled(
|
|||||||
// FIXME: this need to change to dpu_fire
|
// FIXME: this need to change to dpu_fire
|
||||||
val nextStepExecute = io.writeback.fire
|
val nextStepExecute = io.writeback.fire
|
||||||
|
|
||||||
|
def rdGen(set: UInt, step: UInt): UInt = {
|
||||||
|
// each step produces 4x4 output tile, written by 8 threads with 2 regs per
|
||||||
|
// thread
|
||||||
|
require(numLanes == 8, "currently assumes 8-wide warps")
|
||||||
|
(Cat(set, step) >> 1/*2 regs/thread*/)
|
||||||
|
// FIXME: add substep here
|
||||||
|
}
|
||||||
|
|
||||||
io.writeback.valid := bothQueueValid
|
io.writeback.valid := bothQueueValid
|
||||||
io.writeback.bits.wid := warpReg
|
io.writeback.bits.wid := warpReg
|
||||||
|
io.writeback.bits.rd := rdGen(setExecute, stepExecute)
|
||||||
io.writeback.bits.last := setDone(setExecute) && stepDone(stepExecute)
|
io.writeback.bits.last := setDone(setExecute) && stepDone(stepExecute)
|
||||||
|
|
||||||
// FIXME: debug dummy: pipe A directly to writeback
|
// FIXME: debug dummy: pipe A directly to writeback
|
||||||
|
|||||||
Reference in New Issue
Block a user