diff --git a/src/main/scala/radiance/core/TensorCoreDecoupled.scala b/src/main/scala/radiance/core/TensorCoreDecoupled.scala index 652608b..b9695ad 100644 --- a/src/main/scala/radiance/core/TensorCoreDecoupled.scala +++ b/src/main/scala/radiance/core/TensorCoreDecoupled.scala @@ -33,8 +33,9 @@ class TensorCoreDecoupled( ) extends Module { val numWarpBits = log2Ceil(numWarps) val wordSize = 4 // TODO FP16 + val wordSizeInBits = wordSize * 8 // TODO FP16 val sourceWidth = log2Ceil(numSourceIds) - val dataWidth = numLanes * wordSize * 8/*bits*/ // TODO FP16 + val dataWidth = numLanes * wordSizeInBits // TODO FP16 val numFPRegBits = log2Ceil(numFPRegs) val io = IO(new Bundle { @@ -45,7 +46,7 @@ class TensorCoreDecoupled( val last = Bool() val wid = UInt(numWarpBits.W) val rd = UInt(numFPRegBits.W) - val data = Vec(numLanes, UInt((wordSize * 8/*bits*/).W)) + val data = Vec(numLanes, UInt((wordSizeInBits).W)) }) val respA = Flipped(Decoupled(new TensorMemResp(sourceWidth, dataWidth))) val respB = Flipped(Decoupled(new TensorMemResp(sourceWidth, dataWidth))) @@ -223,9 +224,6 @@ class TensorCoreDecoupled( io.writeback.bits.data.widthOption.get, "response data width does not match the writeback data width") - // FIXME: this need to change to dpu_ready - val dpuReady = io.writeback.ready // FIXME: this need be actual dpu - val substepExecute = RegInit(0.U(1.W)) when (respQueueA.fire) { substepExecute := substepExecute + 1.U @@ -267,7 +265,10 @@ class TensorCoreDecoupled( fullAQueue.io.enq.bits.data := fullAEnqData fullAQueue.io.enq.bits.tag := fullAEnqTag - val operandsValid = fullAQueue.io.deq.valid && respQueueB.valid // FIXME? + val operandsValid = fullAQueue.io.deq.valid && respQueueB.valid + val operandA = fullAQueue.io.deq.bits.data + val operandB = respQueueB.bits.data + val dpuReady = Wire(Bool()) val dpuFire = operandsValid && dpuReady val substepCompute = RegInit(0.U(1.W)) when (dpuFire) { @@ -301,6 +302,66 @@ class TensorCoreDecoupled( } assertAligned + // Dot-product unit + // + // 4x2 four-element DPUs summing up to 32 MACs in total + val dpus = Seq.fill(4)(Seq.fill(2)( + Module(new TensorDotProductUnit(half = false)) + )) + // operandA is 4x4 in K-major + val operandADimensional = + operandA.asBools.grouped(wordSizeInBits).map(VecInit(_).asUInt).toSeq + .grouped(4).toSeq + println(s"operandA: ${fullAQueue.io.deq.bits.data.widthOption.get} bits") + println(s"A: ${operandADimensional.length}, ${operandADimensional(0).length}") + assert(operandADimensional.length == tilingParams.mc && + operandADimensional(0).length == tilingParams.kc, + "operand width doesn't agree with tiling parameter") + // operandB is 2x4, i.e. 4x2 in N-major + val operandBDimensional = + operandB.asBools.grouped(wordSizeInBits).map(VecInit(_).asUInt).toSeq + .grouped(4).toSeq + println(s"B: ${operandBDimensional.length}, ${operandBDimensional(0).length}") + val ncSubstep = tilingParams.nc / 2 + assert(tilingParams.mc * ncSubstep == numLanes, + "substep tile size doesn't match writeback throughput") + assert(operandBDimensional.length == ncSubstep && + operandBDimensional(0).length == tilingParams.kc, + "operand width doesn't agree with tiling parameter") + + for (m <- 0 until tilingParams.mc) { + for (n <- 0 until ncSubstep) { + dpus(m)(n).io.in.valid := dpuFire + dpus(m)(n).io.in.bits.a := operandADimensional(m) + dpus(m)(n).io.in.bits.b := operandBDimensional(n) + dpus(m)(n).io.in.bits.c := 0.U // FIXME: bogus accum data + // dpu ready couples with writeback backpressure + dpus(m)(n).io.stall := !io.writeback.ready + } + } + dpuReady := !dpus(0)(0).io.stall + dontTouch(dpuFire) + dontTouch(dpuReady) + + val dpuValids = dpus.flatMap(_.map(_.io.out.valid)) + val dpuValid = dpuValids.reduce(_ && _) + def assertDPU = { + val dpuStalls = dpus.flatMap(_.map(_.io.stall)) + assert(dpuStalls.reduce(_ && _) === dpuStalls.reduce(_ || _), + "stall signals of DPUs went unaligned") + assert(dpuValids.reduce(_ && _) === dpuValids.reduce(_ || _), + "valid signals of DPUs went unaligned") + } + assertDPU + + // flatten DPU output into 1D array in M-major order + val flattenedDPUOut = (0 until ncSubstep).flatMap { n => + (0 until tilingParams.mc).map { m => + dpus(m)(n).io.out.bits.data + } + } + io.writeback.bits.data := flattenedDPUOut + def rdGen(set: UInt, step: UInt): UInt = { // each step produces 4x4 output tile, written by 8 threads with 2 regs per // thread @@ -309,19 +370,11 @@ class TensorCoreDecoupled( // FIXME: add substep here } - io.writeback.valid := operandsValid // FIXME: bypass logic + io.writeback.valid := dpuValid io.writeback.bits.wid := warpReg io.writeback.bits.rd := rdGen(setExecute, stepExecute) io.writeback.bits.last := setDone(setExecute) && stepDone(stepExecute) - // FIXME: debug dummy: pipe A directly to writeback - val groupedRespA = respQueueA.bits.data - .asBools.grouped(wordSize * 8/*bits*/) - .map(VecInit(_).asUInt) - (io.writeback.bits.data zip groupedRespA).foreach { case (wb, data) => - wb := data - } - // State transition // ---------------- // @@ -400,7 +453,7 @@ class TensorCoreDecoupledTLImp(outer: TensorCoreDecoupledTL) val tensor = Module(new TensorCoreDecoupled( 8, 8, outer.numSrcIds , TensorTilingParams())) - val wordSize = 4 // FIXME: hardcoded + val wordSize = 4 // @cleanup: hardcoded val zip = Seq((outer.node.out(0), tensor.io.reqA), (outer.node.out(1), tensor.io.reqB)) @@ -431,7 +484,7 @@ class TensorCoreDecoupledTLImp(outer: TensorCoreDecoupledTL) tlOutB.d.ready := tensor.io.respB.ready tensor.io.initiate.valid := io.start - tensor.io.initiate.bits.wid := 0.U // FIXME + tensor.io.initiate.bits.wid := 0.U // TODO tensor.io.writeback.ready := true.B io.finished := tensor.io.writeback.valid && tensor.io.writeback.bits.last @@ -443,7 +496,7 @@ class TensorCoreDecoupledTLRAM(implicit p: Parameters) extends LazyModule { val xbar = LazyModule(new TLXbar) val ram = LazyModule(new TLRAM( address = AddressSet(0x0000, 0xffffff), - beatBytes = 32 // FIXME: hardcoded + beatBytes = 32 // @cleanup: hardcoded )) ram.node :=* xbar.node :=* tensor.node @@ -461,11 +514,11 @@ class TensorCoreDecoupledTwoTLRAM(implicit p: Parameters) extends LazyModule { val xbar = LazyModule(new TLXbar) val ramA = LazyModule(new TLRAM( address = AddressSet(0x000, 0xfffeff), - beatBytes = 32 // FIXME: hardcoded + beatBytes = 32 // @cleanup: hardcoded )) val ramB = LazyModule(new TLRAM( address = AddressSet(0x100, 0xfffeff), - beatBytes = 32 // FIXME: hardcoded + beatBytes = 32 // @cleanup: hardcoded )) xbar.node :=* tensor.node