tensor: Support FP16 in TensorCoreDecoupled
This commit is contained in:
@@ -15,28 +15,46 @@ import radiance.memory.SourceGenerator
|
|||||||
|
|
||||||
case class TensorTilingParams(
|
case class TensorTilingParams(
|
||||||
// Dimension of the SMEM tile
|
// Dimension of the SMEM tile
|
||||||
m: Int = 16,
|
m: Int,
|
||||||
n: Int = 16,
|
n: Int,
|
||||||
k: Int = 16,
|
k: Int,
|
||||||
// Dimension of the compute tile. This is determined by the number of MAC
|
// Dimension of the compute tile. This is determined by the number of MAC
|
||||||
// units
|
// units
|
||||||
mc: Int = 4,
|
mc: Int,
|
||||||
nc: Int = 4,
|
nc: Int,
|
||||||
kc: Int = 4
|
kc: Int,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
object TensorTilingParams {
|
||||||
|
def fp16: TensorTilingParams = {
|
||||||
|
TensorTilingParams (
|
||||||
|
m = 16, n = 16, k = 32,
|
||||||
|
mc = 4, nc = 4, kc = 8
|
||||||
|
)
|
||||||
|
}
|
||||||
|
def fp32: TensorTilingParams = {
|
||||||
|
TensorTilingParams (
|
||||||
|
m = 16, n = 16, k = 16,
|
||||||
|
mc = 4, nc = 4, kc = 4
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
class TensorCoreDecoupled(
|
class TensorCoreDecoupled(
|
||||||
val numWarps: Int,
|
val numWarps: Int,
|
||||||
val numLanes: Int,
|
val numLanes: Int,
|
||||||
val numSourceIds: Int,
|
val half: Boolean, // input datatype is FP16 if true, FP32 if false
|
||||||
val tilingParams: TensorTilingParams,
|
val numSourceIds: Int = 16,
|
||||||
val numFPRegs: Int = 32
|
val numFPRegs: Int = 32
|
||||||
) extends Module {
|
) extends Module {
|
||||||
|
val tilingParams =
|
||||||
|
if (half) TensorTilingParams.fp16 else TensorTilingParams.fp32
|
||||||
val numWarpBits = log2Ceil(numWarps)
|
val numWarpBits = log2Ceil(numWarps)
|
||||||
val wordSize = 4 // TODO FP16
|
val wordSize = if (half) 2 else 4
|
||||||
val wordSizeInBits = wordSize * 8 // TODO FP16
|
val wordSizeInBits = wordSize * 8/*bits*/
|
||||||
val sourceWidth = log2Ceil(numSourceIds)
|
val sourceWidth = log2Ceil(numSourceIds)
|
||||||
val dataWidth = numLanes * wordSizeInBits // TODO FP16
|
val laneWidth = 4/*bytes*/ * 8/*bits*/
|
||||||
|
val memWidth = numLanes * laneWidth
|
||||||
val numFPRegBits = log2Ceil(numFPRegs)
|
val numFPRegBits = log2Ceil(numFPRegs)
|
||||||
|
|
||||||
val io = IO(new Bundle {
|
val io = IO(new Bundle {
|
||||||
@@ -47,11 +65,11 @@ class TensorCoreDecoupled(
|
|||||||
val last = Bool()
|
val last = Bool()
|
||||||
val wid = UInt(numWarpBits.W)
|
val wid = UInt(numWarpBits.W)
|
||||||
val rd = UInt(numFPRegBits.W)
|
val rd = UInt(numFPRegBits.W)
|
||||||
val data = Vec(numLanes, UInt((wordSizeInBits).W))
|
val data = Vec(numLanes, UInt(laneWidth.W))
|
||||||
})
|
})
|
||||||
val respA = Flipped(Decoupled(new TensorMemResp(sourceWidth, dataWidth)))
|
val respA = Flipped(Decoupled(new TensorMemResp(sourceWidth, memWidth)))
|
||||||
val respB = Flipped(Decoupled(new TensorMemResp(sourceWidth, dataWidth)))
|
val respB = Flipped(Decoupled(new TensorMemResp(sourceWidth, memWidth)))
|
||||||
val respC = Input(UInt(dataWidth.W))
|
val respC = Input(UInt(memWidth.W))
|
||||||
val reqA = Decoupled(new TensorMemReq(sourceWidth))
|
val reqA = Decoupled(new TensorMemReq(sourceWidth))
|
||||||
val reqB = Decoupled(new TensorMemReq(sourceWidth))
|
val reqB = Decoupled(new TensorMemReq(sourceWidth))
|
||||||
val reqC = Output(Valid(UInt(numFPRegBits.W)))
|
val reqC = Output(Valid(UInt(numFPRegBits.W)))
|
||||||
@@ -185,7 +203,8 @@ class TensorCoreDecoupled(
|
|||||||
val blockRow = set
|
val blockRow = set
|
||||||
val blockCol = index
|
val blockCol = index
|
||||||
val blockIndex = (blockRow << indexBits) + blockCol
|
val blockIndex = (blockRow << indexBits) + blockCol
|
||||||
val blockSize = numLanes * wordSize
|
val blockSize = numLanes * laneWidth
|
||||||
|
require(blockSize == memWidth)
|
||||||
val blockSizeBits = log2Ceil(blockSize)
|
val blockSizeBits = log2Ceil(blockSize)
|
||||||
val byteOffset = blockIndex << blockSizeBits
|
val byteOffset = blockIndex << blockSizeBits
|
||||||
base + byteOffset
|
base + byteOffset
|
||||||
@@ -222,8 +241,8 @@ class TensorCoreDecoupled(
|
|||||||
tagB.set := stateB.set
|
tagB.set := stateB.set
|
||||||
tagB.index := stateB.index
|
tagB.index := stateB.index
|
||||||
|
|
||||||
val respATagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth)))
|
val respATagged = Wire(Decoupled(new TensorMemRespWithTag(memWidth)))
|
||||||
val respBTagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth)))
|
val respBTagged = Wire(Decoupled(new TensorMemRespWithTag(memWidth)))
|
||||||
Seq((io.reqA, (io.respA, respATagged)),
|
Seq((io.reqA, (io.respA, respATagged)),
|
||||||
(io.reqB, (io.respB, respBTagged))).zipWithIndex.foreach {
|
(io.reqB, (io.respB, respBTagged))).zipWithIndex.foreach {
|
||||||
case ((req, (resp, respTagged)), i) => {
|
case ((req, (resp, respTagged)), i) => {
|
||||||
@@ -543,24 +562,32 @@ class TensorCoreDecoupled(
|
|||||||
require(tilingParams.mc * ncSubstep == numLanes,
|
require(tilingParams.mc * ncSubstep == numLanes,
|
||||||
"substep tile size doesn't match writeback throughput")
|
"substep tile size doesn't match writeback throughput")
|
||||||
val dpus = Seq.fill(tilingParams.mc)(Seq.fill(ncSubstep)(
|
val dpus = Seq.fill(tilingParams.mc)(Seq.fill(ncSubstep)(
|
||||||
Module(new TensorDotProductUnit(dim = 4, half = false))
|
Module(new TensorDotProductUnit(
|
||||||
|
dim = tilingParams.kc,
|
||||||
|
half = half,
|
||||||
|
))
|
||||||
))
|
))
|
||||||
|
|
||||||
// reshape operands for easier routing to DPU
|
// reshape UInt into a two-dimensional array where the innermost dimension
|
||||||
def reshapeByFourWords(x: UInt): Seq[Seq[UInt]] = {
|
// has `numWords` elements
|
||||||
|
def reshapeByWords(x: UInt, wordSizeInBits: Int, numWords: Int)
|
||||||
|
: Seq[Seq[UInt]] = {
|
||||||
x.asBools.grouped(wordSizeInBits).map(VecInit(_).asUInt).toSeq
|
x.asBools.grouped(wordSizeInBits).map(VecInit(_).asUInt).toSeq
|
||||||
.grouped(4/*k-dim*/).toSeq
|
.grouped(numWords).toSeq
|
||||||
}
|
}
|
||||||
val operandADimensional = reshapeByFourWords(operandA)
|
val operandADimensional =
|
||||||
|
reshapeByWords(operandA, wordSizeInBits, tilingParams.kc)
|
||||||
require(operandADimensional.length == tilingParams.mc &&
|
require(operandADimensional.length == tilingParams.mc &&
|
||||||
operandADimensional(0).length == tilingParams.kc,
|
operandADimensional(0).length == tilingParams.kc,
|
||||||
"operand width doesn't agree with tiling parameter")
|
"operand width doesn't agree with tiling parameter")
|
||||||
val operandBDimensional = reshapeByFourWords(operandB)
|
val operandBDimensional =
|
||||||
|
reshapeByWords(operandB, wordSizeInBits, tilingParams.kc)
|
||||||
require(operandBDimensional.length == ncSubstep &&
|
require(operandBDimensional.length == ncSubstep &&
|
||||||
operandBDimensional(0).length == tilingParams.kc,
|
operandBDimensional(0).length == tilingParams.kc,
|
||||||
"operand width doesn't agree with tiling parameter")
|
"operand width doesn't agree with tiling parameter")
|
||||||
// note operand C is M-major
|
// note operand C is M-major, and always FP32
|
||||||
val operandCDimensional = reshapeByFourWords(operandC)
|
val operandCDimensional =
|
||||||
|
reshapeByWords(operandC, 4/*fp32*/ * 8/*bits*/, tilingParams.mc)
|
||||||
require(operandCDimensional.length == ncSubstep &&
|
require(operandCDimensional.length == ncSubstep &&
|
||||||
operandCDimensional(0).length == tilingParams.mc,
|
operandCDimensional(0).length == tilingParams.mc,
|
||||||
"operand width doesn't agree with tiling parameter")
|
"operand width doesn't agree with tiling parameter")
|
||||||
@@ -609,7 +636,7 @@ class TensorCoreDecoupled(
|
|||||||
val substep = UInt(1.W)
|
val substep = UInt(1.W)
|
||||||
}
|
}
|
||||||
|
|
||||||
val queueDepth = 5 // needs to be at least the DPU latency
|
val queueDepth = (if (half) 6 else 5) // needs to be at least the DPU latency
|
||||||
val tagQueue = Module(new Queue(new TensorComputeTag, queueDepth))
|
val tagQueue = Module(new Queue(new TensorComputeTag, queueDepth))
|
||||||
tagQueue.io.enq.valid := dpuFire
|
tagQueue.io.enq.valid := dpuFire
|
||||||
tagQueue.io.enq.bits.warp := operandATag.warp
|
tagQueue.io.enq.bits.warp := operandATag.warp
|
||||||
@@ -617,6 +644,8 @@ class TensorCoreDecoupled(
|
|||||||
tagQueue.io.enq.bits.step := stepCompute
|
tagQueue.io.enq.bits.step := stepCompute
|
||||||
tagQueue.io.enq.bits.substep := substepCompute
|
tagQueue.io.enq.bits.substep := substepCompute
|
||||||
tagQueue.io.deq.ready := io.writeback.fire
|
tagQueue.io.deq.ready := io.writeback.fire
|
||||||
|
// this is not necessary for correctness, and might trigger when there's a
|
||||||
|
// lot of writeback contention
|
||||||
assert(tagQueue.io.enq.ready === true.B,
|
assert(tagQueue.io.enq.ready === true.B,
|
||||||
"tag queue full, DPU operation might be throttled")
|
"tag queue full, DPU operation might be throttled")
|
||||||
assert(!dpuValid || tagQueue.io.deq.valid,
|
assert(!dpuValid || tagQueue.io.deq.valid,
|
||||||
@@ -727,7 +756,7 @@ class TensorCoreDecoupledTLImp(outer: TensorCoreDecoupledTL)
|
|||||||
require(outer.node.out.length == 2/*A and B*/)
|
require(outer.node.out.length == 2/*A and B*/)
|
||||||
|
|
||||||
val tensor = Module(new TensorCoreDecoupled(
|
val tensor = Module(new TensorCoreDecoupled(
|
||||||
8, 8, outer.numSourceIds , TensorTilingParams()))
|
8, 8, half = true, outer.numSourceIds))
|
||||||
val wordSize = 4 // @cleanup: hardcoded
|
val wordSize = 4 // @cleanup: hardcoded
|
||||||
|
|
||||||
val zip = Seq((outer.node.out(0), tensor.io.reqA),
|
val zip = Seq((outer.node.out(0), tensor.io.reqA),
|
||||||
|
|||||||
Reference in New Issue
Block a user