tensor: Support FP16 in TensorCoreDecoupled

This commit is contained in:
Hansung Kim
2024-10-25 22:26:04 -07:00
parent eed821eda6
commit 543eb2feb4

View File

@@ -15,28 +15,46 @@ import radiance.memory.SourceGenerator
case class TensorTilingParams( case class TensorTilingParams(
// Dimension of the SMEM tile // Dimension of the SMEM tile
m: Int = 16, m: Int,
n: Int = 16, n: Int,
k: Int = 16, k: Int,
// Dimension of the compute tile. This is determined by the number of MAC // Dimension of the compute tile. This is determined by the number of MAC
// units // units
mc: Int = 4, mc: Int,
nc: Int = 4, nc: Int,
kc: Int = 4 kc: Int,
) )
object TensorTilingParams {
def fp16: TensorTilingParams = {
TensorTilingParams (
m = 16, n = 16, k = 32,
mc = 4, nc = 4, kc = 8
)
}
def fp32: TensorTilingParams = {
TensorTilingParams (
m = 16, n = 16, k = 16,
mc = 4, nc = 4, kc = 4
)
}
}
class TensorCoreDecoupled( class TensorCoreDecoupled(
val numWarps: Int, val numWarps: Int,
val numLanes: Int, val numLanes: Int,
val numSourceIds: Int, val half: Boolean, // input datatype is FP16 if true, FP32 if false
val tilingParams: TensorTilingParams, val numSourceIds: Int = 16,
val numFPRegs: Int = 32 val numFPRegs: Int = 32
) extends Module { ) extends Module {
val tilingParams =
if (half) TensorTilingParams.fp16 else TensorTilingParams.fp32
val numWarpBits = log2Ceil(numWarps) val numWarpBits = log2Ceil(numWarps)
val wordSize = 4 // TODO FP16 val wordSize = if (half) 2 else 4
val wordSizeInBits = wordSize * 8 // TODO FP16 val wordSizeInBits = wordSize * 8/*bits*/
val sourceWidth = log2Ceil(numSourceIds) val sourceWidth = log2Ceil(numSourceIds)
val dataWidth = numLanes * wordSizeInBits // TODO FP16 val laneWidth = 4/*bytes*/ * 8/*bits*/
val memWidth = numLanes * laneWidth
val numFPRegBits = log2Ceil(numFPRegs) val numFPRegBits = log2Ceil(numFPRegs)
val io = IO(new Bundle { val io = IO(new Bundle {
@@ -47,11 +65,11 @@ class TensorCoreDecoupled(
val last = Bool() val last = Bool()
val wid = UInt(numWarpBits.W) val wid = UInt(numWarpBits.W)
val rd = UInt(numFPRegBits.W) val rd = UInt(numFPRegBits.W)
val data = Vec(numLanes, UInt((wordSizeInBits).W)) val data = Vec(numLanes, UInt(laneWidth.W))
}) })
val respA = Flipped(Decoupled(new TensorMemResp(sourceWidth, dataWidth))) val respA = Flipped(Decoupled(new TensorMemResp(sourceWidth, memWidth)))
val respB = Flipped(Decoupled(new TensorMemResp(sourceWidth, dataWidth))) val respB = Flipped(Decoupled(new TensorMemResp(sourceWidth, memWidth)))
val respC = Input(UInt(dataWidth.W)) val respC = Input(UInt(memWidth.W))
val reqA = Decoupled(new TensorMemReq(sourceWidth)) val reqA = Decoupled(new TensorMemReq(sourceWidth))
val reqB = Decoupled(new TensorMemReq(sourceWidth)) val reqB = Decoupled(new TensorMemReq(sourceWidth))
val reqC = Output(Valid(UInt(numFPRegBits.W))) val reqC = Output(Valid(UInt(numFPRegBits.W)))
@@ -185,7 +203,8 @@ class TensorCoreDecoupled(
val blockRow = set val blockRow = set
val blockCol = index val blockCol = index
val blockIndex = (blockRow << indexBits) + blockCol val blockIndex = (blockRow << indexBits) + blockCol
val blockSize = numLanes * wordSize val blockSize = numLanes * laneWidth
require(blockSize == memWidth)
val blockSizeBits = log2Ceil(blockSize) val blockSizeBits = log2Ceil(blockSize)
val byteOffset = blockIndex << blockSizeBits val byteOffset = blockIndex << blockSizeBits
base + byteOffset base + byteOffset
@@ -222,8 +241,8 @@ class TensorCoreDecoupled(
tagB.set := stateB.set tagB.set := stateB.set
tagB.index := stateB.index tagB.index := stateB.index
val respATagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth))) val respATagged = Wire(Decoupled(new TensorMemRespWithTag(memWidth)))
val respBTagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth))) val respBTagged = Wire(Decoupled(new TensorMemRespWithTag(memWidth)))
Seq((io.reqA, (io.respA, respATagged)), Seq((io.reqA, (io.respA, respATagged)),
(io.reqB, (io.respB, respBTagged))).zipWithIndex.foreach { (io.reqB, (io.respB, respBTagged))).zipWithIndex.foreach {
case ((req, (resp, respTagged)), i) => { case ((req, (resp, respTagged)), i) => {
@@ -543,24 +562,32 @@ class TensorCoreDecoupled(
require(tilingParams.mc * ncSubstep == numLanes, require(tilingParams.mc * ncSubstep == numLanes,
"substep tile size doesn't match writeback throughput") "substep tile size doesn't match writeback throughput")
val dpus = Seq.fill(tilingParams.mc)(Seq.fill(ncSubstep)( val dpus = Seq.fill(tilingParams.mc)(Seq.fill(ncSubstep)(
Module(new TensorDotProductUnit(dim = 4, half = false)) Module(new TensorDotProductUnit(
dim = tilingParams.kc,
half = half,
))
)) ))
// reshape operands for easier routing to DPU // reshape UInt into a two-dimensional array where the innermost dimension
def reshapeByFourWords(x: UInt): Seq[Seq[UInt]] = { // has `numWords` elements
def reshapeByWords(x: UInt, wordSizeInBits: Int, numWords: Int)
: Seq[Seq[UInt]] = {
x.asBools.grouped(wordSizeInBits).map(VecInit(_).asUInt).toSeq x.asBools.grouped(wordSizeInBits).map(VecInit(_).asUInt).toSeq
.grouped(4/*k-dim*/).toSeq .grouped(numWords).toSeq
} }
val operandADimensional = reshapeByFourWords(operandA) val operandADimensional =
reshapeByWords(operandA, wordSizeInBits, tilingParams.kc)
require(operandADimensional.length == tilingParams.mc && require(operandADimensional.length == tilingParams.mc &&
operandADimensional(0).length == tilingParams.kc, operandADimensional(0).length == tilingParams.kc,
"operand width doesn't agree with tiling parameter") "operand width doesn't agree with tiling parameter")
val operandBDimensional = reshapeByFourWords(operandB) val operandBDimensional =
reshapeByWords(operandB, wordSizeInBits, tilingParams.kc)
require(operandBDimensional.length == ncSubstep && require(operandBDimensional.length == ncSubstep &&
operandBDimensional(0).length == tilingParams.kc, operandBDimensional(0).length == tilingParams.kc,
"operand width doesn't agree with tiling parameter") "operand width doesn't agree with tiling parameter")
// note operand C is M-major // note operand C is M-major, and always FP32
val operandCDimensional = reshapeByFourWords(operandC) val operandCDimensional =
reshapeByWords(operandC, 4/*fp32*/ * 8/*bits*/, tilingParams.mc)
require(operandCDimensional.length == ncSubstep && require(operandCDimensional.length == ncSubstep &&
operandCDimensional(0).length == tilingParams.mc, operandCDimensional(0).length == tilingParams.mc,
"operand width doesn't agree with tiling parameter") "operand width doesn't agree with tiling parameter")
@@ -609,7 +636,7 @@ class TensorCoreDecoupled(
val substep = UInt(1.W) val substep = UInt(1.W)
} }
val queueDepth = 5 // needs to be at least the DPU latency val queueDepth = (if (half) 6 else 5) // needs to be at least the DPU latency
val tagQueue = Module(new Queue(new TensorComputeTag, queueDepth)) val tagQueue = Module(new Queue(new TensorComputeTag, queueDepth))
tagQueue.io.enq.valid := dpuFire tagQueue.io.enq.valid := dpuFire
tagQueue.io.enq.bits.warp := operandATag.warp tagQueue.io.enq.bits.warp := operandATag.warp
@@ -617,6 +644,8 @@ class TensorCoreDecoupled(
tagQueue.io.enq.bits.step := stepCompute tagQueue.io.enq.bits.step := stepCompute
tagQueue.io.enq.bits.substep := substepCompute tagQueue.io.enq.bits.substep := substepCompute
tagQueue.io.deq.ready := io.writeback.fire tagQueue.io.deq.ready := io.writeback.fire
// this is not necessary for correctness, and might trigger when there's a
// lot of writeback contention
assert(tagQueue.io.enq.ready === true.B, assert(tagQueue.io.enq.ready === true.B,
"tag queue full, DPU operation might be throttled") "tag queue full, DPU operation might be throttled")
assert(!dpuValid || tagQueue.io.deq.valid, assert(!dpuValid || tagQueue.io.deq.valid,
@@ -727,7 +756,7 @@ class TensorCoreDecoupledTLImp(outer: TensorCoreDecoupledTL)
require(outer.node.out.length == 2/*A and B*/) require(outer.node.out.length == 2/*A and B*/)
val tensor = Module(new TensorCoreDecoupled( val tensor = Module(new TensorCoreDecoupled(
8, 8, outer.numSourceIds , TensorTilingParams())) 8, 8, half = true, outer.numSourceIds))
val wordSize = 4 // @cleanup: hardcoded val wordSize = 4 // @cleanup: hardcoded
val zip = Seq((outer.node.out(0), tensor.io.reqA), val zip = Seq((outer.node.out(0), tensor.io.reqA),