tensor: Instantiate actual DPU
This commit is contained in:
@@ -33,8 +33,9 @@ class TensorCoreDecoupled(
|
|||||||
) extends Module {
|
) extends Module {
|
||||||
val numWarpBits = log2Ceil(numWarps)
|
val numWarpBits = log2Ceil(numWarps)
|
||||||
val wordSize = 4 // TODO FP16
|
val wordSize = 4 // TODO FP16
|
||||||
|
val wordSizeInBits = wordSize * 8 // TODO FP16
|
||||||
val sourceWidth = log2Ceil(numSourceIds)
|
val sourceWidth = log2Ceil(numSourceIds)
|
||||||
val dataWidth = numLanes * wordSize * 8/*bits*/ // TODO FP16
|
val dataWidth = numLanes * wordSizeInBits // TODO FP16
|
||||||
val numFPRegBits = log2Ceil(numFPRegs)
|
val numFPRegBits = log2Ceil(numFPRegs)
|
||||||
|
|
||||||
val io = IO(new Bundle {
|
val io = IO(new Bundle {
|
||||||
@@ -45,7 +46,7 @@ class TensorCoreDecoupled(
|
|||||||
val last = Bool()
|
val last = Bool()
|
||||||
val wid = UInt(numWarpBits.W)
|
val wid = UInt(numWarpBits.W)
|
||||||
val rd = UInt(numFPRegBits.W)
|
val rd = UInt(numFPRegBits.W)
|
||||||
val data = Vec(numLanes, UInt((wordSize * 8/*bits*/).W))
|
val data = Vec(numLanes, UInt((wordSizeInBits).W))
|
||||||
})
|
})
|
||||||
val respA = Flipped(Decoupled(new TensorMemResp(sourceWidth, dataWidth)))
|
val respA = Flipped(Decoupled(new TensorMemResp(sourceWidth, dataWidth)))
|
||||||
val respB = Flipped(Decoupled(new TensorMemResp(sourceWidth, dataWidth)))
|
val respB = Flipped(Decoupled(new TensorMemResp(sourceWidth, dataWidth)))
|
||||||
@@ -223,9 +224,6 @@ class TensorCoreDecoupled(
|
|||||||
io.writeback.bits.data.widthOption.get,
|
io.writeback.bits.data.widthOption.get,
|
||||||
"response data width does not match the writeback data width")
|
"response data width does not match the writeback data width")
|
||||||
|
|
||||||
// FIXME: this need to change to dpu_ready
|
|
||||||
val dpuReady = io.writeback.ready // FIXME: this need be actual dpu
|
|
||||||
|
|
||||||
val substepExecute = RegInit(0.U(1.W))
|
val substepExecute = RegInit(0.U(1.W))
|
||||||
when (respQueueA.fire) {
|
when (respQueueA.fire) {
|
||||||
substepExecute := substepExecute + 1.U
|
substepExecute := substepExecute + 1.U
|
||||||
@@ -267,7 +265,10 @@ class TensorCoreDecoupled(
|
|||||||
fullAQueue.io.enq.bits.data := fullAEnqData
|
fullAQueue.io.enq.bits.data := fullAEnqData
|
||||||
fullAQueue.io.enq.bits.tag := fullAEnqTag
|
fullAQueue.io.enq.bits.tag := fullAEnqTag
|
||||||
|
|
||||||
val operandsValid = fullAQueue.io.deq.valid && respQueueB.valid // FIXME?
|
val operandsValid = fullAQueue.io.deq.valid && respQueueB.valid
|
||||||
|
val operandA = fullAQueue.io.deq.bits.data
|
||||||
|
val operandB = respQueueB.bits.data
|
||||||
|
val dpuReady = Wire(Bool())
|
||||||
val dpuFire = operandsValid && dpuReady
|
val dpuFire = operandsValid && dpuReady
|
||||||
val substepCompute = RegInit(0.U(1.W))
|
val substepCompute = RegInit(0.U(1.W))
|
||||||
when (dpuFire) {
|
when (dpuFire) {
|
||||||
@@ -301,6 +302,66 @@ class TensorCoreDecoupled(
|
|||||||
}
|
}
|
||||||
assertAligned
|
assertAligned
|
||||||
|
|
||||||
|
// Dot-product unit
|
||||||
|
//
|
||||||
|
// 4x2 four-element DPUs summing up to 32 MACs in total
|
||||||
|
val dpus = Seq.fill(4)(Seq.fill(2)(
|
||||||
|
Module(new TensorDotProductUnit(half = false))
|
||||||
|
))
|
||||||
|
// operandA is 4x4 in K-major
|
||||||
|
val operandADimensional =
|
||||||
|
operandA.asBools.grouped(wordSizeInBits).map(VecInit(_).asUInt).toSeq
|
||||||
|
.grouped(4).toSeq
|
||||||
|
println(s"operandA: ${fullAQueue.io.deq.bits.data.widthOption.get} bits")
|
||||||
|
println(s"A: ${operandADimensional.length}, ${operandADimensional(0).length}")
|
||||||
|
assert(operandADimensional.length == tilingParams.mc &&
|
||||||
|
operandADimensional(0).length == tilingParams.kc,
|
||||||
|
"operand width doesn't agree with tiling parameter")
|
||||||
|
// operandB is 2x4, i.e. 4x2 in N-major
|
||||||
|
val operandBDimensional =
|
||||||
|
operandB.asBools.grouped(wordSizeInBits).map(VecInit(_).asUInt).toSeq
|
||||||
|
.grouped(4).toSeq
|
||||||
|
println(s"B: ${operandBDimensional.length}, ${operandBDimensional(0).length}")
|
||||||
|
val ncSubstep = tilingParams.nc / 2
|
||||||
|
assert(tilingParams.mc * ncSubstep == numLanes,
|
||||||
|
"substep tile size doesn't match writeback throughput")
|
||||||
|
assert(operandBDimensional.length == ncSubstep &&
|
||||||
|
operandBDimensional(0).length == tilingParams.kc,
|
||||||
|
"operand width doesn't agree with tiling parameter")
|
||||||
|
|
||||||
|
for (m <- 0 until tilingParams.mc) {
|
||||||
|
for (n <- 0 until ncSubstep) {
|
||||||
|
dpus(m)(n).io.in.valid := dpuFire
|
||||||
|
dpus(m)(n).io.in.bits.a := operandADimensional(m)
|
||||||
|
dpus(m)(n).io.in.bits.b := operandBDimensional(n)
|
||||||
|
dpus(m)(n).io.in.bits.c := 0.U // FIXME: bogus accum data
|
||||||
|
// dpu ready couples with writeback backpressure
|
||||||
|
dpus(m)(n).io.stall := !io.writeback.ready
|
||||||
|
}
|
||||||
|
}
|
||||||
|
dpuReady := !dpus(0)(0).io.stall
|
||||||
|
dontTouch(dpuFire)
|
||||||
|
dontTouch(dpuReady)
|
||||||
|
|
||||||
|
val dpuValids = dpus.flatMap(_.map(_.io.out.valid))
|
||||||
|
val dpuValid = dpuValids.reduce(_ && _)
|
||||||
|
def assertDPU = {
|
||||||
|
val dpuStalls = dpus.flatMap(_.map(_.io.stall))
|
||||||
|
assert(dpuStalls.reduce(_ && _) === dpuStalls.reduce(_ || _),
|
||||||
|
"stall signals of DPUs went unaligned")
|
||||||
|
assert(dpuValids.reduce(_ && _) === dpuValids.reduce(_ || _),
|
||||||
|
"valid signals of DPUs went unaligned")
|
||||||
|
}
|
||||||
|
assertDPU
|
||||||
|
|
||||||
|
// flatten DPU output into 1D array in M-major order
|
||||||
|
val flattenedDPUOut = (0 until ncSubstep).flatMap { n =>
|
||||||
|
(0 until tilingParams.mc).map { m =>
|
||||||
|
dpus(m)(n).io.out.bits.data
|
||||||
|
}
|
||||||
|
}
|
||||||
|
io.writeback.bits.data := flattenedDPUOut
|
||||||
|
|
||||||
def rdGen(set: UInt, step: UInt): UInt = {
|
def rdGen(set: UInt, step: UInt): UInt = {
|
||||||
// each step produces 4x4 output tile, written by 8 threads with 2 regs per
|
// each step produces 4x4 output tile, written by 8 threads with 2 regs per
|
||||||
// thread
|
// thread
|
||||||
@@ -309,19 +370,11 @@ class TensorCoreDecoupled(
|
|||||||
// FIXME: add substep here
|
// FIXME: add substep here
|
||||||
}
|
}
|
||||||
|
|
||||||
io.writeback.valid := operandsValid // FIXME: bypass logic
|
io.writeback.valid := dpuValid
|
||||||
io.writeback.bits.wid := warpReg
|
io.writeback.bits.wid := warpReg
|
||||||
io.writeback.bits.rd := rdGen(setExecute, stepExecute)
|
io.writeback.bits.rd := rdGen(setExecute, stepExecute)
|
||||||
io.writeback.bits.last := setDone(setExecute) && stepDone(stepExecute)
|
io.writeback.bits.last := setDone(setExecute) && stepDone(stepExecute)
|
||||||
|
|
||||||
// FIXME: debug dummy: pipe A directly to writeback
|
|
||||||
val groupedRespA = respQueueA.bits.data
|
|
||||||
.asBools.grouped(wordSize * 8/*bits*/)
|
|
||||||
.map(VecInit(_).asUInt)
|
|
||||||
(io.writeback.bits.data zip groupedRespA).foreach { case (wb, data) =>
|
|
||||||
wb := data
|
|
||||||
}
|
|
||||||
|
|
||||||
// State transition
|
// State transition
|
||||||
// ----------------
|
// ----------------
|
||||||
//
|
//
|
||||||
@@ -400,7 +453,7 @@ class TensorCoreDecoupledTLImp(outer: TensorCoreDecoupledTL)
|
|||||||
|
|
||||||
val tensor = Module(new TensorCoreDecoupled(
|
val tensor = Module(new TensorCoreDecoupled(
|
||||||
8, 8, outer.numSrcIds , TensorTilingParams()))
|
8, 8, outer.numSrcIds , TensorTilingParams()))
|
||||||
val wordSize = 4 // FIXME: hardcoded
|
val wordSize = 4 // @cleanup: hardcoded
|
||||||
|
|
||||||
val zip = Seq((outer.node.out(0), tensor.io.reqA),
|
val zip = Seq((outer.node.out(0), tensor.io.reqA),
|
||||||
(outer.node.out(1), tensor.io.reqB))
|
(outer.node.out(1), tensor.io.reqB))
|
||||||
@@ -431,7 +484,7 @@ class TensorCoreDecoupledTLImp(outer: TensorCoreDecoupledTL)
|
|||||||
tlOutB.d.ready := tensor.io.respB.ready
|
tlOutB.d.ready := tensor.io.respB.ready
|
||||||
|
|
||||||
tensor.io.initiate.valid := io.start
|
tensor.io.initiate.valid := io.start
|
||||||
tensor.io.initiate.bits.wid := 0.U // FIXME
|
tensor.io.initiate.bits.wid := 0.U // TODO
|
||||||
tensor.io.writeback.ready := true.B
|
tensor.io.writeback.ready := true.B
|
||||||
|
|
||||||
io.finished := tensor.io.writeback.valid && tensor.io.writeback.bits.last
|
io.finished := tensor.io.writeback.valid && tensor.io.writeback.bits.last
|
||||||
@@ -443,7 +496,7 @@ class TensorCoreDecoupledTLRAM(implicit p: Parameters) extends LazyModule {
|
|||||||
val xbar = LazyModule(new TLXbar)
|
val xbar = LazyModule(new TLXbar)
|
||||||
val ram = LazyModule(new TLRAM(
|
val ram = LazyModule(new TLRAM(
|
||||||
address = AddressSet(0x0000, 0xffffff),
|
address = AddressSet(0x0000, 0xffffff),
|
||||||
beatBytes = 32 // FIXME: hardcoded
|
beatBytes = 32 // @cleanup: hardcoded
|
||||||
))
|
))
|
||||||
|
|
||||||
ram.node :=* xbar.node :=* tensor.node
|
ram.node :=* xbar.node :=* tensor.node
|
||||||
@@ -461,11 +514,11 @@ class TensorCoreDecoupledTwoTLRAM(implicit p: Parameters) extends LazyModule {
|
|||||||
val xbar = LazyModule(new TLXbar)
|
val xbar = LazyModule(new TLXbar)
|
||||||
val ramA = LazyModule(new TLRAM(
|
val ramA = LazyModule(new TLRAM(
|
||||||
address = AddressSet(0x000, 0xfffeff),
|
address = AddressSet(0x000, 0xfffeff),
|
||||||
beatBytes = 32 // FIXME: hardcoded
|
beatBytes = 32 // @cleanup: hardcoded
|
||||||
))
|
))
|
||||||
val ramB = LazyModule(new TLRAM(
|
val ramB = LazyModule(new TLRAM(
|
||||||
address = AddressSet(0x100, 0xfffeff),
|
address = AddressSet(0x100, 0xfffeff),
|
||||||
beatBytes = 32 // FIXME: hardcoded
|
beatBytes = 32 // @cleanup: hardcoded
|
||||||
))
|
))
|
||||||
|
|
||||||
xbar.node :=* tensor.node
|
xbar.node :=* tensor.node
|
||||||
|
|||||||
Reference in New Issue
Block a user