diff --git a/src/main/scala/radiance/core/TensorCoreDecoupled.scala b/src/main/scala/radiance/core/TensorCoreDecoupled.scala index 2f53269..903cc1d 100644 --- a/src/main/scala/radiance/core/TensorCoreDecoupled.scala +++ b/src/main/scala/radiance/core/TensorCoreDecoupled.scala @@ -15,28 +15,46 @@ import radiance.memory.SourceGenerator case class TensorTilingParams( // Dimension of the SMEM tile - m: Int = 16, - n: Int = 16, - k: Int = 16, + m: Int, + n: Int, + k: Int, // Dimension of the compute tile. This is determined by the number of MAC // units - mc: Int = 4, - nc: Int = 4, - kc: Int = 4 + mc: Int, + nc: Int, + kc: Int, ) +object TensorTilingParams { + def fp16: TensorTilingParams = { + TensorTilingParams ( + m = 16, n = 16, k = 32, + mc = 4, nc = 4, kc = 8 + ) + } + def fp32: TensorTilingParams = { + TensorTilingParams ( + m = 16, n = 16, k = 16, + mc = 4, nc = 4, kc = 4 + ) + } +} + class TensorCoreDecoupled( val numWarps: Int, val numLanes: Int, - val numSourceIds: Int, - val tilingParams: TensorTilingParams, + val half: Boolean, // input datatype is FP16 if true, FP32 if false + val numSourceIds: Int = 16, val numFPRegs: Int = 32 ) extends Module { + val tilingParams = + if (half) TensorTilingParams.fp16 else TensorTilingParams.fp32 val numWarpBits = log2Ceil(numWarps) - val wordSize = 4 // TODO FP16 - val wordSizeInBits = wordSize * 8 // TODO FP16 + val wordSize = if (half) 2 else 4 + val wordSizeInBits = wordSize * 8/*bits*/ val sourceWidth = log2Ceil(numSourceIds) - val dataWidth = numLanes * wordSizeInBits // TODO FP16 + val laneWidth = 4/*bytes*/ * 8/*bits*/ + val memWidth = numLanes * laneWidth val numFPRegBits = log2Ceil(numFPRegs) val io = IO(new Bundle { @@ -47,11 +65,11 @@ class TensorCoreDecoupled( val last = Bool() val wid = UInt(numWarpBits.W) val rd = UInt(numFPRegBits.W) - val data = Vec(numLanes, UInt((wordSizeInBits).W)) + val data = Vec(numLanes, UInt(laneWidth.W)) }) - val respA = Flipped(Decoupled(new TensorMemResp(sourceWidth, dataWidth))) - val respB = Flipped(Decoupled(new TensorMemResp(sourceWidth, dataWidth))) - val respC = Input(UInt(dataWidth.W)) + val respA = Flipped(Decoupled(new TensorMemResp(sourceWidth, memWidth))) + val respB = Flipped(Decoupled(new TensorMemResp(sourceWidth, memWidth))) + val respC = Input(UInt(memWidth.W)) val reqA = Decoupled(new TensorMemReq(sourceWidth)) val reqB = Decoupled(new TensorMemReq(sourceWidth)) val reqC = Output(Valid(UInt(numFPRegBits.W))) @@ -185,7 +203,8 @@ class TensorCoreDecoupled( val blockRow = set val blockCol = index val blockIndex = (blockRow << indexBits) + blockCol - val blockSize = numLanes * wordSize + val blockSize = numLanes * laneWidth + require(blockSize == memWidth) val blockSizeBits = log2Ceil(blockSize) val byteOffset = blockIndex << blockSizeBits base + byteOffset @@ -222,8 +241,8 @@ class TensorCoreDecoupled( tagB.set := stateB.set tagB.index := stateB.index - val respATagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth))) - val respBTagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth))) + val respATagged = Wire(Decoupled(new TensorMemRespWithTag(memWidth))) + val respBTagged = Wire(Decoupled(new TensorMemRespWithTag(memWidth))) Seq((io.reqA, (io.respA, respATagged)), (io.reqB, (io.respB, respBTagged))).zipWithIndex.foreach { case ((req, (resp, respTagged)), i) => { @@ -543,24 +562,32 @@ class TensorCoreDecoupled( require(tilingParams.mc * ncSubstep == numLanes, "substep tile size doesn't match writeback throughput") val dpus = Seq.fill(tilingParams.mc)(Seq.fill(ncSubstep)( - Module(new TensorDotProductUnit(dim = 4, half = false)) + Module(new TensorDotProductUnit( + dim = tilingParams.kc, + half = half, + )) )) - // reshape operands for easier routing to DPU - def reshapeByFourWords(x: UInt): Seq[Seq[UInt]] = { + // reshape UInt into a two-dimensional array where the innermost dimension + // has `numWords` elements + def reshapeByWords(x: UInt, wordSizeInBits: Int, numWords: Int) + : Seq[Seq[UInt]] = { x.asBools.grouped(wordSizeInBits).map(VecInit(_).asUInt).toSeq - .grouped(4/*k-dim*/).toSeq + .grouped(numWords).toSeq } - val operandADimensional = reshapeByFourWords(operandA) + val operandADimensional = + reshapeByWords(operandA, wordSizeInBits, tilingParams.kc) require(operandADimensional.length == tilingParams.mc && operandADimensional(0).length == tilingParams.kc, "operand width doesn't agree with tiling parameter") - val operandBDimensional = reshapeByFourWords(operandB) + val operandBDimensional = + reshapeByWords(operandB, wordSizeInBits, tilingParams.kc) require(operandBDimensional.length == ncSubstep && operandBDimensional(0).length == tilingParams.kc, "operand width doesn't agree with tiling parameter") - // note operand C is M-major - val operandCDimensional = reshapeByFourWords(operandC) + // note operand C is M-major, and always FP32 + val operandCDimensional = + reshapeByWords(operandC, 4/*fp32*/ * 8/*bits*/, tilingParams.mc) require(operandCDimensional.length == ncSubstep && operandCDimensional(0).length == tilingParams.mc, "operand width doesn't agree with tiling parameter") @@ -609,7 +636,7 @@ class TensorCoreDecoupled( val substep = UInt(1.W) } - val queueDepth = 5 // needs to be at least the DPU latency + val queueDepth = (if (half) 6 else 5) // needs to be at least the DPU latency val tagQueue = Module(new Queue(new TensorComputeTag, queueDepth)) tagQueue.io.enq.valid := dpuFire tagQueue.io.enq.bits.warp := operandATag.warp @@ -617,6 +644,8 @@ class TensorCoreDecoupled( tagQueue.io.enq.bits.step := stepCompute tagQueue.io.enq.bits.substep := substepCompute tagQueue.io.deq.ready := io.writeback.fire + // this is not necessary for correctness, and might trigger when there's a + // lot of writeback contention assert(tagQueue.io.enq.ready === true.B, "tag queue full, DPU operation might be throttled") assert(!dpuValid || tagQueue.io.deq.valid, @@ -727,7 +756,7 @@ class TensorCoreDecoupledTLImp(outer: TensorCoreDecoupledTL) require(outer.node.out.length == 2/*A and B*/) val tensor = Module(new TensorCoreDecoupled( - 8, 8, outer.numSourceIds , TensorTilingParams())) + 8, 8, half = true, outer.numSourceIds)) val wordSize = 4 // @cleanup: hardcoded val zip = Seq((outer.node.out(0), tensor.io.reqA),